/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.trainer;

import edu.cmu.sphinx.frontend.Data;
import edu.cmu.sphinx.frontend.DataEndSignal;
import edu.cmu.sphinx.frontend.DataProcessingException;
import edu.cmu.sphinx.frontend.DataStartSignal;
import edu.cmu.sphinx.frontend.FrontEnd;
import edu.cmu.sphinx.frontend.Signal;
import edu.cmu.sphinx.frontend.util.StreamCepstrumSource;
import edu.cmu.sphinx.linguist.acoustic.HMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore;
import edu.cmu.sphinx.trainer.Edge;
import edu.cmu.sphinx.trainer.Learner;
import edu.cmu.sphinx.trainer.Node;
import edu.cmu.sphinx.trainer.Utterance;
import edu.cmu.sphinx.trainer.UtteranceGraph;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Component;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.logging.Logger;

public class BaumWelchLearner
implements Learner {
    @S4Component(type=FrontEnd.class)
    public static final String FRONT_END = "frontend";
    private FrontEnd frontEnd;
    @S4Component(type=StreamCepstrumSource.class)
    public static final String DATA_SOURCE = "source";
    private StreamCepstrumSource dataSource;
    private LogMath logMath;
    private static Logger logger = Logger.getLogger("edu.cmu.sphinx.trainer.BaumWelch");
    private Data curFeature;
    private UtteranceGraph graph;
    private TrainerScore[][] scoreArray;
    private int lastFeatureIndex;
    private int currentFeatureIndex;
    private float[] betas;
    private float[] outputProbs;
    private float[] componentScores;
    private float[] probCurrentFrame;
    private float totalLogScore;

    @Override
    public void newProperties(PropertySheet propertySheet) throws PropertyException {
        this.logMath = LogMath.getLogMath();
        this.dataSource = (StreamCepstrumSource)propertySheet.getComponent(DATA_SOURCE);
        this.frontEnd = (FrontEnd)propertySheet.getComponent(FRONT_END);
        this.frontEnd.setDataSource(this.dataSource);
    }

    protected FrontEnd getFrontEnd() {
        return this.frontEnd;
    }

    @Override
    public void setUtterance(Utterance utterance) throws IOException {
        String string = utterance.toString();
        FileInputStream fileInputStream = new FileInputStream(string);
        this.dataSource.setInputStream(fileInputStream, false);
    }

    private boolean getFeature() {
        try {
            this.curFeature = this.frontEnd.getData();
            if (this.curFeature == null) {
                return false;
            }
            if (this.curFeature instanceof DataStartSignal) {
                this.curFeature = this.frontEnd.getData();
                if (this.curFeature == null) {
                    return false;
                }
            }
            if (this.curFeature instanceof DataEndSignal) {
                return false;
            }
            if (this.curFeature instanceof Signal) {
                throw new Error("Can't score non-content feature");
            }
        }
        catch (DataProcessingException dataProcessingException) {
            System.out.println("DataProcessingException " + dataProcessingException);
            dataProcessingException.printStackTrace();
            return false;
        }
        return true;
    }

    @Override
    public void start() {
    }

    @Override
    public void stop() {
    }

    @Override
    public void initializeComputation(Utterance utterance, UtteranceGraph utteranceGraph) throws IOException {
        this.setUtterance(utterance);
        this.setGraph(utteranceGraph);
    }

    @Override
    public void setGraph(UtteranceGraph utteranceGraph) {
        this.graph = utteranceGraph;
    }

    private TrainerScore[][] prepareScore() {
        Object object;
        int n;
        ArrayList<TrainerScore[]> arrayList = new ArrayList<TrainerScore[]>();
        int n2 = this.graph.size();
        TrainerScore[] trainerScoreArray = new TrainerScore[n2];
        this.betas = new float[n2];
        this.outputProbs = new float[n2];
        this.probCurrentFrame = new float[n2];
        Node node = this.graph.getInitialNode();
        int n3 = this.graph.indexOf(node);
        for (n = 0; n < n2; ++n) {
            this.probCurrentFrame[n] = -3.4028235E38f;
        }
        this.probCurrentFrame[n3] = 0.0f;
        node.startOutgoingEdgeIterator();
        while (node.hasMoreOutgoingEdges()) {
            Edge edge = node.nextOutgoingEdge();
            Node node2 = edge.getDestination();
            int n4 = this.graph.indexOf(node2);
            if (!node2.isType("STATE")) {
                this.probCurrentFrame[n4] = 0.0f;
                continue;
            }
            object = (HMMState)node2.getObject();
            if (!object.isEmitting()) {
                this.probCurrentFrame[n4] = 0.0f;
            }
            assert (false);
        }
        this.lastFeatureIndex = 0;
        while (this.getFeature()) {
            this.forwardPass(trainerScoreArray);
            arrayList.add(trainerScoreArray);
            ++this.lastFeatureIndex;
        }
        logger.info("Feature frames read: " + this.lastFeatureIndex);
        for (n = 0; n < this.probCurrentFrame.length; ++n) {
            this.probCurrentFrame[n] = -3.4028235E38f;
        }
        Node node3 = this.graph.getFinalNode();
        int n5 = this.graph.indexOf(node3);
        this.probCurrentFrame[n5] = 0.0f;
        node3.startIncomingEdgeIterator();
        while (node3.hasMoreIncomingEdges()) {
            Edge edge = node3.nextIncomingEdge();
            object = edge.getSource();
            int n6 = this.graph.indexOf((Node)object);
            if (!((Node)object).isType("STATE")) {
                this.probCurrentFrame[n6] = 0.0f;
                assert (false);
                continue;
            }
            HMMState hMMState = (HMMState)((Node)object).getObject();
            if (hMMState.isEmitting()) continue;
            this.probCurrentFrame[n6] = 0.0f;
        }
        return (TrainerScore[][])arrayList.toArray((T[])new TrainerScore[arrayList.size()][]);
    }

    @Override
    public TrainerScore[] getScore() {
        if (this.scoreArray == null) {
            this.scoreArray = this.prepareScore();
            this.currentFeatureIndex = this.lastFeatureIndex;
        }
        --this.currentFeatureIndex;
        if (this.currentFeatureIndex >= 0) {
            float f = -3.4028235E38f;
            TrainerScore[] trainerScoreArray = this.scoreArray[this.currentFeatureIndex];
            assert (trainerScoreArray.length == this.betas.length);
            this.backwardPass(trainerScoreArray);
            for (int i = 0; i < this.betas.length; ++i) {
                trainerScoreArray[i].setGamma();
                f = this.logMath.addAsLinear(f, trainerScoreArray[i].getGamma());
            }
            if (this.currentFeatureIndex == this.lastFeatureIndex - 1) {
                TrainerScore.setLogLikelihood(f);
                this.totalLogScore = f;
            } else if (Math.abs(this.totalLogScore - f) > Math.abs(this.totalLogScore)) {
                System.out.println("WARNING: log probabilities differ: " + this.totalLogScore + " and " + f);
            }
            return trainerScoreArray;
        }
        this.scoreArray = null;
        return null;
    }

    private float calculateScores(int n) {
        float f;
        SenoneHMMState senoneHMMState = (SenoneHMMState)this.graph.getNode(n).getObject();
        if (senoneHMMState != null && senoneHMMState.isEmitting()) {
            this.componentScores = senoneHMMState.calculateComponentScore(this.curFeature);
            f = senoneHMMState.getScore(this.curFeature);
            assert (this.componentScores.length == 1);
        } else {
            this.componentScores = null;
            f = 0.0f;
        }
        return f;
    }

    private void forwardPass(TrainerScore[] trainerScoreArray) {
        SenoneHMM senoneHMM;
        HMMState hMMState;
        Node node;
        int n;
        for (int i = 0; i < this.graph.size(); ++i) {
            this.outputProbs[i] = this.calculateScores(i);
            trainerScoreArray[i] = new TrainerScore(this.curFeature, this.outputProbs[i], (HMMState)this.graph.getNode(i).getObject(), this.componentScores);
            trainerScoreArray[i].setAlpha(this.probCurrentFrame[i]);
        }
        float[] fArray = this.probCurrentFrame;
        this.probCurrentFrame = new float[this.graph.size()];
        for (n = 0; n < this.graph.size(); ++n) {
            node = this.graph.getNode(n);
            if (!node.isType("STATE")) continue;
            hMMState = (SenoneHMMState)node.getObject();
            senoneHMM = (SenoneHMM)((SenoneHMMState)hMMState).getHMM();
            if (!((SenoneHMMState)hMMState).isEmitting()) continue;
            this.probCurrentFrame[n] = -3.4028235E38f;
            node.startIncomingEdgeIterator();
            while (node.hasMoreIncomingEdges()) {
                float f;
                Node node2 = node.nextIncomingEdge().getSource();
                int n2 = this.graph.indexOf(node2);
                HMMState hMMState2 = (HMMState)node2.getObject();
                if (hMMState2 != null) {
                    assert (!hMMState2.isEmitting() || hMMState2.getHMM() == senoneHMM);
                    f = !hMMState2.isEmitting() ? 0.0f : senoneHMM.getTransitionProbability(hMMState2.getState(), ((SenoneHMMState)hMMState).getState());
                } else {
                    f = 0.0f;
                }
                this.probCurrentFrame[n] = this.logMath.addAsLinear(this.probCurrentFrame[n], fArray[n2] + f);
            }
            int n3 = n;
            this.probCurrentFrame[n3] = this.probCurrentFrame[n3] + this.outputProbs[n];
            trainerScoreArray[n].setAlpha(this.probCurrentFrame[n]);
        }
        for (n = 0; n < this.graph.size(); ++n) {
            node = this.graph.getNode(n);
            hMMState = null;
            senoneHMM = null;
            if (node.isType("STATE")) {
                hMMState = (HMMState)node.getObject();
                senoneHMM = (SenoneHMM)hMMState.getHMM();
                if (hMMState.isEmitting()) {
                    continue;
                }
            } else if (this.graph.isInitialNode(node)) {
                trainerScoreArray[n].setAlpha(-3.4028235E38f);
                this.probCurrentFrame[n] = -3.4028235E38f;
                continue;
            }
            this.probCurrentFrame[n] = -3.4028235E38f;
            node.startIncomingEdgeIterator();
            while (node.hasMoreIncomingEdges()) {
                float f;
                Node node3 = node.nextIncomingEdge().getSource();
                int n4 = this.graph.indexOf(node3);
                if (node3.isType("STATE")) {
                    HMMState hMMState3 = (HMMState)node3.getObject();
                    assert (!hMMState3.isEmitting() || hMMState3.getHMM() == senoneHMM);
                    f = !hMMState3.isEmitting() ? 0.0f : senoneHMM.getTransitionProbability(hMMState3.getState(), hMMState.getState());
                } else {
                    f = 0.0f;
                }
                this.probCurrentFrame[n] = this.logMath.addAsLinear(this.probCurrentFrame[n], this.probCurrentFrame[n4] + f);
            }
            trainerScoreArray[n].setAlpha(this.probCurrentFrame[n]);
        }
    }

    private void backwardPass(TrainerScore[] trainerScoreArray) {
        HMMState hMMState;
        Node node;
        int n;
        for (int i = 0; i < this.graph.size(); ++i) {
            this.outputProbs[i] = trainerScoreArray[i].getScore();
            trainerScoreArray[i].setBeta(this.probCurrentFrame[i]);
        }
        float[] fArray = this.probCurrentFrame;
        this.probCurrentFrame = new float[this.graph.size()];
        for (n = 0; n < this.graph.size(); ++n) {
            node = this.graph.getNode(n);
            if (!node.isType("STATE")) continue;
            hMMState = (HMMState)node.getObject();
            SenoneHMM senoneHMM = (SenoneHMM)hMMState.getHMM();
            if (!hMMState.isEmitting()) continue;
            this.probCurrentFrame[n] = -3.4028235E38f;
            node.startOutgoingEdgeIterator();
            while (node.hasMoreOutgoingEdges()) {
                float f;
                Node node2 = node.nextOutgoingEdge().getDestination();
                int n2 = this.graph.indexOf(node2);
                HMMState hMMState2 = (HMMState)node2.getObject();
                if (hMMState2 != null) {
                    assert (!hMMState2.isEmitting() || hMMState2.getHMM() == senoneHMM);
                    f = hMMState2.getHMM() != senoneHMM ? 0.0f : senoneHMM.getTransitionProbability(hMMState.getState(), hMMState2.getState());
                } else {
                    f = 0.0f;
                }
                this.probCurrentFrame[n] = this.logMath.addAsLinear(this.probCurrentFrame[n], fArray[n2] + f + this.outputProbs[n2]);
            }
            trainerScoreArray[n].setBeta(this.probCurrentFrame[n]);
        }
        for (n = this.graph.size() - 1; n >= 0; --n) {
            node = this.graph.getNode(n);
            hMMState = null;
            if (node.isType("STATE")) {
                hMMState = (HMMState)node.getObject();
                if (hMMState.isEmitting()) {
                    continue;
                }
            } else if (this.graph.isFinalNode(node)) {
                trainerScoreArray[n].setBeta(-3.4028235E38f);
                this.probCurrentFrame[n] = -3.4028235E38f;
                continue;
            }
            this.probCurrentFrame[n] = -3.4028235E38f;
            node.startOutgoingEdgeIterator();
            while (node.hasMoreOutgoingEdges()) {
                float f;
                Node node3 = node.nextOutgoingEdge().getDestination();
                int n3 = this.graph.indexOf(node3);
                if (node3.isType("STATE")) {
                    HMMState hMMState3 = (HMMState)node3.getObject();
                    assert (hMMState3.isEmitting() || hMMState3 == hMMState);
                    f = 0.0f;
                } else {
                    f = 0.0f;
                }
                this.probCurrentFrame[n] = this.logMath.addAsLinear(this.probCurrentFrame[n], this.probCurrentFrame[n3] + f);
            }
            trainerScoreArray[n].setBeta(this.probCurrentFrame[n]);
        }
    }
}

