/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.util;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.UndirectedModel;
import cc.mallet.grmm.types.Variable;
import gnu.trove.THashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;

public class Models {
    public static FactorGraph addEvidence(FactorGraph mdl, Assignment assn) {
        return Models.addEvidence(mdl, assn, null);
    }

    public static FactorGraph addEvidence(FactorGraph mdl, Assignment assn, Map toSlicedMap) {
        FactorGraph newMdl = new FactorGraph(mdl.numVariables());
        Models.addSlicedPotentials(mdl, newMdl, assn, toSlicedMap);
        return newMdl;
    }

    public static UndirectedModel addEvidence(UndirectedModel mdl, Assignment assn) {
        UndirectedModel newMdl = new UndirectedModel(mdl.numVariables());
        Models.addSlicedPotentials(mdl, newMdl, assn, null);
        return newMdl;
    }

    private static void addSlicedPotentials(FactorGraph fromMdl, FactorGraph toMdl, Assignment assn, Map toSlicedMap) {
        THashSet inputVars = new THashSet(Arrays.asList(assn.getVars()));
        THashSet remainingVars = new THashSet((Collection)fromMdl.variablesSet());
        remainingVars.removeAll((Collection<?>)inputVars);
        Iterator it = fromMdl.factorsIterator();
        while (it.hasNext()) {
            Factor ptl = (Factor)it.next();
            THashSet theseVars = new THashSet((Collection)ptl.varSet());
            theseVars.retainAll((Collection<?>)remainingVars);
            Factor slicedPtl = ptl.slice(assn);
            toMdl.addFactor(slicedPtl);
            if (toSlicedMap == null) continue;
            toSlicedMap.put(ptl, slicedPtl);
        }
    }

    public static Assignment bestAssignment(FactorGraph mdl, Inferencer inf) {
        inf.computeMarginals(mdl);
        int[] outcomes = new int[mdl.numVariables()];
        for (int i = 0; i < outcomes.length; ++i) {
            int best;
            Variable var = mdl.get(i);
            outcomes[i] = best = inf.lookupMarginal(var).argmax();
        }
        return new Assignment(mdl, outcomes);
    }

    public static double entropy(FactorGraph mdl) {
        JunctionTreeInferencer inf = new JunctionTreeInferencer();
        inf.computeMarginals(mdl);
        JunctionTree jt = inf.lookupJunctionTree();
        return jt.entropy();
    }

    public static double KL(FactorGraph mdl1, FactorGraph mdl2) {
        AssignmentIterator assnIt;
        Factor marg1;
        JunctionTreeInferencer inf1 = new JunctionTreeInferencer();
        inf1.computeMarginals(mdl1);
        JunctionTree jt1 = inf1.lookupJunctionTree();
        JunctionTreeInferencer inf2 = new JunctionTreeInferencer();
        inf2.computeMarginals(mdl2);
        JunctionTree jt2 = inf2.lookupJunctionTree();
        double entropy = jt1.entropy();
        double energy = 0.0;
        for (Factor marg2 : jt2.clusterPotentials()) {
            marg1 = inf1.lookupMarginal(marg2.varSet());
            assnIt = marg2.assignmentIterator();
            while (assnIt.hasNext()) {
                energy += marg1.value(assnIt) * marg2.logValue(assnIt);
                assnIt.advance();
            }
        }
        for (Factor marg2 : jt2.sepsetPotentials()) {
            marg1 = inf1.lookupMarginal(marg2.varSet());
            assnIt = marg2.assignmentIterator();
            while (assnIt.hasNext()) {
                energy -= marg1.value(assnIt) * marg2.logValue(assnIt);
                assnIt.advance();
            }
        }
        return -entropy - energy;
    }

    public static void removeConstantFactors(FactorGraph sliced) {
        ArrayList factors = new ArrayList(sliced.factors());
        for (Factor factor : factors) {
            if (!(factor instanceof ConstantFactor)) continue;
            sliced.divideBy(factor);
        }
    }
}

