/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.BalancedWinnow;
import cc.mallet.classify.Boostable;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import java.io.Serializable;
import java.util.Arrays;

public class BalancedWinnowTrainer
extends ClassifierTrainer<BalancedWinnow>
implements Boostable,
Serializable {
    private static final long serialVersionUID = 1L;
    public static final double DEFAULT_EPSILON = 0.5;
    public static final double DEFAULT_DELTA = 0.1;
    public static final int DEFAULT_MAX_ITERATIONS = 30;
    public static final double DEFAULT_COOLING_RATE = 0.5;
    double m_epsilon;
    double m_delta;
    int m_maxIterations;
    double m_coolingRate;
    double[][] m_weights;
    BalancedWinnow classifier;

    @Override
    public BalancedWinnow getClassifier() {
        return this.classifier;
    }

    public BalancedWinnowTrainer() {
        this(0.5, 0.1, 30, 0.5);
    }

    public BalancedWinnowTrainer(double epsilon, double delta, int maxIterations, double coolingRate) {
        this.m_epsilon = epsilon;
        this.m_delta = delta;
        this.m_maxIterations = maxIterations;
        this.m_coolingRate = coolingRate;
    }

    @Override
    public BalancedWinnow train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        double epsilon = this.m_epsilon;
        Alphabet dict = trainingList.getDataAlphabet();
        int numLabels = trainingList.getTargetAlphabet().size();
        int numFeats = dict.size();
        this.m_weights = new double[numLabels][numFeats + 1];
        for (int i = 0; i < numLabels; ++i) {
            Arrays.fill(this.m_weights[i], 1.0);
        }
        double[] results = new double[numLabels];
        for (int iter2 = 0; iter2 < this.m_maxIterations; ++iter2) {
            for (int ii = 0; ii < trainingList.size(); ++ii) {
                int fi;
                int fvi;
                Instance inst = (Instance)trainingList.get(ii);
                Labeling labeling = inst.getLabeling();
                FeatureVector fv = (FeatureVector)inst.getData();
                int fvisize = fv.numLocations();
                int correctIndex = labeling.getBestIndex();
                Arrays.fill(results, 0.0);
                for (int lpos = 0; lpos < numLabels; ++lpos) {
                    for (int fvi2 = 0; fvi2 < fvisize; ++fvi2) {
                        int fi2 = fv.indexAtLocation(fvi2);
                        double vi = fv.valueAtLocation(fvi2);
                        int n = lpos;
                        results[n] = results[n] + vi * this.m_weights[lpos][fi2];
                    }
                    int n = lpos;
                    results[n] = results[n] + this.m_weights[lpos][numFeats];
                }
                int predictedIndex = 0;
                int secondHighestIndex = 0;
                double max2 = Double.MIN_VALUE;
                double secondMax = Double.MIN_VALUE;
                for (int i = 0; i < numLabels; ++i) {
                    if (results[i] > max2) {
                        secondMax = max2;
                        max2 = results[i];
                        secondHighestIndex = predictedIndex;
                        predictedIndex = i;
                        continue;
                    }
                    if (!(results[i] > secondMax)) continue;
                    secondMax = results[i];
                    secondHighestIndex = i;
                }
                if (predictedIndex != correctIndex) {
                    for (fvi = 0; fvi < fvisize; ++fvi) {
                        fi = fv.indexAtLocation(fvi);
                        double[] dArray = this.m_weights[predictedIndex];
                        int n = fi;
                        dArray[n] = dArray[n] * (1.0 - epsilon);
                        double[] dArray2 = this.m_weights[correctIndex];
                        int n2 = fi;
                        dArray2[n2] = dArray2[n2] * (1.0 + epsilon);
                    }
                    double[] dArray = this.m_weights[predictedIndex];
                    int n = numFeats;
                    dArray[n] = dArray[n] * (1.0 - epsilon);
                    double[] dArray3 = this.m_weights[correctIndex];
                    int n3 = numFeats;
                    dArray3[n3] = dArray3[n3] * (1.0 + epsilon);
                    continue;
                }
                if (!(max2 / secondMax - 1.0 < this.m_delta)) continue;
                for (fvi = 0; fvi < fvisize; ++fvi) {
                    fi = fv.indexAtLocation(fvi);
                    double[] dArray = this.m_weights[secondHighestIndex];
                    int n = fi;
                    dArray[n] = dArray[n] * (1.0 - epsilon);
                    double[] dArray4 = this.m_weights[correctIndex];
                    int n4 = fi;
                    dArray4[n4] = dArray4[n4] * (1.0 + epsilon);
                }
                double[] dArray = this.m_weights[secondHighestIndex];
                int n = numFeats;
                dArray[n] = dArray[n] * (1.0 - epsilon);
                double[] dArray5 = this.m_weights[correctIndex];
                int n5 = numFeats;
                dArray5[n5] = dArray5[n5] * (1.0 + epsilon);
            }
            epsilon *= 1.0 - this.m_coolingRate;
        }
        this.classifier = new BalancedWinnow(trainingList.getPipe(), this.m_weights);
        return this.classifier;
    }
}

