/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.NBClassifier;
import dragon.ir.classification.featureselection.NullFeatureSelector;
import dragon.ir.index.IRDoc;
import dragon.ir.index.IRTerm;
import dragon.ir.index.IndexReader;
import dragon.ir.kngbase.KnowledgeBase;
import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.matrix.DoubleSparseMatrix;
import dragon.util.MathUtil;

public class SemanticNBClassifier
extends NBClassifier {
    private IndexReader topicIndexReader;
    private DoubleSparseMatrix topicTransMatrix;
    private double transCoefficient;
    private double bkgCoefficient;
    private int[] topicMap;
    private int[] termMap;

    public SemanticNBClassifier(String modelFile) {
        super(modelFile);
    }

    public SemanticNBClassifier(IndexReader indexReader, double bkgCoefficient) {
        super(indexReader);
        this.topicIndexReader = null;
        this.topicTransMatrix = null;
        this.transCoefficient = 0.0;
        this.bkgCoefficient = bkgCoefficient;
        this.featureSelector = new NullFeatureSelector();
    }

    public SemanticNBClassifier(IndexReader indexReader, IndexReader topicIndexReader, DoubleSparseMatrix topicTransMatrix, double transCoefficient, double bkgCoefficient) {
        super(indexReader);
        this.featureSelector = new NullFeatureSelector();
        this.topicIndexReader = topicIndexReader;
        this.topicTransMatrix = topicTransMatrix;
        this.transCoefficient = transCoefficient;
        this.bkgCoefficient = bkgCoefficient;
        this.topicMap = new int[topicIndexReader.getCollection().getTermNum()];
        int i = 0;
        while (i < this.topicMap.length) {
            this.topicMap[i] = i;
            ++i;
        }
        this.termMap = new int[indexReader.getCollection().getTermNum()];
        int i2 = 0;
        while (i2 < this.termMap.length) {
            this.termMap[i2] = i2;
            ++i2;
        }
    }

    public SemanticNBClassifier(IndexReader indexReader, IndexReader topicIndexReader, KnowledgeBase kngBase, double transCoefficient, double bkgCoefficient) {
        super(indexReader);
        this.featureSelector = new NullFeatureSelector();
        this.topicIndexReader = topicIndexReader;
        this.topicTransMatrix = kngBase.getKnowledgeMatrix();
        this.transCoefficient = transCoefficient;
        this.bkgCoefficient = bkgCoefficient;
        this.topicMap = new int[topicIndexReader.getCollection().getTermNum()];
        int i = 0;
        while (i < this.topicMap.length) {
            this.topicMap[i] = kngBase.getRowKeyList().search(topicIndexReader.getTermKey(i));
            ++i;
        }
        this.termMap = new int[kngBase.getColumnKeyList().size()];
        i = 0;
        while (i < this.termMap.length) {
            IRTerm curTerm = indexReader.getIRTerm(kngBase.getColumnKeyList().search(i));
            this.termMap[i] = curTerm == null ? -1 : curTerm.getIndex();
            ++i;
        }
    }

    public double getTranslationCoefficient() {
        return this.transCoefficient;
    }

    public void setTranslationCoefficient(double transCoefficient) {
        this.transCoefficient = transCoefficient;
    }

    public double getBackgroundCoefficient() {
        return this.bkgCoefficient;
    }

    public void setBackgroundCoefficient(double bkgCoefficient) {
        this.bkgCoefficient = bkgCoefficient;
    }

    public void train(DocClassSet trainingDocSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.classPrior = this.getClassPrior(trainingDocSet);
        this.featureSelector.train(this.indexReader, trainingDocSet);
        this.arrLabel = new String[trainingDocSet.getClassNum()];
        int i = 0;
        while (i < trainingDocSet.getClassNum()) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
            ++i;
        }
        this.model = new DoubleFlatDenseMatrix(trainingDocSet.getClassNum(), this.featureSelector.getSelectedFeatureNum());
        double[] bkgModel = this.getBackgroundModel(this.indexReader);
        i = 0;
        while (i < trainingDocSet.getClassNum()) {
            double a;
            int k;
            int classSum = 0;
            DocClass cur = trainingDocSet.getDocClass(i);
            int j = 0;
            while (j < cur.getDocNum()) {
                IRDoc curDoc = cur.getDoc(j);
                int[] arrIndex = this.indexReader.getTermIndexList(curDoc.getIndex());
                int[] arrFreq = this.indexReader.getTermFrequencyList(curDoc.getIndex());
                k = 0;
                while (k < arrIndex.length) {
                    int newTermIndex = this.featureSelector.map(arrIndex[k]);
                    if (newTermIndex >= 0) {
                        classSum += arrFreq[k];
                        this.model.add(i, newTermIndex, arrFreq[k]);
                    }
                    ++k;
                }
                ++j;
            }
            if (this.topicTransMatrix != null) {
                double[] transModel = this.computeTranslationModel(cur);
                a = (1.0 - this.bkgCoefficient) * (1.0 - this.transCoefficient) / (double)classSum;
                double b = (1.0 - this.transCoefficient) * this.bkgCoefficient;
                k = 0;
                while (k < this.model.columns()) {
                    this.model.setDouble(i, k, Math.log(transModel[k] * this.transCoefficient + this.model.getDouble(i, k) * a + bkgModel[k] * b));
                    ++k;
                }
            } else {
                a = (1.0 - this.bkgCoefficient) / (double)classSum;
                k = 0;
                while (k < this.model.columns()) {
                    this.model.setDouble(i, k, Math.log(this.model.getDouble(i, k) * a + bkgModel[k] * this.bkgCoefficient));
                    ++k;
                }
            }
            ++i;
        }
    }

    private double[] computeTranslationModel(DocClass curClass) {
        int termIndex;
        int topicIndex;
        int j;
        int[] arrIndex;
        int topicNum = this.topicIndexReader.getCollection().getTermNum();
        int[] arrCount = new int[topicNum];
        int termNum = this.indexReader.getCollection().getTermNum();
        int docNum = this.topicIndexReader.getCollection().getDocNum();
        int i = 0;
        while (i < curClass.getDocNum()) {
            IRDoc curDoc = curClass.getDoc(i);
            if (curDoc.getIndex() < docNum) {
                arrIndex = this.topicIndexReader.getTermIndexList(curDoc.getIndex());
                int[] arrFreq = this.topicIndexReader.getTermFrequencyList(curDoc.getIndex());
                if (arrIndex != null) {
                    j = 0;
                    while (j < arrIndex.length) {
                        int n = arrIndex[j];
                        arrCount[n] = arrCount[n] + arrFreq[j];
                        ++j;
                    }
                }
            }
            ++i;
        }
        i = 0;
        while (i < this.topicMap.length) {
            topicIndex = this.topicMap[i];
            if (topicIndex < 0) {
                arrCount[i] = 0;
            } else if (topicIndex >= this.topicTransMatrix.rows()) {
                arrCount[i] = 0;
            } else if (this.topicTransMatrix.getNonZeroNumInRow(topicIndex) <= 0) {
                arrCount[i] = 0;
            }
            ++i;
        }
        double sum2 = MathUtil.sumArray(arrCount);
        double[] arrModel = new double[termNum];
        i = 0;
        while (i < topicNum) {
            if (arrCount[i] > 0) {
                topicIndex = this.topicMap[i];
                double rate = (double)arrCount[i] / sum2;
                arrIndex = this.topicTransMatrix.getNonZeroColumnsInRow(topicIndex);
                double[] arrScore = this.topicTransMatrix.getNonZeroDoubleScoresInRow(topicIndex);
                j = 0;
                while (j < arrIndex.length) {
                    termIndex = this.termMap[arrIndex[j]];
                    if (termIndex >= 0) {
                        int n = termIndex;
                        arrModel[n] = arrModel[n] + rate * arrScore[j];
                    }
                    ++j;
                }
            }
            ++i;
        }
        if (arrModel.length == this.featureSelector.getSelectedFeatureNum()) {
            return arrModel;
        }
        double[] arrSelectedModel = new double[this.featureSelector.getSelectedFeatureNum()];
        sum2 = 0.0;
        i = 0;
        while (i < arrModel.length) {
            termIndex = this.featureSelector.map(i);
            if (termIndex >= 0) {
                sum2 += arrModel[i];
                arrSelectedModel[termIndex] = arrModel[i];
            }
            ++i;
        }
        i = 0;
        while (i < arrSelectedModel.length) {
            arrSelectedModel[i] = arrSelectedModel[i] / sum2;
            ++i;
        }
        return arrSelectedModel;
    }

    private double[] getBackgroundModel(IndexReader reader) {
        int termNum = reader.getCollection().getTermNum();
        int featureNum = this.featureSelector.getSelectedFeatureNum();
        double sum2 = 0.0;
        double[] arrModel = new double[featureNum];
        int i = 0;
        while (i < termNum) {
            int newIndex = this.featureSelector.map(i);
            if (newIndex >= 0) {
                arrModel[newIndex] = reader.getIRTerm(i).getFrequency();
                sum2 += arrModel[newIndex];
            }
            ++i;
        }
        i = 0;
        while (i < featureNum) {
            arrModel[i] = arrModel[i] / sum2;
            ++i;
        }
        return arrModel;
    }
}

