/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;

public class InfoGain
extends RankedFeatureVector {
    static double staticBaseEntropy;
    static LabelVector staticBaseLabelDistribution;
    double baseEntropy;
    LabelVector baseLabelDistribution;

    private static double[] calcInfoGains(InstanceList ilist) {
        double log2 = Math.log(2.0);
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        double[] infogains = new double[numFeatures];
        double[][] targetFeatureCount = new double[numClasses][numFeatures];
        double[] featureCountSum = new double[numFeatures];
        double[] targetCount = new double[numClasses];
        double targetCountSum = 0.0;
        for (int i = 0; i < ilist.size(); ++i) {
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            double instanceWeight = ilist.getInstanceWeight(i);
            double labelWeightSum = 0.0;
            for (int ll = 0; ll < labeling.numLocations(); ++ll) {
                int li = labeling.indexAtLocation(ll);
                double labelWeight = labeling.valueAtLocation(ll);
                labelWeightSum += labelWeight;
                if (labelWeight == 0.0) continue;
                double count2 = labelWeight * instanceWeight;
                for (int fl = 0; fl < fv.numLocations(); ++fl) {
                    int fli = fv.indexAtLocation(fl);
                    if (!(fv.valueAtLocation(fl) > 0.0)) continue;
                    double[] dArray = targetFeatureCount[li];
                    int n = fli;
                    dArray[n] = dArray[n] + count2;
                    int n2 = fli;
                    featureCountSum[n2] = featureCountSum[n2] + count2;
                }
                int n = li;
                targetCount[n] = targetCount[n] + count2;
                targetCountSum += count2;
            }
            assert (Math.abs(labelWeightSum - 1.0) < 1.0E-4);
        }
        if (targetCountSum == 0.0) {
            staticBaseEntropy = 0.0;
            staticBaseLabelDistribution = new LabelVector((LabelAlphabet)ilist.getTargetAlphabet(), targetCount);
            return infogains;
        }
        assert (targetCountSum > 0.0) : targetCountSum;
        double[] classDistribution = new double[numClasses];
        staticBaseEntropy = 0.0;
        for (int li = 0; li < numClasses; ++li) {
            double p;
            classDistribution[li] = p = targetCount[li] / targetCountSum;
            assert (p <= 1.0) : p;
            if (p == 0.0) continue;
            staticBaseEntropy -= p * Math.log(p) / log2;
        }
        staticBaseLabelDistribution = new LabelVector((LabelAlphabet)ilist.getTargetAlphabet(), classDistribution);
        for (int fi = 0; fi < numFeatures; ++fi) {
            double featurePresentEntropy = 0.0;
            double norm = featureCountSum[fi];
            if (norm > 0.0) {
                for (int li = 0; li < numClasses; ++li) {
                    double p = targetFeatureCount[li][fi] / norm;
                    assert (p <= 1.00000001) : p;
                    if (p == 0.0) continue;
                    featurePresentEntropy -= p * Math.log(p) / log2;
                }
            }
            assert (!Double.isNaN(featurePresentEntropy)) : fi;
            norm = targetCountSum - featureCountSum[fi];
            double featureAbsentEntropy = 0.0;
            if (norm > 0.0) {
                for (int li = 0; li < numClasses; ++li) {
                    double p = (targetCount[li] - targetFeatureCount[li][fi]) / norm;
                    assert (p <= 1.00000001) : p;
                    if (p == 0.0) continue;
                    featureAbsentEntropy -= p * Math.log(p) / log2;
                }
            }
            assert (!Double.isNaN(featureAbsentEntropy)) : fi;
            infogains[fi] = staticBaseEntropy - featureCountSum[fi] / targetCountSum * featurePresentEntropy - (targetCountSum - featureCountSum[fi]) / targetCountSum * featureAbsentEntropy;
            assert (!Double.isNaN(infogains[fi])) : fi;
        }
        return infogains;
    }

    public InfoGain(InstanceList ilist) {
        super(ilist.getDataAlphabet(), InfoGain.calcInfoGains(ilist));
        this.baseEntropy = staticBaseEntropy;
        this.baseLabelDistribution = staticBaseLabelDistribution;
    }

    public InfoGain(Alphabet vocab, double[] infogains) {
        super(vocab, infogains);
    }

    public double getBaseEntropy() {
        return this.baseEntropy;
    }

    public LabelVector getBaseLabelDistribution() {
        return this.baseLabelDistribution;
    }

    public static class Factory
    implements RankedFeatureVector.Factory {
        @Override
        public RankedFeatureVector newRankedFeatureVector(InstanceList ilist) {
            return new InfoGain(ilist);
        }
    }
}

