/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;

public class PolylingualTopicModel
implements Serializable {
    static CommandOption.SpacedStrings languageInputFiles = new CommandOption.SpacedStrings(PolylingualTopicModel.class, "language-inputs", "FILENAME [FILENAME ...]", true, null, "Filenames for polylingual topic model. Each language should have its own file, with the same number of instances in each file. If a document is missing in one language, there should be an empty instance.", null);
    static CommandOption.String outputModelFilename = new CommandOption.String(PolylingualTopicModel.class, "output-model", "FILENAME", true, null, "The filename in which to write the binary topic model at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String inputModelFilename = new CommandOption.String(PolylingualTopicModel.class, "input-model", "FILENAME", true, null, "The filename from which to read the binary topic model to which the --input will be appended, allowing incremental training.  By default this is null, indicating that no file will be read.", null);
    static CommandOption.String inferencerFilename = new CommandOption.String(PolylingualTopicModel.class, "inferencer-filename", "FILENAME", true, null, "A topic inferencer applies a previously trained topic model to new documents.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String evaluatorFilename = new CommandOption.String(PolylingualTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String stateFile = new CommandOption.String(PolylingualTopicModel.class, "output-state", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String topicKeysFile = new CommandOption.String(PolylingualTopicModel.class, "output-topic-keys", "FILENAME", true, null, "The filename in which to write the top words for each topic and any Dirichlet parameters.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String docTopicsFile = new CommandOption.String(PolylingualTopicModel.class, "output-doc-topics", "FILENAME", true, null, "The filename in which to write the topic proportions per document, at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.Double docTopicsThreshold = new CommandOption.Double(PolylingualTopicModel.class, "doc-topics-threshold", "DECIMAL", true, 0.0, "When writing topic proportions per document with --output-doc-topics, do not print topics with proportions less than this threshold value.", null);
    static CommandOption.Integer docTopicsMax = new CommandOption.Integer(PolylingualTopicModel.class, "doc-topics-max", "INTEGER", true, -1, "When writing topic proportions per document with --output-doc-topics, do not print more than INTEGER number of topics.  A negative value indicates that all topics should be printed.", null);
    static CommandOption.Integer outputModelIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "output-model-interval", "INTEGER", true, 0, "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file.  You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);
    static CommandOption.Integer outputStateIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "output-state-interval", "INTEGER", true, 0, "The number of iterations between writing the sampling state to a text file.  You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);
    static CommandOption.Integer numTopicsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null);
    static CommandOption.Integer numIterationsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(PolylingualTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
    static CommandOption.Integer topWordsOption = new CommandOption.Integer(PolylingualTopicModel.class, "num-top-words", "INTEGER", true, 20, "The number of most probable words to print for each topic after model estimation.", null);
    static CommandOption.Integer showTopicsIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "show-topics-interval", "INTEGER", true, 50, "The number of iterations between printing a brief summary of the topics so far.", null);
    static CommandOption.Integer optimizeIntervalOption = new CommandOption.Integer(PolylingualTopicModel.class, "optimize-interval", "INTEGER", true, 0, "The number of iterations between reestimating dirichlet hyperparameters.", null);
    static CommandOption.Integer optimizeBurnInOption = new CommandOption.Integer(PolylingualTopicModel.class, "optimize-burn-in", "INTEGER", true, 200, "The number of iterations to run before first estimating dirichlet hyperparameters.", null);
    static CommandOption.Double alphaOption = new CommandOption.Double(PolylingualTopicModel.class, "alpha", "DECIMAL", true, 50.0, "Alpha parameter: smoothing over topic distribution.", null);
    static CommandOption.Double betaOption = new CommandOption.Double(PolylingualTopicModel.class, "beta", "DECIMAL", true, 0.01, "Beta parameter: smoothing over unigram distribution.", null);
    int numLanguages = 1;
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected LabelAlphabet topicAlphabet;
    protected int numStopwords = 0;
    protected int numTopics;
    HashSet<String> testingIDs = null;
    protected int topicMask;
    protected int topicBits;
    protected Alphabet[] alphabets;
    protected int[] vocabularySizes;
    protected double[] alpha;
    protected double alphaSum;
    protected double[] betas;
    protected double[] betaSums;
    protected int[] languageMaxTypeCounts;
    public static final double DEFAULT_BETA = 0.01;
    protected double[] languageSmoothingOnlyMasses;
    protected double[][] languageCachedCoefficients;
    int topicTermCount = 0;
    int betaTopicCount = 0;
    int smoothingOnlyCount = 0;
    protected int[] oneDocTopicCounts;
    protected int[][][] languageTypeTopicCounts;
    protected int[][] languageTokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    protected int iterationsSoFar = 1;
    public int numIterations = 1000;
    public int burninPeriod = 5;
    public int saveSampleInterval = 5;
    public int optimizeInterval = 10;
    public int showTopicsInterval = 10;
    public int wordsPerTopic = 7;
    protected int saveModelInterval = 0;
    protected String modelFilename;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public PolylingualTopicModel(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics);
    }

    public PolylingualTopicModel(int numberOfTopics, double alphaSum) {
        this(numberOfTopics, alphaSum, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public PolylingualTopicModel(int numberOfTopics, double alphaSum, Randoms random) {
        this(PolylingualTopicModel.newLabelAlphabet(numberOfTopics), alphaSum, random);
    }

    public PolylingualTopicModel(LabelAlphabet topicAlphabet, double alphaSum, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.random = random;
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Polylingual LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public void loadTestingIDs(File testingIDFile) throws IOException {
        this.testingIDs = new HashSet();
        BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
        String id = null;
        while ((id = in.readLine()) != null) {
            this.testingIDs.add(id);
        }
        in.close();
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setModelOutput(int interval, String filename) {
        this.saveModelInterval = interval;
        this.modelFilename = filename;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public void addInstances(InstanceList[] training) {
        this.numLanguages = training.length;
        this.languageTokensPerTopic = new int[this.numLanguages][this.numTopics];
        this.alphabets = new Alphabet[this.numLanguages];
        this.vocabularySizes = new int[this.numLanguages];
        this.betas = new double[this.numLanguages];
        this.betaSums = new double[this.numLanguages];
        this.languageMaxTypeCounts = new int[this.numLanguages];
        this.languageTypeTopicCounts = new int[this.numLanguages][][];
        int numInstances = training[0].size();
        HashSet[] stoplists = new HashSet[this.numLanguages];
        for (int language = 0; language < this.numLanguages; ++language) {
            if (training[language].size() != numInstances) {
                System.err.println("Warning: language " + language + " has " + training[language].size() + " instances, lang 0 has " + numInstances);
            }
            this.alphabets[language] = training[language].getDataAlphabet();
            this.vocabularySizes[language] = this.alphabets[language].size();
            this.betas[language] = 0.01;
            this.betaSums[language] = this.betas[language] * (double)this.vocabularySizes[language];
            this.languageTypeTopicCounts[language] = new int[this.vocabularySizes[language]][];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] typeTotals = new int[this.vocabularySizes[language]];
            for (Instance instance : training[language]) {
                if (this.testingIDs != null && this.testingIDs.contains(instance.getName())) continue;
                FeatureSequence tokens = (FeatureSequence)instance.getData();
                for (int position = 0; position < tokens.getLength(); ++position) {
                    int type;
                    int n = type = tokens.getIndexAtPosition(position);
                    typeTotals[n] = typeTotals[n] + 1;
                }
            }
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                if (typeTotals[type] > this.languageMaxTypeCounts[language]) {
                    this.languageMaxTypeCounts[language] = typeTotals[type];
                }
                typeTopicCounts[type] = new int[Math.min(this.numTopics, typeTotals[type])];
            }
        }
        for (int doc = 0; doc < numInstances; ++doc) {
            if (this.testingIDs != null && this.testingIDs.contains(((Instance)training[0].get(doc)).getName())) continue;
            Instance[] instances = new Instance[this.numLanguages];
            LabelSequence[] topicSequences = new LabelSequence[this.numLanguages];
            for (int language = 0; language < this.numLanguages; ++language) {
                int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
                int[] tokensPerTopic = this.languageTokensPerTopic[language];
                instances[language] = (Instance)training[language].get(doc);
                FeatureSequence tokens = (FeatureSequence)instances[language].getData();
                topicSequences[language] = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
                int[] topics = topicSequences[language].getFeatures();
                for (int position = 0; position < tokens.size(); ++position) {
                    int topic;
                    int type = tokens.getIndexAtPosition(position);
                    int[] currentTypeTopicCounts = typeTopicCounts[type];
                    topics[position] = topic = this.random.nextInt(this.numTopics);
                    int n = topic;
                    tokensPerTopic[n] = tokensPerTopic[n] + 1;
                    int index = 0;
                    int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                        currentTopic = currentTypeTopicCounts[++index] & this.topicMask;
                    }
                    int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    if (currentValue == 0) {
                        currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                        continue;
                    }
                    currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                    while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                        int temp = currentTypeTopicCounts[index];
                        currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                        currentTypeTopicCounts[index - 1] = temp;
                        --index;
                    }
                }
            }
            TopicAssignment t = new TopicAssignment(instances, topicSequences);
            this.data.add(t);
        }
        this.initializeHistograms();
        this.languageSmoothingOnlyMasses = new double[this.numLanguages];
        this.languageCachedCoefficients = new double[this.numLanguages][this.numTopics];
        this.cacheValues();
    }

    private void initializeHistograms() {
        int maxTokens = 0;
        int totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            int length = 0;
            for (LabelSequence sequence2 : this.data.get((int)doc).topicSequences) {
                length += sequence2.getLength();
            }
            if (length > maxTokens) {
                maxTokens = length;
            }
            totalTokens += length;
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    private void cacheValues() {
        for (int language = 0; language < this.numLanguages; ++language) {
            this.languageSmoothingOnlyMasses[language] = 0.0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = language;
                this.languageSmoothingOnlyMasses[n] = this.languageSmoothingOnlyMasses[n] + this.alpha[topic] * this.betas[language] / ((double)this.languageTokensPerTopic[language][topic] + this.betaSums[language]);
                this.languageCachedCoefficients[language][topic] = this.alpha[topic] / ((double)this.languageTokensPerTopic[language][topic] + this.betaSums[language]);
            }
        }
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
    }

    public void estimate() throws IOException {
        this.estimate(this.numIterations);
    }

    public void estimate(int iterationsThisRound) throws IOException {
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        long totalTime = 0L;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.optimizeBetas();
                this.clearHistograms();
                this.cacheValues();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            for (int doc = 0; doc < this.data.size(); ++doc) {
                this.sampleTopicsForOneDoc(this.data.get(doc), this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0);
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            totalTime += elapsedMillis;
            if ((this.iterationsSoFar + 1) % 10 == 0) {
                double ll = this.modelLogLikelihood();
                System.out.println(elapsedMillis + "\t" + totalTime + "\t" + ll);
            } else {
                System.out.print(elapsedMillis + " ");
            }
            ++this.iterationsSoFar;
        }
    }

    public void optimizeBetas() {
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] countHistogram = new int[this.languageMaxTypeCounts[language] + 1];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                int[] counts = typeTopicCounts[type];
                for (int index = 0; index < counts.length && counts[index] > 0; ++index) {
                    int count2;
                    int n = count2 = counts[index] >> this.topicBits;
                    countHistogram[n] = countHistogram[n] + 1;
                }
            }
            int maxTopicSize = 0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (tokensPerTopic[topic] <= maxTopicSize) continue;
                maxTopicSize = tokensPerTopic[topic];
            }
            int[] topicSizeHistogram = new int[maxTopicSize + 1];
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = tokensPerTopic[topic];
                topicSizeHistogram[n] = topicSizeHistogram[n] + 1;
            }
            this.betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram, topicSizeHistogram, this.vocabularySizes[language], this.betaSums[language]);
            this.betas[language] = this.betaSums[language] / (double)this.vocabularySizes[language];
        }
    }

    protected void sampleTopicsForOneDoc(TopicAssignment topicAssignment, boolean shouldSaveState) {
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures();
            int docLength = topicAssignment.topicSequences[language].getLength();
            for (int position = 0; position < docLength; ++position) {
                int n = oneDocTopics[position];
                localTopicCounts[n] = localTopicCounts[n] + 1;
            }
        }
        int denseIndex = 0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (localTopicCounts[topic] == 0) continue;
            localTopicIndex[denseIndex] = topic;
            ++denseIndex;
        }
        int nonZeroTopics = denseIndex;
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures();
            int docLength = topicAssignment.topicSequences[language].getLength();
            FeatureSequence tokenSequence = (FeatureSequence)topicAssignment.instances[language].getData();
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            double beta = this.betas[language];
            double betaSum = this.betaSums[language];
            double smoothingOnlyMass = this.languageSmoothingOnlyMasses[language];
            double[] cachedCoefficients = this.languageCachedCoefficients[language];
            double topicBetaMass = 0.0;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                int topic = localTopicIndex[denseIndex];
                int n = localTopicCounts[topic];
                topicBetaMass += beta * (double)n / ((double)tokensPerTopic[topic] + betaSum);
                cachedCoefficients[topic] = (this.alpha[topic] + (double)n) / ((double)tokensPerTopic[topic] + betaSum);
            }
            double topicTermMass = 0.0;
            double[] topicTermScores = new double[this.numTopics];
            for (int position = 0; position < docLength; ++position) {
                int temp;
                double sample;
                int currentValue;
                int type = tokenSequence.getIndexAtPosition(position);
                int oldTopic = oneDocTopics[position];
                if (oldTopic == -1) continue;
                int[] currentTypeTopicCounts = typeTopicCounts[type];
                smoothingOnlyMass -= this.alpha[oldTopic] * beta / ((double)tokensPerTopic[oldTopic] + betaSum);
                topicBetaMass -= beta * (double)localTopicCounts[oldTopic] / ((double)tokensPerTopic[oldTopic] + betaSum);
                int n = oldTopic;
                localTopicCounts[n] = localTopicCounts[n] - 1;
                if (localTopicCounts[oldTopic] == 0) {
                    denseIndex = 0;
                    while (localTopicIndex[denseIndex] != oldTopic) {
                        ++denseIndex;
                    }
                    while (denseIndex < nonZeroTopics) {
                        if (denseIndex < localTopicIndex.length - 1) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                        }
                        ++denseIndex;
                    }
                    --nonZeroTopics;
                }
                int n2 = oldTopic;
                tokensPerTopic[n2] = tokensPerTopic[n2] - 1;
                smoothingOnlyMass += this.alpha[oldTopic] * beta / ((double)tokensPerTopic[oldTopic] + betaSum);
                topicBetaMass += beta * (double)localTopicCounts[oldTopic] / ((double)tokensPerTopic[oldTopic] + betaSum);
                cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)tokensPerTopic[oldTopic] + betaSum);
                int index = 0;
                boolean alreadyDecremented = false;
                topicTermMass = 0.0;
                while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                    int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    if (!alreadyDecremented && currentTopic == oldTopic) {
                        currentTypeTopicCounts[index] = --currentValue == 0 ? 0 : (currentValue << this.topicBits) + oldTopic;
                        for (int subIndex = index; subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]; ++subIndex) {
                            int temp2 = currentTypeTopicCounts[subIndex];
                            currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
                            currentTypeTopicCounts[subIndex + 1] = temp2;
                        }
                        alreadyDecremented = true;
                        continue;
                    }
                    double score = cachedCoefficients[currentTopic] * (double)currentValue;
                    topicTermMass += score;
                    topicTermScores[index] = score;
                    ++index;
                }
                double origSample = sample = this.random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
                int newTopic = -1;
                if (sample < topicTermMass) {
                    int i = -1;
                    while (sample > 0.0) {
                        sample -= topicTermScores[++i];
                    }
                    newTopic = currentTypeTopicCounts[i] & this.topicMask;
                    currentValue = currentTypeTopicCounts[i] >> this.topicBits;
                    currentTypeTopicCounts[i] = (currentValue + 1 << this.topicBits) + newTopic;
                    while (i > 0 && currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
                        temp = currentTypeTopicCounts[i];
                        currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
                        currentTypeTopicCounts[i - 1] = temp;
                        --i;
                    }
                } else {
                    if ((sample -= topicTermMass) < topicBetaMass) {
                        sample /= beta;
                        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                            int topic = localTopicIndex[denseIndex];
                            if (!((sample -= (double)localTopicCounts[topic] / ((double)tokensPerTopic[topic] + betaSum)) <= 0.0)) continue;
                            newTopic = topic;
                            break;
                        }
                    } else {
                        sample -= topicBetaMass;
                        sample /= beta;
                        newTopic = 0;
                        sample -= this.alpha[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                        while (sample > 0.0) {
                            sample -= this.alpha[++newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                        }
                    }
                    index = 0;
                    while (currentTypeTopicCounts[index] > 0 && (currentTypeTopicCounts[index] & this.topicMask) != newTopic) {
                        ++index;
                    }
                    if (currentTypeTopicCounts[index] == 0) {
                        currentTypeTopicCounts[index] = (1 << this.topicBits) + newTopic;
                    } else {
                        currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + newTopic;
                        while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                            temp = currentTypeTopicCounts[index];
                            currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                            currentTypeTopicCounts[index - 1] = temp;
                            --index;
                        }
                    }
                }
                if (newTopic == -1) {
                    System.err.println("PolylingualTopicModel sampling error: " + origSample + " " + sample + " " + smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                    newTopic = this.numTopics - 1;
                }
                oneDocTopics[position] = newTopic;
                smoothingOnlyMass -= this.alpha[newTopic] * beta / ((double)tokensPerTopic[newTopic] + betaSum);
                topicBetaMass -= beta * (double)localTopicCounts[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                int n3 = newTopic;
                localTopicCounts[n3] = localTopicCounts[n3] + 1;
                if (localTopicCounts[newTopic] == 1) {
                    for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                        localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                    }
                    localTopicIndex[denseIndex] = newTopic;
                    ++nonZeroTopics;
                }
                int n4 = newTopic;
                tokensPerTopic[n4] = tokensPerTopic[n4] + 1;
                cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)tokensPerTopic[newTopic] + betaSum);
                topicBetaMass += beta * (double)localTopicCounts[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                this.languageSmoothingOnlyMasses[language] = smoothingOnlyMass += this.alpha[newTopic] * beta / ((double)tokensPerTopic[newTopic] + betaSum);
            }
        }
        if (shouldSaveState) {
            int totalLength = 0;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                int topic = localTopicIndex[denseIndex];
                int[] nArray = this.topicDocCounts[topic];
                int n = localTopicCounts[topic];
                nArray[n] = nArray[n] + 1;
                totalLength += localTopicCounts[topic];
            }
            int n = totalLength;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
        }
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        TreeSet[][] languageTopicSortedWords = new TreeSet[this.numLanguages][this.numTopics];
        for (int language = 0; language < this.numLanguages; ++language) {
            TreeSet[] topicSortedWords = languageTopicSortedWords[language];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            for (int topic = 0; topic < this.numTopics; ++topic) {
                topicSortedWords[topic] = new TreeSet();
            }
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                int[] topicCounts = typeTopicCounts[type];
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int topic = topicCounts[index] & this.topicMask;
                    int count2 = topicCounts[index] >> this.topicBits;
                    topicSortedWords[topic].add(new IDSorter(type, count2));
                }
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out.println(topic + "\t" + this.formatter.format(this.alpha[topic]));
            for (int language = 0; language < this.numLanguages; ++language) {
                out.print(" " + language + "\t" + this.languageTokensPerTopic[language][topic] + "\t" + this.betas[language] + "\t");
                TreeSet sortedWords = languageTopicSortedWords[language][topic];
                Alphabet alphabet = this.alphabets[language];
                Iterator iterator2 = sortedWords.iterator();
                for (int word = 1; iterator2.hasNext() && word < numWords; ++word) {
                    IDSorter info = (IDSorter)iterator2.next();
                    out.print(alphabet.lookupObject(info.getID()) + " ");
                }
                out.println();
            }
        }
    }

    public void printDocumentTopics(File f2) throws IOException {
        this.printDocumentTopics(new PrintWriter(f2, "UTF-8"));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max2) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max2 < 0 || max2 > this.numTopics) {
            max2 = this.numTopics;
        }
        for (int di = 0; di < this.data.size(); ++di) {
            pw.print(di);
            pw.print(' ');
            int totalLength = 0;
            for (int language = 0; language < this.numLanguages; ++language) {
                LabelSequence topicSequence = this.data.get((int)di).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();
                int docLength = topicSequence.getLength();
                totalLength += docLength;
                for (int token2 = 0; token2 < docLength; ++token2) {
                    int n = currentDocTopics[token2];
                    topicCounts[n] = topicCounts[n] + 1;
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)totalLength);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max2 && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                pw.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f2) throws IOException {
        PrintStream out = new PrintStream((OutputStream)new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f2))), false, "UTF-8");
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc lang pos typeindex type topic");
        for (int doc = 0; doc < this.data.size(); ++doc) {
            for (int language = 0; language < this.numLanguages; ++language) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instances[language].getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequences[language];
                for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                    int type = tokenSequence.getIndexAtPosition(pi);
                    int topic = topicSequence.getIndexAtPosition(pi);
                    out.print(doc);
                    out.print(' ');
                    out.print(language);
                    out.print(' ');
                    out.print(pi);
                    out.print(' ');
                    out.print(type);
                    out.print(' ');
                    out.print(this.alphabets[language].lookupObject(type));
                    out.print(' ');
                    out.print(topic);
                    out.println();
                }
            }
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            int totalLength = 0;
            for (int language = 0; language < this.numLanguages; ++language) {
                LabelSequence topicSequence = this.data.get((int)doc).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();
                totalLength += topicSequence.getLength();
                for (int token2 = 0; token2 < topicSequence.getLength(); ++token2) {
                    int n = currentDocTopics[token2];
                    topicCounts[n] = topicCounts[n] + 1;
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)totalLength);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        for (int language = 0; language < this.numLanguages; ++language) {
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            double beta = this.betas[language];
            int nonZeroTypeTopics = 0;
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                topicCounts = typeTopicCounts[type];
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int topic = topicCounts[index] & this.topicMask;
                    int count2 = topicCounts[index] >> this.topicBits;
                    ++nonZeroTypeTopics;
                    if (!Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(beta + (double)count2))) continue;
                    System.out.println(count2);
                    System.exit(1);
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (!Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(beta * (double)this.numTopics + (double)tokensPerTopic[topic]))) continue;
                System.out.println("after topic " + topic + " " + tokensPerTopic[topic]);
                System.exit(1);
            }
            logLikelihood += Dirichlet.logGammaStirling(beta * (double)this.numTopics) - Dirichlet.logGammaStirling(beta) * (double)nonZeroTypeTopics;
        }
        if (Double.isNaN(logLikelihood)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public TopicInferencer getInferencer(int language) {
        return new TopicInferencer(this.languageTypeTopicCounts[language], this.languageTokensPerTopic[language], this.alphabets[language], this.alpha, this.betas[language], this.betaSums[language]);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeInt(this.numLanguages);
        out.writeObject(this.data);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeObject(this.testingIDs);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeObject(this.alphabets);
        out.writeObject(this.vocabularySizes);
        out.writeObject(this.alpha);
        out.writeDouble(this.alphaSum);
        out.writeObject(this.betas);
        out.writeObject(this.betaSums);
        out.writeObject(this.languageMaxTypeCounts);
        out.writeObject(this.languageTypeTopicCounts);
        out.writeObject(this.languageTokensPerTopic);
        out.writeObject(this.languageSmoothingOnlyMasses);
        out.writeObject(this.languageCachedCoefficients);
        out.writeObject(this.docLengthCounts);
        out.writeObject(this.topicDocCounts);
        out.writeInt(this.numIterations);
        out.writeInt(this.burninPeriod);
        out.writeInt(this.saveSampleInterval);
        out.writeInt(this.optimizeInterval);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeInt(this.saveStateInterval);
        out.writeObject(this.stateFilename);
        out.writeInt(this.saveModelInterval);
        out.writeObject(this.modelFilename);
        out.writeObject(this.random);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.numLanguages = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.testingIDs = (HashSet)in.readObject();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.alphabets = (Alphabet[])in.readObject();
        this.vocabularySizes = (int[])in.readObject();
        this.alpha = (double[])in.readObject();
        this.alphaSum = in.readDouble();
        this.betas = (double[])in.readObject();
        this.betaSums = (double[])in.readObject();
        this.languageMaxTypeCounts = (int[])in.readObject();
        this.languageTypeTopicCounts = (int[][][])in.readObject();
        this.languageTokensPerTopic = (int[][])in.readObject();
        this.languageSmoothingOnlyMasses = (double[])in.readObject();
        this.languageCachedCoefficients = (double[][])in.readObject();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.saveModelInterval = in.readInt();
        this.modelFilename = (String)in.readObject();
        this.random = (Randoms)in.readObject();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
    }

    public void write(File serializedModelFile) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serializedModelFile));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Problem serializing PolylingualTopicModel to file " + serializedModelFile + ": " + e);
        }
    }

    public static PolylingualTopicModel read(File f2) throws Exception {
        PolylingualTopicModel topicModel = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f2));
        topicModel = (PolylingualTopicModel)ois.readObject();
        ois.close();
        topicModel.initializeHistograms();
        return topicModel;
    }

    public static void main(String[] args) throws IOException {
        CommandOption.setSummary(PolylingualTopicModel.class, "A tool for estimating, saving and printing diagnostics for topic models over comparable corpora.");
        CommandOption.process(PolylingualTopicModel.class, args);
        PolylingualTopicModel topicModel = null;
        if (PolylingualTopicModel.inputModelFilename.value != null) {
            try {
                topicModel = PolylingualTopicModel.read(new File(PolylingualTopicModel.inputModelFilename.value));
            }
            catch (Exception e) {
                System.err.println("Unable to restore saved topic model " + PolylingualTopicModel.inputModelFilename.value + ": " + e);
                System.exit(1);
            }
        } else {
            Object data2;
            int numLanguages = PolylingualTopicModel.languageInputFiles.value.length;
            InstanceList[] training = new InstanceList[numLanguages];
            for (int i = 0; i < training.length; ++i) {
                training[i] = InstanceList.load(new File(PolylingualTopicModel.languageInputFiles.value[i]));
                if (training[i] != null) {
                    System.out.println(i + " is not null");
                    continue;
                }
                System.out.println(i + " is null");
            }
            System.out.println("Data loaded.");
            if (training[0].size() > 0 && training[0].get(0) != null && !((data2 = ((Instance)training[0].get(0)).getData()) instanceof FeatureSequence)) {
                System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
                System.exit(1);
            }
            topicModel = new PolylingualTopicModel(PolylingualTopicModel.numTopicsOption.value, PolylingualTopicModel.alphaOption.value);
            if (PolylingualTopicModel.randomSeedOption.value != 0) {
                topicModel.setRandomSeed(PolylingualTopicModel.randomSeedOption.value);
            }
            topicModel.addInstances(training);
        }
        topicModel.setTopicDisplay(PolylingualTopicModel.showTopicsIntervalOption.value, PolylingualTopicModel.topWordsOption.value);
        topicModel.setNumIterations(PolylingualTopicModel.numIterationsOption.value);
        topicModel.setOptimizeInterval(PolylingualTopicModel.optimizeIntervalOption.value);
        topicModel.setBurninPeriod(PolylingualTopicModel.optimizeBurnInOption.value);
        if (PolylingualTopicModel.outputStateIntervalOption.value != 0) {
            topicModel.setSaveState(PolylingualTopicModel.outputStateIntervalOption.value, PolylingualTopicModel.stateFile.value);
        }
        if (PolylingualTopicModel.outputModelIntervalOption.value != 0) {
            topicModel.setModelOutput(PolylingualTopicModel.outputModelIntervalOption.value, PolylingualTopicModel.outputModelFilename.value);
        }
        topicModel.estimate();
        if (PolylingualTopicModel.topicKeysFile.value != null) {
            topicModel.printTopWords(new File(PolylingualTopicModel.topicKeysFile.value), PolylingualTopicModel.topWordsOption.value, false);
        }
        if (PolylingualTopicModel.stateFile.value != null) {
            topicModel.printState(new File(PolylingualTopicModel.stateFile.value));
        }
        if (PolylingualTopicModel.docTopicsFile.value != null) {
            PrintWriter out = new PrintWriter(new FileWriter(new File(PolylingualTopicModel.docTopicsFile.value)));
            topicModel.printDocumentTopics(out, PolylingualTopicModel.docTopicsThreshold.value, PolylingualTopicModel.docTopicsMax.value);
            out.close();
        }
        if (PolylingualTopicModel.inferencerFilename.value != null) {
            try {
                for (int language = 0; language < topicModel.numLanguages; ++language) {
                    ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(PolylingualTopicModel.inferencerFilename.value + "." + language));
                    oos.writeObject(topicModel.getInferencer(language));
                    oos.close();
                }
            }
            catch (Exception e) {
                System.err.println(e.getMessage());
            }
        }
        if (PolylingualTopicModel.outputModelFilename.value != null) {
            assert (topicModel != null);
            topicModel.write(new File(PolylingualTopicModel.outputModelFilename.value));
        }
    }

    public class TopicAssignment
    implements Serializable {
        public Instance[] instances;
        public LabelSequence[] topicSequences;
        public Labeling topicDistribution;

        public TopicAssignment(Instance[] instances, LabelSequence[] topicSequences) {
            this.instances = instances;
            this.topicSequences = topicSequences;
        }
    }
}

