/*
 * Decompiled with CFR 0.152.
 */
package dragon.ir.classification;

import dragon.ir.classification.AbstractClassifier;
import dragon.ir.classification.DocClass;
import dragon.ir.classification.DocClassSet;
import dragon.ir.classification.featureselection.FeatureSelector;
import dragon.ir.classification.multiclass.AllPairCodeMatrix;
import dragon.ir.classification.multiclass.CodeMatrix;
import dragon.ir.classification.multiclass.MultiClassDecoder;
import dragon.ir.index.IndexReader;
import dragon.matrix.Row;
import dragon.matrix.SparseMatrix;
import dragon.util.MathUtil;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Vector;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

public class LibSVMClassifier
extends AbstractClassifier {
    private svm_parameter param;
    private svm_model model;
    private CodeMatrix codeMatrix;
    private MultiClassDecoder classDecoder;
    private double[] arrProb;
    private double[] arrConfidence;
    private boolean scale;
    private int[] rank;

    public LibSVMClassifier(String modelFile) {
        try {
            ObjectInputStream oin = new ObjectInputStream(new FileInputStream(modelFile));
            this.model = (svm_model)oin.readObject();
            this.param = (svm_parameter)oin.readObject();
            this.codeMatrix = (CodeMatrix)oin.readObject();
            this.classDecoder = (MultiClassDecoder)oin.readObject();
            this.classNum = oin.readInt();
            this.scale = oin.readBoolean();
            this.featureSelector = (FeatureSelector)oin.readObject();
            this.arrLabel = new String[this.classNum];
            int i = 0;
            while (i < this.arrLabel.length) {
                this.arrLabel[i] = (String)oin.readObject();
                ++i;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public LibSVMClassifier(IndexReader indexReader) {
        super(indexReader);
        this.param = this.getDefaultParameter();
        this.codeMatrix = new AllPairCodeMatrix(1);
        this.classDecoder = null;
        svm.showErrorMessage = true;
        svm.showMessage = false;
        this.model = null;
        this.scale = true;
    }

    public LibSVMClassifier(SparseMatrix doctermMatrix) {
        super(doctermMatrix);
        this.param = this.getDefaultParameter();
        this.codeMatrix = new AllPairCodeMatrix(1);
        this.classDecoder = null;
        svm.showErrorMessage = true;
        svm.showMessage = false;
        this.model = null;
        this.scale = true;
    }

    public void setMultiClassDecoder(MultiClassDecoder decoder2) {
        this.classDecoder = decoder2;
    }

    public void setUseProbEstimate(boolean option) {
        this.param.probability = option ? 1 : 0;
    }

    public void setScalingOption(boolean option) {
        this.scale = option;
    }

    public void train(DocClassSet trainingDocSet) {
        if (this.indexReader == null && this.doctermMatrix == null) {
            return;
        }
        this.trainFeatureSelector(trainingDocSet);
        this.arrLabel = new String[trainingDocSet.getClassNum()];
        int i = 0;
        while (i < trainingDocSet.getClassNum()) {
            this.arrLabel[i] = trainingDocSet.getDocClass(i).getClassName();
            ++i;
        }
        this.classNum = trainingDocSet.getClassNum();
        this.codeMatrix.setClassNum(this.classNum);
        svm_problem prob = this.getTrainingProblem(trainingDocSet);
        this.model = svm.svm_train((svm_problem)prob, (svm_parameter)this.param);
    }

    public int classify(Row doc) {
        int label;
        svm_node[] curDoc = this.readDoc(doc);
        if (curDoc == null) {
            return -1;
        }
        if (this.classDecoder == null) {
            if (this.param.probability == 1) {
                if (this.arrProb == null || this.arrProb.length != this.classNum) {
                    this.arrProb = new double[this.classNum];
                }
                label = (int)svm.svm_predict_probability((svm_model)this.model, (svm_node[])curDoc, (double[])this.arrProb);
                this.rank = MathUtil.rankElementInArray(this.arrProb, true);
            } else {
                label = (int)svm.svm_predict((svm_model)this.model, (svm_node[])curDoc);
            }
        } else {
            if (this.arrConfidence == null || this.arrConfidence.length != this.codeMatrix.getClassifierNum()) {
                this.arrConfidence = new double[this.codeMatrix.getClassifierNum()];
            }
            svm.svm_predict_values((svm_model)this.model, (svm_node[])curDoc, (double[])this.arrConfidence);
            label = this.classDecoder.decode(this.codeMatrix, this.arrConfidence);
        }
        return label;
    }

    public int[] rank() {
        if (this.classDecoder == null) {
            return this.rank;
        }
        return this.classDecoder.rank();
    }

    public void saveModel(String modelFile) {
        try {
            if (this.model == null) {
                return;
            }
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(modelFile));
            out.writeObject(this.model);
            out.writeObject(this.param);
            out.writeObject(this.codeMatrix);
            out.writeObject(this.classDecoder);
            out.writeInt(this.classNum);
            out.writeBoolean(this.scale);
            out.writeObject(this.featureSelector);
            int i = 0;
            while (i < this.classNum) {
                out.writeObject(this.getClassLabel(i));
                ++i;
            }
            out.flush();
            out.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private svm_parameter getDefaultParameter() {
        svm_parameter param2 = new svm_parameter();
        param2.svm_type = 0;
        param2.kernel_type = 0;
        param2.degree = 3;
        param2.gamma = 0.0;
        param2.coef0 = 0.0;
        param2.nu = 0.5;
        param2.cache_size = 100.0;
        param2.C = 1.0;
        param2.eps = 0.001;
        param2.p = 0.1;
        param2.shrinking = 1;
        param2.probability = 1;
        param2.nr_weight = 0;
        param2.weight_label = new int[0];
        param2.weight = new double[0];
        return param2;
    }

    private svm_problem getTrainingProblem(DocClassSet trainingDocSet) {
        Vector<svm_node[]> vx = new Vector<svm_node[]>();
        Vector<Integer> vy = new Vector<Integer>();
        int maxIndex = 0;
        int i = 0;
        while (i < trainingDocSet.getClassNum()) {
            DocClass curClass = trainingDocSet.getDocClass(i);
            int j = 0;
            while (j < curClass.getDocNum()) {
                svm_node[] curDoc = this.readDoc(this.getRow(curClass.getDoc(j).getIndex()));
                if (curDoc != null) {
                    vx.addElement(curDoc);
                    vy.addElement(new Integer(curClass.getClassID()));
                    maxIndex = Math.max(maxIndex, curDoc[curDoc.length - 1].index);
                }
                ++j;
            }
            ++i;
        }
        svm_problem prob = new svm_problem();
        prob.l = vy.size();
        prob.x = new svm_node[prob.l][];
        i = 0;
        while (i < prob.l) {
            prob.x[i] = (svm_node[])vx.elementAt(i);
            ++i;
        }
        prob.y = new double[prob.l];
        i = 0;
        while (i < prob.l) {
            prob.y[i] = ((Integer)vy.elementAt(i)).intValue();
            ++i;
        }
        if (this.param.gamma == 0.0) {
            this.param.gamma = 1.0 / (double)maxIndex;
        }
        return prob;
    }

    protected svm_node[] readDoc(Row curDoc) {
        if (curDoc == null) {
            return null;
        }
        int num2 = 0;
        int j = 0;
        while (j < curDoc.getNonZeroNum()) {
            if (this.featureSelector.map(curDoc.getNonZeroColumn(j)) >= 0) {
                ++num2;
            }
            ++j;
        }
        if (num2 == 0) {
            return null;
        }
        svm_node[] arrNode = new svm_node[num2];
        num2 = 0;
        j = 0;
        while (j < curDoc.getNonZeroNum()) {
            int newIndex = this.featureSelector.map(curDoc.getNonZeroColumn(j));
            if (newIndex >= 0) {
                arrNode[num2] = new svm_node();
                arrNode[num2].index = newIndex;
                arrNode[num2].value = curDoc.getNonZeroDoubleScore(j);
                ++num2;
            }
            ++j;
        }
        if (this.scale) {
            double sum2 = 0.0;
            j = 0;
            while (j < num2) {
                sum2 += arrNode[j].value * arrNode[j].value;
                sum2 = Math.sqrt(sum2);
                j = 0;
                while (j < num2) {
                    arrNode[j].value /= sum2;
                    ++j;
                }
                ++j;
            }
        }
        return arrNode;
    }
}

