/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.AdaBoost;
import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.util.Random;
import java.util.logging.Logger;

public class AdaBoostTrainer
extends ClassifierTrainer<AdaBoost> {
    private static Logger logger = MalletLogger.getLogger(AdaBoostTrainer.class.getName());
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    ClassifierTrainer weakLearner;
    int numRounds;
    AdaBoost classifier;

    @Override
    public AdaBoost getClassifier() {
        return this.classifier;
    }

    public AdaBoostTrainer(ClassifierTrainer weakLearner, int numRounds) {
        if (!(weakLearner instanceof Boostable)) {
            throw new IllegalArgumentException("weak learner not boostable");
        }
        if (numRounds <= 0) {
            throw new IllegalArgumentException("number of rounds must be positive");
        }
        this.weakLearner = weakLearner;
        this.numRounds = numRounds;
    }

    public AdaBoostTrainer(ClassifierTrainer weakLearner) {
        this(weakLearner, 100);
    }

    @Override
    public AdaBoost train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        Random random = new Random();
        double w = 1.0 / (double)trainingList.size();
        InstanceList trainingInsts = new InstanceList(trainingList.getPipe(), trainingList.size());
        for (int i = 0; i < trainingList.size(); ++i) {
            trainingInsts.add((Instance)trainingList.get(i), w);
        }
        boolean[] correct = new boolean[trainingInsts.size()];
        int numClasses = trainingInsts.getTargetAlphabet().size();
        if (numClasses != 2) {
            logger.info("AdaBoostTrainer.train: WARNING: more than two classes");
        }
        Classifier[] weakLearners = new Classifier[this.numRounds];
        double[] alphas = new double[this.numRounds];
        InstanceList roundTrainingInsts = new InstanceList(trainingInsts.getPipe());
        for (int round = 0; round < this.numRounds; ++round) {
            int i;
            double err2;
            logger.info("===========  AdaBoostTrainer round " + (round + 1) + " begin");
            int resamplingIterations = 0;
            do {
                err2 = 0.0;
                roundTrainingInsts = trainingInsts.sampleWithInstanceWeights(random);
                weakLearners[round] = this.weakLearner.train(roundTrainingInsts);
                for (int i2 = 0; i2 < trainingInsts.size(); ++i2) {
                    Instance inst = (Instance)trainingInsts.get(i2);
                    if (weakLearners[round].classify(inst).bestLabelIsCorrect()) {
                        correct[i2] = true;
                        continue;
                    }
                    correct[i2] = false;
                    err2 += trainingInsts.getInstanceWeight(i2);
                }
            } while (Maths.almostEquals(err2, 0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
            if (Maths.almostEquals(err2, 0.0) || err2 > 0.5) {
                int numClassifiersToUse;
                logger.info("AdaBoostTrainer stopped at " + (round + 1) + " / " + this.numRounds + " rounds: numClasses=" + numClasses + " error=" + err2);
                int n = numClassifiersToUse = round == 0 ? 1 : round;
                if (round == 0) {
                    alphas[0] = 1.0;
                }
                double[] betas = new double[numClassifiersToUse];
                Classifier[] weakClassifiers = new Classifier[numClassifiersToUse];
                System.arraycopy(alphas, 0, betas, 0, numClassifiersToUse);
                System.arraycopy(weakLearners, 0, weakClassifiers, 0, numClassifiersToUse);
                for (int i3 = 0; i3 < betas.length; ++i3) {
                    logger.info("AdaBoostTrainer weight[weakLearner[" + i3 + "]]=" + betas[i3]);
                }
                return new AdaBoost(roundTrainingInsts.getPipe(), weakClassifiers, betas);
            }
            alphas[round] = Math.log((1.0 - err2) / err2);
            double reweightFactor = err2 / (1.0 - err2);
            double sum2 = 0.0;
            for (i = 0; i < trainingInsts.size(); ++i) {
                w = trainingInsts.getInstanceWeight(i);
                if (correct[i]) {
                    w *= reweightFactor;
                }
                trainingInsts.setInstanceWeight(i, w);
                sum2 += w;
            }
            for (i = 0; i < trainingInsts.size(); ++i) {
                trainingInsts.setInstanceWeight(i, trainingInsts.getInstanceWeight(i) / sum2);
            }
            logger.info("===========  AdaBoostTrainer round " + (round + 1) + " finished, weak classifier training error = " + err2);
        }
        for (int i = 0; i < alphas.length; ++i) {
            logger.info("AdaBoostTrainer weight[weakLearner[" + i + "]]=" + alphas[i]);
        }
        this.classifier = new AdaBoost(roundTrainingInsts.getPipe(), weakLearners, alphas);
        return this.classifier;
    }
}

