/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.test;

import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestFactors
extends TestCase {
    public TestFactors(String name) {
        super(name);
    }

    public void testNormalizeAsCpt() {
        double[] vals = new double[]{1.0, 4.0, 2.0, 6.0};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        TableFactor ptl = new TableFactor(new Variable[]{v1, v2}, vals);
        Factors.normalizeAsCpt(ptl, v1);
        this.comparePotentials(ptl, new double[]{0.3333, 0.4, 0.6666, 0.6});
    }

    public void testSparseNormalizeAsCpt() {
        double[] vals = new double[]{1.0, 4.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        Variable v3 = new Variable(2);
        TableFactor ptl = new TableFactor(new Variable[]{v1, v2, v3}, vals);
        Factors.normalizeAsCpt(ptl, v3);
        this.comparePotentials(ptl, new double[]{0.2, 0.8, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0});
    }

    public void testNormalizeAsCptLogSpace() {
        double[] vals = new double[]{0.0, 1.3862943611198906, 0.6931471805599453, 1.791759469228055};
        Variable v1 = new Variable(2);
        Variable v2 = new Variable(2);
        LogTableFactor ptl = LogTableFactor.makeFromLogValues(new Variable[]{v1, v2}, vals);
        System.out.println(ptl);
        Factors.normalizeAsCpt(ptl, v1);
        System.out.println(ptl);
        this.comparePotentials(ptl, new double[]{0.3333, 0.4, 0.6666, 0.6});
    }

    private void comparePotentials(DiscreteFactor ptl, double[] expected) {
        double[] actual = ptl.toValueArray();
        TestFactors.assertEquals((int)expected.length, (int)actual.length);
        for (int i = 0; i < expected.length; ++i) {
            TestFactors.assertEquals((double)expected[i], (double)actual[i], (double)0.001);
        }
    }

    public void testRetainMass() {
        Variable v = new Variable(4);
        LogTableFactor ptl = LogTableFactor.makeFromValues(v, new double[]{0.75, 0.0, 0.05, 0.2});
        TableFactor actual = Factors.retainMass(ptl, 0.9);
        System.out.println(actual);
    }

    public void testMutualInfo1() {
        TableFactor ptl1 = new TableFactor(new Variable(2), new double[]{0.7, 0.3});
        TableFactor ptl2 = new TableFactor(new Variable(2), new double[]{0.2, 0.8});
        Factor joint = ptl1.multiply(ptl2);
        TestFactors.assertEquals((double)0.0, (double)Factors.mutualInformation(joint), (double)1.0E-5);
    }

    public void testMutualInfo2() {
        Variable[] vars = new Variable[]{new Variable(2), new Variable(2)};
        TableFactor joint = new TableFactor(vars, new double[]{0.3, 0.2, 0.1, 0.4});
        System.out.println(joint.dumpToString());
        TestFactors.assertEquals((double)0.08630462, (double)Factors.mutualInformation(joint), (double)1.0E-5);
    }

    public void testMix() {
        Variable var = new Variable(2);
        TableFactor tf = new TableFactor(var, new double[]{0.3, 0.7});
        LogTableFactor ltf = LogTableFactor.makeFromValues(var, new double[]{0.5, 0.5});
        Factor mix = Factors.mix(tf, ltf, 0.5);
        TableFactor ans = new TableFactor(var, new double[]{0.4, 0.6});
        TestFactors.assertTrue((boolean)ans.almostEquals(mix));
    }

    public void testCorr() {
        Variable var1 = new Variable(2);
        Variable var2 = new Variable(2);
        TableFactor f2 = new TableFactor(new Variable[]{var1, var2}, new double[]{0.3, 0.1, 0.2, 0.4});
        double corr = Factors.corr(f2);
        TestFactors.assertEquals((double)0.1, (double)corr, (double)1.0E-5);
    }

    public void testLogErrorRange() {
        Variable var1 = new Variable(2);
        Variable var2 = new Variable(2);
        TableFactor f1 = new TableFactor(new Variable[]{var1, var2}, new double[]{0.3, 0.1, 0.2, 0.4});
        TableFactor f2 = new TableFactor(new Variable[]{var1, var2}, new double[]{0.2, 0.1, 0.4, 0.3});
        TestFactors.assertEquals((double)Math.log(2.0), (double)Factors.logErrorRange(f1, f2), (double)1.0E-10);
        TestFactors.assertEquals((double)Math.log(2.0), (double)Factors.logErrorRange(f2, f1), (double)1.0E-10);
        TableFactor f3 = new TableFactor(new Variable[]{var1, var2}, new double[]{0.2, 0.4, 0.3, 0.1});
        double exp = Math.log(4.0) - Math.log(1.5);
        TestFactors.assertEquals((double)exp, (double)Factors.logErrorRange(f1, f3), (double)1.0E-10);
        TestFactors.assertEquals((double)exp, (double)Factors.logErrorRange(f3, f1), (double)1.0E-10);
    }

    public static Test suite() {
        return new TestSuite(TestFactors.class);
    }

    public static void main(String[] args) throws Throwable {
        TestSuite theSuite;
        if (args.length > 0) {
            theSuite = new TestSuite();
            for (int i = 0; i < args.length; ++i) {
                theSuite.addTest((Test)new TestFactors(args[i]));
            }
        } else {
            theSuite = (TestSuite)TestFactors.suite();
        }
        TestRunner.run((Test)theSuite);
    }
}

