/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.topicmodel;

import dragon.ir.topicmodel.AbstractModel;
import dragon.matrix.IntSparseMatrix;
import java.util.Random;

public class CrossMixtureModel
extends AbstractModel {
    protected IntSparseMatrix[] arrTopicReader;
    protected double[] bkgModel;
    protected double bkgCoefficient;
    protected double comCoefficient;
    protected int themeNum;
    protected int collectionNum;
    protected int maxTermNum;
    protected int maxDocNum;
    private double[][][] arrDocWeight;
    private double[][][] arrProb;
    private double[][] arrCommonProb;

    public CrossMixtureModel(IntSparseMatrix[] arrTopicMatrix, int themeNum, double[] bkgModel, double bkgCoefficient, double comCoefficient) {
        this.arrTopicReader = arrTopicMatrix;
        this.themeNum = themeNum;
        this.collectionNum = this.arrTopicReader.length;
        this.bkgModel = new double[bkgModel.length];
        this.comCoefficient = comCoefficient;
        int i = 0;
        while (i < bkgModel.length) {
            this.bkgModel[i] = bkgModel[i] * bkgCoefficient;
            ++i;
        }
        this.bkgCoefficient = bkgCoefficient;
        this.maxTermNum = this.arrTopicReader[0].columns();
        this.maxDocNum = this.arrTopicReader[0].rows();
        i = 1;
        while (i < this.arrTopicReader.length) {
            if (this.arrTopicReader[i].columns() > this.maxTermNum) {
                this.maxTermNum = this.arrTopicReader[i].columns();
            }
            if (this.arrTopicReader[i].rows() > this.maxDocNum) {
                this.maxDocNum = this.arrTopicReader[i].rows();
            }
            ++i;
        }
    }

    public double[][][] getModels() {
        return this.arrProb;
    }

    public double[][] getCommonModels() {
        return this.arrCommonProb;
    }

    public double[][][] getDocMemberships() {
        return this.arrDocWeight;
    }

    public boolean estimateModel() {
        this.arrProb = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        this.arrCommonProb = new double[this.themeNum][this.maxTermNum];
        this.arrDocWeight = new double[this.collectionNum][this.themeNum][this.maxDocNum];
        double[][][] arrTempProb = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        double[][] arrTempCommonProb = new double[this.themeNum][this.maxTermNum];
        double[] arrDocWeightSum = new double[this.themeNum];
        this.initialize(this.maxTermNum, this.collectionNum, this.themeNum, this.maxDocNum, this.arrCommonProb, this.arrProb, this.arrDocWeight);
        this.printStatus("Estimating the coefficients of simple mixture model...");
        int k = 0;
        while (k < this.iterations) {
            int j;
            this.printStatus("Iteration #" + (k + 1));
            int i = 0;
            while (i < this.themeNum) {
                j = 0;
                while (j < this.maxTermNum) {
                    arrTempCommonProb[i][j] = 0.0;
                    ++j;
                }
                ++i;
            }
            int n = 0;
            while (n < this.collectionNum) {
                i = 0;
                while (i < this.themeNum) {
                    j = 0;
                    while (j < this.maxTermNum) {
                        arrTempProb[n][i][j] = 0.0;
                        ++j;
                    }
                    ++i;
                }
                ++n;
            }
            n = 0;
            while (n < this.collectionNum) {
                int docNum = this.arrTopicReader[n].rows();
                i = 0;
                while (i < docNum) {
                    int[] arrIndex = this.arrTopicReader[n].getNonZeroColumnsInRow(i);
                    int[] arrFreq = this.arrTopicReader[n].getNonZeroIntScoresInRow(i);
                    int m = 0;
                    while (m < this.themeNum) {
                        arrDocWeightSum[m] = 0.0;
                        ++m;
                    }
                    j = 0;
                    while (j < arrIndex.length) {
                        int termIndex = arrIndex[j];
                        double themeProbSum = 0.0;
                        m = 0;
                        while (m < this.themeNum) {
                            themeProbSum += (this.comCoefficient * this.arrCommonProb[m][j] + (1.0 - this.comCoefficient) * this.arrProb[n][m][j]) * this.arrDocWeight[n][m][i];
                            ++m;
                        }
                        double bkgProb = this.bkgModel[termIndex] / (themeProbSum * (1.0 - this.bkgCoefficient) + this.bkgModel[termIndex]);
                        m = 0;
                        while (m < this.themeNum) {
                            double themeProb = themeProbSum == 0.0 ? 0.0 : (this.comCoefficient * this.arrCommonProb[m][termIndex] + (1.0 - this.comCoefficient) * this.arrProb[n][m][termIndex]) * this.arrDocWeight[n][m][i] / themeProbSum;
                            double comThemeProb = this.comCoefficient * this.arrCommonProb[m][termIndex] + (1.0 - this.comCoefficient) * this.arrProb[n][m][termIndex];
                            comThemeProb = comThemeProb > 0.0 ? this.comCoefficient * this.arrCommonProb[m][termIndex] / comThemeProb : 0.0;
                            double termProb = (double)arrFreq[j] * themeProb;
                            int n2 = m;
                            arrDocWeightSum[n2] = arrDocWeightSum[n2] + termProb;
                            double[] dArray = arrTempProb[n][m];
                            int n3 = termIndex;
                            dArray[n3] = dArray[n3] + (termProb *= 1.0 - bkgProb) * (1.0 - comThemeProb);
                            double[] dArray2 = arrTempCommonProb[m];
                            int n4 = termIndex;
                            dArray2[n4] = dArray2[n4] + termProb * comThemeProb;
                            ++m;
                        }
                        ++j;
                    }
                    double docWeightSum = 0.0;
                    m = 0;
                    while (m < this.themeNum) {
                        docWeightSum += arrDocWeightSum[m];
                        ++m;
                    }
                    if (docWeightSum > 0.0) {
                        m = 0;
                        while (m < this.themeNum) {
                            this.arrDocWeight[n][m][i] = arrDocWeightSum[m] / docWeightSum;
                            ++m;
                        }
                    } else {
                        m = 0;
                        while (m < this.themeNum) {
                            this.arrDocWeight[n][m][i] = 0.0;
                            ++m;
                        }
                    }
                    ++i;
                }
                ++n;
            }
            i = 0;
            while (i < this.themeNum) {
                double termProbSum = 0.0;
                j = 0;
                while (j < this.maxTermNum) {
                    termProbSum += arrTempCommonProb[i][j];
                    ++j;
                }
                j = 0;
                while (j < this.maxTermNum) {
                    this.arrCommonProb[i][j] = arrTempCommonProb[i][j] / termProbSum;
                    ++j;
                }
                n = 0;
                while (n < this.collectionNum) {
                    termProbSum = 0.0;
                    j = 0;
                    while (j < this.maxTermNum) {
                        termProbSum += arrTempProb[n][i][j];
                        ++j;
                    }
                    j = 0;
                    while (j < this.maxTermNum) {
                        this.arrProb[n][i][j] = arrTempProb[n][i][j] / termProbSum;
                        ++j;
                    }
                    ++n;
                }
                ++i;
            }
            ++k;
        }
        this.printStatus("");
        return true;
    }

    protected void initialize(int maxTermNum, int collectionNum, int themeNum, int maxDocNum, double[][] arrCommonModel, double[][][] arrModel, double[][][] arrDocMembership) {
        int j;
        double termProb = 1.0 / (double)maxTermNum;
        int i = 0;
        while (i < themeNum) {
            j = 0;
            while (j < maxTermNum) {
                arrCommonModel[i][j] = termProb;
                ++j;
            }
            ++i;
        }
        int n = 0;
        while (n < collectionNum) {
            i = 0;
            while (i < themeNum) {
                j = 0;
                while (j < maxTermNum) {
                    arrModel[n][i][j] = termProb;
                    ++j;
                }
                ++i;
            }
            ++n;
        }
        Random random = this.seed >= 0 ? new Random(this.seed) : new Random();
        n = 0;
        while (n < collectionNum) {
            j = 0;
            while (j < maxDocNum) {
                double docProb = 0.0;
                i = 0;
                while (i < themeNum) {
                    arrDocMembership[n][i][j] = random.nextDouble();
                    docProb += arrDocMembership[n][i][j];
                    ++i;
                }
                i = 0;
                while (i < themeNum) {
                    arrDocMembership[n][i][j] = arrDocMembership[n][i][j] / docProb;
                    ++i;
                }
                ++j;
            }
            ++n;
        }
    }
}

