/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.ml.seqmodel.crf.AbstractTrainer;
import dragon.ml.seqmodel.crf.Labeler;
import dragon.ml.seqmodel.crf.ViterbiBasicLabeler;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

public class CollinsBasicTrainer
extends AbstractTrainer {
    protected int topSolutions;
    protected double beta;
    protected boolean useUpdated;

    public CollinsBasicTrainer(ModelGraph model, FeatureGenerator featureGenerator) {
        super(model, featureGenerator);
        this.topSolutions = Math.min(3, model.getStateNum());
        this.beta = 0.05;
        this.useUpdated = false;
    }

    /*
     * Unable to fully structure code
     */
    public boolean train(Dataset dataset) {
        dataset.startScan();
        while (dataset.hasNext()) {
            this.model.mapLabelToState(dataset.next());
        }
        if (!this.featureGenerator.train(dataset)) {
            return false;
        }
        featureNum = this.featureGenerator.getFeatureNum();
        this.lambda = new double[featureNum];
        lambdaAvg = new double[featureNum];
        lambdaSum = new double[featureNum];
        MathUtil.initArray(this.lambda, 0.0);
        MathUtil.initArray(lambdaAvg, 0.0);
        MathUtil.initArray(this.lambda, 0.0);
        labeler = this.getLabeler();
        solutions = new DataSequence[this.topSolutions];
        autoStartPos = new int[this.topSolutions];
        trainingCount = 0;
        t = 0;
        while (t < this.maxIteration) {
            numErrs = 0;
            dataset.startScan();
            while (dataset.hasNext()) {
                block14: {
                    if (trainingCount > 0) {
                        MathUtil.copyArray(lambdaSum, lambdaAvg);
                        MathUtil.multiArray(lambdaAvg, 1.0 / (double)trainingCount);
                    }
                    MathUtil.initArray(autoStartPos, 0);
                    manualSeq = dataset.next();
                    autoSeq = manualSeq.copy();
                    labeler.label(autoSeq, this.useUpdated != false ? lambdaAvg : this.lambda);
                    correctScore = this.getSequenceScore(manualSeq, this.useUpdated != false ? lambdaAvg : this.lambda);
                    solutionNum = 0;
                    k = 0;
                    while (k < this.topSolutions) {
                        autoSeq = manualSeq.copy();
                        curScore = labeler.getBestSolution(autoSeq, k);
                        if (curScore < correctScore * (1.0 - this.beta)) break;
                        this.model.mapLabelToState(autoSeq);
                        if (!this.isCorrect(manualSeq, autoSeq)) {
                            solutions[solutionNum] = autoSeq;
                            ++solutionNum;
                        }
                        ++k;
                    }
                    if (solutionNum <= 0) break block14;
                    startPos = this.model.getMarkovOrder() - 1;
                    while (startPos < manualSeq.length()) {
                        block15: {
                            endPos = this.getSegmentEnd(manualSeq, startPos);
                            different = false;
                            s = 0;
                            while (s < solutionNum) {
                                if (autoStartPos[s] != startPos || this.getSegmentEnd(solutions[s], autoStartPos[s]) != endPos || manualSeq.getLabel(endPos) != solutions[s].getLabel(endPos)) {
                                    different = true;
                                    break;
                                }
                                ++s;
                            }
                            if (!different) break block15;
                            ++numErrs;
                            this.updateWeights(manualSeq, startPos, endPos, 1.0, this.lambda);
                            s = 0;
                            ** GOTO lbl69
                            {
                                autoEndPos = this.getSegmentEnd(solutions[s], autoStartPos[s]);
                                this.updateWeights(solutions[s], autoStartPos[s], autoEndPos, -1.0 / (double)solutionNum, this.lambda);
                                autoStartPos[s] = autoEndPos + 1;
                                do {
                                    if (autoStartPos[s] <= endPos) continue block6;
                                    ++s;
lbl69:
                                    // 2 sources

                                } while (s < solutionNum);
                            }
                        }
                        s = 0;
                        ** GOTO lbl78
                        {
                            autoEndPos = this.getSegmentEnd(solutions[s], autoStartPos[s]);
                            autoStartPos[s] = autoEndPos + 1;
                            do {
                                if (autoStartPos[s] <= endPos) continue block8;
                                ++s;
lbl78:
                                // 2 sources

                            } while (s < solutionNum);
                        }
                        startPos = endPos + 1;
                    }
                }
                MathUtil.sumArray(lambdaSum, this.lambda);
                ++trainingCount;
            }
            System.out.println("Iteration " + t + " numErrs " + numErrs);
            if (numErrs == 0) break;
            ++t;
        }
        MathUtil.multiArray(lambdaSum, 1.0 / (double)trainingCount);
        MathUtil.copyArray(lambdaSum, this.lambda);
        return true;
    }

    protected boolean isCorrect(DataSequence manual, DataSequence auto) {
        int i = 0;
        while (i < manual.length()) {
            if (manual.getLabel(i) != auto.getLabel(i)) {
                return false;
            }
            ++i;
        }
        return true;
    }

    protected void updateWeights(DataSequence dataSeq, int startPos, int endPos, double wt, double[] grad) {
        this.featureGenerator.startScanFeaturesAt(dataSeq, startPos, endPos);
        while (this.featureGenerator.hasNext()) {
            Feature feature = this.featureGenerator.next();
            int f2 = feature.getIndex();
            int yp = feature.getLabel();
            int yprev = feature.getPrevLabel();
            if (dataSeq.getLabel(endPos) != yp || yprev >= 0 && yprev != dataSeq.getLabel(startPos - 1)) continue;
            int n = f2;
            grad[n] = grad[n] + wt * feature.getValue();
        }
    }

    protected double getSequenceScore(DataSequence dataSeq, double[] grad) {
        int startPos = this.model.getMarkovOrder() - 1;
        double score = 0.0;
        while (startPos < dataSeq.length()) {
            int endPos = this.getSegmentEnd(dataSeq, startPos);
            this.featureGenerator.startScanFeaturesAt(dataSeq, startPos, endPos);
            while (this.featureGenerator.hasNext()) {
                Feature feature = this.featureGenerator.next();
                int f2 = feature.getIndex();
                int yp = feature.getLabel();
                int yprev = feature.getPrevLabel();
                if (dataSeq.getLabel(endPos) != yp || yprev >= 0 && yprev != dataSeq.getLabel(startPos - 1)) continue;
                score += grad[f2] * feature.getValue();
            }
            startPos = endPos + 1;
        }
        return score;
    }

    protected Labeler getLabeler() {
        return new ViterbiBasicLabeler(this.model, this.featureGenerator);
    }

    protected int getSegmentEnd(DataSequence dataSeq, int start) {
        return start;
    }
}

