/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;
import cc.mallet.types.RankedFeatureVector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;

public class NaiveBayes
extends Classifier
implements Serializable {
    Multinomial.Logged prior;
    Multinomial.Logged[] p;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public NaiveBayes(Pipe instancePipe, Multinomial.Logged prior, Multinomial.Logged[] classIndex2FeatureProb) {
        super(instancePipe);
        this.prior = prior;
        this.p = classIndex2FeatureProb;
    }

    private static Multinomial.Logged[] logMultinomials(Multinomial[] m) {
        Multinomial.Logged[] ml = new Multinomial.Logged[m.length];
        for (int i = 0; i < m.length; ++i) {
            ml[i] = new Multinomial.Logged(m[i]);
        }
        return ml;
    }

    public NaiveBayes(Pipe dataPipe, Multinomial prior, Multinomial[] classIndex2FeatureProb) {
        this(dataPipe, new Multinomial.Logged(prior), NaiveBayes.logMultinomials(classIndex2FeatureProb));
    }

    public Multinomial.Logged[] getMultinomials() {
        return this.p;
    }

    public Multinomial.Logged getPriors() {
        return this.prior;
    }

    public void printWords(int numToPrint) {
        Alphabet alphabet = this.instancePipe.getDataAlphabet();
        int numFeatures = alphabet.size();
        int numLabels = this.instancePipe.getTargetAlphabet().size();
        double[] probs = new double[numFeatures];
        numToPrint = Math.min(numToPrint, numFeatures);
        for (int li = 0; li < numLabels; ++li) {
            Arrays.fill(probs, 0.0);
            this.p[li].addProbabilities(probs);
            RankedFeatureVector rfv = new RankedFeatureVector(alphabet, probs);
            System.out.println("\nFeature probabilities " + this.instancePipe.getTargetAlphabet().lookupObject(li));
            for (int i = 0; i < numToPrint; ++i) {
                System.out.println(rfv.getObjectAtRank(i) + " " + rfv.getValueAtRank(i));
            }
        }
    }

    @Override
    public Classification classify(Instance instance) {
        int ci;
        int ci2;
        int numClasses = this.getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int fvisize = fv.numLocations();
        this.prior.addLogProbabilities(scores);
        for (int fvi = 0; fvi < fvisize; ++fvi) {
            int fi = fv.indexAtLocation(fvi);
            for (ci2 = 0; ci2 < numClasses; ++ci2) {
                if (ci2 >= this.p.length || fi >= this.p[ci2].size()) continue;
                int n = ci2;
                scores[n] = scores[n] + fv.valueAtLocation(fvi) * this.p[ci2].logProbability(fi);
            }
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        for (ci2 = 0; ci2 < numClasses; ++ci2) {
            if (!(scores[ci2] > maxScore)) continue;
            maxScore = scores[ci2];
        }
        ci2 = 0;
        while (ci2 < numClasses) {
            int n = ci2++;
            scores[n] = scores[n] - maxScore;
        }
        double sum2 = 0.0;
        for (ci = 0; ci < numClasses; ++ci) {
            scores[ci] = Math.exp(scores[ci]);
            sum2 += scores[ci];
        }
        ci = 0;
        while (ci < numClasses) {
            int n = ci++;
            scores[n] = scores[n] / sum2;
        }
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    private double dataLogProbability(Instance instance, int labelIndex) {
        FeatureVector fv = (FeatureVector)instance.getData();
        int fvisize = fv.numLocations();
        double logProb = 0.0;
        for (int fvi = 0; fvi < fvisize; ++fvi) {
            logProb += fv.valueAtLocation(fvi) * this.p[labelIndex].logProbability(fv.indexAtLocation(fvi));
        }
        return logProb;
    }

    public double dataLogLikelihood(InstanceList ilist) {
        double logLikelihood = 0.0;
        for (int ii = 0; ii < ilist.size(); ++ii) {
            double instanceWeight = ilist.getInstanceWeight(ii);
            Instance inst = (Instance)ilist.get(ii);
            Labeling labeling = inst.getLabeling();
            if (labeling != null) {
                logLikelihood += instanceWeight * this.dataLogProbability(inst, labeling.getBestIndex());
                continue;
            }
            Labeling predicted = this.classify(inst).getLabeling();
            for (int lpos = 0; lpos < predicted.numLocations(); ++lpos) {
                int li = predicted.indexAtLocation(lpos);
                double labelWeight = predicted.valueAtLocation(lpos);
                if (labelWeight == 0.0) continue;
                logLikelihood += instanceWeight * labelWeight * this.dataLogProbability(inst, li);
            }
        }
        return logLikelihood;
    }

    public double labelLogLikelihood(InstanceList ilist) {
        double logLikelihood = 0.0;
        for (int ii = 0; ii < ilist.size(); ++ii) {
            double instanceWeight = ilist.getInstanceWeight(ii);
            Instance inst = (Instance)ilist.get(ii);
            Labeling labeling = inst.getLabeling();
            if (labeling == null) continue;
            Labeling predicted = this.classify(inst).getLabeling();
            if (labeling.numLocations() == 1) {
                logLikelihood += instanceWeight * Math.log(predicted.value(labeling.getBestIndex()));
                continue;
            }
            for (int lpos = 0; lpos < labeling.numLocations(); ++lpos) {
                int li = labeling.indexAtLocation(lpos);
                double labelWeight = labeling.valueAtLocation(lpos);
                if (labelWeight == 0.0) continue;
                logLikelihood += instanceWeight * labelWeight * Math.log(predicted.value(li));
            }
        }
        return logLikelihood;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        out.writeObject(this.prior);
        out.writeObject(this.p);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayes versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        this.prior = (Multinomial.Logged)in.readObject();
        this.p = (Multinomial.Logged[])in.readObject();
    }
}

