/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.Segment;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.logging.Level;
import java.util.logging.Logger;

public class SumLatticeConstrained
extends SumLatticeDefault {
    private static Logger logger = MalletLogger.getLogger(SumLatticeConstrained.class.getName());

    public SumLatticeConstrained(Transducer t, Sequence input2, Sequence output2, Segment requiredSegment, Sequence constrainedSequence) {
        this(t, input2, output2, (Transducer.Incrementor)null, null, SumLatticeConstrained.makeConstraints(t, input2, output2, requiredSegment, constrainedSequence));
    }

    private static int[] makeConstraints(Transducer t, Sequence inputSequence, Sequence outputSequence, Segment requiredSegment, Sequence constrainedSequence) {
        if (constrainedSequence.size() != inputSequence.size()) {
            throw new IllegalArgumentException("constrainedSequence.size [" + constrainedSequence.size() + "] != inputSequence.size [" + inputSequence.size() + "]");
        }
        int[] constraints = new int[constrainedSequence.size() + 1];
        for (int c = 0; c < constraints.length; ++c) {
            constraints[c] = 0;
        }
        for (int i = requiredSegment.getStart(); i <= requiredSegment.getEnd(); ++i) {
            int si = t.stateIndexOfString((String)constrainedSequence.get(i));
            if (si == -1) {
                logger.warning("Could not find state " + constrainedSequence.get(i) + ". Check that state labels match startTages and inTags, and that all labels are seen in training data.");
            }
            constraints[i + 1] = si + 1;
        }
        if (requiredSegment.getEnd() + 2 < constraints.length) {
            String endTag2 = requiredSegment.getInTag().toString();
            int statei = t.stateIndexOfString(endTag2);
            if (statei == -1) {
                throw new IllegalArgumentException("Could not find state " + endTag2 + ". Check that state labels match startTags and InTags.");
            }
            constraints[requiredSegment.getEnd() + 2] = -(statei + 1);
        }
        logger.fine("Segment:\n" + requiredSegment.sequenceToString() + "\nconstrainedSequence:\n" + constrainedSequence + "\nConstraints:\n");
        for (int i = 0; i < constraints.length; ++i) {
            logger.fine(constraints[i] + "\t");
        }
        logger.fine("");
        return constraints;
    }

    public SumLatticeConstrained(Transducer trans, Sequence input2, Sequence output2, Transducer.Incrementor incrementor, LabelAlphabet outputAlphabet, int[] constraints) {
        double transitionWeight;
        SumLatticeDefault.LatticeNode destinationNode;
        Transducer.State destination;
        Transducer.TransitionIterator iter2;
        int i;
        int ip;
        this.t = trans;
        this.input = input2;
        this.output = output2;
        this.latticeLength = input2.size() + 1;
        int numStates = this.t.numStates();
        this.nodes = new SumLatticeDefault.LatticeNode[this.latticeLength][numStates];
        this.gammas = new double[this.latticeLength][numStates];
        double[][] outputCounts = null;
        if (outputAlphabet != null) {
            outputCounts = new double[this.latticeLength][outputAlphabet.size()];
        }
        for (int i2 = 0; i2 < numStates; ++i2) {
            for (ip = 0; ip < this.latticeLength; ++ip) {
                this.gammas[ip][i2] = Double.NEGATIVE_INFINITY;
            }
        }
        logger.fine("Starting Constrained Foward pass");
        boolean atLeastOneInitialState = false;
        for (i = 0; i < numStates; ++i) {
            double initialWeight = this.t.getState(i).getInitialWeight();
            if (!(initialWeight > Double.NEGATIVE_INFINITY)) continue;
            this.getLatticeNode((int)0, (int)i).alpha = initialWeight;
            atLeastOneInitialState = true;
        }
        if (!atLeastOneInitialState) {
            logger.warning("There are no starting states!");
        }
        for (ip = 0; ip < this.latticeLength - 1; ++ip) {
            for (int i3 = 0; i3 < numStates; ++i3) {
                logger.fine("ip=" + ip + ", i=" + i3);
                if (constraints[ip] > 0) {
                    if (constraints[ip] - 1 != i3) {
                        logger.fine("Current state does not match positive constraint. position=" + ip + ", constraint=" + (constraints[ip] - 1) + ", currState=" + i3);
                        continue;
                    }
                } else if (constraints[ip] < 0 && constraints[ip] + 1 == -i3) {
                    logger.fine("Current state does not match negative constraint. position=" + ip + ", constraint=" + (constraints[ip] + 1) + ", currState=" + i3);
                    continue;
                }
                if (this.nodes[ip][i3] == null || this.nodes[ip][i3].alpha == Double.NEGATIVE_INFINITY) {
                    if (this.nodes[ip][i3] == null) {
                        logger.fine("nodes[ip][i] is NULL");
                    } else if (this.nodes[ip][i3].alpha == Double.NEGATIVE_INFINITY) {
                        logger.fine("nodes[ip][i].alpha is -Inf");
                    }
                    logger.fine("-INFINITE weight or NULL...skipping");
                    continue;
                }
                Transducer.State s = this.t.getState(i3);
                iter2 = s.transitionIterator(input2, ip, output2, ip);
                if (logger.isLoggable(Level.FINE)) {
                    logger.fine(" Starting Forward transition iteration from state " + s.getName() + " on input " + input2.get(ip).toString() + " and output " + (output2 == null ? "(null)" : output2.get(ip).toString()));
                }
                while (iter2.hasNext()) {
                    destination = iter2.nextState();
                    boolean legalTransition = true;
                    if (ip + 1 < constraints.length && constraints[ip + 1] > 0 && constraints[ip + 1] - 1 != destination.getIndex()) {
                        logger.fine("Destination state does not match positive constraint. Assigning -infinite weight. position=" + (ip + 1) + ", constraint=" + (constraints[ip + 1] - 1) + ", source =" + i3 + ", destination=" + destination.getIndex());
                        legalTransition = false;
                    } else if (ip + 1 < constraints.length && constraints[ip + 1] < 0 && -(constraints[ip + 1] + 1) == destination.getIndex()) {
                        logger.fine("Destination state does not match negative constraint. Assigning -infinite weight. position=" + (ip + 1) + ", constraint=" + (constraints[ip + 1] + 1) + ", destination=" + destination.getIndex());
                        legalTransition = false;
                    }
                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine("Forward Lattice[inputPos=" + ip + "][source=" + s.getName() + "][dest=" + destination.getName() + "]");
                    }
                    destinationNode = this.getLatticeNode(ip + 1, destination.getIndex());
                    destinationNode.output = iter2.getOutput();
                    transitionWeight = iter2.getWeight();
                    if (legalTransition) {
                        logger.fine("transitionWeight=" + transitionWeight + " nodes[" + ip + "][" + i3 + "].alpha=" + this.nodes[ip][i3].alpha + " destinationNode.alpha=" + destinationNode.alpha);
                        destinationNode.alpha = Transducer.sumLogProb(destinationNode.alpha, this.nodes[ip][i3].alpha + transitionWeight);
                        logger.fine("Set alpha of latticeNode at ip = " + (ip + 1) + " stateIndex = " + destination.getIndex() + ", destinationNode.alpha = " + destinationNode.alpha);
                        continue;
                    }
                    logger.fine("Illegal transition from state " + i3 + " to state " + destination.getIndex() + ". Setting alpha to -Inf");
                }
            }
        }
        this.totalWeight = Double.NEGATIVE_INFINITY;
        for (i = 0; i < numStates; ++i) {
            if (this.nodes[this.latticeLength - 1][i] == null || constraints[this.latticeLength - 1] > 0 && i != constraints[this.latticeLength - 1] - 1 || constraints[this.latticeLength - 1] < 0 && -i == constraints[this.latticeLength - 1] + 1) continue;
            logger.fine("Summing final lattice weight. state=" + i + ", alpha=" + this.nodes[this.latticeLength - 1][i].alpha + ", final weight = " + this.t.getState(i).getFinalWeight());
            this.totalWeight = Transducer.sumLogProb(this.totalWeight, this.nodes[this.latticeLength - 1][i].alpha + this.t.getState(i).getFinalWeight());
        }
        if (this.totalWeight == Double.NEGATIVE_INFINITY) {
            return;
        }
        for (i = 0; i < numStates; ++i) {
            if (this.nodes[this.latticeLength - 1][i] == null) continue;
            Transducer.State s = this.t.getState(i);
            this.nodes[this.latticeLength - 1][i].beta = s.getFinalWeight();
            this.gammas[this.latticeLength - 1][i] = this.nodes[this.latticeLength - 1][i].alpha + this.nodes[this.latticeLength - 1][i].beta - this.totalWeight;
            if (incrementor == null) continue;
            double p = Math.exp(this.gammas[this.latticeLength - 1][i]);
            assert (p >= 0.0 && p <= 1.0 && !Double.isNaN(p)) : "p=" + p + " gamma=" + this.gammas[this.latticeLength - 1][i];
            incrementor.incrementFinalState(s, p);
        }
        for (ip = this.latticeLength - 2; ip >= 0; --ip) {
            for (int i4 = 0; i4 < numStates; ++i4) {
                if (this.nodes[ip][i4] == null || this.nodes[ip][i4].alpha == Double.NEGATIVE_INFINITY) continue;
                Transducer.State s = this.t.getState(i4);
                iter2 = s.transitionIterator(input2, ip, output2, ip);
                while (iter2.hasNext()) {
                    int j;
                    destination = iter2.nextState();
                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine("Backward Lattice[inputPos=" + ip + "][source=" + s.getName() + "][dest=" + destination.getName() + "]");
                    }
                    if ((destinationNode = this.nodes[ip + 1][j = destination.getIndex()]) == null) continue;
                    transitionWeight = iter2.getWeight();
                    assert (!Double.isNaN(transitionWeight));
                    double oldBeta = this.nodes[ip][i4].beta;
                    assert (!Double.isNaN(this.nodes[ip][i4].beta));
                    this.nodes[ip][i4].beta = Transducer.sumLogProb(this.nodes[ip][i4].beta, destinationNode.beta + transitionWeight);
                    assert (!Double.isNaN(this.nodes[ip][i4].beta)) : "dest.beta=" + destinationNode.beta + " trans=" + transitionWeight + " sum=" + (destinationNode.beta + transitionWeight) + " oldBeta=" + oldBeta;
                    assert (!Double.isNaN(this.nodes[ip][i4].alpha));
                    assert (!Double.isNaN(transitionWeight));
                    assert (!Double.isNaN(this.nodes[ip + 1][j].beta));
                    assert (!Double.isNaN(this.totalWeight));
                    if (incrementor == null && outputAlphabet == null) continue;
                    double xi = this.nodes[ip][i4].alpha + transitionWeight + this.nodes[ip + 1][j].beta - this.totalWeight;
                    double p = Math.exp(xi);
                    assert (p > Double.NEGATIVE_INFINITY && !Double.isNaN(p)) : "xis[" + ip + "][" + i4 + "][" + j + "]=" + xi;
                    if (incrementor != null) {
                        incrementor.incrementTransition(iter2, p);
                    }
                    if (outputAlphabet == null) continue;
                    int outputIndex = outputAlphabet.lookupIndex(iter2.getOutput(), false);
                    assert (outputIndex >= 0);
                    double[] dArray = outputCounts[ip];
                    int n = outputIndex;
                    dArray[n] = dArray[n] + p;
                }
                this.gammas[ip][i4] = this.nodes[ip][i4].alpha + this.nodes[ip][i4].beta - this.totalWeight;
            }
        }
        if (incrementor != null) {
            for (i = 0; i < numStates; ++i) {
                double p = Math.exp(this.gammas[0][i]);
                assert (p > Double.NEGATIVE_INFINITY && !Double.isNaN(p));
                incrementor.incrementInitialState(this.t.getState(i), p);
            }
        }
        if (outputAlphabet != null) {
            this.labelings = new LabelVector[this.latticeLength];
            for (ip = this.latticeLength - 2; ip >= 0; --ip) {
                assert (Math.abs(1.0 - MatrixOps.sum(outputCounts[ip])) < 1.0E-6);
                this.labelings[ip] = new LabelVector(outputAlphabet, outputCounts[ip]);
            }
        }
    }
}

