/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.NBClassifier;
import dragon.ir.index.IRDoc;
import dragon.ir.index.IRTerm;
import dragon.ir.index.IndexReader;
import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.matrix.IntRow;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

public class NigamActiveLearning
extends NBClassifier {
    private IntRow[] externalUnlabeled;
    private DocClass unlabeledSet;
    private DocClass unlabeledSetBackup;
    private int externalDocOffset;
    private double convergeThreshold;
    private double unlabeledRate;
    private int runNum;

    public NigamActiveLearning(String modelFile) {
        super(modelFile);
    }

    public NigamActiveLearning(IndexReader indexReader, double unlabeledRate) {
        super(indexReader);
        this.externalDocOffset = indexReader.getCollection().getDocNum();
        this.runNum = 15;
        this.convergeThreshold = 1.0E-4;
        this.unlabeledRate = unlabeledRate;
    }

    public void setUnlabeledData(IndexReader newIndexReader, DocClass docSet) {
        int[] termMap = this.getTermMap(newIndexReader, this.indexReader);
        this.externalUnlabeled = new IntRow[docSet.getDocNum()];
        this.unlabeledSet = new DocClass(0);
        int docNum = 0;
        int i = 0;
        while (i < this.externalUnlabeled.length) {
            IRDoc curDoc = docSet.getDoc(i);
            int[] arrIndex = newIndexReader.getTermIndexList(curDoc.getIndex());
            int[] arrFreq = newIndexReader.getTermFrequencyList(curDoc.getIndex());
            if (arrIndex != null) {
                int termNum = 0;
                int j = 0;
                while (j < arrIndex.length) {
                    if (termMap[arrIndex[j]] >= 0) {
                        ++termNum;
                    }
                    ++j;
                }
                if (termNum != 0) {
                    int[] arrNewIndex = new int[termNum];
                    int[] arrNewFreq = new int[termNum];
                    termNum = 0;
                    j = 0;
                    while (j < arrIndex.length) {
                        int newIndex = termMap[arrIndex[j]];
                        if (newIndex >= 0) {
                            arrNewIndex[termNum] = newIndex;
                            arrNewFreq[termNum] = arrFreq[j];
                            ++termNum;
                        }
                        ++j;
                    }
                    this.externalUnlabeled[docNum] = new IntRow(docNum, termNum, arrNewIndex, arrNewFreq);
                    curDoc.setIndex(this.externalDocOffset + docNum);
                    curDoc.setKey("external_unlabeled" + curDoc.getKey());
                    this.unlabeledSet.addDoc(curDoc);
                    ++docNum;
                }
            }
            ++i;
        }
    }

    public void setUnlabeledData(DocClass docSet) {
        this.unlabeledSet = docSet;
        this.externalUnlabeled = null;
    }

    public DocClassSet classify(DocClassSet trainingDocSet, DocClass testingDocs) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return null;
        }
        if (this.unlabeledRate > 0.0) {
            int i;
            this.unlabeledSetBackup = this.unlabeledSet;
            this.unlabeledSet = new DocClass(0);
            if (this.unlabeledSetBackup != null) {
                i = 0;
                while (i < this.unlabeledSetBackup.getDocNum()) {
                    this.unlabeledSet.addDoc(this.unlabeledSetBackup.getDoc(i));
                    ++i;
                }
            }
            ArrayList<IRDoc> list2 = new ArrayList<IRDoc>(testingDocs.getDocNum());
            i = 0;
            while (i < testingDocs.getDocNum()) {
                list2.add(testingDocs.getDoc(i));
                ++i;
            }
            Collections.shuffle(list2, new Random(10L));
            int num2 = (int)(this.unlabeledRate * (double)list2.size());
            i = 0;
            while (i < num2) {
                this.unlabeledSet.addDoc((IRDoc)list2.get(i));
                ++i;
            }
            this.train(trainingDocSet);
            this.unlabeledSet.removeAll();
            this.unlabeledSet = this.unlabeledSetBackup;
        } else {
            this.train(trainingDocSet);
        }
        return this.classify(testingDocs);
    }

    public void train(DocClassSet trainingDocSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.classNum = trainingDocSet.getClassNum();
        this.arrLabel = new String[this.classNum];
        int i = 0;
        while (i < this.classNum) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
            ++i;
        }
        this.eStep(trainingDocSet);
        double prevProb = 0.0;
        int curRun = 0;
        double prob = -1.7976931348623157E308;
        while (Math.abs(prob - prevProb) > this.convergeThreshold && curRun < this.runNum) {
            IRDoc curDoc;
            int j;
            DocClass cur;
            prevProb = prob;
            prob = 0.0;
            DocClassSet classifiedUnlabeledSet = this.classify(this.unlabeledSet);
            i = 0;
            while (i < trainingDocSet.getClassNum()) {
                cur = trainingDocSet.getDocClass(i);
                j = 0;
                while (j < cur.getDocNum()) {
                    curDoc = cur.getDoc(j);
                    prob += curDoc.getWeight();
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < trainingDocSet.getClassNum()) {
                cur = trainingDocSet.getDocClass(i);
                j = 0;
                while (j < cur.getDocNum()) {
                    curDoc = cur.getDoc(j);
                    double docProb = this.classPrior.get(i);
                    int[] arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                    int[] arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                    int k = 0;
                    while (k < arrIndex.length) {
                        int newTermIndex = this.featureSelector.map(arrIndex[k]);
                        if (newTermIndex >= 0) {
                            docProb += (double)arrFreq[k] * this.model.getDouble(i, newTermIndex);
                        }
                        ++k;
                    }
                    prob += docProb;
                    ++j;
                }
                ++i;
            }
            DocClassSet newTrainingSet = new DocClassSet(trainingDocSet.getClassNum());
            i = 0;
            while (i < trainingDocSet.getClassNum()) {
                cur = trainingDocSet.getDocClass(i);
                j = 0;
                while (j < cur.getDocNum()) {
                    newTrainingSet.addDoc(i, cur.getDoc(j));
                    ++j;
                }
                ++i;
            }
            i = 0;
            while (i < classifiedUnlabeledSet.getClassNum()) {
                cur = classifiedUnlabeledSet.getDocClass(i);
                j = 0;
                while (j < cur.getDocNum()) {
                    newTrainingSet.addDoc(i, cur.getDoc(j));
                    ++j;
                }
                ++i;
            }
            this.eStep(newTrainingSet);
            ++curRun;
        }
    }

    public int classify(IRDoc curDoc) {
        int[] arrFreq;
        int[] arrIndex;
        if (curDoc.getKey().startsWith("external_unlabeled")) {
            arrIndex = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroColumns();
            arrFreq = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroIntScores();
        } else {
            arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
            arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
        }
        IntRow row = new IntRow(0, arrIndex.length, arrIndex, arrFreq);
        int label = this.classify(row);
        curDoc.setWeight(this.lastClassProb.get(label));
        return label;
    }

    private void eStep(DocClassSet trainingDocSet) {
        this.classPrior = this.getClassPrior(trainingDocSet);
        this.featureSelector.train(this.indexReader, trainingDocSet);
        this.model = new DoubleFlatDenseMatrix(trainingDocSet.getClassNum(), this.featureSelector.getSelectedFeatureNum());
        this.model.assign(1.0);
        int i = 0;
        while (i < trainingDocSet.getClassNum()) {
            int k;
            int classSum = this.featureSelector.getSelectedFeatureNum();
            DocClass cur = trainingDocSet.getDocClass(i);
            int j = 0;
            while (j < cur.getDocNum()) {
                int[] arrFreq;
                int[] arrIndex;
                IRDoc curDoc = cur.getDoc(j);
                if (curDoc.getKey().startsWith("external_unlabeled")) {
                    arrIndex = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroColumns();
                    arrFreq = this.externalUnlabeled[curDoc.getIndex() - this.externalDocOffset].getNonZeroIntScores();
                } else {
                    arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                    arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                }
                k = 0;
                while (k < arrIndex.length) {
                    int newTermIndex = this.featureSelector.map(arrIndex[k]);
                    if (newTermIndex >= 0) {
                        classSum += arrFreq[k];
                        this.model.add(i, newTermIndex, arrFreq[k]);
                    }
                    ++k;
                }
                ++j;
            }
            double rate = 1.0 / (double)classSum;
            k = 0;
            while (k < this.model.columns()) {
                this.model.setDouble(i, k, Math.log(this.model.getDouble(i, k) * rate));
                ++k;
            }
            ++i;
        }
    }

    private int[] getTermMap(IndexReader src, IndexReader dest) {
        int[] termMap = new int[src.getCollection().getTermNum()];
        int i = 0;
        while (i < termMap.length) {
            IRTerm irTerm = dest.getIRTerm(src.getTermKey(i));
            termMap[i] = irTerm != null ? irTerm.getIndex() : -1;
            ++i;
        }
        return termMap;
    }
}

