/*
 * Decompiled with CFR 0.152.
 */
package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleDenseMatrix;
import dragon.ml.seqmodel.crf.AbstractCRF;
import dragon.ml.seqmodel.crf.Trainer;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

public abstract class AbstractTrainer
extends AbstractCRF
implements Trainer {
    protected static double xtol = 1.0E-16;
    protected boolean doScaling = true;
    protected int maxIteration = 100;

    public AbstractTrainer(ModelGraph model, FeatureGenerator featureGen) {
        super(model, featureGen);
    }

    public boolean needScaling() {
        return this.doScaling;
    }

    public void setScalingOption(boolean option) {
        this.doScaling = option;
    }

    public int getMaxIteration() {
        return this.maxIteration;
    }

    public void setMaxIteration(int maxIteration) {
        this.maxIteration = maxIteration;
    }

    protected void genStateVector(DoubleDenseMatrix transMatrix, double[] oldStateVector, double[] newStateVector, boolean transpose2) {
        int j = 0;
        while (j < transMatrix.columns()) {
            int i = this.edgeGen.first(j);
            while (i < transMatrix.rows()) {
                int r = i;
                int c = j;
                if (transpose2) {
                    r = j;
                    c = i;
                }
                int n = r;
                newStateVector[n] = newStateVector[n] + transMatrix.getDouble(i, j) * oldStateVector[c];
                i = this.edgeGen.next(j, i);
            }
            ++j;
        }
    }

    protected void genStateVectorLog(DoubleDenseMatrix transMatrix, double[] oldStateVector, double[] newStateVector, boolean transpose2) {
        int j = 0;
        while (j < transMatrix.columns()) {
            int i = this.edgeGen.first(j);
            while (i < transMatrix.rows()) {
                int r = i;
                int c = j;
                if (transpose2) {
                    r = j;
                    c = i;
                }
                newStateVector[r] = MathUtil.logSumExp(newStateVector[r], transMatrix.getDouble(i, j) + oldStateVector[c]);
                i = this.edgeGen.next(j, i);
            }
            ++j;
        }
    }
}

