/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification.featureselection;

import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.AbstractFeatureSelector;
import dragon.ir.index.IndexReader;
import dragon.matrix.IntDenseMatrix;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.IndexComparator;
import dragon.nlp.compare.WeightComparator;
import dragon.util.SortedArray;
import java.io.Serializable;

public class MutualInfoFeatureSelector
extends AbstractFeatureSelector
implements Serializable {
    private static final long serialVersionUID = 1L;
    private double topPercentage;
    private boolean avgMode;

    public MutualInfoFeatureSelector(double topPercentage, boolean avgMode) {
        this.topPercentage = topPercentage;
        this.avgMode = avgMode;
    }

    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet trainingSet) {
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        int docNum = 0;
        int i = 0;
        while (i < trainingSet.getClassNum()) {
            docNum += trainingSet.getDocClass(i).getDocNum();
            ++i;
        }
        SortedArray list2 = this.computeTermMI(this.getTermDistribution(indexReader, trainingSet), classPrior, docNum);
        int termNum = (int)(this.topPercentage * (double)indexReader.getCollection().getTermNum());
        termNum = Math.min(list2.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        i = 0;
        while (i < termNum) {
            selectedList.add(list2.get(i));
            ++i;
        }
        int[] featureMap = new int[selectedList.size()];
        i = 0;
        while (i < featureMap.length) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
            ++i;
        }
        return featureMap;
    }

    protected int[] getSelectedFeatures(SparseMatrix doctermMatrix, DocClassSet trainingSet) {
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        int docNum = 0;
        int i = 0;
        while (i < trainingSet.getClassNum()) {
            docNum += trainingSet.getDocClass(i).getDocNum();
            ++i;
        }
        SortedArray list2 = this.computeTermMI(this.getTermDistribution(doctermMatrix, trainingSet), classPrior, docNum);
        int termNum = (int)(this.topPercentage * (double)doctermMatrix.columns());
        termNum = Math.min(list2.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        i = 0;
        while (i < termNum) {
            selectedList.add(list2.get(i));
            ++i;
        }
        int[] featureMap = new int[selectedList.size()];
        i = 0;
        while (i < featureMap.length) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
            ++i;
        }
        return featureMap;
    }

    private SortedArray computeTermMI(IntDenseMatrix termDistri, DoubleVector classPrior, int docNum) {
        DoubleVector classVector = classPrior.copy();
        classVector.multiply(docNum);
        DoubleVector termVector = new DoubleVector(termDistri.columns());
        int i = 0;
        while (i < termDistri.columns()) {
            termVector.set(i, termDistri.getColumnSum(i));
            ++i;
        }
        double total = docNum;
        DoubleVector chiVector = new DoubleVector(classVector.size());
        SortedArray list2 = new SortedArray(termVector.size(), new IndexComparator());
        i = 0;
        while (i < termVector.size()) {
            if (!(termVector.get(i) <= 0.0)) {
                int j = 0;
                while (j < classVector.size()) {
                    chiVector.set(j, this.calMutualInformation(termDistri.getInt(j, i), classVector.get(j), termVector.get(i), total));
                    ++j;
                }
                Token curTerm = new Token(i, 0);
                if (this.avgMode) {
                    curTerm.setWeight(chiVector.dotProduct(classPrior));
                } else {
                    curTerm.setWeight(chiVector.getMaxValue());
                }
                list2.add(curTerm);
            }
            ++i;
        }
        list2.setComparator(new WeightComparator(true));
        return list2;
    }

    private double calMutualInformation(double t1t2occur, double t1sum, double t2sum, double total) {
        if (t1t2occur == 0.0 || t1sum == 0.0 || t2sum == 0.0) {
            return 0.0;
        }
        return Math.log(t1t2occur * total / (t1sum * t2sum));
    }
}

