/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.inference.gbp.AbstractMessageStrategy;
import cc.mallet.grmm.inference.gbp.MessageArray;
import cc.mallet.grmm.inference.gbp.RegionEdge;
import cc.mallet.grmm.inference.gbp.RegionGraph;
import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.TableFactor;
import java.util.Iterator;

public class SparseMessageSender
extends AbstractMessageStrategy {
    private double epsilon;

    public SparseMessageSender(double epsilon) {
        this.epsilon = epsilon;
    }

    @Override
    public void sendMessage(RegionEdge edge) {
        TableFactor pruned;
        Factor product = this.msgProduct(edge);
        for (Factor ptl : edge.factorsToSend) {
            product.multiplyBy(ptl);
        }
        TableFactor result = (TableFactor)product.marginalize(edge.to.vars);
        result.normalize();
        if (this.shouldPruneMessage(edge, result)) {
            pruned = Factors.retainMass(result, this.epsilon);
            pruned.normalize();
        } else {
            pruned = result;
        }
        this.newMessages.setMessage(edge.from, edge.to, pruned);
    }

    @Override
    public MessageArray averageMessages(RegionGraph rg, MessageArray a1, MessageArray a2, double inertiaWeight) {
        MessageArray arr = new MessageArray(rg);
        Iterator it = rg.edgeIterator();
        while (it.hasNext()) {
            RegionEdge edge = (RegionEdge)it.next();
            DiscreteFactor msg1 = a1.getMessage(edge.from, edge.to);
            DiscreteFactor msg2 = a2.getMessage(edge.from, edge.to);
            if (msg1 == null) continue;
            TableFactor averaged = (TableFactor)Factors.average(msg1, msg2, inertiaWeight);
            TableFactor pruned = this.shouldPruneMessage(edge, averaged) ? Factors.retainMass(averaged, this.epsilon) : averaged;
            arr.setMessage(edge.from, edge.to, pruned);
        }
        int locs = 0;
        int idxs = 0;
        Iterator it2 = rg.edgeIterator();
        while (it2.hasNext()) {
            RegionEdge edge = (RegionEdge)it2.next();
            DiscreteFactor msg = arr.getMessage(edge.from, edge.to);
            locs += msg.numLocations();
            idxs += new HashVarSet(msg.varSet()).weight();
        }
        System.out.println("Sparsity quotient = " + locs + " of " + idxs);
        return arr;
    }

    private boolean shouldPruneMessage(RegionEdge edge, Factor msg) {
        return edge.to.children.isEmpty();
    }
}

