/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.pr.PRAuxiliaryModel;
import cc.mallet.fst.semi_supervised.pr.SumLatticePR;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

public class ConstraintsOptimizableByPR
implements Serializable,
Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
    private static final long serialVersionUID = 1L;
    protected boolean cacheStale;
    protected int numParameters;
    protected int numThreads;
    protected InstanceList trainingSet;
    protected double cachedValue = -1.23456789E8;
    protected double[] cachedGradient;
    protected CRF crf;
    protected ThreadPoolExecutor executor;
    protected double[][][][] cachedDots;
    PRAuxiliaryModel model;

    public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model) {
        this(crf, ilist, model, 1);
    }

    public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model, int numThreads) {
        this.crf = crf;
        this.trainingSet = ilist;
        this.model = model;
        this.numParameters = model.numParameters();
        this.cachedGradient = new double[this.numParameters];
        this.cacheStale = true;
        this.numThreads = numThreads;
        this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
        this.cacheDotProducts();
    }

    public void cacheDotProducts() {
        this.cachedDots = new double[this.trainingSet.size()][][][];
        for (int i = 0; i < this.trainingSet.size(); ++i) {
            int k;
            int j;
            FeatureVectorSequence input2 = (FeatureVectorSequence)((Instance)this.trainingSet.get(i)).getData();
            this.cachedDots[i] = new double[input2.size()][this.crf.numStates()][this.crf.numStates()];
            for (j = 0; j < input2.size(); ++j) {
                for (k = 0; k < this.crf.numStates(); ++k) {
                    for (int l = 0; l < this.crf.numStates(); ++l) {
                        this.cachedDots[i][j][k][l] = Double.NEGATIVE_INFINITY;
                    }
                }
            }
            for (j = 0; j < input2.size(); ++j) {
                for (k = 0; k < this.crf.numStates(); ++k) {
                    Transducer.TransitionIterator iter2 = this.crf.getState(k).transitionIterator(input2, j);
                    while (iter2.hasNext()) {
                        int l = iter2.next().getIndex();
                        this.cachedDots[i][j][k][l] = iter2.getWeight();
                    }
                }
            }
        }
    }

    @Override
    public int getNumParameters() {
        return this.numParameters;
    }

    @Override
    public void getParameters(double[] params) {
        this.model.getParameters(params);
    }

    @Override
    public double getParameter(int index) {
        return this.model.getParameter(index);
    }

    @Override
    public void setParameters(double[] params) {
        this.cacheStale = true;
        this.model.setParameters(params);
    }

    @Override
    public void setParameter(int index, double value2) {
        this.cacheStale = true;
        this.model.setParameter(index, value2);
    }

    protected double getExpectationValue() {
        this.model.zeroExpectations();
        ArrayList<Callable<Double>> tasks = new ArrayList<Callable<Double>>();
        int increment = this.trainingSet.size() / this.numThreads;
        int start = 0;
        int end = increment;
        for (int taskIndex = 0; taskIndex < this.numThreads; ++taskIndex) {
            tasks.add(new ExpectationTask(start, end, this.model.copy()));
            start = end;
            end = taskIndex == this.numThreads - 2 ? this.trainingSet.size() : start + increment;
        }
        double value2 = 0.0;
        try {
            List results = this.executor.invokeAll(tasks);
            for (Future f2 : results) {
                try {
                    value2 += ((Double)f2.get()).doubleValue();
                }
                catch (ExecutionException ee) {
                    ee.printStackTrace();
                }
            }
        }
        catch (InterruptedException ie) {
            ie.printStackTrace();
        }
        this.combine(this.model, tasks);
        return value2 += this.model.getValue();
    }

    @Override
    public double getValue() {
        if (this.cacheStale) {
            this.cachedValue = this.getExpectationValue();
            this.model.getValueGradient(this.cachedGradient);
            this.cacheStale = false;
            logger.info("getValue (auxiliary distribution) = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    public double getCompleteValueContribution() {
        if (this.cacheStale) {
            this.getValue();
        }
        double value2 = this.model.getCompleteValueContribution();
        return value2;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cacheStale) {
            this.getValue();
        }
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }

    private void combine(PRAuxiliaryModel orig, ArrayList<Callable<Double>> tasks) {
        for (int i = 0; i < tasks.size(); ++i) {
            ExpectationTask task = (ExpectationTask)tasks.get(i);
            PRAuxiliaryModel model = task.getModelCopy();
            for (int ci = 0; ci < model.numConstraints(); ++ci) {
                PRConstraint origConstraint = orig.getConstraint(ci);
                PRConstraint copyConstraint = model.getConstraint(ci);
                double[] expectation = new double[origConstraint.numDimensions()];
                copyConstraint.getExpectations(expectation);
                origConstraint.addExpectations(expectation);
            }
        }
    }

    public void shutdown() {
        this.executor.shutdown();
    }

    public double[][][][] getCachedDots() {
        return this.cachedDots;
    }

    public PRAuxiliaryModel getAuxModel() {
        return this.model;
    }

    private class ExpectationTask
    implements Callable<Double> {
        private int start;
        private int end;
        private PRAuxiliaryModel modelCopy;

        public ExpectationTask(int start, int end, PRAuxiliaryModel modelCopy) {
            this.start = start;
            this.end = end;
            this.modelCopy = modelCopy;
        }

        @Override
        public Double call() throws Exception {
            double value2 = 0.0;
            for (int ii = this.start; ii < this.end; ++ii) {
                Instance inst = (Instance)ConstraintsOptimizableByPR.this.trainingSet.get(ii);
                Sequence input2 = (Sequence)inst.getData();
                value2 -= new SumLatticePR(ConstraintsOptimizableByPR.this.crf, ii, input2, null, this.modelCopy, ConstraintsOptimizableByPR.this.cachedDots[ii], true, null, null, false).getTotalWeight();
            }
            return value2;
        }

        public PRAuxiliaryModel getModelCopy() {
            return this.modelCopy;
        }
    }
}

