/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.Transducer;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.util.Maths;

public class EntropyLattice {
    protected int latticeLength;
    protected int inputLength;
    protected Transducer transducer;
    protected int numStates;
    protected LatticeNode[][] nodes;
    protected double entropy;

    public EntropyLattice(FeatureVectorSequence fvs, double[][] gammas, double[][][] xis, Transducer transducer, Transducer.Incrementor incrementor, double scalingFactor) {
        this.inputLength = fvs.size();
        this.latticeLength = this.inputLength + 1;
        this.transducer = transducer;
        this.numStates = transducer.numStates();
        this.nodes = new LatticeNode[this.latticeLength][this.numStates];
        this.entropy = this.forwardLattice(gammas, xis);
        double backwardEntropy = this.backwardLattice(gammas, xis);
        assert (Maths.almostEquals(this.entropy, backwardEntropy)) : this.entropy + " " + backwardEntropy;
        if (incrementor != null) {
            this.updateCounts(fvs, gammas, xis, scalingFactor, incrementor);
        }
    }

    public double getEntropy() {
        return this.entropy;
    }

    public double forwardLattice(double[][] gammas, double[][][] xis) {
        double gamma;
        for (int a = 0; a < this.numStates; ++a) {
            this.getLatticeNode((int)0, (int)a).alpha = 0.0;
        }
        for (int ip = 1; ip < this.latticeLength; ++ip) {
            for (int a = 0; a < this.numStates; ++a) {
                LatticeNode node = this.getLatticeNode(ip, a);
                gamma = gammas[ip][a];
                if (!(gamma > Double.NEGATIVE_INFINITY)) continue;
                for (int b = 0; b < this.numStates; ++b) {
                    double xi = xis[ip - 1][b][a];
                    if (!(xi > Double.NEGATIVE_INFINITY)) continue;
                    double condProb = Math.exp(xi) / Math.exp(gamma);
                    node.alpha += condProb * (xi - gamma + this.getLatticeNode((int)(ip - 1), (int)b).alpha);
                }
            }
        }
        double entropy = 0.0;
        for (int a = 0; a < this.numStates; ++a) {
            gamma = gammas[this.inputLength][a];
            double gammaProb = Math.exp(gamma);
            if (!(gamma > Double.NEGATIVE_INFINITY)) continue;
            entropy += gammaProb * gamma;
            entropy += gammaProb * this.getLatticeNode((int)this.inputLength, (int)a).alpha;
        }
        return entropy;
    }

    public double backwardLattice(double[][] gammas, double[][][] xis) {
        double gamma;
        for (int a = 0; a < this.numStates; ++a) {
            this.getLatticeNode((int)this.inputLength, (int)a).beta = 0.0;
        }
        for (int ip = this.inputLength; ip >= 0; --ip) {
            for (int a = 0; a < this.numStates; ++a) {
                LatticeNode node = this.getLatticeNode(ip, a);
                gamma = gammas[ip][a];
                if (!(gamma > Double.NEGATIVE_INFINITY)) continue;
                for (int b = 0; b < this.numStates; ++b) {
                    double xi = xis[ip][a][b];
                    if (!(xi > Double.NEGATIVE_INFINITY)) continue;
                    double condProb = Math.exp(xi) / Math.exp(gamma);
                    node.beta += condProb * (xi - gamma + this.getLatticeNode((int)(ip + 1), (int)b).beta);
                }
            }
        }
        double entropy = 0.0;
        for (int a = 0; a < this.numStates; ++a) {
            gamma = gammas[0][a];
            double gammaProb = Math.exp(gamma);
            if (!(gamma > Double.NEGATIVE_INFINITY)) continue;
            entropy += gammaProb * gamma;
            entropy += gammaProb * this.getLatticeNode((int)0, (int)a).beta;
        }
        return entropy;
    }

    private void updateCounts(FeatureVectorSequence fvs, double[][] gammas, double[][][] xis, double scalingFactor, Transducer.Incrementor incrementor) {
        for (int ip = 0; ip < this.inputLength; ++ip) {
            for (int a = 0; a < this.numStates; ++a) {
                if (this.nodes[ip][a] == null) continue;
                Transducer.State sourceState = this.transducer.getState(a);
                Transducer.TransitionIterator iter2 = sourceState.transitionIterator(fvs, ip, null, ip);
                while (iter2.hasNext()) {
                    int b = iter2.next().getIndex();
                    double xi = xis[ip][a][b];
                    if (xi == Double.NEGATIVE_INFINITY) continue;
                    double xiProb = Math.exp(xi);
                    double constrEntropy = xiProb * (xi + this.nodes[ip][a].alpha + this.nodes[ip + 1][b].beta);
                    assert (constrEntropy <= 0.0) : "Negative entropy should be negative! " + constrEntropy;
                    double covContribution = constrEntropy - xiProb * this.entropy;
                    assert (!Double.isNaN(covContribution)) : "xi: " + xi + ", nodes[" + ip + "][" + a + "].alpha: " + this.nodes[ip][a].alpha + ", nodes[" + (ip + 1) + "][" + b + "].beta: " + this.nodes[ip + 1][b].beta;
                    incrementor.incrementTransition(iter2, covContribution * scalingFactor);
                }
            }
        }
    }

    public LatticeNode getLatticeNode(int ip, int si) {
        if (this.nodes[ip][si] == null) {
            this.nodes[ip][si] = new LatticeNode(ip, this.transducer.getState(si));
        }
        return this.nodes[ip][si];
    }

    public class LatticeNode {
        public int ip;
        public Transducer.State state;
        public double alpha;
        public double beta;

        LatticeNode(int ip, Transducer.State state) {
            this.ip = ip;
            this.state = state;
            this.alpha = 0.0;
            this.beta = 0.0;
        }
    }
}

