/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.confidence;

import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.confidence.TransducerSequenceConfidenceEstimator;
import cc.mallet.types.Instance;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

public class QBCSequenceConfidenceEstimator
extends TransducerSequenceConfidenceEstimator {
    private static Logger logger = MalletLogger.getLogger(QBCSequenceConfidenceEstimator.class.getName());
    Transducer[] committee;

    public QBCSequenceConfidenceEstimator(Transducer model, Transducer[] committee) {
        super(model);
        this.committee = committee;
    }

    @Override
    public double estimateConfidenceFor(Instance instance, Object[] startTags, Object[] inTags) {
        Sequence[] predictions = new Sequence[this.committee.length];
        for (int i = 0; i < this.committee.length; ++i) {
            predictions[i] = new MaxLatticeDefault(this.committee[i], (Sequence)instance.getData()).bestOutputSequence();
        }
        double avg = this.avgVoteEntropy(predictions);
        return -1.0 * avg;
    }

    private double avgVoteEntropy(Sequence[] predictions) {
        double sum2 = 0.0;
        for (int i = 0; i < predictions[0].size(); ++i) {
            HashMap<String, Integer> label2Count = new HashMap<String, Integer>();
            for (int j = 0; j < predictions.length; ++j) {
                String label = predictions[j].get(i).toString();
                Integer count2 = (Integer)label2Count.get(label);
                if (count2 == null) {
                    count2 = new Integer(0);
                }
                label2Count.put(label, new Integer(count2 + 1));
            }
            sum2 += this.voteEntropy(label2Count);
        }
        return sum2 / (double)predictions[0].size();
    }

    private double voteEntropy(HashMap label2Count) {
        Iterator iter2 = label2Count.keySet().iterator();
        double sum2 = 0.0;
        while (iter2.hasNext()) {
            String label = (String)iter2.next();
            int count2 = (Integer)label2Count.get(label);
            double quot = (double)count2 / (double)this.committee.length;
            sum2 += quot * Math.log(quot);
        }
        double ret = -1.0 * sum2 / Math.log(this.committee.length);
        return ret;
    }
}

