/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.acoustic.tiedstate.HTK;

import edu.cmu.sphinx.util.LogMath;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.StringTokenizer;

public class GMMDiag {
    public int nT;
    public String nom;
    public LogMath logMath;
    private int ncoefs;
    private int ngauss;
    protected float[] weights;
    protected float[][] means;
    protected float[][] covar;
    private float[] logPreComputedGaussianFactor;
    protected float[] loglikes;
    private static final float distFloor = -3.4028235E38f;

    public GMMDiag() {
    }

    public GMMDiag(int n, int n2) {
        this.ngauss = n;
        this.ncoefs = n2;
        this.allocate();
    }

    public int getNgauss() {
        return this.ngauss;
    }

    public float getWeight(int n) {
        return (float)this.logMath.logToLinear(this.weights[n]);
    }

    public float getVar(int n, int n2) {
        return -1.0f / (2.0f * this.covar[n][n2]);
    }

    public void setWeight(int n, float f) {
        if (this.weights == null) {
            this.weights = new float[this.ngauss];
        }
        this.weights[n] = this.logMath.linearToLog(f);
    }

    public void setVar(int n, int n2, float f) {
        if (f <= 0.0f) {
            System.err.println("WARNING: setVar " + f);
        }
        this.covar[n][n2] = -1.0f / (2.0f * f);
    }

    public void setMean(int n, int n2, float f) {
        this.means[n][n2] = f;
    }

    public float getMean(int n, int n2) {
        return this.means[n][n2];
    }

    public void save(String string) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(string));
            printWriter.println(this.ngauss + " " + this.ncoefs);
            for (int i = 0; i < this.ngauss; ++i) {
                int n;
                printWriter.println("gauss " + i + ' ' + this.getWeight(i));
                for (n = 0; n < this.ncoefs; ++n) {
                    printWriter.print(this.means[i][n] + " ");
                }
                printWriter.println();
                for (n = 0; n < this.ncoefs; ++n) {
                    printWriter.print(this.getVar(i, n) + " ");
                }
                printWriter.println();
            }
            printWriter.println(this.nT);
            printWriter.close();
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
    }

    public void load(String string) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(string));
            String string2 = bufferedReader.readLine();
            String[] stringArray = string2.split(" ");
            this.ngauss = Integer.parseInt(stringArray[0]);
            this.ncoefs = Integer.parseInt(stringArray[1]);
            this.allocate();
            for (int i = 0; i < this.ngauss; ++i) {
                int n;
                string2 = bufferedReader.readLine();
                stringArray = string2.split(" ");
                if (!stringArray[0].equals("gauss") || Integer.parseInt(stringArray[1]) != i) {
                    System.err.println("Error loading GMM " + string2 + ' ' + i);
                    System.exit(1);
                }
                this.setWeight(i, Float.parseFloat(stringArray[2]));
                string2 = bufferedReader.readLine();
                stringArray = string2.split(" ");
                for (n = 0; n < this.ncoefs; ++n) {
                    this.setMean(i, n, Float.parseFloat(stringArray[n]));
                }
                string2 = bufferedReader.readLine();
                stringArray = string2.split(" ");
                for (n = 0; n < this.ncoefs; ++n) {
                    this.setVar(i, n, Float.parseFloat(stringArray[n]));
                }
            }
            string2 = bufferedReader.readLine();
            if (string2 != null) {
                this.nT = Integer.parseInt(string2);
            }
            bufferedReader.close();
            this.precomputeDistance();
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
    }

    public void saveHTK(String string, String string2) {
        this.saveHTK(string, string2, "<USER>");
    }

    public PrintWriter saveHTKheader(String string, String string2) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(string));
            printWriter.println("~o");
            printWriter.println("<HMMSETID> tree");
            printWriter.println("<STREAMINFO> 1 " + this.getNcoefs());
            printWriter.println("<VECSIZE> " + this.getNcoefs() + "<NULLD>" + string2 + "<DIAGC>");
            printWriter.println("~r \"rtree_1\"");
            printWriter.println("<REGTREE> 1");
            printWriter.println("<TNODE> 1 " + this.getNgauss());
            return printWriter;
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
            return null;
        }
    }

    public void saveHTKState(PrintWriter printWriter) {
        printWriter.println("<NUMMIXES> " + this.getNgauss());
        for (int i = 1; i <= this.getNgauss(); ++i) {
            int n;
            printWriter.println("<MIXTURE> " + i + ' ' + this.getWeight(i - 1));
            printWriter.println("<RCLASS> 1");
            printWriter.println("<MEAN> " + this.getNcoefs());
            for (n = 0; n < this.getNcoefs(); ++n) {
                printWriter.print(this.getMean(i - 1, n) + " ");
            }
            printWriter.println();
            printWriter.println("<VARIANCE> " + this.getNcoefs());
            for (n = 0; n < this.getNcoefs(); ++n) {
                printWriter.print(this.getVar(i - 1, n) + " ");
            }
            printWriter.println();
        }
    }

    public void saveHTKtailer(int n, PrintWriter printWriter) {
        int n2;
        printWriter.println("<TRANSP> " + n);
        for (n2 = 0; n2 < n; ++n2) {
            printWriter.print("0 ");
        }
        printWriter.println();
        for (n2 = 1; n2 < n - 1; ++n2) {
            int n3;
            for (n3 = 0; n3 < n2; ++n3) {
                printWriter.print("0 ");
            }
            printWriter.print("0.5 0.5");
            for (n3 = n2 + 3; n3 < n; ++n3) {
                printWriter.print("0 ");
            }
        }
        printWriter.println();
        printWriter.println("0 0 0");
        printWriter.println("<ENDHMM>");
    }

    public void saveHTK(String string, String string2, String string3) {
        try {
            PrintWriter printWriter = new PrintWriter(new FileWriter(string));
            printWriter.println("~o");
            printWriter.println("<HMMSETID> tree");
            printWriter.println("<STREAMINFO> 1 " + this.getNcoefs());
            printWriter.println("<VECSIZE> " + this.getNcoefs() + "<NULLD>" + string3 + "<DIAGC>");
            printWriter.println("~r \"rtree_1\"");
            printWriter.println("<REGTREE> 1");
            printWriter.println("<TNODE> 1 " + this.getNgauss());
            printWriter.println("~h \"" + string2 + '\"');
            printWriter.println("<BEGINHMM>");
            printWriter.println("<NUMSTATES> 3");
            printWriter.println("<STATE> 2");
            printWriter.println("<NUMMIXES> " + this.getNgauss());
            for (int i = 1; i <= this.getNgauss(); ++i) {
                int n;
                printWriter.println("<MIXTURE> " + i + ' ' + this.getWeight(i - 1));
                printWriter.println("<RCLASS> 1");
                printWriter.println("<MEAN> " + this.getNcoefs());
                for (n = 0; n < this.getNcoefs(); ++n) {
                    printWriter.print(this.getMean(i - 1, n) + " ");
                }
                printWriter.println();
                printWriter.println("<VARIANCE> " + this.getNcoefs());
                for (n = 0; n < this.getNcoefs(); ++n) {
                    printWriter.print(this.getVar(i - 1, n) + " ");
                }
                printWriter.println();
            }
            printWriter.println("<TRANSP> 3");
            printWriter.println("0 1 0");
            printWriter.println("0 0.7 0.3");
            printWriter.println("0 0 0");
            printWriter.println("<ENDHMM>");
            printWriter.close();
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
    }

    public void loadHTK(String string) {
        try {
            StringTokenizer stringTokenizer;
            String string2;
            BufferedReader bufferedReader = new BufferedReader(new FileReader(string));
            this.ngauss = 0;
            this.ncoefs = 0;
            while ((string2 = bufferedReader.readLine()) != null) {
                if (!string2.contains("<MEAN>")) continue;
                ++this.ngauss;
                if (this.ncoefs != 0) continue;
                stringTokenizer = new StringTokenizer(string2);
                stringTokenizer.nextToken();
                this.ncoefs = Integer.parseInt(stringTokenizer.nextToken());
            }
            bufferedReader.close();
            this.allocate();
            bufferedReader = new BufferedReader(new FileReader(string));
            int n = 0;
            while ((string2 = bufferedReader.readLine()) != null) {
                String string3;
                if (!string2.contains("<MEAN>")) continue;
                string2 = bufferedReader.readLine();
                stringTokenizer = new StringTokenizer(string2);
                int n2 = 0;
                while (stringTokenizer.hasMoreTokens()) {
                    string3 = stringTokenizer.nextToken();
                    this.setMean(n, n2, Float.parseFloat(string3));
                    ++n2;
                }
                string2 = bufferedReader.readLine();
                if (!string2.contains("<VARIANCE>")) {
                    bufferedReader.close();
                    throw new IOException();
                }
                string2 = bufferedReader.readLine();
                stringTokenizer = new StringTokenizer(string2);
                n2 = 0;
                while (stringTokenizer.hasMoreTokens()) {
                    string3 = stringTokenizer.nextToken();
                    this.setVar(n, n2, Float.parseFloat(string3));
                    ++n2;
                }
                ++n;
            }
            bufferedReader.close();
            this.precomputeDistance();
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
    }

    public void loadScaleKMeans(String string) {
        int n = 0;
        try {
            int n2;
            String string2;
            BufferedReader bufferedReader = new BufferedReader(new FileReader(string));
            while ((string2 = bufferedReader.readLine()) != null) {
                ++n;
            }
            this.ngauss = n / 2;
            bufferedReader.close();
            bufferedReader = new BufferedReader(new FileReader(string));
            string2 = bufferedReader.readLine();
            String[] stringArray = string2.split(" ");
            this.ncoefs = stringArray.length - 1;
            bufferedReader.close();
            bufferedReader = new BufferedReader(new FileReader(string));
            this.allocate();
            this.nT = 0;
            for (n2 = 0; n2 < this.ngauss; ++n2) {
                int n3;
                string2 = bufferedReader.readLine();
                stringArray = string2.split(" ");
                this.weights[n2] = Float.parseFloat(stringArray[0]);
                this.nT = (int)((float)this.nT + this.weights[n2]);
                for (n3 = 0; n3 < this.ncoefs; ++n3) {
                    this.setMean(n2, n3, Float.parseFloat(stringArray[n3 + 1]));
                }
                string2 = bufferedReader.readLine();
                stringArray = string2.split(" ");
                for (n3 = 0; n3 < this.ncoefs; ++n3) {
                    this.setVar(n2, n3, Float.parseFloat(stringArray[n3]));
                }
            }
            for (n2 = 0; n2 < this.ngauss; ++n2) {
                this.setWeight(n2, this.weights[n2] / (float)this.nT);
            }
            bufferedReader.close();
            this.precomputeDistance();
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
    }

    private void allocateWeights() {
        this.logMath = LogMath.getLogMath();
        this.weights = new float[this.ngauss];
        for (int i = 0; i < this.ngauss; ++i) {
            this.setWeight(i, 1.0f / (float)this.ngauss);
        }
    }

    public void precomputeDistance() {
        for (int i = 0; i < this.ngauss; ++i) {
            float f = 0.0f;
            for (int j = 0; j < this.ncoefs; ++j) {
                f += this.logMath.linearToLog(this.getVar(i, j));
            }
            this.logPreComputedGaussianFactor[i] = (f += this.logMath.linearToLog(Math.PI * 2) * (float)this.ncoefs) * 0.5f;
        }
    }

    private void allocate() {
        if (this.weights == null) {
            this.allocateWeights();
        }
        if (this.means == null) {
            this.loglikes = new float[this.ngauss];
            this.means = new float[this.ngauss][this.ncoefs];
            this.covar = new float[this.ngauss][this.ncoefs];
            this.logPreComputedGaussianFactor = new float[this.ngauss];
        }
    }

    public void computeLogLikes(float[] fArray) {
        float f = 0.0f;
        for (int i = 0; i < this.ngauss; ++i) {
            f = 0.0f;
            for (int j = 0; j < fArray.length; ++j) {
                float f2 = fArray[j] - this.means[i][j];
                f += f2 * f2 * this.covar[i][j];
            }
            if (Float.isNaN(f -= this.logPreComputedGaussianFactor[i])) {
                System.err.println("gs2 is Nan, converting to 0 debug " + i + ' ' + this.logPreComputedGaussianFactor[i] + ' ' + this.means[i][0] + ' ' + this.covar[i][0]);
                f = -3.4028235E38f;
            }
            if (f < -3.4028235E38f) {
                f = -3.4028235E38f;
            }
            this.loglikes[i] = this.weights[i] + f;
        }
    }

    public float getLogLike() {
        float f = this.loglikes[0];
        for (int i = 1; i < this.ngauss; ++i) {
            f = this.logMath.addAsLinear(f, this.loglikes[i]);
        }
        return f;
    }

    public int getWinningGauss() {
        int n = 0;
        for (int i = 1; i < this.ngauss; ++i) {
            if (!(this.loglikes[i] > this.loglikes[n])) continue;
            n = i;
        }
        return n;
    }

    public int getNcoefs() {
        return this.ncoefs;
    }

    public GMMDiag getMarginal(boolean[] blArray) {
        int n;
        int n2 = 0;
        for (boolean n3 : blArray) {
            if (!n3) continue;
            ++n2;
        }
        Object object = new GMMDiag(this.getNgauss(), n2);
        int n4 = 0;
        for (n = 0; n < this.ncoefs; ++n) {
            if (!blArray[n]) continue;
            for (int i = 0; i < this.ngauss; ++i) {
                ((GMMDiag)object).setMean(i, n4, this.getMean(i, n));
                ((GMMDiag)object).setVar(i, n4, this.getVar(i, n));
            }
            ++n4;
        }
        for (n = 0; n < this.ngauss; ++n) {
            ((GMMDiag)object).setWeight(n, this.getWeight(n));
        }
        ((GMMDiag)object).precomputeDistance();
        return object;
    }

    public GMMDiag merge(GMMDiag gMMDiag, float f) {
        int n;
        GMMDiag gMMDiag2 = new GMMDiag(this.getNgauss() + gMMDiag.getNgauss(), this.getNcoefs());
        for (n = 0; n < this.getNgauss(); ++n) {
            System.arraycopy(this.means[n], 0, gMMDiag2.means[n], 0, this.getNcoefs());
            System.arraycopy(this.covar[n], 0, gMMDiag2.covar[n], 0, this.getNcoefs());
            gMMDiag2.setWeight(n, this.getWeight(n) * f);
        }
        for (n = 0; n < gMMDiag.getNgauss(); ++n) {
            System.arraycopy(gMMDiag.means[n], 0, gMMDiag2.means[this.ngauss + n], 0, this.getNcoefs());
            System.arraycopy(gMMDiag.covar[n], 0, gMMDiag2.covar[this.ngauss + n], 0, this.getNcoefs());
            gMMDiag2.setWeight(this.ngauss + n, gMMDiag.getWeight(n) * (1.0f - f));
        }
        gMMDiag2.precomputeDistance();
        return gMMDiag2;
    }

    public GMMDiag getGauss(int n) {
        GMMDiag gMMDiag = new GMMDiag(1, this.getNcoefs());
        System.arraycopy(this.means[n], 0, gMMDiag.means[0], 0, this.getNcoefs());
        System.arraycopy(this.covar[n], 0, gMMDiag.covar[0], 0, this.getNcoefs());
        gMMDiag.setWeight(0, 1.0f);
        gMMDiag.precomputeDistance();
        return gMMDiag;
    }

    public void setNom(String string) {
        this.nom = string;
    }

    public boolean isEqual(GMMDiag gMMDiag) {
        if (this.getNgauss() != gMMDiag.getNgauss()) {
            return false;
        }
        if (this.getNgauss() != gMMDiag.getNcoefs()) {
            return false;
        }
        for (int i = 0; i < this.getNgauss(); ++i) {
            if (this.isDiff(this.getWeight(i), gMMDiag.getWeight(i))) {
                return false;
            }
            for (int j = 0; j < this.getNcoefs(); ++j) {
                if (this.isDiff(this.getMean(i, j), gMMDiag.getMean(i, j))) {
                    return false;
                }
                if (!this.isDiff(this.getVar(i, j), gMMDiag.getVar(i, j))) continue;
                return false;
            }
        }
        return true;
    }

    private boolean isDiff(float f, float f2) {
        return (double)Math.abs(1.0f - f2 / f) > 0.01;
    }

    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < this.getNgauss(); ++i) {
            stringBuilder.append(this.getMean(i, 0)).append(' ').append(this.getVar(i, 0)).append('\n');
        }
        return stringBuilder.toString();
    }
}

