/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.ml.seqmodel.crf.LBFGSBasicTrainer;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

public class LBFGSSegmentTrainer
extends LBFGSBasicTrainer {
    private int maxSegmentLength;

    public LBFGSSegmentTrainer(ModelGraph model, FeatureGenerator featureGenerator, int maxSegmentLength) {
        super(model, featureGenerator);
        this.maxSegmentLength = maxSegmentLength;
    }

    protected double computeFunctionGradient(Dataset diter, double[] lambda, double[] grad) {
        try {
            if (this.doScaling) {
                return this.computeFunctionGradientLL(diter, lambda, grad);
            }
            int stateNum = this.model.getStateNum();
            double logli = 0.0;
            double[][] alpha_Y = null;
            double[][] beta_Y = null;
            double[] expF = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix Mi_YY = new DoubleFlatDenseMatrix(stateNum, stateNum);
            int f2 = 0;
            while (f2 < lambda.length) {
                grad[f2] = -1.0 * lambda[f2] * this.invSigmaSquare;
                logli -= lambda[f2] * lambda[f2] * this.invSigmaSquare / 2.0;
                ++f2;
            }
            diter.startScan();
            while (diter.hasNext()) {
                int ell;
                int i;
                DataSequence dataSeq = diter.next();
                f2 = 0;
                while (f2 < lambda.length) {
                    expF[f2] = 0.0;
                    ++f2;
                }
                int base = -1;
                if (alpha_Y == null || alpha_Y.length < dataSeq.length() - base) {
                    alpha_Y = new double[2 * dataSeq.length()][];
                    i = 0;
                    while (i < alpha_Y.length) {
                        alpha_Y[i] = new double[stateNum];
                        ++i;
                    }
                }
                if (beta_Y == null || beta_Y.length < dataSeq.length()) {
                    beta_Y = new double[2 * dataSeq.length()][];
                    i = 0;
                    while (i < beta_Y.length) {
                        beta_Y[i] = new double[stateNum];
                        ++i;
                    }
                }
                MathUtil.initArray(beta_Y[dataSeq.length() - 1], 1.0);
                i = dataSeq.length() - 2;
                while (i >= 0) {
                    MathUtil.initArray(beta_Y[i], 0.0);
                    ell = 1;
                    while (ell <= this.maxSegmentLength && i + ell < dataSeq.length()) {
                        this.computeTransMatrix(lambda, dataSeq, i + 1, i + ell, Mi_YY, true);
                        this.genStateVector(Mi_YY, beta_Y[i + ell], beta_Y[i], false);
                        ++ell;
                    }
                    --i;
                }
                double thisSeqLogli = 0.0;
                MathUtil.initArray(alpha_Y[0], 1.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                i = 0;
                while (i < dataSeq.length()) {
                    if (segmentEnd < i) {
                        segmentStart = i;
                        segmentEnd = dataSeq.getSegmentEnd(i);
                    }
                    if (segmentEnd - segmentStart + 1 > this.maxSegmentLength) {
                        invalid = true;
                        break;
                    }
                    MathUtil.initArray(alpha_Y[i - base], 0.0);
                    ell = 1;
                    while (ell <= this.maxSegmentLength && i - ell >= base) {
                        this.computeTransMatrix(lambda, dataSeq, i - ell + 1, i, Mi_YY, true);
                        this.featureGenerator.startScanFeaturesAt(dataSeq, i - ell + 1, i);
                        boolean isSegment = i - ell + 1 == segmentStart && i == segmentEnd;
                        while (this.featureGenerator.hasNext()) {
                            Feature feature = this.featureGenerator.next();
                            f2 = feature.getIndex();
                            int yp = feature.getLabel();
                            int yprev = feature.getPrevLabel();
                            double val = feature.getValue();
                            if (isSegment && dataSeq.getLabel(i) == yp && (i - ell >= 0 && yprev == dataSeq.getLabel(i - ell) || yprev < 0)) {
                                int n = f2;
                                grad[n] = grad[n] + val;
                                thisSeqLogli += val * lambda[f2];
                            }
                            if (yprev < 0) {
                                yprev = 0;
                                while (yprev < Mi_YY.rows()) {
                                    int n = f2;
                                    expF[n] = expF[n] + val * alpha_Y[i - ell - base][yprev] * Mi_YY.getDouble(yprev, yp) * beta_Y[i][yp];
                                    ++yprev;
                                }
                                continue;
                            }
                            int n = f2;
                            expF[n] = expF[n] + val * alpha_Y[i - ell - base][yprev] * Mi_YY.getDouble(yprev, yp) * beta_Y[i][yp];
                        }
                        this.genStateVector(Mi_YY, alpha_Y[i - ell - base], alpha_Y[i - base], true);
                        ++ell;
                    }
                    ++i;
                }
                if (invalid) continue;
                double Zx = MathUtil.sumArray(alpha_Y[dataSeq.length() - 1 - base]);
                logli += (thisSeqLogli -= Math.log(Zx));
                f2 = 0;
                while (f2 < grad.length) {
                    int n = f2;
                    grad[n] = grad[n] - expF[f2] / Zx;
                    ++f2;
                }
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }

    protected double computeFunctionGradientLL(Dataset diter, double[] lambda, double[] grad) {
        try {
            double logli = 0.0;
            int stateNum = this.model.getStateNum();
            double[][] alpha_Y = null;
            double[][] beta_Y = null;
            double[] expF = new double[this.featureGenerator.getFeatureNum()];
            DoubleFlatDenseMatrix Mi_YY = new DoubleFlatDenseMatrix(stateNum, stateNum);
            int f2 = 0;
            while (f2 < lambda.length) {
                grad[f2] = -1.0 * lambda[f2] * this.invSigmaSquare;
                logli -= lambda[f2] * lambda[f2] * this.invSigmaSquare / 2.0;
                ++f2;
            }
            diter.startScan();
            while (diter.hasNext()) {
                int ell;
                int i;
                DataSequence dataSeq = diter.next();
                f2 = 0;
                while (f2 < lambda.length) {
                    expF[f2] = MathUtil.LOG0;
                    ++f2;
                }
                int base = -1;
                if (alpha_Y == null || alpha_Y.length < dataSeq.length() - base) {
                    alpha_Y = new double[2 * dataSeq.length()][];
                    i = 0;
                    while (i < alpha_Y.length) {
                        alpha_Y[i] = new double[stateNum];
                        ++i;
                    }
                }
                if (beta_Y == null || beta_Y.length < dataSeq.length()) {
                    beta_Y = new double[2 * dataSeq.length()][];
                    i = 0;
                    while (i < beta_Y.length) {
                        beta_Y[i] = new double[stateNum];
                        ++i;
                    }
                }
                MathUtil.initArray(beta_Y[dataSeq.length() - 1], 0.0);
                i = dataSeq.length() - 2;
                while (i >= 0) {
                    MathUtil.initArray(beta_Y[i], MathUtil.LOG0);
                    ell = 1;
                    while (ell <= this.maxSegmentLength && i + ell < dataSeq.length()) {
                        this.computeTransMatrix(lambda, dataSeq, i + 1, i + ell, Mi_YY, false);
                        this.genStateVectorLog(Mi_YY, beta_Y[i + ell], beta_Y[i], false);
                        ++ell;
                    }
                    --i;
                }
                double thisSeqLogli = 0.0;
                MathUtil.initArray(alpha_Y[0], 0.0);
                int segmentStart = 0;
                int segmentEnd = -1;
                boolean invalid = false;
                i = 0;
                while (i < dataSeq.length()) {
                    if (segmentEnd < i) {
                        segmentStart = i;
                        segmentEnd = dataSeq.getSegmentEnd(i);
                    }
                    if (segmentEnd - segmentStart + 1 > this.maxSegmentLength) {
                        invalid = true;
                        break;
                    }
                    MathUtil.initArray(alpha_Y[i - base], MathUtil.LOG0);
                    ell = 1;
                    while (ell <= this.maxSegmentLength && i - ell >= base) {
                        this.computeTransMatrix(lambda, dataSeq, i - ell + 1, i, Mi_YY, false);
                        this.featureGenerator.startScanFeaturesAt(dataSeq, i - ell, i);
                        boolean isSegment = i - ell + 1 == segmentStart && i == segmentEnd;
                        while (this.featureGenerator.hasNext()) {
                            Feature feature = this.featureGenerator.next();
                            f2 = feature.getIndex();
                            int yp = feature.getLabel();
                            int yprev = feature.getPrevLabel();
                            double val = feature.getValue();
                            if (isSegment && dataSeq.getLabel(i) == yp && (i - ell >= 0 && yprev == dataSeq.getLabel(i - ell) || yprev < 0)) {
                                int n = f2;
                                grad[n] = grad[n] + val;
                                thisSeqLogli += val * lambda[f2];
                            }
                            if (yprev < 0) {
                                yprev = 0;
                                while (yprev < Mi_YY.rows()) {
                                    expF[f2] = MathUtil.logSumExp(expF[f2], alpha_Y[i - ell - base][yprev] + Mi_YY.getDouble(yprev, yp) + MathUtil.log(val) + beta_Y[i][yp]);
                                    ++yprev;
                                }
                                continue;
                            }
                            expF[f2] = MathUtil.logSumExp(expF[f2], alpha_Y[i - ell - base][yprev] + Mi_YY.getDouble(yprev, yp) + MathUtil.log(val) + beta_Y[i][yp]);
                        }
                        this.genStateVectorLog(Mi_YY, alpha_Y[i - ell - base], alpha_Y[i - base], true);
                        ++ell;
                    }
                    ++i;
                }
                if (invalid) continue;
                double lZx = MathUtil.logSumExp(alpha_Y[dataSeq.length() - 1 - base]);
                logli += (thisSeqLogli -= lZx);
                f2 = 0;
                while (f2 < grad.length) {
                    int n = f2;
                    grad[n] = grad[n] - MathUtil.exp(expF[f2] - lZx);
                    ++f2;
                }
            }
            return logli;
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0;
        }
    }
}

