/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.confidence;

import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.Segment;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.io.PrintStream;
import java.util.ArrayList;

public class ConfidenceCorrectorEvaluator {
    Object[] startTags;
    Object[] inTags;

    public ConfidenceCorrectorEvaluator(Object[] startTags, Object[] inTags) {
        this.startTags = startTags;
        this.inTags = inTags;
    }

    private boolean containsErrorInUncorrectedSegments(Sequence trueSequence, Sequence predSequence, Sequence correctedSequence, Segment correctedSegment) {
        for (int i = 0; i < trueSequence.size(); ++i) {
            if (correctedSegment.indexInSegment(i)) {
                int j;
                if (correctedSequence.get(i).equals(trueSequence.get(i))) continue;
                System.err.println("\nTruth: ");
                for (j = 0; j < trueSequence.size(); ++j) {
                    System.err.print(trueSequence.get(j) + " ");
                }
                System.err.println("\nPredicted: ");
                for (j = 0; j < trueSequence.size(); ++j) {
                    System.err.print(predSequence.get(j) + " ");
                }
                System.err.println("\nCorrected: ");
                for (j = 0; j < trueSequence.size(); ++j) {
                    System.err.print(correctedSequence.get(j) + " ");
                }
                throw new IllegalStateException("Corrected sequence does not have correct labels for corrected segment: " + correctedSegment);
            }
            if (predSequence.get(i).equals(trueSequence.get(i))) continue;
            return true;
        }
        return false;
    }

    public void evaluate(Transducer model, ArrayList predictions, InstanceList ilist, ArrayList correctedSegments, String description, PrintStream outputStream, boolean errorsInUncorrected) {
        if (predictions.size() != ilist.size() || correctedSegments.size() != ilist.size()) {
            throw new IllegalArgumentException("number of predicted sequence (" + predictions.size() + ") and number of corrected segments (" + correctedSegments.size() + ") must be equal to length of instancelist (" + ilist.size() + ")");
        }
        int numIncorrect2Correct = 0;
        int numCorrect2Incorrect = 0;
        int numPropagatedIncorrect2Correct = 0;
        int numPredictedCorrect = 0;
        int numCorrectedCorrect = 0;
        int numUncorrectedCorrectBeforePropagation = 0;
        int numUncorrectedCorrectAfterPropagation = 0;
        int totalTokens = 0;
        int totalTokensInUncorrectedRegion = 0;
        int numCorrectedSequences = 0;
        for (int i = 0; i < ilist.size(); ++i) {
            Instance instance = (Instance)ilist.get(i);
            Sequence input2 = (Sequence)instance.getData();
            Sequence trueSequence = (Sequence)instance.getTarget();
            Sequence<Object> predSequence = new MaxLatticeDefault(model, input2).bestOutputSequence();
            Sequence correctedSequence = (Sequence)predictions.get(i);
            Segment correctedSegment = (Segment)correctedSegments.get(i);
            if (correctedSegment == null || errorsInUncorrected && !this.containsErrorInUncorrectedSegments(trueSequence, predSequence, correctedSequence, correctedSegment)) continue;
            ++numCorrectedSequences;
            totalTokens += trueSequence.size();
            boolean[] predictedMatches = this.getMatches(trueSequence, predSequence);
            boolean[] correctedMatches = this.getMatches(trueSequence, correctedSequence);
            for (int j = 0; j < predictedMatches.length; ++j) {
                numPredictedCorrect += predictedMatches[j] ? 1 : 0;
                numCorrectedCorrect += correctedMatches[j] ? 1 : 0;
                if (predictedMatches[j] && !correctedMatches[j]) {
                    ++numCorrect2Incorrect;
                } else if (!predictedMatches[j] && correctedMatches[j]) {
                    ++numIncorrect2Correct;
                }
                if (j >= correctedSegment.getStart() && j <= correctedSegment.getEnd()) continue;
                ++totalTokensInUncorrectedRegion;
                if (!predictedMatches[j] && correctedMatches[j]) {
                    ++numPropagatedIncorrect2Correct;
                }
                numUncorrectedCorrectBeforePropagation += predictedMatches[j] ? 1 : 0;
                numUncorrectedCorrectAfterPropagation += correctedMatches[j] ? 1 : 0;
            }
        }
        double tokenAccuracyBeforeCorrection = (double)numPredictedCorrect / (double)totalTokens;
        double tokenAccuracyAfterCorrection = (double)numCorrectedCorrect / (double)totalTokens;
        double uncorrectedRegionAccuracyBeforeCorrection = (double)numUncorrectedCorrectBeforePropagation / (double)totalTokensInUncorrectedRegion;
        double uncorrectedRegionAccuracyAfterCorrection = (double)numUncorrectedCorrectAfterPropagation / (double)totalTokensInUncorrectedRegion;
        outputStream.println(description + "\nEvaluating effect of error-propagation in sequences containing at least one token error:" + "\ntotal number correctedsequences: " + numCorrectedSequences + "\ntotal number tokens: " + totalTokens + "\ntotal number tokens in \"uncorrected region\":" + totalTokensInUncorrectedRegion + "\ntotal number correct tokens before correction:" + numPredictedCorrect + "\ntotal number correct tokens after correction:" + numCorrectedCorrect + "\ntoken accuracy before correction: " + tokenAccuracyBeforeCorrection + "\ntoken accuracy after correction: " + tokenAccuracyAfterCorrection + "\nnumber tokens corrected by propagation: " + numPropagatedIncorrect2Correct + "\nnumber tokens made incorrect by propagation: " + numCorrect2Incorrect + "\ntoken accuracy of \"uncorrected region\" before propagation: " + uncorrectedRegionAccuracyBeforeCorrection + "\ntoken accuracy of \"uncorrected region\" after propagataion: " + uncorrectedRegionAccuracyAfterCorrection);
    }

    private boolean[] getMatches(Sequence s1, Sequence s2) {
        if (s1.size() != s2.size()) {
            throw new IllegalArgumentException("s1.size: " + s1.size() + " s2.size: " + s2.size());
        }
        boolean[] ret = new boolean[s1.size()];
        for (int i = 0; i < s1.size(); ++i) {
            ret[i] = s1.get(i).equals(s2.get(i));
        }
        return ret;
    }
}

