/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.tree;

import dr.evolution.io.Importer;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.Clade;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.AbstractCladeImportanceDistribution;
import dr.evomodel.tree.ConditionalCladeFrequency;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Set;

public class WeightedMultiplicativeBinary
extends AbstractCladeImportanceDistribution {
    private final int TAXA_COUNT;
    private double EPSILON;
    private long samples = 0L;
    private HashMap<BitSet, Clade> cladeProbabilities = new HashMap();
    private TreeTrace[] traces;
    private int burnin;

    public WeightedMultiplicativeBinary(Tree tree, double d) {
        this.EPSILON = d;
        this.TAXA_COUNT = tree.getExternalNodeCount();
    }

    public WeightedMultiplicativeBinary(TreeTrace[] treeTraceArray, double d, int n, boolean bl) {
        this.EPSILON = d;
        this.traces = treeTraceArray;
        int n2 = Integer.MAX_VALUE;
        for (TreeTrace treeTrace : treeTraceArray) {
            if (treeTrace.getMaximumState() >= n2) continue;
            n2 = treeTrace.getMaximumState();
        }
        Tree tree = treeTraceArray[0].getTree(0, n);
        this.TAXA_COUNT = tree.getExternalNodeCount();
        if (n < 0 || n >= n2) {
            this.burnin = n2 / (10 * treeTraceArray[0].getStepSize());
            if (bl) {
                System.out.println("WARNING: Burn-in larger than total number of states - using 10% of smallest trace");
            }
        } else {
            this.burnin = n;
        }
        this.analyzeTrace(bl);
    }

    public void analyzeTrace(boolean bl) {
        if (bl && this.traces.length > 1) {
            System.out.println("Combining " + this.traces.length + " traces.");
        }
        Tree tree = this.getTree(0);
        for (TreeTrace treeTrace : this.traces) {
            int n = treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize());
            double d = (double)n / 60.0;
            int n2 = 1;
            if (bl) {
                System.out.println("Analyzing " + n + " trees...");
                System.out.println("0              25             50             75            100");
                System.out.println("|--------------|--------------|--------------|--------------|");
                System.out.print("*");
            }
            for (int i = 1; i < n; ++i) {
                tree = treeTrace.getTree(i, this.burnin * treeTrace.getStepSize());
                this.addTree(tree);
                if (i < (int)Math.round((double)n2 * d) || n2 > 60) continue;
                if (bl) {
                    System.out.print("*");
                    System.out.flush();
                }
                ++n2;
            }
            if (!bl) continue;
            System.out.println("*");
        }
    }

    public void report(Tree tree) throws IOException {
        System.err.println("making report");
        SimpleTree simpleTree = new SimpleTree(tree);
        System.out.println("Estimated marginal posterior by condiational clade frequencies:");
        System.out.println(this.getTreeProbability(simpleTree));
        System.out.flush();
    }

    public double getTreeProbability(SimpleTree simpleTree) {
        return this.calculateTreeProbabilityLog(simpleTree);
    }

    public double getTreeProbability(SimpleTree simpleTree, HashMap<String, Integer> hashMap) {
        return this.calculateTreeProbabilityLog(simpleTree, hashMap);
    }

    private double calculateTreeProbabilityLog(Tree tree) {
        double d = 0.0;
        double d2 = Math.pow(2.0, tree.getExternalNodeCount() - 1) - 1.0;
        ArrayList<Clade> arrayList = new ArrayList<Clade>();
        ArrayList<Clade> arrayList2 = new ArrayList<Clade>();
        this.getClades(tree, tree.getRoot(), arrayList2, arrayList);
        for (Clade clade : arrayList) {
            double d3 = this.EPSILON;
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                d3 += (double)this.cladeProbabilities.get(clade.getBits()).getSampleCount();
            }
            d += Math.log(d3 / ((double)this.samples + d2 * this.EPSILON));
        }
        return d;
    }

    private double calculateTreeProbabilityLog(Tree tree, HashMap<String, Integer> hashMap) {
        double d = 0.0;
        double d2 = Math.pow(2.0, tree.getExternalNodeCount() - 1) - 1.0;
        ArrayList<Clade> arrayList = new ArrayList<Clade>();
        ArrayList<Clade> arrayList2 = new ArrayList<Clade>();
        this.getClades(tree, tree.getRoot(), arrayList2, arrayList, hashMap);
        for (Clade clade : arrayList) {
            double d3 = this.EPSILON;
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                double d4 = this.getTrees(clade.getBits().cardinality()) * this.getTrees(this.TAXA_COUNT - clade.getBits().cardinality() + 1);
                d3 += (double)this.cladeProbabilities.get(clade.getBits()).getSampleCount() / d4;
            }
            d += Math.log(d3 / ((double)this.samples + d2 * this.EPSILON));
        }
        return d;
    }

    private double getTrees(int n) {
        double d = 1.0;
        for (int i = 3; i <= n; ++i) {
            d *= (double)(2 * i - 3);
        }
        return d;
    }

    private double calculateTreeProbabilityLogRecursive(Tree tree, NodeRef nodeRef) {
        double d = 0.0;
        NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
        NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
        if (tree.isExternal(nodeRef2) && tree.isExternal(nodeRef3)) {
            return 0.0;
        }
        if (!tree.isExternal(nodeRef2) && !tree.isExternal(nodeRef3)) {
            Clade clade = this.getClade(tree, nodeRef2);
            Clade clade2 = this.getClade(tree, nodeRef3);
            double d2 = 0.0;
            d2 = this.cladeProbabilities.containsKey(clade.getBits()) ? (d2 += ((double)this.cladeProbabilities.get(clade.getBits()).getSampleCount() + this.EPSILON) / (double)this.samples) : (d2 += this.EPSILON / (double)this.samples);
            d2 = this.cladeProbabilities.containsKey(clade2.getBits()) ? (d2 += ((double)this.cladeProbabilities.get(clade2.getBits()).getSampleCount() + this.EPSILON) / (double)this.samples) : (d2 += this.EPSILON / (double)this.samples);
            d += Math.log(d2 / 2.0);
            d += this.calculateTreeProbabilityLogRecursive(tree, nodeRef2);
            return d += this.calculateTreeProbabilityLogRecursive(tree, nodeRef3);
        }
        Clade clade = this.getClade(tree, nodeRef2);
        Clade clade3 = this.getClade(tree, nodeRef3);
        double d3 = 0.0;
        if (clade.getSize() > 1) {
            d3 = this.cladeProbabilities.containsKey(clade.getBits()) ? (d3 += ((double)this.cladeProbabilities.get(clade.getBits()).getSampleCount() + this.EPSILON) / (double)this.samples) : (d3 += this.EPSILON / (double)this.samples);
        }
        if (clade3.getSize() > 1) {
            d3 = this.cladeProbabilities.containsKey(clade3.getBits()) ? (d3 += ((double)this.cladeProbabilities.get(clade3.getBits()).getSampleCount() + this.EPSILON) / (double)this.samples) : (d3 += this.EPSILON / (double)this.samples);
        }
        d += Math.log(d3);
        if (!tree.isExternal(nodeRef2)) {
            d += this.calculateTreeProbabilityLogRecursive(tree, nodeRef2);
        }
        if (!tree.isExternal(nodeRef3)) {
            d += this.calculateTreeProbabilityLogRecursive(tree, nodeRef3);
        }
        return d;
    }

    @Override
    public double getChanceForNodeHeights(TreeModel treeModel, Likelihood likelihood) {
        return 0.0;
    }

    @Override
    public double setNodeHeights(TreeModel treeModel, Likelihood likelihood) {
        return 0.0;
    }

    public final Tree getTree(int n) {
        int n2 = 0;
        int n3 = 0;
        for (TreeTrace treeTrace : this.traces) {
            if (n < (n3 += treeTrace.getTreeCount(this.burnin * treeTrace.getStepSize()))) {
                return treeTrace.getTree(n - n2, this.burnin * treeTrace.getStepSize());
            }
            n2 = n3;
        }
        throw new RuntimeException("Couldn't find tree " + n);
    }

    @Override
    public void addTree(Tree tree) {
        ++this.samples;
        ArrayList<Clade> arrayList = new ArrayList<Clade>();
        ArrayList<Clade> arrayList2 = new ArrayList<Clade>();
        this.getClades(tree, tree.getRoot(), arrayList2, arrayList);
        for (Clade clade : arrayList) {
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                Clade clade2 = this.cladeProbabilities.get(clade.getBits());
                clade2.addHeight(clade.getHeight());
                continue;
            }
            clade.addHeight(clade.getHeight());
            this.cladeProbabilities.put(clade.getBits(), clade);
        }
    }

    public void addTree(Tree tree, HashMap<String, Integer> hashMap) {
        ++this.samples;
        ArrayList<Clade> arrayList = new ArrayList<Clade>();
        ArrayList<Clade> arrayList2 = new ArrayList<Clade>();
        this.getClades(tree, tree.getRoot(), arrayList2, arrayList, hashMap);
        for (Clade clade : arrayList) {
            if (this.cladeProbabilities.containsKey(clade.getBits())) {
                Clade clade2 = this.cladeProbabilities.get(clade.getBits());
                clade2.addHeight(clade.getHeight());
                continue;
            }
            clade.addHeight(clade.getHeight());
            this.cladeProbabilities.put(clade.getBits(), clade);
        }
    }

    public static ConditionalCladeFrequency analyzeLogFile(Reader[] readerArray, double d, int n, boolean bl) throws IOException {
        TreeTrace[] treeTraceArray = new TreeTrace[readerArray.length];
        for (int i = 0; i < readerArray.length; ++i) {
            try {
                treeTraceArray[i] = TreeTrace.loadTreeTrace(readerArray[i]);
            }
            catch (Importer.ImportException importException) {
                throw new RuntimeException(importException.toString());
            }
            readerArray[i].close();
        }
        return new ConditionalCladeFrequency(treeTraceArray, d, n, bl);
    }

    @Override
    public double getTreeProbability(Tree tree) {
        return this.calculateTreeProbabilityLogRecursive(tree, tree.getRoot());
    }

    @Override
    public double splitClade(Clade clade, Clade[] cladeArray) {
        Object object;
        double d = Math.pow(2.0, clade.getSize()) - 1.0;
        double d2 = 0.0;
        double d3 = 0.0;
        List<Clade> list = this.getPossibleChildren(clade);
        for (Clade clade2 : list) {
            d3 += (double)clade2.getSampleCount();
        }
        double d4 = Math.random() * (d3 += this.EPSILON * d);
        for (Clade clade3 : list) {
            if (!((d4 -= (double)clade3.getSampleCount() + this.EPSILON) < 0.0)) continue;
            cladeArray[0] = clade3;
            double d5 = ((double)clade3.getSampleCount() + this.EPSILON) / (double)this.samples;
            BitSet bitSet = (BitSet)cladeArray[0].getBits().clone();
            bitSet.xor(clade.getBits());
            if (bitSet.cardinality() > 1) {
                Clade clade4 = this.cladeProbabilities.get(bitSet);
                d5 = clade4 != null ? (d5 += ((double)clade4.getSampleCount() + this.EPSILON) / (double)this.samples) : (d5 += this.EPSILON / (double)this.samples);
                d2 = d5 / 2.0;
                break;
            }
            d2 = d5;
            break;
        }
        if (d4 >= 0.0) {
            Clade clade5;
            do {
                object = (BitSet)clade.getBits().clone();
                int n = -1;
                do {
                    if ((n = ((BitSet)object).nextSetBit(n + 1)) <= -1 || !MathUtils.nextBoolean()) continue;
                    ((BitSet)object).clear(n);
                } while (n > -1);
            } while (((BitSet)object).cardinality() == 0 || ((BitSet)object).cardinality() == clade.getSize() || this.cladeProbabilities.containsKey(object));
            cladeArray[0] = clade5 = new Clade((BitSet)object, 0.5);
            BitSet bitSet = (BitSet)cladeArray[0].getBits().clone();
            bitSet.xor(clade.getBits());
            cladeArray[1] = this.cladeProbabilities.containsKey(bitSet) ? this.cladeProbabilities.get(bitSet) : new Clade(bitSet, 0.5);
            d2 = cladeArray[0].getSize() > 1 && cladeArray[1].getSize() > 1 ? ((double)(cladeArray[0].getSampleCount() + cladeArray[1].getSampleCount()) + 2.0 * this.EPSILON) / ((double)this.samples * 2.0) : (cladeArray[0].getSize() > 1 ? ((double)cladeArray[0].getSampleCount() + this.EPSILON) / (double)this.samples : ((double)cladeArray[1].getSampleCount() + this.EPSILON) / (double)this.samples);
        } else {
            object = (BitSet)cladeArray[0].getBits().clone();
            ((BitSet)object).xor(clade.getBits());
            cladeArray[1] = this.cladeProbabilities.get(object);
            if (cladeArray[1] == null) {
                cladeArray[1] = new Clade((BitSet)object, 0.5);
                cladeArray[1].addHeight(0.5);
            }
        }
        return Math.log(d2);
    }

    private List<Clade> getPossibleChildren(Clade clade) {
        ArrayList<Clade> arrayList = new ArrayList<Clade>();
        Set<BitSet> set = this.cladeProbabilities.keySet();
        for (BitSet bitSet : set) {
            if (bitSet.cardinality() >= clade.getSize() || !this.containsClade(clade.getBits(), bitSet)) continue;
            arrayList.add(this.cladeProbabilities.get(bitSet));
        }
        return arrayList;
    }
}

