/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.AbstractClassifier;
import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.FeatureSelector;
import dragon.ir.classification.multiclass.CodeMatrix;
import dragon.ir.classification.multiclass.HingeLoss;
import dragon.ir.classification.multiclass.LossMultiClassDecoder;
import dragon.ir.classification.multiclass.MultiClassDecoder;
import dragon.ir.classification.multiclass.OVACodeMatrix;
import dragon.ir.index.IndexReader;
import dragon.matrix.Row;
import dragon.matrix.SparseMatrix;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import jnisvmlight.FeatureVector;
import jnisvmlight.KernelParam;
import jnisvmlight.LabeledFeatureVector;
import jnisvmlight.LearnParam;
import jnisvmlight.SVMLightInterface;
import jnisvmlight.SVMLightModel;
import jnisvmlight.TrainingParameters;

public class SVMLightClassifier
extends AbstractClassifier {
    private SVMLightModel[] arrModel;
    private LearnParam learnParam;
    private KernelParam kernelParam;
    private CodeMatrix codeMatrix;
    private MultiClassDecoder classDecoder;
    private double[] arrConfidence;
    private boolean scale;

    public SVMLightClassifier(String modelFile) {
        try {
            ObjectInputStream oin = new ObjectInputStream(new FileInputStream(modelFile));
            this.arrModel = new SVMLightModel[oin.readInt()];
            int i = 0;
            while (i < this.arrModel.length) {
                this.arrModel[i] = (SVMLightModel)oin.readObject();
                ++i;
            }
            this.codeMatrix = (CodeMatrix)oin.readObject();
            this.classDecoder = (MultiClassDecoder)oin.readObject();
            this.classNum = oin.readInt();
            this.scale = oin.readBoolean();
            this.featureSelector = (FeatureSelector)oin.readObject();
            this.arrLabel = new String[this.classNum];
            i = 0;
            while (i < this.arrLabel.length) {
                this.arrLabel[i] = (String)oin.readObject();
                ++i;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public SVMLightClassifier(IndexReader indexReader) {
        super(indexReader);
        this.learnParam = new LearnParam();
        this.kernelParam = new KernelParam();
        this.classDecoder = new LossMultiClassDecoder(new HingeLoss());
        this.codeMatrix = new OVACodeMatrix(1);
        this.classNum = 0;
        this.scale = false;
    }

    public SVMLightClassifier(SparseMatrix doctermMatrix) {
        super(doctermMatrix);
        this.learnParam = new LearnParam();
        this.kernelParam = new KernelParam();
        this.classDecoder = new LossMultiClassDecoder(new HingeLoss());
        this.codeMatrix = new OVACodeMatrix(1);
        this.classNum = 0;
        this.scale = false;
    }

    public void setUseLinearKernel() {
        this.kernelParam.kernel_type = 0L;
    }

    public void setUseRBFKernel() {
        this.kernelParam.kernel_type = 2L;
    }

    public void setUsePolynomialKernel() {
        this.kernelParam.kernel_type = 1L;
    }

    public void setUserSigmoidKernel() {
        this.kernelParam.kernel_type = 3L;
    }

    public void setScalingOption(boolean option) {
        this.scale = option;
    }

    public void setCodeMatrix(CodeMatrix matrix) {
        this.codeMatrix = matrix;
    }

    public void setMultiClassDecoder(MultiClassDecoder decoder2) {
        this.classDecoder = decoder2;
    }

    public int[] rank() {
        return this.classDecoder.rank();
    }

    public void train(DocClassSet trainingDocSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        try {
            this.trainFeatureSelector(trainingDocSet);
            this.arrLabel = new String[trainingDocSet.getClassNum()];
            int i = 0;
            while (i < trainingDocSet.getClassNum()) {
                this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
                ++i;
            }
            this.classNum = trainingDocSet.getClassNum();
            this.codeMatrix.setClassNum(this.classNum);
            ArrayList[] arrClass = new ArrayList[this.classNum];
            TrainingParameters param2 = new TrainingParameters(this.learnParam, this.kernelParam);
            SVMLightInterface svm2 = new SVMLightInterface();
            this.arrModel = new SVMLightModel[this.codeMatrix.getClassifierNum()];
            i = 0;
            while (i < this.classNum) {
                arrClass[i] = this.loadData(trainingDocSet.getDocClass(i));
                ++i;
            }
            i = 0;
            while (i < this.codeMatrix.getClassifierNum()) {
                LabeledFeatureVector[] arrDoc = this.loadData(arrClass, this.codeMatrix, i);
                int posNum = 0;
                int negNum = 0;
                int j = 0;
                while (j < arrDoc.length) {
                    if (arrDoc[j].getLabel() > 0.0) {
                        ++posNum;
                    } else {
                        ++negNum;
                    }
                    ++j;
                }
                param2.getLearningParameters().svm_costratio = 1.0;
                this.arrModel[i] = svm2.trainModel(arrDoc, param2);
                ++i;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public int classify(Row doc) {
        if (this.arrModel == null) {
            return -1;
        }
        LabeledFeatureVector example = this.loadData(doc);
        if (example == null) {
            return -1;
        }
        if (this.arrConfidence == null || this.arrConfidence.length != this.codeMatrix.getClassifierNum()) {
            this.arrConfidence = new double[this.codeMatrix.getClassifierNum()];
        }
        int j = 0;
        while (j < this.codeMatrix.getClassifierNum()) {
            this.arrConfidence[j] = this.arrModel[j].classify((FeatureVector)example);
            ++j;
        }
        return this.classDecoder.decode(this.codeMatrix, this.arrConfidence);
    }

    public double[] getBinaryClassifierConfidence() {
        return this.arrConfidence;
    }

    public void saveModel(String modelFile) {
        try {
            if (this.arrModel == null) {
                return;
            }
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(modelFile));
            out.writeInt(this.arrModel.length);
            int i = 0;
            while (i < this.arrModel.length) {
                this.arrModel[i].removeTrainingData();
                out.writeObject(this.arrModel[i]);
                ++i;
            }
            out.writeObject(this.codeMatrix);
            out.writeObject(this.classDecoder);
            out.writeInt(this.classNum);
            out.writeBoolean(this.scale);
            out.writeObject(this.featureSelector);
            i = 0;
            while (i < this.classNum) {
                out.writeObject(this.getClassLabel(i));
                ++i;
            }
            out.flush();
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private LabeledFeatureVector[] loadData(ArrayList[] arrClass, CodeMatrix matrix, int classifierIndex) {
        int j;
        ArrayList<LabeledFeatureVector> list2 = new ArrayList<LabeledFeatureVector>();
        int i = 0;
        while (i < this.classNum) {
            int label = this.codeMatrix.getCode(i, classifierIndex);
            if (label != 0) {
                j = 0;
                while (j < arrClass[i].size()) {
                    LabeledFeatureVector curDoc = (LabeledFeatureVector)arrClass[i].get(j);
                    curDoc.setLabel((double)label);
                    list2.add(curDoc);
                    ++j;
                }
            }
            ++i;
        }
        LabeledFeatureVector[] all = new LabeledFeatureVector[list2.size()];
        j = 0;
        while (j < list2.size()) {
            all[j] = (LabeledFeatureVector)list2.get(j);
            ++j;
        }
        list2.clear();
        return all;
    }

    private ArrayList loadData(DocClass docs) {
        ArrayList<LabeledFeatureVector> list2 = new ArrayList<LabeledFeatureVector>(docs.getDocNum());
        int i = 0;
        while (i < docs.getDocNum()) {
            LabeledFeatureVector curDoc = this.loadData(this.getRow(docs.getDoc(i).getIndex()));
            if (curDoc != null) {
                list2.add(curDoc);
            }
            ++i;
        }
        return list2;
    }

    protected LabeledFeatureVector loadData(Row doc) {
        if (doc == null) {
            return null;
        }
        int num2 = 0;
        int j = 0;
        while (j < doc.getNonZeroNum()) {
            if (this.featureSelector.map(doc.getNonZeroColumn(j)) >= 0) {
                ++num2;
            }
            ++j;
        }
        if (num2 == 0) {
            return null;
        }
        int[] ids = new int[num2];
        double[] values2 = new double[num2];
        num2 = 0;
        j = 0;
        while (j < doc.getNonZeroNum()) {
            int newIndex = this.featureSelector.map(doc.getNonZeroColumn(j));
            if (newIndex >= 0) {
                ids[num2] = newIndex + 1;
                values2[num2] = doc.getNonZeroDoubleScore(j);
                ++num2;
            }
            ++j;
        }
        if (this.scale) {
            double sum2 = 0.0;
            j = 0;
            while (j < num2) {
                sum2 += values2[j] * values2[j];
                sum2 = Math.sqrt(sum2);
                j = 0;
                while (j < num2) {
                    values2[j] = values2[j] / sum2;
                    ++j;
                }
                ++j;
            }
        }
        return new LabeledFeatureVector(1.0, ids, values2);
    }
}

