/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.MatrixOps;
import gnu.trove.THashSet;
import java.util.Collection;
import java.util.Iterator;

public class Utils {
    public static double lookupMinusLogZ(FactorGraph mdl, Inferencer inf) {
        int[] vals = new int[mdl.numVariables()];
        for (int vi = 0; vi < vals.length; ++vi) {
            Variable var = mdl.getVariable(vi);
            Factor mrg = inf.lookupMarginal(var);
            vals[vi] = mrg.argmax();
        }
        Assignment assn = new Assignment(mdl, vals);
        double prob = inf.lookupLogJoint(assn);
        double energy = mdl.logValue(assn);
        return prob - energy;
    }

    public static double localMagnetization(Inferencer inferencer, Variable var) {
        if (var.getNumOutcomes() != 2) {
            throw new IllegalArgumentException();
        }
        Factor marg = inferencer.lookupMarginal(var);
        AssignmentIterator it = marg.assignmentIterator();
        double v1 = marg.value(it);
        it.advance();
        double v2 = marg.value(it);
        return v1 - v2;
    }

    public static double[] allL1MarginalDistance(FactorGraph mdl, Inferencer inf1, Inferencer inf2) {
        double[] dist = new double[mdl.numVariables()];
        int i = 0;
        Iterator it = mdl.variablesIterator();
        while (it.hasNext()) {
            Variable var = (Variable)it.next();
            Factor bel1 = inf1.lookupMarginal(var);
            Factor bel2 = inf2.lookupMarginal(var);
            dist[i++] = Factors.oneDistance(bel1, bel2);
        }
        return dist;
    }

    public static double avgL1MarginalDistance(FactorGraph mdl, Inferencer inf1, Inferencer inf2) {
        double[] dist = Utils.allL1MarginalDistance(mdl, inf1, inf2);
        return MatrixOps.mean(dist);
    }

    public static double maxL1MarginalDistance(FactorGraph mdl, Inferencer inf1, Inferencer inf2) {
        double[] dist = Utils.allL1MarginalDistance(mdl, inf1, inf2);
        return MatrixOps.max(dist);
    }

    public static int[] toSizesArray(Variable[] vars) {
        int[] szs = new int[vars.length];
        for (int i = 0; i < vars.length; ++i) {
            szs[i] = vars[i].getNumOutcomes();
        }
        return szs;
    }

    public static VarSet defaultIntersection(VarSet v1, VarSet v2) {
        THashSet hset = new THashSet((Collection)v1);
        hset.retainAll((Collection)v2);
        Variable[] ret = new Variable[hset.size()];
        int vai = 0;
        for (int vi = 0; vi < v1.size(); ++vi) {
            Variable var = v1.get(vi);
            if (!hset.contains((Object)var)) continue;
            ret[vai++] = var;
        }
        return new HashVarSet(ret);
    }
}

