/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectDoubleHashMap;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;

public class HierarchicalLDA {
    InstanceList instances;
    InstanceList testing;
    NCRPNode rootNode;
    NCRPNode node;
    int numLevels;
    int numDocuments;
    int numTypes;
    double alpha = 10.0;
    double gamma = 1.0;
    double eta = 0.1;
    double etaSum;
    int[][] levels;
    NCRPNode[] documentLeaves;
    int totalNodes = 0;
    String stateFile = "hlda.state";
    Randoms random;
    boolean showProgress = true;
    int displayTopicsInterval = 50;
    int numWordsToDisplay = 10;

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    public void setEta(double eta) {
        this.eta = eta;
    }

    public void setStateFile(String stateFile) {
        this.stateFile = stateFile;
    }

    public void setTopicDisplay(int interval, int words) {
        this.displayTopicsInterval = interval;
        this.numWordsToDisplay = words;
    }

    public void setProgressDisplay(boolean showProgress) {
        this.showProgress = showProgress;
    }

    public void initialize(InstanceList instances, InstanceList testing, int numLevels, Randoms random) {
        this.instances = instances;
        this.testing = testing;
        this.numLevels = numLevels;
        this.random = random;
        if (!(((Instance)instances.get(0)).getData() instanceof FeatureSequence)) {
            throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
        }
        this.numDocuments = instances.size();
        this.numTypes = instances.getDataAlphabet().size();
        this.etaSum = this.eta * (double)this.numTypes;
        NCRPNode[] path = new NCRPNode[numLevels];
        this.rootNode = new NCRPNode(this.numTypes);
        this.levels = new int[this.numDocuments][];
        this.documentLeaves = new NCRPNode[this.numDocuments];
        for (int doc = 0; doc < this.numDocuments; ++doc) {
            FeatureSequence fs = (FeatureSequence)((Instance)instances.get(doc)).getData();
            int seqLen = fs.getLength();
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            for (int level = 1; level < numLevels; ++level) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
            }
            this.node = path[numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            for (int token2 = 0; token2 < seqLen; ++token2) {
                int type = fs.getIndexAtPosition(token2);
                this.levels[doc][token2] = random.nextInt(numLevels);
                this.node = path[this.levels[doc][token2]];
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
            }
        }
    }

    public void estimate(int numIterations) {
        for (int iteration = 1; iteration <= numIterations; ++iteration) {
            int doc;
            for (doc = 0; doc < this.numDocuments; ++doc) {
                this.samplePath(doc, iteration);
            }
            for (doc = 0; doc < this.numDocuments; ++doc) {
                this.sampleTopics(doc);
            }
            if (this.showProgress) {
                System.out.print(".");
                if (iteration % 50 == 0) {
                    System.out.println(" " + iteration);
                }
            }
            if (iteration % this.displayTopicsInterval != 0) continue;
            this.printNodes();
        }
    }

    public void samplePath(int doc, int iteration) {
        int i;
        int level;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        this.documentLeaves[doc].dropPath();
        TObjectDoubleHashMap<NCRPNode> nodeWeights = new TObjectDoubleHashMap<NCRPNode>();
        this.calculateNCRP(nodeWeights, this.rootNode, 0.0);
        TIntIntHashMap[] typeCounts = new TIntIntHashMap[this.numLevels];
        for (level = 0; level < this.numLevels; ++level) {
            typeCounts[level] = new TIntIntHashMap();
        }
        int[] docLevels = this.levels[doc];
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        for (int token2 = 0; token2 < docLevels.length; ++token2) {
            level = docLevels[token2];
            int type = fs.getIndexAtPosition(token2);
            if (!typeCounts[level].containsKey(type)) {
                typeCounts[level].put(type, 1);
            } else {
                typeCounts[level].increment(type);
            }
            int n = type;
            path[level].typeCounts[n] = path[level].typeCounts[n] - 1;
            assert (path[level].typeCounts[type] >= 0);
            --path[level].totalTokens;
            assert (path[level].totalTokens >= 0);
        }
        double[] newTopicWeights = new double[this.numLevels];
        for (level = 1; level < this.numLevels; ++level) {
            int[] types = typeCounts[level].keys();
            int totalTokens = 0;
            for (int t : types) {
                for (i = 0; i < typeCounts[level].get(t); ++i) {
                    int n = level;
                    newTopicWeights[n] = newTopicWeights[n] + Math.log((this.eta + (double)i) / (this.etaSum + (double)totalTokens));
                    ++totalTokens;
                }
            }
        }
        this.calculateWordLikelihood(nodeWeights, this.rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
        NCRPNode[] nodes = nodeWeights.keys((NCRPNode[])new NCRPNode[0]);
        double[] weights = new double[nodes.length];
        double sum2 = 0.0;
        double max2 = Double.NEGATIVE_INFINITY;
        for (i = 0; i < nodes.length; ++i) {
            if (!(nodeWeights.get(nodes[i]) > max2)) continue;
            max2 = nodeWeights.get(nodes[i]);
        }
        for (i = 0; i < nodes.length; ++i) {
            weights[i] = Math.exp(nodeWeights.get(nodes[i]) - max2);
            sum2 += weights[i];
        }
        node = nodes[this.random.nextDiscrete(weights, sum2)];
        if (!node.isLeaf()) {
            node = node.getNewLeaf();
        }
        node.addPath();
        this.documentLeaves[doc] = node;
        for (level = this.numLevels - 1; level >= 0; --level) {
            int[] types;
            int[] arr$ = types = typeCounts[level].keys();
            int len$ = arr$.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                int t;
                int n = t = arr$[i$];
                node.typeCounts[n] = node.typeCounts[n] + typeCounts[level].get(t);
                node.totalTokens += typeCounts[level].get(t);
            }
            node = node.parent;
        }
    }

    public void calculateNCRP(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        for (NCRPNode child : node.children) {
            this.calculateNCRP(nodeWeights, child, weight + Math.log((double)child.customers / ((double)node.customers + this.gamma)));
        }
        nodeWeights.put(node, weight + Math.log(this.gamma / ((double)node.customers + this.gamma)));
    }

    public void calculateWordLikelihood(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight, TIntIntHashMap[] typeCounts, double[] newTopicWeights, int level, int iteration) {
        double nodeWeight = 0.0;
        int[] types = typeCounts[level].keys();
        int totalTokens = 0;
        for (int type : types) {
            for (int i = 0; i < typeCounts[level].get(type); ++i) {
                nodeWeight += Math.log((this.eta + (double)node.typeCounts[type] + (double)i) / (this.etaSum + (double)node.totalTokens + (double)totalTokens));
                ++totalTokens;
            }
        }
        for (NCRPNode child : node.children) {
            this.calculateWordLikelihood(nodeWeights, child, weight + nodeWeight, typeCounts, newTopicWeights, level + 1, iteration);
        }
        ++level;
        while (level < this.numLevels) {
            nodeWeight += newTopicWeights[level];
            ++level;
        }
        nodeWeights.adjustValue(node, nodeWeight);
    }

    public void propagateTopicWeight(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        if (!nodeWeights.containsKey(node)) {
            return;
        }
        for (NCRPNode child : node.children) {
            this.propagateTopicWeight(nodeWeights, child, weight);
        }
        nodeWeights.adjustValue(node, weight);
    }

    public void sampleTopics(int doc) {
        int token2;
        int level;
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        int seqLen = fs.getLength();
        int[] docLevels = this.levels[doc];
        NCRPNode[] path = new NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        double[] levelWeights = new double[this.numLevels];
        for (token2 = 0; token2 < seqLen; ++token2) {
            int n = docLevels[token2];
            levelCounts[n] = levelCounts[n] + 1;
        }
        for (token2 = 0; token2 < seqLen; ++token2) {
            int type = fs.getIndexAtPosition(token2);
            int n = docLevels[token2];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token2]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum2 = 0.0;
            for (level = 0; level < this.numLevels; ++level) {
                levelWeights[level] = (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens);
                sum2 += levelWeights[level];
            }
            docLevels[token2] = level = this.random.nextDiscrete(levelWeights, sum2);
            int n3 = docLevels[token2];
            levelCounts[n3] = levelCounts[n3] + 1;
            node = path[level];
            int n4 = type;
            node.typeCounts[n4] = node.typeCounts[n4] + 1;
            ++node.totalTokens;
        }
    }

    public void printState() throws IOException, FileNotFoundException {
        this.printState(new PrintWriter(new BufferedWriter(new FileWriter(this.stateFile))));
    }

    public void printState(PrintWriter out) throws IOException {
        int doc = 0;
        Alphabet alphabet = this.instances.getDataAlphabet();
        for (Instance instance : this.instances) {
            int level;
            FeatureSequence fs = (FeatureSequence)instance.getData();
            int seqLen = fs.getLength();
            int[] docLevels = this.levels[doc];
            StringBuffer path = new StringBuffer();
            NCRPNode node = this.documentLeaves[doc];
            for (level = this.numLevels - 1; level >= 0; --level) {
                path.append(node.nodeID + " ");
                node = node.parent;
            }
            for (int token2 = 0; token2 < seqLen; ++token2) {
                int type = fs.getIndexAtPosition(token2);
                level = docLevels[token2];
                out.println(path + "" + type + " " + alphabet.lookupObject(type) + " " + level + " ");
            }
            ++doc;
        }
    }

    public void printNodes() {
        this.printNode(this.rootNode, 0);
    }

    public void printNode(NCRPNode node, int indent) {
        StringBuffer out = new StringBuffer();
        for (int i = 0; i < indent; ++i) {
            out.append("  ");
        }
        out.append(node.totalTokens + "/" + node.customers + " ");
        out.append(node.getTopWords(this.numWordsToDisplay));
        System.out.println(out);
        for (NCRPNode child : node.children) {
            this.printNode(child, indent + 1);
        }
    }

    public double empiricalLikelihood(int numSamples, InstanceList testing) {
        int doc;
        int sample;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        path[0] = this.rootNode;
        Dirichlet dirichlet = new Dirichlet(this.numLevels, this.alpha);
        double[] multinomial = new double[this.numTypes];
        double[][] likelihoods = new double[testing.size()][numSamples];
        for (sample = 0; sample < numSamples; ++sample) {
            int type;
            int level;
            Arrays.fill(multinomial, 0.0);
            for (level = 1; level < this.numLevels; ++level) {
                path[level] = path[level - 1].selectExisting();
            }
            double[] levelWeights = dirichlet.nextDistribution();
            for (type = 0; type < this.numTypes; ++type) {
                for (level = 0; level < this.numLevels; ++level) {
                    NCRPNode node = path[level];
                    int n = type;
                    multinomial[n] = multinomial[n] + levelWeights[level] * (this.eta + (double)node.typeCounts[type]) / (this.etaSum + (double)node.totalTokens);
                }
            }
            for (type = 0; type < this.numTypes; ++type) {
                multinomial[type] = Math.log(multinomial[type]);
            }
            for (doc = 0; doc < testing.size(); ++doc) {
                FeatureSequence fs = (FeatureSequence)((Instance)testing.get(doc)).getData();
                int seqLen = fs.getLength();
                for (int token2 = 0; token2 < seqLen; ++token2) {
                    type = fs.getIndexAtPosition(token2);
                    double[] dArray = likelihoods[doc];
                    int n = sample;
                    dArray[n] = dArray[n] + multinomial[type];
                }
            }
        }
        double averageLogLikelihood = 0.0;
        double logNumSamples = Math.log(numSamples);
        for (doc = 0; doc < testing.size(); ++doc) {
            double max2 = Double.NEGATIVE_INFINITY;
            for (sample = 0; sample < numSamples; ++sample) {
                if (!(likelihoods[doc][sample] > max2)) continue;
                max2 = likelihoods[doc][sample];
            }
            double sum2 = 0.0;
            for (sample = 0; sample < numSamples; ++sample) {
                sum2 += Math.exp(likelihoods[doc][sample] - max2);
            }
            averageLogLikelihood += Math.log(sum2) + max2 - logNumSamples;
        }
        return averageLogLikelihood;
    }

    public static void main(String[] args) {
        try {
            InstanceList instances = InstanceList.load(new File(args[0]));
            InstanceList testing = InstanceList.load(new File(args[1]));
            HierarchicalLDA sampler = new HierarchicalLDA();
            sampler.initialize(instances, testing, 5, new Randoms());
            sampler.estimate(250);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    class NCRPNode {
        int customers = 0;
        ArrayList<NCRPNode> children;
        NCRPNode parent;
        int level;
        int totalTokens;
        int[] typeCounts;
        public int nodeID;

        public NCRPNode(NCRPNode parent, int dimensions, int level) {
            this.parent = parent;
            this.children = new ArrayList();
            this.level = level;
            this.totalTokens = 0;
            this.typeCounts = new int[dimensions];
            this.nodeID = HierarchicalLDA.this.totalNodes++;
        }

        public NCRPNode(int dimensions) {
            this(null, dimensions, 0);
        }

        public NCRPNode addChild() {
            NCRPNode node = new NCRPNode(this, this.typeCounts.length, this.level + 1);
            this.children.add(node);
            return node;
        }

        public boolean isLeaf() {
            return this.level == HierarchicalLDA.this.numLevels - 1;
        }

        public NCRPNode getNewLeaf() {
            NCRPNode node = this;
            for (int l = this.level; l < HierarchicalLDA.this.numLevels - 1; ++l) {
                node = node.addChild();
            }
            return node;
        }

        public void dropPath() {
            NCRPNode node = this;
            --node.customers;
            if (node.customers == 0) {
                node.parent.remove(node);
            }
            for (int l = 1; l < HierarchicalLDA.this.numLevels; ++l) {
                node = node.parent;
                --node.customers;
                if (node.customers != 0) continue;
                node.parent.remove(node);
            }
        }

        public void remove(NCRPNode node) {
            this.children.remove(node);
        }

        public void addPath() {
            NCRPNode node = this;
            ++node.customers;
            for (int l = 1; l < HierarchicalLDA.this.numLevels; ++l) {
                node = node.parent;
                ++node.customers;
            }
        }

        public NCRPNode selectExisting() {
            double[] weights = new double[this.children.size()];
            int i = 0;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            return this.children.get(choice);
        }

        public NCRPNode select() {
            double[] weights = new double[this.children.size() + 1];
            weights[0] = HierarchicalLDA.this.gamma / (HierarchicalLDA.this.gamma + (double)this.customers);
            int i = 1;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            if (choice == 0) {
                return this.addChild();
            }
            return this.children.get(choice - 1);
        }

        public String getTopWords(int numWords) {
            Object[] sortedTypes = new IDSorter[HierarchicalLDA.this.numTypes];
            for (int type = 0; type < HierarchicalLDA.this.numTypes; ++type) {
                sortedTypes[type] = new IDSorter(type, this.typeCounts[type]);
            }
            Arrays.sort(sortedTypes);
            Alphabet alphabet = HierarchicalLDA.this.instances.getDataAlphabet();
            StringBuffer out = new StringBuffer();
            for (int i = 0; i < 10; ++i) {
                out.append(alphabet.lookupObject(((IDSorter)sortedTypes[i]).getID()) + " ");
            }
            return out.toString();
        }
    }
}

