/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.pr.CachedDotTransitionIterator;
import cc.mallet.fst.semi_supervised.pr.PRAuxiliaryModel;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

public class SumLatticePR
implements SumLattice {
    private static Logger logger = MalletLogger.getLogger(SumLatticePR.class.getName());
    protected double totalWeight;
    protected int latticeLength;
    protected double[][] gammas;
    protected double[][][] xis;
    protected LabelVector[] labelings;
    protected Transducer transducer;
    protected LatticeNode[][] nodes;
    private Sequence input;

    public SumLatticePR(Transducer trans, int index, Sequence input2, Sequence output2, PRAuxiliaryModel auxModel, double[][][] cachedDots, boolean incrementConstraints, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet, boolean saveXis) {
        Transducer.State destination;
        CachedDotTransitionIterator iter2;
        int i;
        int ip;
        assert (output2 == null || input2.size() == output2.size());
        this.input = input2;
        this.transducer = trans;
        this.latticeLength = input2.size() + 1;
        int numStates = this.transducer.numStates();
        this.nodes = new LatticeNode[this.latticeLength][numStates];
        this.gammas = new double[this.latticeLength][numStates];
        if (saveXis) {
            this.xis = new double[this.latticeLength][numStates][numStates];
        }
        double[][] outputCounts = null;
        if (outputAlphabet != null) {
            outputCounts = new double[this.latticeLength][outputAlphabet.size()];
        }
        for (int i2 = 0; i2 < numStates; ++i2) {
            for (ip = 0; ip < this.latticeLength; ++ip) {
                this.gammas[ip][i2] = Double.NEGATIVE_INFINITY;
            }
            if (!saveXis) continue;
            for (int j = 0; j < numStates; ++j) {
                for (int ip2 = 0; ip2 < this.latticeLength; ++ip2) {
                    this.xis[ip2][i2][j] = Double.NEGATIVE_INFINITY;
                }
            }
        }
        boolean atLeastOneInitialState = false;
        for (i = 0; i < numStates; ++i) {
            double initialWeight = this.transducer.getState(i).getInitialWeight();
            if (!(initialWeight > Double.NEGATIVE_INFINITY)) continue;
            this.getLatticeNode((int)0, (int)i).alpha = initialWeight;
            atLeastOneInitialState = true;
        }
        if (!atLeastOneInitialState) {
            logger.warning("There are no starting states!");
        }
        for (ip = 0; ip < this.latticeLength - 1; ++ip) {
            for (int i3 = 0; i3 < numStates; ++i3) {
                if (this.nodes[ip][i3] == null || this.nodes[ip][i3].alpha == Double.NEGATIVE_INFINITY) continue;
                Transducer.State s = this.transducer.getState(i3);
                iter2 = new CachedDotTransitionIterator((CRF.State)s, input2, ip, null, cachedDots[ip][i3]);
                auxModel.preProcess(index, ip, input2);
                while (iter2.hasNext()) {
                    destination = iter2.next();
                    LatticeNode destinationNode = this.getLatticeNode(ip + 1, destination.getIndex());
                    destinationNode.output = iter2.getOutput();
                    double transitionWeight = iter2.getWeight();
                    destinationNode.alpha = Transducer.sumLogProb(destinationNode.alpha, this.nodes[ip][i3].alpha + (transitionWeight += auxModel.getWeight(index, ip, input2, iter2)));
                }
            }
        }
        this.totalWeight = Double.NEGATIVE_INFINITY;
        for (i = 0; i < numStates; ++i) {
            if (this.nodes[this.latticeLength - 1][i] == null) continue;
            this.totalWeight = Transducer.sumLogProb(this.totalWeight, this.nodes[this.latticeLength - 1][i].alpha + this.transducer.getState(i).getFinalWeight());
        }
        if (this.totalWeight == Double.NEGATIVE_INFINITY) {
            return;
        }
        for (i = 0; i < numStates; ++i) {
            if (this.nodes[this.latticeLength - 1][i] == null) continue;
            Transducer.State s = this.transducer.getState(i);
            this.nodes[this.latticeLength - 1][i].beta = s.getFinalWeight();
            this.gammas[this.latticeLength - 1][i] = this.nodes[this.latticeLength - 1][i].alpha + this.nodes[this.latticeLength - 1][i].beta - this.totalWeight;
            if (incrementor == null) continue;
            double p = Math.exp(this.gammas[this.latticeLength - 1][i]);
            assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", gamma=" + this.gammas[this.latticeLength - 1][i];
            incrementor.incrementFinalState(s, p);
        }
        for (ip = this.latticeLength - 2; ip >= 0; --ip) {
            for (int i4 = 0; i4 < numStates; ++i4) {
                if (this.nodes[ip][i4] == null || this.nodes[ip][i4].alpha == Double.NEGATIVE_INFINITY) continue;
                Transducer.State s = this.transducer.getState(i4);
                iter2 = new CachedDotTransitionIterator((CRF.State)s, input2, ip, null, cachedDots[ip][i4]);
                auxModel.preProcess(index, ip, input2);
                while (iter2.hasNext()) {
                    destination = iter2.next();
                    int j = destination.getIndex();
                    LatticeNode destinationNode = this.nodes[ip + 1][j];
                    if (destinationNode == null) continue;
                    double transitionWeight = iter2.getWeight();
                    this.nodes[ip][i4].beta = Transducer.sumLogProb(this.nodes[ip][i4].beta, destinationNode.beta + (transitionWeight += auxModel.getWeight(index, ip, input2, iter2)));
                    double xi = this.nodes[ip][i4].alpha + transitionWeight + this.nodes[ip + 1][j].beta - this.totalWeight;
                    if (saveXis) {
                        this.xis[ip][i4][j] = xi;
                    }
                    if (incrementor == null && auxModel.numParameters() <= 0 && outputAlphabet == null) continue;
                    double p = Math.exp(xi);
                    assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", xis[" + ip + "][" + i4 + "][" + j + "]=" + xi;
                    if (incrementor != null) {
                        incrementor.incrementTransition(iter2, p);
                    }
                    if (incrementConstraints) {
                        auxModel.incrementTransition(index, ip, input2, iter2, p);
                    }
                    if (outputAlphabet == null) continue;
                    int outputIndex = outputAlphabet.lookupIndex(iter2.getOutput(), false);
                    assert (outputIndex >= 0);
                    double[] dArray = outputCounts[ip];
                    int n = outputIndex;
                    dArray[n] = dArray[n] + p;
                }
                this.gammas[ip][i4] = this.nodes[ip][i4].alpha + this.nodes[ip][i4].beta - this.totalWeight;
            }
        }
        if (incrementor != null) {
            for (i = 0; i < numStates; ++i) {
                double p = Math.exp(this.gammas[0][i]);
                assert (p >= 0.0 && p <= 1.000001) : "p=" + p;
                incrementor.incrementInitialState(this.transducer.getState(i), p);
            }
        }
        if (outputAlphabet != null) {
            this.labelings = new LabelVector[this.latticeLength];
            for (ip = this.latticeLength - 2; ip >= 0; --ip) {
                assert (Math.abs(1.0 - MatrixOps.sum(outputCounts[ip])) < 1.0E-6);
                this.labelings[ip] = new LabelVector(outputAlphabet, outputCounts[ip]);
            }
        }
    }

    protected LatticeNode getLatticeNode(int ip, int stateIndex) {
        if (this.nodes[ip][stateIndex] == null) {
            this.nodes[ip][stateIndex] = new LatticeNode(ip, this.transducer.getState(stateIndex));
        }
        return this.nodes[ip][stateIndex];
    }

    @Override
    public double[][][] getXis() {
        return this.xis;
    }

    @Override
    public double[][] getGammas() {
        return this.gammas;
    }

    @Override
    public double getTotalWeight() {
        assert (!Double.isNaN(this.totalWeight));
        return this.totalWeight;
    }

    @Override
    public double getGammaWeight(int inputPosition, Transducer.State s) {
        return this.gammas[inputPosition][s.getIndex()];
    }

    public double getGammaWeight(int inputPosition, int stateIndex) {
        return this.gammas[inputPosition][stateIndex];
    }

    @Override
    public double getGammaProbability(int inputPosition, Transducer.State s) {
        return Math.exp(this.gammas[inputPosition][s.getIndex()]);
    }

    public double getGammaProbability(int inputPosition, int stateIndex) {
        return Math.exp(this.gammas[inputPosition][stateIndex]);
    }

    @Override
    public double getXiProbability(int ip, Transducer.State s1, Transducer.State s2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int i = s1.getIndex();
        int j = s2.getIndex();
        return Math.exp(this.xis[ip][i][j]);
    }

    @Override
    public double getXiWeight(int ip, Transducer.State s1, Transducer.State s2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int i = s1.getIndex();
        int j = s2.getIndex();
        return this.xis[ip][i][j];
    }

    @Override
    public int length() {
        return this.latticeLength;
    }

    @Override
    public double getAlpha(int ip, Transducer.State s) {
        LatticeNode node = this.getLatticeNode(ip, s.getIndex());
        return node.alpha;
    }

    @Override
    public double getBeta(int ip, Transducer.State s) {
        LatticeNode node = this.getLatticeNode(ip, s.getIndex());
        return node.beta;
    }

    @Override
    public LabelVector getLabelingAtPosition(int outputPosition) {
        if (this.labelings != null) {
            return this.labelings[outputPosition];
        }
        return null;
    }

    @Override
    public Transducer getTransducer() {
        return this.transducer;
    }

    @Override
    public Sequence getInput() {
        return this.input;
    }

    protected class LatticeNode {
        int inputPosition;
        Transducer.State state;
        Object output;
        double alpha = Double.NEGATIVE_INFINITY;
        double beta = Double.NEGATIVE_INFINITY;

        LatticeNode(int inputPosition, Transducer.State state) {
            this.inputPosition = inputPosition;
            this.state = state;
            assert (this.alpha == Double.NEGATIVE_INFINITY);
        }
    }
}

