/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.recommenders.jayes.inference.jtree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import org.eclipse.recommenders.internal.jayes.util.ArrayUtils;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.factor.AbstractFactor;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.factor.arraywrapper.IArrayWrapper;
import org.eclipse.recommenders.jayes.inference.AbstractInferer;
import org.eclipse.recommenders.jayes.inference.jtree.JunctionTree;
import org.eclipse.recommenders.jayes.inference.jtree.JunctionTreeBuilder;
import org.eclipse.recommenders.jayes.util.Graph;
import org.eclipse.recommenders.jayes.util.MathUtils;
import org.eclipse.recommenders.jayes.util.NumericalInstabilityException;
import org.eclipse.recommenders.jayes.util.OrderIgnoringPair;
import org.eclipse.recommenders.jayes.util.Pair;
import org.eclipse.recommenders.jayes.util.sharing.CanonicalArrayWrapperManager;
import org.eclipse.recommenders.jayes.util.sharing.CanonicalIntArrayManager;
import org.eclipse.recommenders.jayes.util.triangulation.MinFillIn;

public class JunctionTreeAlgorithm
extends AbstractInferer {
    private static final double ONE = 1.0;
    private static final double ONE_LOG = 0.0;
    protected Map<OrderIgnoringPair<Integer>, AbstractFactor> sepSets;
    protected Graph junctionTree;
    protected AbstractFactor[] nodePotentials;
    protected Map<Pair<Integer, Integer>, int[]> preparedMultiplications;
    protected int[][] concernedClusters;
    protected AbstractFactor[] queryFactors;
    protected int[][] preparedQueries;
    protected boolean[] isBeliefValid;
    protected List<Pair<AbstractFactor, IArrayWrapper>> initializations;
    protected int[][] queryFactorReverseMapping;
    protected Set<Integer> clustersHavingEvidence;
    protected boolean[] isObserved;
    protected double[] scratchpad;
    protected JunctionTreeBuilder junctionTreeBuilder = JunctionTreeBuilder.forHeuristic(new MinFillIn());

    public void setJunctionTreeBuilder(JunctionTreeBuilder bldr) {
        this.junctionTreeBuilder = bldr;
    }

    @Override
    public double[] getBeliefs(BayesNode node) {
        int nodeId;
        if (!this.beliefsValid) {
            this.beliefsValid = true;
            this.updateBeliefs();
        }
        if (!this.isBeliefValid[nodeId = node.getId()]) {
            this.isBeliefValid[nodeId] = true;
            if (!this.evidence.containsKey(node)) {
                this.validateBelief(nodeId);
            } else {
                Arrays.fill(this.beliefs[nodeId], 0.0);
                this.beliefs[nodeId][node.getOutcomeIndex((String)((String)this.evidence.get((Object)node)))] = 1.0;
            }
        }
        return super.getBeliefs(node);
    }

    private void validateBelief(int nodeId) {
        AbstractFactor f = this.queryFactors[nodeId];
        f.sumPrepared(new DoubleArrayWrapper(this.beliefs[nodeId]), this.preparedQueries[nodeId]);
        if (f.isLogScale()) {
            MathUtils.exp(this.beliefs[nodeId]);
        }
        try {
            this.beliefs[nodeId] = MathUtils.normalize(this.beliefs[nodeId]);
        }
        catch (IllegalArgumentException exception) {
            throw new NumericalInstabilityException("Numerical instability detected for evidence: " + this.evidence + " and node : " + nodeId + ", consider using logarithmic scale computation (configurable in FactorFactory)", exception);
        }
    }

    @Override
    protected void updateBeliefs() {
        Arrays.fill(this.isBeliefValid, false);
        this.doUpdateBeliefs();
    }

    private void doUpdateBeliefs() {
        this.incorporateAllEvidence();
        int propagationRoot = this.findPropagationRoot();
        this.replayFactorInitializations();
        this.collectEvidence(propagationRoot, this.skipCollection(propagationRoot));
        this.distributeEvidence(propagationRoot, this.skipDistribution(propagationRoot));
    }

    private void replayFactorInitializations() {
        for (Pair<AbstractFactor, IArrayWrapper> init : this.initializations) {
            init.getFirst().copyValues(init.getSecond());
        }
    }

    private void incorporateAllEvidence() {
        for (Pair<AbstractFactor, IArrayWrapper> init : this.initializations) {
            init.getFirst().resetSelections();
        }
        this.clustersHavingEvidence.clear();
        Arrays.fill(this.isObserved, false);
        for (BayesNode n : this.evidence.keySet()) {
            this.incorporateEvidence(n);
        }
    }

    private void incorporateEvidence(BayesNode node) {
        int n = node.getId();
        this.isObserved[n] = true;
        int[] nArray = this.concernedClusters[n];
        int n2 = nArray.length;
        int n3 = 0;
        while (n3 < n2) {
            Integer concernedCluster = nArray[n3];
            this.nodePotentials[concernedCluster].select(n, node.getOutcomeIndex((String)this.evidence.get(node)));
            this.clustersHavingEvidence.add(concernedCluster);
            ++n3;
        }
    }

    private int findPropagationRoot() {
        int propagationRoot = 0;
        for (BayesNode n : this.evidence.keySet()) {
            propagationRoot = this.concernedClusters[n.getId()][0];
        }
        return propagationRoot;
    }

    private Set<Integer> skipCollection(int root) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipCollection(root, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipCollection(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (int neighbor : this.junctionTree.getNeighbors(node)) {
            if (visited.contains(neighbor)) continue;
            this.recursiveSkipCollection(neighbor, visited, skipped);
            if (skipped.contains(neighbor)) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.clustersHavingEvidence.contains(node)) {
            skipped.add(node);
        }
    }

    private Set<Integer> skipDistribution(int distNode) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipDistribution(distNode, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipDistribution(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (Integer neighbor : this.junctionTree.getNeighbors(node)) {
            if (visited.contains(neighbor)) continue;
            this.recursiveSkipDistribution(neighbor, visited, skipped);
            if (skipped.contains(neighbor)) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.isQueryFactorOfUnobservedVariable(node)) {
            skipped.add(node);
        }
    }

    private boolean isQueryFactorOfUnobservedVariable(int node) {
        int[] nArray = this.queryFactorReverseMapping[node];
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int i = nArray[n2];
            if (!this.isObserved[i]) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    private void collectEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (int n : this.junctionTree.getNeighbors(cluster)) {
            if (marked.contains(n)) continue;
            this.collectEvidence(n, marked);
            this.messagePass(n, cluster);
        }
    }

    private void distributeEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (int n : this.junctionTree.getNeighbors(cluster)) {
            if (marked.contains(n)) continue;
            this.messagePass(cluster, n);
            this.distributeEvidence(n, marked);
        }
    }

    private void messagePass(int v1, int v2) {
        OrderIgnoringPair<Integer> sepSetEdge = new OrderIgnoringPair<Integer>(v1, v2);
        AbstractFactor sepSet = this.sepSets.get(sepSetEdge);
        if (!this.needMessagePass(sepSet)) {
            return;
        }
        IArrayWrapper newSepValues = sepSet.getValues();
        System.arraycopy(newSepValues.toDoubleArray(), 0, this.scratchpad, 0, newSepValues.length());
        int[] preparedOp = this.preparedMultiplications.get(Pair.newPair(v2, v1));
        this.nodePotentials[sepSetEdge.getFirst()].sumPrepared(newSepValues, preparedOp);
        if (this.isOnlyFirstLogScale(sepSetEdge)) {
            MathUtils.exp(newSepValues);
        }
        if (this.areBothEndsLogScale(sepSetEdge)) {
            MathUtils.secureSubtract(newSepValues.toDoubleArray(), this.scratchpad, this.scratchpad);
        } else {
            MathUtils.secureDivide(newSepValues.toDoubleArray(), this.scratchpad, this.scratchpad);
        }
        if (this.isOnlySecondLogScale(sepSetEdge)) {
            MathUtils.log(this.scratchpad);
        }
        this.nodePotentials[sepSetEdge.getSecond()].multiplyPrepared(new DoubleArrayWrapper(this.scratchpad), this.preparedMultiplications.get(Pair.newPair(v1, v2)));
    }

    private boolean needMessagePass(AbstractFactor sepSet) {
        int[] nArray = sepSet.getDimensionIDs();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int var = nArray[n2];
            if (!this.isObserved[var]) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    private boolean isOnlyFirstLogScale(OrderIgnoringPair<Integer> edge) {
        return this.nodePotentials[edge.getFirst()].isLogScale() && !this.nodePotentials[edge.getSecond()].isLogScale();
    }

    private boolean isOnlySecondLogScale(OrderIgnoringPair<Integer> edge) {
        return !this.nodePotentials[edge.getFirst()].isLogScale() && this.nodePotentials[edge.getSecond()].isLogScale();
    }

    @Override
    public void setNetwork(BayesNet net) {
        super.setNetwork(net);
        this.initializeFields(net.getNodes().size());
        JunctionTree jtree = this.buildJunctionTree(net);
        Map<AbstractFactor, Integer> homeClusters = this.computeHomeClusters(net, jtree.getClusters());
        this.initializeClusterFactors(net, jtree.getClusters(), homeClusters);
        this.initializeSepsetFactors(jtree.getSepSets());
        this.determineConcernedClusters();
        this.setQueryFactors();
        this.initializePotentialValues();
        this.multiplyCPTsIntoPotentials(net, homeClusters);
        this.prepareMultiplications();
        this.prepareScratch();
        this.invokeInitialBeliefUpdate();
        this.storePotentialValues();
    }

    private void determineConcernedClusters() {
        this.concernedClusters = new int[this.queryFactors.length][];
        List[] temp = new List[this.concernedClusters.length];
        int i = 0;
        while (i < temp.length) {
            temp[i] = new ArrayList();
            ++i;
        }
        i = 0;
        while (i < this.nodePotentials.length) {
            int[] dimensionIDs;
            int[] nArray = dimensionIDs = this.nodePotentials[i].getDimensionIDs();
            int n = dimensionIDs.length;
            int n2 = 0;
            while (n2 < n) {
                int var = nArray[n2];
                temp[var].add(i);
                ++n2;
            }
            ++i;
        }
        i = 0;
        while (i < temp.length) {
            this.concernedClusters[i] = ArrayUtils.toIntArray(temp[i]);
            ++i;
        }
    }

    private void initializeFields(int numNodes) {
        this.isBeliefValid = new boolean[this.beliefs.length];
        Arrays.fill(this.isBeliefValid, false);
        this.queryFactors = new AbstractFactor[numNodes];
        this.preparedQueries = new int[numNodes][];
        this.sepSets = new HashMap<OrderIgnoringPair<Integer>, AbstractFactor>(numNodes);
        this.preparedMultiplications = new HashMap<Pair<Integer, Integer>, int[]>(numNodes);
        this.initializations = new ArrayList<Pair<AbstractFactor, IArrayWrapper>>();
        this.clustersHavingEvidence = new HashSet<Integer>(numNodes);
        this.isObserved = new boolean[numNodes];
    }

    private JunctionTree buildJunctionTree(BayesNet net) {
        JunctionTree jtree = this.junctionTreeBuilder.buildJunctionTree(net);
        this.junctionTree = jtree.getGraph();
        return jtree;
    }

    private Map<AbstractFactor, Integer> computeHomeClusters(BayesNet net, List<List<Integer>> clusters) {
        HashMap<AbstractFactor, Integer> homeClusters = new HashMap<AbstractFactor, Integer>();
        block0: for (BayesNode node : net.getNodes()) {
            int[] nodeAndParents = node.getFactor().getDimensionIDs();
            ListIterator<List<Integer>> clusterIt = clusters.listIterator();
            while (clusterIt.hasNext()) {
                if (!this.containsAll(clusterIt.next(), nodeAndParents)) continue;
                homeClusters.put(node.getFactor(), clusterIt.nextIndex() - 1);
                continue block0;
            }
        }
        return homeClusters;
    }

    private boolean containsAll(List<Integer> list, int[] ints) {
        int[] nArray = ints;
        int n = ints.length;
        int n2 = 0;
        while (n2 < n) {
            int n3 = nArray[n2];
            if (!list.contains(n3)) {
                return false;
            }
            ++n2;
        }
        return true;
    }

    private void initializeClusterFactors(BayesNet net, List<List<Integer>> clusters, Map<AbstractFactor, Integer> homeClusters) {
        this.nodePotentials = new AbstractFactor[clusters.size()];
        Map<Integer, List<AbstractFactor>> multiplicationPartners = this.findMultiplicationPartners(net, homeClusters);
        ListIterator<List<Integer>> cliqueIt = clusters.listIterator();
        while (cliqueIt.hasNext()) {
            AbstractFactor cliqueFactor;
            List<Integer> cluster = cliqueIt.next();
            int current = cliqueIt.nextIndex() - 1;
            List<AbstractFactor> multiplicationPartnerList = multiplicationPartners.get(current);
            this.nodePotentials[current] = cliqueFactor = this.factory.create(cluster, multiplicationPartnerList == null ? Collections.emptyList() : multiplicationPartnerList);
        }
    }

    private Map<Integer, List<AbstractFactor>> findMultiplicationPartners(BayesNet net, Map<AbstractFactor, Integer> homeClusters) {
        HashMap<Integer, List<AbstractFactor>> potentialMap = new HashMap<Integer, List<AbstractFactor>>();
        for (BayesNode node : net.getNodes()) {
            Integer nodeHome = homeClusters.get(node.getFactor());
            if (!potentialMap.containsKey(nodeHome)) {
                potentialMap.put(nodeHome, new ArrayList());
            }
            ((List)potentialMap.get(nodeHome)).add(node.getFactor());
        }
        return potentialMap;
    }

    private void initializeSepsetFactors(List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> sepSets) {
        for (Pair<OrderIgnoringPair<Integer>, List<Integer>> sep : sepSets) {
            this.sepSets.put(sep.getFirst(), this.factory.create(sep.getSecond(), Collections.emptyList()));
        }
    }

    private void setQueryFactors() {
        int n;
        int i = 0;
        while (i < this.queryFactors.length) {
            int[] nArray = this.concernedClusters[i];
            n = nArray.length;
            int n2 = 0;
            while (n2 < n) {
                boolean isFirstOrSmallerTable;
                Integer f = nArray[n2];
                boolean bl = isFirstOrSmallerTable = this.queryFactors[i] == null || this.queryFactors[i].getValues().length() > this.nodePotentials[f].getValues().length();
                if (isFirstOrSmallerTable) {
                    this.queryFactors[i] = this.nodePotentials[f];
                }
                ++n2;
            }
            ++i;
        }
        this.queryFactorReverseMapping = new int[this.nodePotentials.length][];
        i = 0;
        while (i < this.nodePotentials.length) {
            ArrayList<Integer> queryVars = new ArrayList<Integer>();
            int[] nArray = this.nodePotentials[i].getDimensionIDs();
            int n3 = nArray.length;
            n = 0;
            while (n < n3) {
                int var = nArray[n];
                if (this.queryFactors[var] == this.nodePotentials[i]) {
                    queryVars.add(var);
                }
                ++n;
            }
            this.queryFactorReverseMapping[i] = ArrayUtils.toIntArray(queryVars);
            ++i;
        }
    }

    private void prepareMultiplications() {
        CanonicalIntArrayManager flyWeight = new CanonicalIntArrayManager();
        this.prepareSepsetMultiplications(flyWeight);
        this.prepareQueries(flyWeight);
    }

    private void prepareSepsetMultiplications(CanonicalIntArrayManager flyWeight) {
        int node = 0;
        while (node < this.nodePotentials.length) {
            for (int n : this.junctionTree.getNeighbors(node)) {
                int[] preparedMultiplication = this.nodePotentials[n].prepareMultiplication(this.sepSets.get(new OrderIgnoringPair<Integer>(node, n)));
                this.preparedMultiplications.put(Pair.newPair(node, n), flyWeight.getInstance(preparedMultiplication));
            }
            ++node;
        }
    }

    private void prepareQueries(CanonicalIntArrayManager flyWeight) {
        int i = 0;
        while (i < this.queryFactors.length) {
            AbstractFactor beliefFactor = this.factory.create(Arrays.asList(i), Collections.emptyList());
            int[] preparedQuery = this.queryFactors[i].prepareMultiplication(beliefFactor);
            this.preparedQueries[i] = flyWeight.getInstance(preparedQuery);
            ++i;
        }
    }

    private void prepareScratch() {
        int maxSize = 0;
        for (AbstractFactor sepSet : this.sepSets.values()) {
            maxSize = Math.max(maxSize, sepSet.getValues().length());
        }
        this.scratchpad = new double[maxSize];
    }

    private void invokeInitialBeliefUpdate() {
        this.collectEvidence(0, new HashSet<Integer>());
        this.distributeEvidence(0, new HashSet<Integer>());
    }

    private void initializePotentialValues() {
        AbstractFactor[] abstractFactorArray = this.nodePotentials;
        int n = this.nodePotentials.length;
        int n2 = 0;
        while (n2 < n) {
            AbstractFactor f;
            f.fill((f = abstractFactorArray[n2]).isLogScale() ? 0.0 : 1.0);
            ++n2;
        }
        for (Map.Entry<OrderIgnoringPair<Integer>, AbstractFactor> sepSet : this.sepSets.entrySet()) {
            if (!this.areBothEndsLogScale(sepSet.getKey())) {
                sepSet.getValue().fill(1.0);
                continue;
            }
            sepSet.getValue().fill(0.0);
        }
    }

    private void multiplyCPTsIntoPotentials(BayesNet net, Map<AbstractFactor, Integer> homeClusters) {
        for (BayesNode node : net.getNodes()) {
            AbstractFactor nodeHome = this.nodePotentials[homeClusters.get(node.getFactor())];
            if (nodeHome.isLogScale()) {
                nodeHome.multiplyCompatibleToLog(node.getFactor());
                continue;
            }
            nodeHome.multiplyCompatible(node.getFactor());
        }
    }

    private boolean areBothEndsLogScale(OrderIgnoringPair<Integer> edge) {
        return this.nodePotentials[edge.getFirst()].isLogScale() && this.nodePotentials[edge.getSecond()].isLogScale();
    }

    private void storePotentialValues() {
        CanonicalArrayWrapperManager flyweight = new CanonicalArrayWrapperManager();
        AbstractFactor[] abstractFactorArray = this.nodePotentials;
        int n = this.nodePotentials.length;
        int n2 = 0;
        while (n2 < n) {
            AbstractFactor pot = abstractFactorArray[n2];
            this.initializations.add(Pair.newPair(pot, flyweight.getInstance(pot.getValues().clone())));
            ++n2;
        }
        for (AbstractFactor sep : this.sepSets.values()) {
            this.initializations.add(Pair.newPair(sep, flyweight.getInstance(sep.getValues().clone())));
        }
    }
}

