/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;

import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.HMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianMixture;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianWeights;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.HMMManager;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.MixtureComponent;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Senone;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.Buffer;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore;
import edu.cmu.sphinx.util.LogMath;
import java.io.IOException;
import java.util.HashMap;
import java.util.logging.Logger;

class HMMPoolManager {
    private HMMManager hmmManager;
    private HashMap<Object, Integer> indexMap;
    private Pool<float[]> meansPool;
    private Pool<float[]> variancePool;
    private Pool<float[][]> matrixPool;
    private GaussianWeights mixtureWeights;
    private Pool<Buffer> meansBufferPool;
    private Pool<Buffer> varianceBufferPool;
    private Pool<Buffer[]> matrixBufferPool;
    private Pool<Buffer> mixtureWeightsBufferPool;
    private Pool<Senone> senonePool;
    private LogMath logMath;
    private float logMixtureWeightFloor;
    private float logTransitionProbabilityFloor;
    private float varianceFloor;
    private float logLikelihood;
    private float currentLogLikelihood;
    private static Logger logger = Logger.getLogger("edu.cmu.sphinx.linguist.acoustic.HMMPoolManager");

    protected HMMPoolManager(Loader loader) throws IOException {
        loader.load();
        this.hmmManager = loader.getHMMManager();
        this.indexMap = new HashMap();
        this.meansPool = loader.getMeansPool();
        this.variancePool = loader.getVariancePool();
        this.mixtureWeights = loader.getMixtureWeights();
        this.matrixPool = loader.getTransitionMatrixPool();
        this.senonePool = loader.getSenonePool();
        this.createBuffers();
        this.logLikelihood = 0.0f;
        this.logMath = LogMath.getLogMath();
    }

    protected void resetBuffers() {
        this.createBuffers();
        this.logLikelihood = 0.0f;
    }

    protected void createBuffers() {
        this.meansBufferPool = this.create1DPoolBuffer(this.meansPool, false);
        this.varianceBufferPool = this.create1DPoolBuffer(this.variancePool, false);
        this.matrixBufferPool = this.create2DPoolBuffer(this.matrixPool, true);
        this.mixtureWeightsBufferPool = this.createWeightsPoolBuffer(this.mixtureWeights);
    }

    private Pool<Buffer> create1DPoolBuffer(Pool<float[]> pool, boolean bl) {
        Pool<Buffer> pool2 = new Pool<Buffer>(pool.getName());
        for (int i = 0; i < pool.size(); ++i) {
            float[] fArray = pool.get(i);
            this.indexMap.put(fArray, i);
            Buffer buffer = new Buffer(fArray.length, bl, i);
            pool2.put(i, buffer);
        }
        return pool2;
    }

    private Pool<Buffer> createWeightsPoolBuffer(GaussianWeights gaussianWeights) {
        Pool<Buffer> pool = new Pool<Buffer>(gaussianWeights.getName());
        int n = gaussianWeights.getStatesNum();
        int n2 = gaussianWeights.getStreamsNum();
        int n3 = gaussianWeights.getGauPerState();
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n; ++j) {
                int n4 = i * n + j;
                Buffer buffer = new Buffer(n3, true, n4);
                pool.put(n4, buffer);
            }
        }
        return pool;
    }

    private Pool<Buffer[]> create2DPoolBuffer(Pool<float[][]> pool, boolean bl) {
        Pool<Buffer[]> pool2 = new Pool<Buffer[]>(pool.getName());
        for (int i = 0; i < pool.size(); ++i) {
            float[][] fArray = pool.get(i);
            this.indexMap.put(fArray, i);
            int n = fArray.length;
            Buffer[] bufferArray = new Buffer[n];
            for (int j = 0; j < n; ++j) {
                bufferArray[j] = new Buffer(fArray[j].length, bl, j);
            }
            pool2.put(i, bufferArray);
        }
        return pool2;
    }

    protected void accumulate(int n, TrainerScore[] trainerScoreArray) {
        this.accumulate(n, trainerScoreArray, null);
    }

    protected void accumulate(int n, TrainerScore[] trainerScoreArray, TrainerScore[] trainerScoreArray2) {
        TrainerScore trainerScore = trainerScoreArray[n];
        this.currentLogLikelihood = 0.0f;
        this.logLikelihood -= trainerScoreArray[0].getScalingFactor();
        SenoneHMMState senoneHMMState = (SenoneHMMState)trainerScore.getState();
        if (senoneHMMState == null) {
            int n2 = trainerScore.getSenoneID();
            if (n2 == -1) {
                this.accumulateMean(n2, trainerScoreArray[n]);
                this.accumulateVariance(n2, trainerScoreArray[n]);
                this.accumulateMixture(n2, trainerScoreArray[n]);
                this.accumulateTransition(n2, n, trainerScoreArray, trainerScoreArray2);
            }
        } else if (senoneHMMState.isEmitting()) {
            int n3 = this.senonePool.indexOf(senoneHMMState.getSenone());
            this.accumulateMixture(n3, trainerScoreArray[n]);
            this.accumulateTransition(n3, n, trainerScoreArray, trainerScoreArray2);
        }
    }

    private void accumulateMean(int n, TrainerScore trainerScore) {
        if (n == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateMean(i, trainerScore);
            }
        } else {
            GaussianMixture gaussianMixture = (GaussianMixture)this.senonePool.get(n);
            MixtureComponent[] mixtureComponentArray = gaussianMixture.getMixtureComponents();
            for (int i = 0; i < mixtureComponentArray.length; ++i) {
                float[] fArray = mixtureComponentArray[i].getMean();
                int n2 = this.indexMap.get(fArray);
                assert (n2 >= 0);
                assert (n2 == n);
                Buffer buffer = this.meansBufferPool.get(n2);
                float[] fArray2 = ((FloatData)trainerScore.getData()).getValues();
                double[] dArray = new double[fArray2.length];
                float f = trainerScore.getComponentGamma()[i];
                double d = this.logMath.logToLinear(f -= this.currentLogLikelihood);
                for (int j = 0; j < dArray.length; ++j) {
                    dArray[j] = (double)fArray2[j] * d;
                }
                buffer.accumulate(dArray, d);
            }
        }
    }

    private void accumulateVariance(int n, TrainerScore trainerScore) {
        if (n == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateVariance(i, trainerScore);
            }
        } else {
            GaussianMixture gaussianMixture = (GaussianMixture)this.senonePool.get(n);
            MixtureComponent[] mixtureComponentArray = gaussianMixture.getMixtureComponents();
            for (int i = 0; i < mixtureComponentArray.length; ++i) {
                float[] fArray = mixtureComponentArray[i].getMean();
                float[] fArray2 = mixtureComponentArray[i].getVariance();
                int n2 = this.indexMap.get(fArray2);
                Buffer buffer = this.varianceBufferPool.get(n2);
                float[] fArray3 = ((FloatData)trainerScore.getData()).getValues();
                double[] dArray = new double[fArray3.length];
                float f = trainerScore.getComponentGamma()[i];
                double d = this.logMath.logToLinear(f -= this.currentLogLikelihood);
                for (int j = 0; j < dArray.length; ++j) {
                    dArray[j] = fArray3[j] - fArray[j];
                    int n3 = j;
                    dArray[n3] = dArray[n3] * (dArray[j] * d);
                }
                buffer.accumulate(dArray, d);
            }
        }
    }

    private void accumulateMixture(int n, TrainerScore trainerScore) {
        if (n == -1) {
            for (int i = 0; i < this.senonePool.size(); ++i) {
                this.accumulateMixture(i, trainerScore);
            }
        } else {
            Buffer buffer = this.mixtureWeightsBufferPool.get(n);
            for (int i = 0; i < this.mixtureWeights.getGauPerState(); ++i) {
                float f = trainerScore.getComponentGamma()[i];
                buffer.logAccumulate(f -= this.currentLogLikelihood, i, this.logMath);
            }
        }
    }

    private void accumulateStateTransition(int n, TrainerScore[] trainerScoreArray, TrainerScore[] trainerScoreArray2) {
        HMMState hMMState = trainerScoreArray[n].getState();
        if (hMMState == null) {
            return;
        }
        int n2 = hMMState.getState();
        SenoneHMM senoneHMM = (SenoneHMM)hMMState.getHMM();
        float[][] fArray = senoneHMM.getTransitionMatrix();
        int n3 = this.indexMap.get(fArray);
        Buffer[] bufferArray = this.matrixBufferPool.get(n3);
        float[] fArray2 = fArray[n2];
        for (int i = 0; i < fArray2.length; ++i) {
            if (fArray2[i] == -3.4028235E38f) continue;
            int n4 = i - n2;
            int n5 = n + n4;
            assert (trainerScoreArray2[n5].getState() == null || trainerScoreArray2[n5].getState().getHMM() == senoneHMM);
            float f = trainerScoreArray[n].getAlpha();
            float f2 = trainerScoreArray2[n5].getBeta();
            float f3 = fArray2[i];
            float f4 = trainerScoreArray2[n5].getScore();
            float f5 = f + f2 + f3 + f4;
            bufferArray[n2].logAccumulate(f5 -= this.currentLogLikelihood, i, this.logMath);
        }
    }

    private void accumulateStateTransition(int n, SenoneHMM senoneHMM, float f) {
        float[][] fArray = senoneHMM.getTransitionMatrix();
        float[] fArray2 = fArray[n];
        int n2 = this.indexMap.get(fArray);
        Buffer[] bufferArray = this.matrixBufferPool.get(n2);
        for (int i = 0; i < fArray2.length; ++i) {
            if (fArray2[i] == -3.4028235E38f) continue;
            bufferArray[n].logAccumulate(f, i, this.logMath);
        }
    }

    private void accumulateTransition(int n, int n2, TrainerScore[] trainerScoreArray, TrainerScore[] trainerScoreArray2) {
        if (n == -1) {
            for (HMM hMM : this.hmmManager) {
                for (int i = 0; i < hMM.getOrder(); ++i) {
                    this.accumulateStateTransition(i, (SenoneHMM)hMM, trainerScoreArray[n2].getScore());
                }
            }
        } else if (trainerScoreArray2 != null) {
            this.accumulateStateTransition(n2, trainerScoreArray, trainerScoreArray2);
        }
    }

    protected void updateLogLikelihood() {
    }

    protected float normalize() {
        this.normalizePool(this.meansBufferPool);
        this.normalizePool(this.varianceBufferPool);
        this.logNormalizePool(this.mixtureWeightsBufferPool);
        this.logNormalize2DPool(this.matrixBufferPool, this.matrixPool);
        return this.logLikelihood;
    }

    private void normalizePool(Pool<Buffer> pool) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer buffer = pool.get(i);
            if (!buffer.wasUsed()) continue;
            buffer.normalize();
        }
    }

    private void logNormalizePool(Pool<Buffer> pool) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer buffer = pool.get(i);
            if (!buffer.wasUsed()) continue;
            buffer.logNormalize();
        }
    }

    private void logNormalize2DPool(Pool<Buffer[]> pool, Pool<float[][]> pool2) {
        assert (pool != null);
        for (int i = 0; i < pool.size(); ++i) {
            Buffer[] bufferArray = pool.get(i);
            float[][] fArray = pool2.get(i);
            for (int j = 0; j < bufferArray.length; ++j) {
                if (!bufferArray[j].wasUsed()) continue;
                bufferArray[j].logNormalizeNonZero(fArray[j]);
            }
        }
    }

    protected void update() {
        this.updateMeans();
        this.updateVariances();
        this.recomputeMixtureComponents();
        this.updateMixtureWeights();
        this.updateTransitionMatrices();
    }

    private void copyVector(float[] fArray, float[] fArray2) {
        assert (fArray.length == fArray2.length);
        System.arraycopy(fArray, 0, fArray2, 0, fArray.length);
    }

    private void updateMeans() {
        assert (this.meansPool.size() == this.meansBufferPool.size());
        for (int i = 0; i < this.meansPool.size(); ++i) {
            float[] fArray = this.meansPool.get(i);
            Buffer buffer = this.meansBufferPool.get(i);
            if (buffer.wasUsed()) {
                float[] fArray2 = buffer.getValues();
                this.copyVector(fArray2, fArray);
                continue;
            }
            logger.info("Senone " + i + " not used.");
        }
    }

    private void updateVariances() {
        assert (this.variancePool.size() == this.varianceBufferPool.size());
        for (int i = 0; i < this.variancePool.size(); ++i) {
            float[] fArray = this.meansPool.get(i);
            float[] fArray2 = this.variancePool.get(i);
            Buffer buffer = this.varianceBufferPool.get(i);
            if (!buffer.wasUsed()) continue;
            float[] fArray3 = buffer.getValues();
            assert (fArray.length == fArray3.length);
            for (int j = 0; j < fArray.length; ++j) {
                int n = j;
                fArray3[n] = fArray3[n] - fArray[j] * fArray[j];
                if (!(fArray3[j] < this.varianceFloor)) continue;
                fArray3[j] = this.varianceFloor;
            }
            this.copyVector(fArray3, fArray2);
        }
    }

    private void recomputeMixtureComponents() {
        for (int i = 0; i < this.senonePool.size(); ++i) {
            MixtureComponent[] mixtureComponentArray;
            GaussianMixture gaussianMixture = (GaussianMixture)this.senonePool.get(i);
            for (MixtureComponent mixtureComponent : mixtureComponentArray = gaussianMixture.getMixtureComponents()) {
                mixtureComponent.precomputeDistance();
            }
        }
    }

    private void updateMixtureWeights() {
        int n = this.mixtureWeights.getStatesNum();
        int n2 = this.mixtureWeights.getStreamsNum();
        assert (n * n2 == this.mixtureWeightsBufferPool.size());
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n; ++j) {
                int n3 = i * n + j;
                Buffer buffer = this.mixtureWeightsBufferPool.get(n3);
                if (!buffer.wasUsed()) continue;
                if (buffer.logFloor(this.logMixtureWeightFloor)) {
                    buffer.logNormalizeToSum(this.logMath);
                }
                float[] fArray = buffer.getValues();
                this.mixtureWeights.put(j, i, fArray);
            }
        }
    }

    private void updateTransitionMatrices() {
        assert (this.matrixPool.size() == this.matrixBufferPool.size());
        for (int i = 0; i < this.matrixPool.size(); ++i) {
            float[][] fArray = this.matrixPool.get(i);
            Buffer[] bufferArray = this.matrixBufferPool.get(i);
            for (int j = 0; j < fArray.length; ++j) {
                Buffer buffer = bufferArray[j];
                if (!buffer.wasUsed()) continue;
                for (int k = 0; k < fArray[j].length; ++k) {
                    float f = buffer.getValue(k);
                    if (f == -3.4028235E38f) continue;
                    assert (fArray[j][k] != -3.4028235E38f);
                    if (!(f < this.logTransitionProbabilityFloor)) continue;
                    buffer.setValue(k, this.logTransitionProbabilityFloor);
                }
                buffer.logNormalizeToSum(this.logMath);
                this.copyVector(buffer.getValues(), fArray[j]);
            }
        }
    }
}

