/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.AbstractClassifier;
import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.FeatureSelector;
import dragon.ir.index.IRDoc;
import dragon.ir.index.IndexReader;
import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.matrix.Row;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

public class NBClassifier
extends AbstractClassifier {
    protected DoubleFlatDenseMatrix model;
    protected DoubleVector classPrior;
    protected DoubleVector lastClassProb;
    private int[] rank;

    public NBClassifier(String modelFile) {
        try {
            ObjectInputStream oin = new ObjectInputStream(new FileInputStream(modelFile));
            this.model = (DoubleFlatDenseMatrix)oin.readObject();
            this.classPrior = (DoubleVector)oin.readObject();
            this.classNum = this.classPrior.size();
            this.featureSelector = (FeatureSelector)oin.readObject();
            this.arrLabel = new String[this.model.rows()];
            int i = 0;
            while (i < this.arrLabel.length) {
                this.arrLabel[i] = (String)oin.readObject();
                ++i;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public NBClassifier(IndexReader indexReader) {
        super(indexReader);
    }

    public NBClassifier(SparseMatrix doctermMatrix) {
        super(doctermMatrix);
    }

    public void train(DocClassSet trainingDocSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.classNum = trainingDocSet.getClassNum();
        this.classPrior = this.getClassPrior(trainingDocSet);
        this.trainFeatureSelector(trainingDocSet);
        this.arrLabel = new String[this.classNum];
        int i = 0;
        while (i < this.classNum) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
            ++i;
        }
        this.model = new DoubleFlatDenseMatrix(this.classNum, this.featureSelector.getSelectedFeatureNum());
        this.model.assign(1.0);
        i = 0;
        while (i < this.classNum) {
            int k;
            int classSum = this.featureSelector.getSelectedFeatureNum();
            DocClass cur = trainingDocSet.getDocClass(i);
            int j = 0;
            while (j < cur.getDocNum()) {
                IRDoc curDoc = cur.getDoc(j);
                Row row = this.getRow(curDoc.getIndex());
                k = 0;
                while (k < row.getNonZeroNum()) {
                    int newTermIndex = this.featureSelector.map(row.getNonZeroColumn(k));
                    if (newTermIndex >= 0) {
                        classSum = (int)((double)classSum + row.getNonZeroDoubleScore(k));
                        this.model.add(i, newTermIndex, row.getNonZeroDoubleScore(k));
                    }
                    ++k;
                }
                ++j;
            }
            double rate = 1.0 / (double)classSum;
            k = 0;
            while (k < this.model.columns()) {
                this.model.setDouble(i, k, Math.log(this.model.getDouble(i, k) * rate));
                ++k;
            }
            ++i;
        }
    }

    protected DoubleVector getClassPrior(DocClassSet docSet) {
        int sum2 = docSet.getClassNum();
        DoubleVector vector = new DoubleVector(docSet.getClassNum());
        vector.assign(1.0);
        int i = 0;
        while (i < docSet.getClassNum()) {
            vector.set(i, docSet.getDocClass(i).getDocNum());
            sum2 += docSet.getDocClass(i).getDocNum();
            ++i;
        }
        i = 0;
        while (i < docSet.getClassNum()) {
            vector.set(i, Math.log(vector.get(i) / (double)sum2));
            ++i;
        }
        return vector;
    }

    public int classify(IRDoc doc) {
        int label = this.classify(this.getRow(doc.getIndex()));
        doc.setWeight(this.lastClassProb.get(label));
        return label;
    }

    public int classify(Row doc) {
        this.lastClassProb = this.classPrior.copy();
        int classNum = this.model.rows();
        int k = 0;
        while (k < doc.getNonZeroNum()) {
            int newTermIndex = this.featureSelector.map(doc.getNonZeroColumn(k));
            if (newTermIndex >= 0) {
                int j = 0;
                while (j < classNum) {
                    this.lastClassProb.add(j, doc.getNonZeroDoubleScore(j) * this.model.getDouble(j, newTermIndex));
                    ++j;
                }
            }
            ++k;
        }
        this.rank = this.lastClassProb.rank(true);
        return this.rank[0];
    }

    public int[] rank() {
        return this.rank;
    }

    public void saveModel(String modelFile) {
        try {
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(modelFile));
            out.writeObject(this.model);
            out.writeObject(this.classPrior);
            out.writeObject(this.featureSelector);
            int i = 0;
            while (i < this.model.rows()) {
                out.writeObject(this.getClassLabel(i));
                ++i;
            }
            out.flush();
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

