/*
 * Decompiled with CFR 0.152.
 */
package banner.tagging;

import banner.tagging.FeatureSet;
import banner.tagging.TagFormat;
import banner.tagging.Tagger;
import banner.types.Sentence;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByStochasticGradient;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import dragon.nlp.tool.Lemmatiser;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

public class CRFTaggerStochasticGradient
implements Tagger {
    protected CRF model;
    private FeatureSet featureSet;
    private int order;

    protected CRFTaggerStochasticGradient(CRF model, FeatureSet featureSet, int order) {
        this.model = model;
        this.featureSet = featureSet;
        this.order = order;
    }

    public static CRFTaggerStochasticGradient load(File f2, Lemmatiser lemmatiser, dragon.nlp.tool.Tagger posTagger, Tagger preTagger) throws IOException {
        try {
            ObjectInputStream ois = new ObjectInputStream(new GZIPInputStream(new FileInputStream(f2)));
            CRF model = (CRF)ois.readObject();
            FeatureSet featureSet = (FeatureSet)ois.readObject();
            if (lemmatiser != null) {
                featureSet.setLemmatiser(lemmatiser);
            }
            if (posTagger != null) {
                featureSet.setPosTagger(posTagger);
            }
            if (preTagger != null) {
                featureSet.setPreTagger(preTagger);
            }
            int order = ois.readInt();
            ois.close();
            return new CRFTaggerStochasticGradient(model, featureSet, order);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static CRFTaggerStochasticGradient train(Set<Sentence> sentences, int order, TagFormat format2, FeatureSet featureSet) {
        System.out.println("CRF order: " + order);
        System.out.println("Tag Format: " + format2.toString());
        System.out.println("Size of Target Alphabet: " + featureSet.getPipe().getTargetAlphabet().size());
        if (sentences.size() == 0) {
            throw new RuntimeException("Number of sentences must be greater than zero");
        }
        InstanceList instances = new InstanceList(featureSet.getPipe());
        for (Sentence sentence : sentences) {
            Instance instance = new Instance(sentence, null, sentence.getSentenceId(), sentence);
            instances.addThruPipe(instance);
        }
        System.out.println("Number of Instances: " + instances.size());
        CRF model = new CRF(featureSet.getPipe(), null);
        if (order == 1) {
            model.addStatesForLabelsConnectedAsIn(instances);
        } else if (order == 2) {
            model.addStatesForBiLabelsConnectedAsIn(instances);
        } else {
            throw new IllegalArgumentException("Order must be equal to 1 or 2");
        }
        CRFTrainerByStochasticGradient crfTrainer = new CRFTrainerByStochasticGradient(model, 0.01);
        System.out.println("start training...");
        crfTrainer.train(instances);
        File file = new File("model_print.txt");
        try {
            file.createNewFile();
            PrintWriter writer = new PrintWriter(file);
            model.print(writer);
            writer.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return new CRFTaggerStochasticGradient(model, featureSet, order);
    }

    public void write(File f2) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(f2)));
            oos.writeObject(this.model);
            oos.writeObject(this.featureSet);
            oos.writeInt(this.order);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Exception writing file " + f2 + ": " + e);
        }
    }

    @Override
    public void tag(Sentence sentence) {
        Instance instance = this.getInstance(sentence);
        Sequence tags = this.model.transduce((Sequence)instance.getData());
        sentence.addMentions(CRFTaggerStochasticGradient.getTagList(tags), 1.0);
    }

    protected Instance getInstance(Sentence sentence) {
        InstanceList iList = new InstanceList(this.model.getInputPipe());
        iList.addThruPipe(new Instance(sentence, null, sentence.getSentenceId(), sentence));
        return (Instance)iList.get(0);
    }

    protected static List<String> getTagList(Sequence<Object> tags) {
        int size2 = tags.size();
        ArrayList<String> tags2 = new ArrayList<String>();
        int i = 0;
        while (i < size2) {
            tags2.add(tags.get(i).toString());
            ++i;
        }
        return tags2;
    }

    public int getOrder() {
        return this.order;
    }

    public Set<String> getFeatureNames() {
        HashSet<String> featureNames = new HashSet<String>();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size2 = inputAlphabet.size();
        int i = 0;
        while (i < size2) {
            String featureName = inputAlphabet.lookupObject(i).toString();
            featureNames.add(featureName);
            ++i;
        }
        return Collections.unmodifiableSet(featureNames);
    }

    public List<List<String>> getFeatureRepresentation(Sentence sentence) {
        Instance instance = this.getInstance(sentence);
        Sequence sentenceSequence = (Sequence)instance.getData();
        Alphabet alphabet = this.model.getInputAlphabet();
        ArrayList<List<String>> sentenceFeatureRepresentation = new ArrayList<List<String>>();
        int i = 0;
        while (i < sentenceSequence.size()) {
            ArrayList<String> tokenFeatureRepresentation = new ArrayList<String>();
            FeatureVector tokenFeatures = (FeatureVector)sentenceSequence.get(i);
            int[] featureIndicies = tokenFeatures.getIndices();
            double[] featureValues = tokenFeatures.getValues();
            int j = 0;
            while (j < featureIndicies.length) {
                StringBuilder tokenFeature = new StringBuilder();
                tokenFeature.append(alphabet.lookupObject(featureIndicies[j]).toString());
                if (featureValues != null) {
                    tokenFeature.append("=");
                    tokenFeature.append(featureValues[j]);
                }
                tokenFeatureRepresentation.add(tokenFeature.toString());
                ++j;
            }
            Collections.sort(tokenFeatureRepresentation);
            sentenceFeatureRepresentation.add(tokenFeatureRepresentation);
            ++i;
        }
        return sentenceFeatureRepresentation;
    }

    public void describe(String fileName) throws IOException {
        System.out.println("Number of default weights = " + this.model.getDefaultWeights().length);
        System.out.println("Number of states = " + this.model.numStates());
        int i = 0;
        while (i < this.model.numStates()) {
            Transducer.State state = this.model.getState(i);
            System.out.println("State " + i + " is " + state.getName());
            ++i;
        }
        SparseVector[] weights = this.model.getWeights();
        System.out.println("Size of weights vector = " + weights.length);
        int i2 = 0;
        while (i2 < weights.length) {
            System.out.print("Number of non-zero values for weight vector " + i2);
            System.out.println(" (" + this.model.getWeightsName(i2) + ") is " + weights[i2].numLocations());
            ++i2;
        }
        int size2 = this.model.getInputAlphabet().size();
        System.out.println("Size of input alphabet: " + size2);
        PrintWriter output2 = new PrintWriter(fileName);
        int i3 = 0;
        while (i3 < size2) {
            String featureName = this.model.getInputAlphabet().lookupObject(i3).toString();
            int equalsIndex = featureName.indexOf("=");
            int atIndex = featureName.indexOf("@");
            int featureTypeEnd = featureName.length();
            if (equalsIndex != -1 && equalsIndex < featureTypeEnd) {
                featureTypeEnd = equalsIndex;
            }
            if (atIndex != -1 && atIndex < featureTypeEnd) {
                featureTypeEnd = atIndex;
            }
            String featureType = featureName.substring(0, featureTypeEnd);
            String featureOffset = "0";
            int featureDataEnd = featureName.length();
            if (atIndex != -1) {
                featureDataEnd = atIndex;
                featureOffset = featureName.substring(atIndex + 1, featureName.length());
            }
            String featureData = "";
            if (featureDataEnd > featureTypeEnd) {
                featureData = featureName.substring(featureTypeEnd + 1, featureDataEnd);
            }
            featureData = featureData.replaceAll("^\"", "\\\"");
            double maxWeight = Double.NEGATIVE_INFINITY;
            int j = 0;
            while (j < weights.length) {
                if (!this.model.getWeightsName(j).endsWith("O:O") && maxWeight < weights[j].value(i3)) {
                    maxWeight = weights[j].value(i3);
                }
                ++j;
            }
            output2.print(String.valueOf(i3) + "\t");
            output2.print(String.valueOf(featureName) + "\t");
            output2.print(String.valueOf(featureType) + "\t");
            output2.print(String.valueOf(featureOffset) + "\t");
            output2.print(String.valueOf(featureData) + "\t");
            output2.print(String.valueOf(maxWeight) + "\t");
            output2.println();
            ++i3;
        }
        output2.close();
    }

    public Map<String, Double> getMaxWeights() {
        HashMap<String, Double> weightMap = new HashMap<String, Double>();
        SparseVector[] weights = this.model.getWeights();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size2 = inputAlphabet.size();
        int i = 0;
        while (i < size2) {
            double max2 = Double.MIN_VALUE;
            int j = 0;
            while (j < weights.length) {
                double weight = weights[j].value(i);
                if (max2 < weight) {
                    max2 = weight;
                }
                ++j;
            }
            String featureName = inputAlphabet.lookupObject(i).toString();
            weightMap.put(featureName, max2);
            ++i;
        }
        return weightMap;
    }

    public Map<String, Double> getMinWeights() {
        HashMap<String, Double> weightMap = new HashMap<String, Double>();
        SparseVector[] weights = this.model.getWeights();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size2 = inputAlphabet.size();
        int i = 0;
        while (i < size2) {
            double min2 = Double.MAX_VALUE;
            int j = 0;
            while (j < weights.length) {
                double weight = weights[j].value(i);
                if (min2 > weight) {
                    min2 = weight;
                }
                ++j;
            }
            String featureName = inputAlphabet.lookupObject(i).toString();
            weightMap.put(featureName, min2);
            ++i;
        }
        return weightMap;
    }
}

