/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.cluster;

import cc.mallet.cluster.Clusterer;
import cc.mallet.cluster.Clustering;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Metric;
import cc.mallet.types.SparseVector;
import cc.mallet.util.VectorStats;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

public class KMeans
extends Clusterer {
    private static final long serialVersionUID = 1L;
    static double MEANS_TOLERANCE = 0.01;
    static int MAX_ITER = 100;
    static double POINTS_TOLERANCE = 0.005;
    public static final int EMPTY_ERROR = 0;
    public static final int EMPTY_DROP = 1;
    public static final int EMPTY_SINGLE = 2;
    Random randinator;
    Metric metric;
    int numClusters;
    int emptyAction;
    ArrayList<SparseVector> clusterMeans;
    private static Logger logger = Logger.getLogger("edu.umass.cs.mallet.base.cluster.KMeans");

    public KMeans(Pipe instancePipe, int numClusters, Metric metric, int emptyAction) {
        super(instancePipe);
        this.emptyAction = emptyAction;
        this.metric = metric;
        this.numClusters = numClusters;
        this.clusterMeans = new ArrayList(numClusters);
        this.randinator = new Random();
    }

    public KMeans(Pipe instancePipe, int numClusters, Metric metric) {
        this(instancePipe, numClusters, metric, 0);
    }

    @Override
    public Clustering cluster(InstanceList instances) {
        int c;
        assert (instances.getPipe() == this.instancePipe);
        this.initializeMeansSample(instances, this.metric);
        int[] clusterLabels = new int[instances.size()];
        ArrayList<InstanceList> instanceClusters = new ArrayList<InstanceList>(this.numClusters);
        double deltaMeans = Double.MAX_VALUE;
        double deltaPoints = instances.size();
        int iterations = 0;
        for (c = 0; c < this.numClusters; ++c) {
            instanceClusters.add(c, new InstanceList(this.instancePipe));
        }
        logger.info("Entering KMeans iteration");
        while (deltaMeans > MEANS_TOLERANCE && iterations < MAX_ITER && deltaPoints > (double)instances.size() * POINTS_TOLERANCE) {
            ++iterations;
            deltaPoints = 0.0;
            for (int n = 0; n < instances.size(); ++n) {
                int instClust = 0;
                double instClustDist = Double.MAX_VALUE;
                for (int c2 = 0; c2 < this.numClusters; ++c2) {
                    double instDist = this.metric.distance(this.clusterMeans.get(c2), (SparseVector)((Instance)instances.get(n)).getData());
                    if (!(instDist < instClustDist)) continue;
                    instClust = c2;
                    instClustDist = instDist;
                }
                ((InstanceList)instanceClusters.get(instClust)).add((Instance)instances.get(n));
                if (clusterLabels[n] == instClust) continue;
                clusterLabels[n] = instClust;
                deltaPoints += 1.0;
            }
            deltaMeans = 0.0;
            block9: for (c = 0; c < this.numClusters; ++c) {
                if (((InstanceList)instanceClusters.get(c)).size() > 0) {
                    SparseVector clusterMean = VectorStats.mean((InstanceList)instanceClusters.get(c));
                    deltaMeans += this.metric.distance(this.clusterMeans.get(c), clusterMean);
                    this.clusterMeans.set(c, clusterMean);
                    instanceClusters.set(c, new InstanceList(this.instancePipe));
                    continue;
                }
                logger.info("Empty cluster found.");
                switch (this.emptyAction) {
                    case 0: {
                        return null;
                    }
                    case 1: {
                        logger.fine("Removing cluster " + c);
                        this.clusterMeans.remove(c);
                        instanceClusters.remove(c);
                        for (int n = 0; n < instances.size(); ++n) {
                            assert (clusterLabels[n] != c) : "Cluster size is " + ((InstanceList)instanceClusters.get(c)).size() + "+ yet clusterLabels[n] is " + clusterLabels[n];
                            if (clusterLabels[n] <= c) continue;
                            int n2 = n;
                            clusterLabels[n2] = clusterLabels[n2] - 1;
                        }
                        --this.numClusters;
                        --c;
                        continue block9;
                    }
                    case 2: {
                        double newCentroidDist = 0.0;
                        int newCentroid = 0;
                        ArrayList cacheList = null;
                        for (int clusters = 0; clusters < this.clusterMeans.size(); ++clusters) {
                            SparseVector centroid = this.clusterMeans.get(clusters);
                            InstanceList centInstances = (InstanceList)instanceClusters.get(clusters);
                            if (centInstances.size() <= 1) continue;
                            for (int n = 0; n < centInstances.size(); ++n) {
                                double currentDist = this.metric.distance(centroid, (SparseVector)((Instance)centInstances.get(n)).getData());
                                if (!(currentDist > newCentroidDist)) continue;
                                newCentroid = n;
                                newCentroidDist = currentDist;
                                cacheList = centInstances;
                            }
                        }
                        if (cacheList == null) {
                            logger.info("Can't find an instance to move.  Exiting.");
                            return null;
                        }
                        this.clusterMeans.set(c, (SparseVector)((Instance)cacheList.get(newCentroid)).getData());
                    }
                    default: {
                        return null;
                    }
                }
            }
            logger.info("Iter " + iterations + " deltaMeans = " + deltaMeans);
        }
        if (deltaMeans <= MEANS_TOLERANCE) {
            logger.info("KMeans converged with deltaMeans = " + deltaMeans);
        } else if (iterations >= MAX_ITER) {
            logger.info("Maximum number of iterations (" + MAX_ITER + ") reached.");
        } else if (deltaPoints <= (double)instances.size() * POINTS_TOLERANCE) {
            logger.info("Minimum number of points (np*" + POINTS_TOLERANCE + "=" + (int)((double)instances.size() * POINTS_TOLERANCE) + ") moved in last iteration. Saying converged.");
        }
        return new Clustering(instances, this.numClusters, clusterLabels);
    }

    private void initializeMeansSample(InstanceList instList, Metric metric) {
        int i;
        ArrayList<Instance> instances = new ArrayList<Instance>(instList.size());
        for (i = 0; i < instList.size(); ++i) {
            Instance ins = (Instance)instList.get(i);
            SparseVector sparse = (SparseVector)ins.getData();
            if (sparse.numLocations() == 0) continue;
            instances.add(ins);
        }
        for (i = 0; i < this.numClusters; ++i) {
            double max2 = 0.0;
            int selected = 0;
            for (int k = 0; k < instances.size(); ++k) {
                double min2 = Double.MAX_VALUE;
                Instance ins = (Instance)instances.get(k);
                SparseVector inst = (SparseVector)ins.getData();
                for (int j = 0; j < this.clusterMeans.size(); ++j) {
                    SparseVector centerInst = this.clusterMeans.get(j);
                    double dist = metric.distance(centerInst, inst);
                    if (!(dist < min2)) continue;
                    min2 = dist;
                }
                if (!(min2 > max2)) continue;
                selected = k;
                max2 = min2;
            }
            Instance newCenter = (Instance)instances.remove(selected);
            this.clusterMeans.add((SparseVector)newCenter.getData());
        }
    }

    public ArrayList<SparseVector> getClusterMeans() {
        return this.clusterMeans;
    }
}

