/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFOptimizableByGradientValues;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.CRFOptimizableByGE;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import java.util.ArrayList;

public class CRFTrainerByLikelihoodAndGE
extends TransducerTrainer {
    private boolean initSupervised;
    private boolean converged;
    private double geWeight;
    private double gpv;
    private int supIterations;
    private int numThreads;
    private int iteration;
    private CRF crf;
    private ArrayList<GEConstraint> constraints;
    private StateLabelMap map;

    public CRFTrainerByLikelihoodAndGE(CRF crf, ArrayList<GEConstraint> constraints, StateLabelMap map2) {
        this.crf = crf;
        this.constraints = constraints;
        this.map = map2;
        this.iteration = 0;
        this.converged = false;
        this.geWeight = 1.0;
        this.initSupervised = false;
        this.gpv = 10.0;
        this.numThreads = 1;
        this.supIterations = Integer.MAX_VALUE;
    }

    public void setGEWeight(double weight) {
        this.geWeight = weight;
    }

    public void setGaussianPriorVariance(double gpv) {
        this.gpv = gpv;
    }

    public void setInitSupervised(boolean flag) {
        this.initSupervised = flag;
    }

    public void setSupervisedIterations(int iterations) {
        this.supIterations = iterations;
    }

    public void setNumThreads(int numThreads) {
        this.numThreads = numThreads;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public boolean train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) {
        Optimizable.ByGradientValue optLikelihood;
        System.err.println(trainingSet.size());
        System.err.println(unlabeledSet.size());
        if (this.initSupervised) {
            TransducerTrainer trainer;
            if (this.numThreads == 1) {
                trainer = new CRFTrainerByLabelLikelihood(this.crf);
                ((CRFTrainerByLabelLikelihood)trainer).setAddNoFactors(true);
                ((CRFTrainerByLabelLikelihood)trainer).setGaussianPriorVariance(this.gpv);
                ((CRFTrainerByLabelLikelihood)trainer).train(trainingSet, this.supIterations);
            } else {
                trainer = new CRFTrainerByThreadedLabelLikelihood(this.crf, this.numThreads);
                ((CRFTrainerByThreadedLabelLikelihood)trainer).setAddNoFactors(true);
                ((CRFTrainerByThreadedLabelLikelihood)trainer).setGaussianPriorVariance(this.gpv);
                ((CRFTrainerByThreadedLabelLikelihood)trainer).train(trainingSet, this.supIterations);
                ((CRFTrainerByThreadedLabelLikelihood)trainer).shutdown();
            }
            this.runEvaluators();
        }
        if (this.numThreads == 1) {
            optLikelihood = new CRFOptimizableByLabelLikelihood(this.crf, trainingSet);
            optLikelihood.setGaussianPriorVariance(this.gpv);
        } else {
            CRFOptimizableByBatchLabelLikelihood likelihood = new CRFOptimizableByBatchLabelLikelihood(this.crf, trainingSet, this.numThreads);
            optLikelihood = new ThreadedOptimizable(likelihood, trainingSet, this.crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(this.crf));
            likelihood.setGaussianPriorVariance(this.gpv);
        }
        CRFOptimizableByGE ge = new CRFOptimizableByGE(this.crf, this.constraints, unlabeledSet, this.map, this.numThreads, this.geWeight);
        ge.setGaussianPriorVariance(Double.POSITIVE_INFINITY);
        CRFOptimizableByGradientValues opt2 = new CRFOptimizableByGradientValues(this.crf, new Optimizable.ByGradientValue[]{optLikelihood, ge});
        LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(opt2);
        try {
            this.converged = optimizer.optimize(numIterations);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        optimizer.reset();
        try {
            this.converged = optimizer.optimize(numIterations);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        if (this.numThreads > 1) {
            ((ThreadedOptimizable)optLikelihood).shutdown();
            ge.shutdown();
        }
        return this.converged;
    }

    @Override
    public boolean train(InstanceList trainingSet, int numIterations) {
        throw new RuntimeException("Must use train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) instead");
    }
}

