/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification.featureselection;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.AbstractFeatureSelector;
import dragon.ir.index.IndexReader;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.IndexComparator;
import dragon.nlp.compare.WeightComparator;
import dragon.util.MathUtil;
import dragon.util.SortedArray;
import java.io.Serializable;

public class InfoGainFeatureSelector
extends AbstractFeatureSelector
implements Serializable {
    private static final long serialVersionUID = 1L;
    private double topPercentage;

    public InfoGainFeatureSelector(double topPercentage) {
        this.topPercentage = topPercentage;
    }

    protected int[] getSelectedFeatures(SparseMatrix doctermMatrix, DocClassSet trainingSet) {
        System.out.println("InfoGainSelector does not accept SparseMatrix as input. Please use IndexReader as input instead.");
        return null;
    }

    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet trainingSet) {
        SortedArray list2 = this.computeTermIG(indexReader, trainingSet);
        int termNum = (int)(this.topPercentage * (double)indexReader.getCollection().getTermNum());
        termNum = Math.min(list2.size(), termNum);
        SortedArray selectedList = new SortedArray(termNum, new IndexComparator());
        int i = 0;
        while (i < termNum) {
            selectedList.add(list2.get(i));
            ++i;
        }
        int[] featureMap = new int[selectedList.size()];
        i = 0;
        while (i < featureMap.length) {
            featureMap[i] = ((Token)selectedList.get(i)).getIndex();
            ++i;
        }
        return featureMap;
    }

    private SortedArray computeTermIG(IndexReader indexReader, DocClassSet trainingSet) {
        int j;
        int trainingDocNum = 0;
        int i = 0;
        while (i < trainingSet.getClassNum()) {
            trainingDocNum += trainingSet.getDocClass(i).getDocNum();
            ++i;
        }
        DoubleVector classPrior = this.getClassPrior(trainingSet);
        double classEntropy = this.calEntropy(classPrior);
        DoubleVector classVector = classPrior.copy();
        classVector.multiply(trainingDocNum);
        int[] arrDoc = new int[indexReader.getCollection().getDocNum()];
        MathUtil.initArray(arrDoc, -1);
        i = 0;
        while (i < trainingSet.getClassNum()) {
            DocClass docClass = trainingSet.getDocClass(i);
            j = 0;
            while (j < docClass.getDocNum()) {
                arrDoc[docClass.getDoc((int)j).getIndex()] = i;
                ++j;
            }
            ++i;
        }
        int termNum = indexReader.getCollection().getTermNum();
        SortedArray list2 = new SortedArray(termNum, new IndexComparator());
        DoubleVector termVector = new DoubleVector(termNum);
        DoubleVector classDistrWiTerm = new DoubleVector(classPrior.size());
        DoubleVector classDistrWoTerm = new DoubleVector(classPrior.size());
        i = 0;
        while (i < termNum) {
            int[] arrDocIndex = indexReader.getTermDocIndexList(i);
            if (arrDocIndex != null && arrDocIndex.length != 0) {
                classDistrWiTerm.assign(0.0);
                classDistrWoTerm.assign(classVector);
                int docCount = 0;
                j = 0;
                while (j < arrDocIndex.length) {
                    int docLabel = arrDoc[arrDocIndex[j]];
                    if (docLabel >= 0) {
                        classDistrWiTerm.add(docLabel, 1.0);
                        classDistrWoTerm.add(docLabel, -1.0);
                        ++docCount;
                    }
                    ++j;
                }
                if (docCount != 0) {
                    classDistrWiTerm.multiply(1.0 / (double)docCount);
                    classDistrWoTerm.multiply(1.0 / (double)(trainingDocNum - docCount));
                    termVector.set(i, classEntropy - this.calEntropy(classDistrWiTerm) - this.calEntropy(classDistrWoTerm));
                }
            }
            ++i;
        }
        i = 0;
        while (i < termVector.size()) {
            Token curTerm = new Token(i, 0);
            if (!(termVector.get(i) <= 0.0)) {
                curTerm.setWeight(termVector.get(i));
                list2.add(curTerm);
            }
            ++i;
        }
        list2.setComparator(new WeightComparator(true));
        return list2;
    }

    private double calEntropy(DoubleVector probVector) {
        double sum2 = 0.0;
        int i = 0;
        while (i < probVector.size()) {
            sum2 = probVector.get(i) == 0.0 ? (sum2 -= Double.MIN_VALUE * Math.log(Double.MIN_VALUE)) : (sum2 -= probVector.get(i) * Math.log(probVector.get(i)));
            ++i;
        }
        return sum2;
    }
}

