/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.inference.gbp.BPRegionGenerator;
import cc.mallet.grmm.inference.gbp.FullMessageStrategy;
import cc.mallet.grmm.inference.gbp.Kikuchi4SquareRegionGenerator;
import cc.mallet.grmm.inference.gbp.MessageArray;
import cc.mallet.grmm.inference.gbp.MessageStrategy;
import cc.mallet.grmm.inference.gbp.Region;
import cc.mallet.grmm.inference.gbp.RegionEdge;
import cc.mallet.grmm.inference.gbp.RegionGraph;
import cc.mallet.grmm.inference.gbp.RegionGraphGenerator;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;

public class ParentChildGBP
extends AbstractInferencer {
    private static final Logger logger = MalletLogger.getLogger(ParentChildGBP.class.getName());
    private static final boolean debug = false;
    private RegionGraphGenerator regioner;
    private MessageStrategy sender;
    private boolean useInertia = true;
    private double inertiaWeight = 0.5;
    private static final double THRESHOLD = 0.001;
    private static final int MAX_ITER = 500;
    private MessageArray oldMessages;
    private MessageArray newMessages;
    private RegionGraph rg;
    private FactorGraph mdl;

    private ParentChildGBP() {
    }

    public ParentChildGBP(RegionGraphGenerator regioner) {
        this(regioner, new FullMessageStrategy());
    }

    public ParentChildGBP(RegionGraphGenerator regioner, MessageStrategy sender) {
        this.regioner = regioner;
        this.sender = sender;
    }

    public static ParentChildGBP makeBPInferencer() {
        ParentChildGBP inferencer = new ParentChildGBP();
        inferencer.regioner = new BPRegionGenerator();
        inferencer.sender = new FullMessageStrategy();
        return inferencer;
    }

    public static ParentChildGBP makeKikuchiInferencer() {
        ParentChildGBP inferencer = new ParentChildGBP();
        inferencer.regioner = new Kikuchi4SquareRegionGenerator();
        inferencer.sender = new FullMessageStrategy();
        return inferencer;
    }

    public boolean getUseInertia() {
        return this.useInertia;
    }

    public void setUseInertia(boolean useInertia) {
        this.useInertia = useInertia;
    }

    public double getInertiaWeight() {
        return this.inertiaWeight;
    }

    public void setInertiaWeight(double inertiaWeight) {
        this.inertiaWeight = inertiaWeight;
    }

    @Override
    public Factor lookupMarginal(Variable variable) {
        Region region = this.rg.findContainingRegion(variable);
        if (region == null) {
            throw new IllegalArgumentException("Could not find region containing variable " + variable + " in region graph " + this.rg);
        }
        Factor belief = this.computeBelief(region);
        Factor varBelief = belief.marginalize(variable);
        return varBelief;
    }

    @Override
    public Factor lookupMarginal(VarSet varSet) {
        Region region = this.rg.findContainingRegion(varSet);
        if (region == null) {
            throw new IllegalArgumentException("Could not find region containing clique " + varSet + " in region graph " + this.rg);
        }
        Factor belief = this.computeBelief(region);
        Factor cliqueBelief = belief.marginalize(varSet);
        return cliqueBelief;
    }

    private Factor computeBelief(Region region) {
        return ParentChildGBP.computeBelief(region, this.newMessages);
    }

    static Factor computeBelief(Region region, MessageArray messages) {
        LogTableFactor result2 = new LogTableFactor(region.vars);
        for (Factor factor : region.factors) {
            result2.multiplyBy(factor);
        }
        for (Region parent : region.parents) {
            DiscreteFactor msg = messages.getMessage(parent, region);
            result2.multiplyBy(msg);
        }
        for (Region child : region.descendants) {
            for (Region uncle : child.parents) {
                if (uncle == region || region.descendants.contains(uncle)) continue;
                result2.multiplyBy(messages.getMessage(uncle, child));
            }
        }
        result2.normalize();
        return result2;
    }

    @Override
    public double lookupLogJoint(Assignment assn) {
        double factorProduct = this.mdl.logValue(assn);
        double F = this.computeFreeEnergy(this.rg);
        double value2 = factorProduct + F;
        return value2;
    }

    private double computeFreeEnergy(RegionGraph rg) {
        double avgEnergy = 0.0;
        double entropy = 0.0;
        Iterator it = rg.iterator();
        while (it.hasNext()) {
            Region region = (Region)it.next();
            Factor belief = this.computeBelief(region);
            double thisEntropy = belief.entropy();
            entropy += (double)region.countingNumber * thisEntropy;
            LogTableFactor product2 = new LogTableFactor(belief.varSet());
            for (Factor ptl : region.factors) {
                product2.multiplyBy(ptl);
            }
            double thisAvgEnergy = 0.0;
            AssignmentIterator assnIt = belief.assignmentIterator();
            while (assnIt.hasNext()) {
                Assignment assn = assnIt.assignment();
                double thisEnergy = -product2.logValue(assn);
                double thisBel = belief.value(assn);
                thisAvgEnergy += thisBel * thisEnergy;
                assnIt.advance();
            }
            avgEnergy += (double)region.countingNumber * thisAvgEnergy;
        }
        return avgEnergy - entropy;
    }

    @Override
    public void computeMarginals(FactorGraph mdl) {
        Timing timing = new Timing();
        this.mdl = mdl;
        this.rg = this.regioner.constructRegionGraph(mdl);
        RegionEdge[] pairs = this.chooseMessageSendingOrder();
        this.newMessages = new MessageArray(this.rg);
        timing.tick("GBP Region Graph construction");
        int iter2 = 0;
        do {
            this.oldMessages = this.newMessages;
            this.newMessages = this.oldMessages.duplicate();
            this.sender.setMessageArray(this.oldMessages, this.newMessages);
            for (int i = 0; i < pairs.length; ++i) {
                RegionEdge edge = pairs[i];
                this.sender.sendMessage(edge);
            }
            if (logger.isLoggable(Level.FINER)) {
                timing.tick("GBP iteration " + iter2);
            }
            ++iter2;
            if (!this.useInertia) continue;
            this.newMessages = this.sender.averageMessages(this.rg, this.oldMessages, this.newMessages, this.inertiaWeight);
        } while (!this.hasConverged() && iter2 < 500);
        logger.info("GBP: Used " + iter2 + " iterations.");
        if (iter2 >= 500) {
            logger.warning("***WARNING: GBP not converged!");
        }
    }

    private RegionEdge[] chooseMessageSendingOrder() {
        ArrayList<RegionEdge> l = new ArrayList<RegionEdge>();
        Iterator it = this.rg.edgeIterator();
        while (it.hasNext()) {
            RegionEdge edge = (RegionEdge)it.next();
            l.add(edge);
        }
        Collections.sort(l, new Comparator(){

            public int compare(Object o1, Object o2) {
                RegionEdge e1 = (RegionEdge)o1;
                RegionEdge e2 = (RegionEdge)o2;
                int l1 = e1.to.vars.size();
                int l2 = e2.to.vars.size();
                return Double.compare(l1, l2);
            }
        });
        return l.toArray(new RegionEdge[l.size()]);
    }

    private boolean hasConverged() {
        Iterator it = this.rg.edgeIterator();
        while (it.hasNext()) {
            RegionEdge edge = (RegionEdge)it.next();
            DiscreteFactor oldMsg = this.oldMessages.getMessage(edge.from, edge.to);
            DiscreteFactor newMsg = this.newMessages.getMessage(edge.from, edge.to);
            if (oldMsg == null) {
                assert (newMsg == null);
                continue;
            }
            if (oldMsg.almostEquals(newMsg, 0.001)) continue;
            return false;
        }
        return true;
    }

    @Override
    public void dump() {
        Iterator it = this.rg.edgeIterator();
        while (it.hasNext()) {
            RegionEdge edge = (RegionEdge)it.next();
            DiscreteFactor newMsg = this.newMessages.getMessage(edge.from, edge.to);
            System.out.println("Message: " + edge.from + " --> " + edge.to + " " + newMsg);
        }
    }
}

