/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.logging.Logger;

public class DMROptimizable
implements Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(DMROptimizable.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(DMROptimizable.class.getName() + "-pl");
    MaxEnt classifier;
    InstanceList trainingList;
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;
    int numIterations = Integer.MAX_VALUE;
    NumberFormat formatter = null;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_LARGE_GAUSSIAN_PRIOR_VARIANCE = 100.0;
    static final double DEFAULT_GAUSSIAN_PRIOR_MEAN = 0.0;
    double gaussianPriorMean = 0.0;
    double gaussianPriorVariance = 1.0;
    double defaultFeatureGaussianPriorVariance = 100.0;
    double[] parameters;
    double[] cachedGradient;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;

    public DMROptimizable() {
    }

    public DMROptimizable(InstanceList instances, MaxEnt initialClassifier) {
        this.trainingList = instances;
        Alphabet alphabet = instances.getDataAlphabet();
        Alphabet labelAlphabet = instances.getTargetAlphabet();
        this.numLabels = labelAlphabet.size();
        this.numFeatures = alphabet.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        if (initialClassifier != null) {
            this.classifier = initialClassifier;
            this.parameters = this.classifier.getParameters();
            this.defaultFeatureIndex = this.classifier.getDefaultFeatureIndex();
            assert (initialClassifier.getInstancePipe() == instances.getPipe());
        } else if (this.classifier == null) {
            this.classifier = new MaxEnt(instances.getPipe(), this.parameters);
        }
        this.formatter = new DecimalFormat("0.###E0");
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            FeatureVector features = (FeatureVector)instance.getData();
            assert (features.getAlphabet() == alphabet);
            boolean hasNaN = false;
            for (int i = 0; i < features.numLocations(); ++i) {
                if (!Double.isNaN(features.valueAtLocation(i))) continue;
                logger.info("NaN for feature " + alphabet.lookupObject(features.indexAtLocation(i)).toString());
                hasNaN = true;
            }
            if (!hasNaN) continue;
            logger.info("NaN in instance: " + instance.getName());
        }
    }

    public void setInterceptGaussianPriorVariance(double sigmaSquared) {
        this.defaultFeatureGaussianPriorVariance = sigmaSquared;
    }

    public void setRegularGaussianPriorVariance(double sigmaSquared) {
        this.gaussianPriorVariance = sigmaSquared;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void setParameter(int index, double v) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[index] = v;
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public void getParameters(double[] buff) {
        if (buff == null || buff.length != this.parameters.length) {
            buff = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, buff, 0, this.parameters.length);
    }

    @Override
    public void setParameters(double[] buff) {
        assert (buff != null);
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (buff.length != this.parameters.length) {
            this.parameters = new double[buff.length];
        }
        System.arraycopy(buff, 0, this.parameters, 0, buff.length);
    }

    @Override
    public double getValue() {
        if (!this.cachedValueStale) {
            return this.cachedValue;
        }
        ++this.numGetValueCalls;
        this.cachedValue = 0.0;
        double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
        double value2 = 0.0;
        int instanceIndex = 0;
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            this.classifier.getUnnormalizedClassificationScores(instance, scores);
            double sumScores = 0.0;
            for (int i = 0; i < scores.length; ++i) {
                scores[i] = Math.exp(scores[i]);
                sumScores += scores[i];
            }
            FeatureVector features = (FeatureVector)instance.getData();
            double totalLength = 0.0;
            for (int i = 0; i < multinomialValues.numLocations(); ++i) {
                int label = multinomialValues.indexAtLocation(i);
                double count2 = multinomialValues.valueAtLocation(i);
                value2 += Dirichlet.logGammaStirling(scores[label] + count2) - Dirichlet.logGammaStirling(scores[label]);
                totalLength += count2;
            }
            if (Double.isNaN(value2 -= Dirichlet.logGammaStirling(sumScores + totalLength) - Dirichlet.logGammaStirling(sumScores))) {
                logger.fine("DCMMaxEntTrainer: Instance " + instance.getName() + "has NaN value.");
                for (int label : multinomialValues.getIndices()) {
                    logger.fine("log(scores)= " + Math.log(scores[label]) + " scores = " + scores[label]);
                }
            }
            if (Double.isInfinite(value2)) {
                logger.warning("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                this.cachedValue -= value2;
                this.cachedValueStale = false;
                return -value2;
            }
            this.cachedValue += value2;
            ++instanceIndex;
        }
        double prior = 0.0;
        for (int label = 0; label < this.numLabels; ++label) {
            for (int feature = 0; feature < this.numFeatures - 1; ++feature) {
                double param2 = this.parameters[label * this.numFeatures + feature];
                prior -= (param2 - this.gaussianPriorMean) * (param2 - this.gaussianPriorMean) / (2.0 * this.gaussianPriorVariance);
            }
            double param3 = this.parameters[label * this.numFeatures + this.defaultFeatureIndex];
            prior -= (param3 - this.gaussianPriorMean) * (param3 - this.gaussianPriorMean) / (2.0 * this.defaultFeatureGaussianPriorVariance);
        }
        double labelProbability = this.cachedValue;
        this.cachedValue += prior;
        this.cachedValueStale = false;
        progressLogger.info("Value (likelihood=" + this.formatter.format(labelProbability) + " prior=" + this.formatter.format(prior) + ") = " + this.formatter.format(this.cachedValue));
        return this.cachedValue;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        MatrixOps.setAll(this.cachedGradient, 0.0);
        double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
        boolean instanceIndex = false;
        for (Instance instance : this.trainingList) {
            FeatureVector multinomialValues = (FeatureVector)instance.getTarget();
            if (multinomialValues == null) continue;
            this.classifier.getUnnormalizedClassificationScores(instance, scores);
            double sumScores = 0.0;
            for (int i = 0; i < scores.length; ++i) {
                scores[i] = Math.exp(scores[i]);
                sumScores += scores[i];
            }
            FeatureVector features = (FeatureVector)instance.getData();
            double totalLength = 0.0;
            for (double count2 : multinomialValues.getValues()) {
                totalLength += count2;
            }
            double digammaDifferenceForSums = Dirichlet.digamma(sumScores + totalLength) - Dirichlet.digamma(sumScores);
            for (int loc = 0; loc < features.numLocations(); ++loc) {
                int index = features.indexAtLocation(loc);
                double value2 = features.valueAtLocation(loc);
                if (value2 == 0.0) continue;
                for (int label = 0; label < this.numLabels; ++label) {
                    int n = label * this.numFeatures + index;
                    this.cachedGradient[n] = this.cachedGradient[n] - value2 * scores[label] * digammaDifferenceForSums;
                }
                for (int labelLoc = 0; labelLoc < multinomialValues.numLocations(); ++labelLoc) {
                    int label = multinomialValues.indexAtLocation(labelLoc);
                    double count3 = multinomialValues.valueAtLocation(labelLoc);
                    double diff2 = 0.0;
                    if (count3 < 20.0) {
                        int i = 0;
                        while ((double)i < count3) {
                            diff2 += 1.0 / (scores[label] + (double)i);
                            ++i;
                        }
                    } else {
                        diff2 = Dirichlet.digamma(scores[label] + count3) - Dirichlet.digamma(scores[label]);
                    }
                    int n = label * this.numFeatures + index;
                    this.cachedGradient[n] = this.cachedGradient[n] + value2 * scores[label] * diff2;
                }
            }
            for (int label = 0; label < this.numLabels; ++label) {
                int n = label * this.numFeatures + this.defaultFeatureIndex;
                this.cachedGradient[n] = this.cachedGradient[n] - scores[label] * digammaDifferenceForSums;
            }
            for (int labelLoc = 0; labelLoc < multinomialValues.numLocations(); ++labelLoc) {
                int label = multinomialValues.indexAtLocation(labelLoc);
                double count4 = multinomialValues.valueAtLocation(labelLoc);
                double diff3 = 0.0;
                if (count4 < 20.0) {
                    int i = 0;
                    while ((double)i < count4) {
                        diff3 += 1.0 / (scores[label] + (double)i);
                        ++i;
                    }
                } else {
                    diff3 = Dirichlet.digamma(scores[label] + count4) - Dirichlet.digamma(scores[label]);
                }
                int n = label * this.numFeatures + this.defaultFeatureIndex;
                this.cachedGradient[n] = this.cachedGradient[n] + scores[label] * diff3;
            }
        }
        ++this.numGetValueGradientCalls;
        for (int label = 0; label < this.numLabels; ++label) {
            for (int feature = 0; feature < this.numFeatures - 1; ++feature) {
                double param2 = this.parameters[label * this.numFeatures + feature];
                int n = label * this.numFeatures + feature;
                this.cachedGradient[n] = this.cachedGradient[n] - (param2 - this.gaussianPriorMean) / this.gaussianPriorVariance;
            }
            double param3 = this.parameters[label * this.numFeatures + this.defaultFeatureIndex];
            int n = label * this.numFeatures + this.defaultFeatureIndex;
            this.cachedGradient[n] = this.cachedGradient[n] - (param3 - this.gaussianPriorMean) / this.defaultFeatureGaussianPriorVariance;
        }
        MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0);
        assert (buffer != null && buffer.length == this.parameters.length);
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }
}

