/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

class JunctionTreePropagation
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(JunctionTreePropagation.class.getName());
    private transient int totalMessagesSent = 0;
    private MessageStrategy strategy;
    private static final long serialVersionUID = 1L;
    private static final int CUURENT_SERIAL_VERSION = 1;

    public JunctionTreePropagation(MessageStrategy strategy) {
        this.strategy = strategy;
    }

    public static JunctionTreePropagation createSumProductInferencer() {
        return new JunctionTreePropagation(new SumProductMessageStrategy());
    }

    public static JunctionTreePropagation createMaxProductInferencer() {
        return new JunctionTreePropagation(new MaxProductMessageStrategy());
    }

    public int getTotalMessagesSent() {
        return this.totalMessagesSent;
    }

    public void computeMarginals(JunctionTree jt) {
        this.propagate(jt);
        jt.normalizeAll();
    }

    private void collectEvidence(JunctionTree jt, VarSet parent, VarSet child) {
        logger.finer("collectEvidence " + parent + " --> " + child);
        for (VarSet gchild : jt.getChildren(child)) {
            this.collectEvidence(jt, child, gchild);
        }
        if (parent != null) {
            ++this.totalMessagesSent;
            this.strategy.sendMessage(jt, child, parent);
        }
    }

    private void distributeEvidence(JunctionTree jt, VarSet parent) {
        for (VarSet child : jt.getChildren(parent)) {
            ++this.totalMessagesSent;
            this.strategy.sendMessage(jt, parent, child);
            this.distributeEvidence(jt, child);
        }
    }

    private void propagate(JunctionTree jt) {
        VarSet root2 = (VarSet)jt.getRoot();
        this.collectEvidence(jt, null, root2);
        this.distributeEvidence(jt, root2);
    }

    public Factor lookupMarginal(JunctionTree jt, VarSet varSet) {
        if (jt == null) {
            throw new IllegalStateException("Call computeMarginals() first.");
        }
        VarSet parent = jt.findParentCluster(varSet);
        if (parent == null) {
            throw new UnsupportedOperationException("No parent cluster in " + jt + " for clique " + varSet);
        }
        Factor cpf = jt.getCPF(parent);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Lookup jt marginal: clique " + varSet + " cluster " + parent);
            logger.finest("  cpf " + cpf);
        }
        Factor marginal = this.strategy.extractBelief(cpf, varSet);
        marginal.normalize();
        return marginal;
    }

    public Factor lookupMarginal(JunctionTree jt, Variable var) {
        if (jt == null) {
            throw new IllegalStateException("Call computeMarginals() first.");
        }
        VarSet parent = jt.findParentCluster(var);
        Factor cpf = jt.getCPF(parent);
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Lookup jt marginal: var " + var + " cluster " + parent);
            logger.finest(" cpf " + cpf);
        }
        Factor marginal = this.strategy.extractBelief(cpf, new HashVarSet(new Variable[]{var}));
        marginal.normalize();
        return marginal;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        out.writeInt(1);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        in.readInt();
    }

    public static class MaxProductMessageStrategy
    implements MessageStrategy,
    Serializable {
        private static final long serialVersionUID = 1L;
        private static final int CUURENT_SERIAL_VERSION = 1;

        @Override
        public void sendMessage(JunctionTree jt, VarSet from2, VarSet to2) {
            Set sepset = jt.getSepset(from2, to2);
            Factor fromCpf = jt.getCPF(from2);
            Factor toCpf = jt.getCPF(to2);
            Factor oldSepsetPot = jt.getSepsetPot(from2, to2);
            Factor lambda = fromCpf.extractMax(sepset);
            lambda.normalize();
            jt.setSepsetPot(lambda, from2, to2);
            toCpf = toCpf.multiply(lambda);
            toCpf.divideBy(oldSepsetPot);
            toCpf.normalize();
            jt.setCPF(to2, toCpf);
        }

        @Override
        public Factor extractBelief(Factor cpf, VarSet varSet) {
            return cpf.extractMax(varSet);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.defaultWriteObject();
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            in.readInt();
        }
    }

    public static class SumProductMessageStrategy
    implements MessageStrategy,
    Serializable {
        private static final long serialVersionUID = 1L;
        private static final int CUURENT_SERIAL_VERSION = 1;

        @Override
        public void sendMessage(JunctionTree jt, VarSet from2, VarSet to2) {
            Set sepset = jt.getSepset(from2, to2);
            Factor fromCpf = jt.getCPF(from2);
            Factor toCpf = jt.getCPF(to2);
            Factor oldSepsetPot = jt.getSepsetPot(from2, to2);
            Factor lambda = fromCpf.marginalize(sepset);
            lambda.normalize();
            jt.setSepsetPot(lambda, from2, to2);
            toCpf = toCpf.multiply(lambda);
            toCpf.divideBy(oldSepsetPot);
            toCpf.normalize();
            jt.setCPF(to2, toCpf);
        }

        @Override
        public Factor extractBelief(Factor cpf, VarSet varSet) {
            return cpf.marginalize(varSet);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.defaultWriteObject();
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            in.readInt();
        }
    }

    public static interface MessageStrategy {
        public void sendMessage(JunctionTree var1, VarSet var2, VarSet var3);

        public Factor extractBelief(Factor var1, VarSet var2);
    }
}

