/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.colouring;

import dr.evolution.alignment.Alignment;
import dr.evolution.coalescent.structure.MetaPopulation;
import dr.evolution.colouring.BranchColouring;
import dr.evolution.colouring.ColourChangeMatrix;
import dr.evolution.colouring.ColourSampler;
import dr.evolution.colouring.DefaultBranchColouring;
import dr.evolution.colouring.DefaultTreeColouring;
import dr.evolution.colouring.TreeColouring;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.math.MathUtils;

public class BasicColourSampler
implements ColourSampler {
    static final int maxIterations = 1000;
    private final int colourCount;
    private int[] nodeColours;
    private double[][] nodePartials;
    private final int[] leafColourCounts;

    public BasicColourSampler(Alignment alignment, Tree tree) {
        if (alignment.getSiteCount() != 1) {
            throw new IllegalArgumentException("Tip colour alignment must consist of a single column!");
        }
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = alignment.getDataType().getStateCount();
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = tree.getExternalNode(i);
            this.nodeColours[nodeRef.getNumber()] = n = alignment.getState(alignment.getTaxonIndex(tree.getTaxonId(i)), 0);
            int n2 = n;
            this.leafColourCounts[n2] = this.leafColourCounts[n2] + 1;
        }
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
    }

    public BasicColourSampler(TaxonList[] taxonListArray, Tree tree) {
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = taxonListArray.length + 1;
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
            NodeRef nodeRef = tree.getExternalNode(i);
            int n = 0;
            for (int j = 0; j < taxonListArray.length; ++j) {
                if (taxonListArray[j].getTaxonIndex(tree.getTaxonId(i)) == -1) continue;
                n = j + 1;
            }
            this.nodeColours[nodeRef.getNumber()] = n;
            int n2 = n;
            this.leafColourCounts[n2] = this.leafColourCounts[n2] + 1;
        }
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
    }

    @Override
    public int[] getLeafColourCounts() {
        return this.leafColourCounts;
    }

    @Override
    public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        DefaultTreeColouring defaultTreeColouring = new DefaultTreeColouring(2, tree);
        double[] dArray = metaPopulation.getPopulationSizes(0.0);
        double[] dArray2 = this.prune(tree, tree.getRoot(), colourChangeMatrix, dArray);
        double d = 0.0;
        for (int i = 0; i < dArray2.length; ++i) {
            d += colourChangeMatrix.getEquilibrium(i) * dArray2[i];
        }
        this.sampleInternalNodes(tree, tree.getRoot(), colourChangeMatrix);
        this.sampleBranchColourings(defaultTreeColouring, tree, tree.getRoot(), colourChangeMatrix);
        double d2 = this.calculateLogProbabilityDensity(defaultTreeColouring, tree, tree.getRoot(), colourChangeMatrix, dArray) - Math.log(d);
        defaultTreeColouring.setLogProbabilityDensity(d2);
        return defaultTreeColouring;
    }

    private final int getColour(NodeRef nodeRef) {
        return this.nodeColours[nodeRef.getNumber()];
    }

    private final void setColour(NodeRef nodeRef, int n) {
        if (n < 0 || n >= this.colourCount) {
            throw new IllegalArgumentException("colour value " + n + " + is outside of range of colours, [0, " + Integer.toString(this.colourCount - 1) + "]");
        }
        this.nodeColours[nodeRef.getNumber()] = n;
    }

    private final double[] prune(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, double[] dArray) {
        double[] dArray2 = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArray2[this.getColour((NodeRef)nodeRef)] = 1.0;
            return dArray2;
        }
        NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
        NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
        double[] dArray3 = this.prune(tree, nodeRef2, colourChangeMatrix, dArray);
        double[] dArray4 = this.prune(tree, nodeRef3, colourChangeMatrix, dArray);
        double d = tree.getNodeHeight(nodeRef);
        double d2 = d - tree.getNodeHeight(tree.getChild(nodeRef, 0));
        double d3 = d - tree.getNodeHeight(tree.getChild(nodeRef, 1));
        for (int i = 0; i < dArray2.length; ++i) {
            double d4 = 0.0;
            double d5 = 0.0;
            for (int j = 0; j < dArray3.length; ++j) {
                d4 += dArray3[j] * colourChangeMatrix.forwardTimeEvolution(i, j, d2);
                d5 += dArray4[j] * colourChangeMatrix.forwardTimeEvolution(i, j, d3);
            }
            dArray2[i] = d4 * d5;
        }
        this.nodePartials[nodeRef.getNumber()] = dArray2;
        return dArray2;
    }

    private final void sampleInternalNodes(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        int n;
        int n2;
        double[] dArray;
        double[] dArray2 = this.nodePartials[nodeRef.getNumber()];
        if (tree.isRoot(nodeRef)) {
            dArray = colourChangeMatrix.getEquilibrium();
        } else {
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            n2 = this.getColour(nodeRef2);
            double d = tree.getNodeHeight(nodeRef2) - tree.getNodeHeight(nodeRef);
            dArray = new double[dArray2.length];
            for (int i = 0; i < dArray2.length; ++i) {
                dArray[i] = colourChangeMatrix.forwardTimeEvolution(n2, i, d);
            }
        }
        for (n = 0; n < dArray2.length; ++n) {
            int n3 = n;
            dArray[n3] = dArray[n3] * dArray2[n];
        }
        n = MathUtils.randomChoicePDF(dArray);
        this.setColour(nodeRef, n);
        for (n2 = 0; n2 < tree.getChildCount(nodeRef); ++n2) {
            NodeRef nodeRef3 = tree.getChild(nodeRef, n2);
            if (tree.isExternal(nodeRef3)) continue;
            this.sampleInternalNodes(tree, nodeRef3, colourChangeMatrix);
        }
    }

    private void sampleBranchColourings(DefaultTreeColouring defaultTreeColouring, Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        if (!tree.isRoot(nodeRef)) {
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            int n = this.getColour(nodeRef2);
            int n2 = this.getColour(nodeRef);
            double d = tree.getNodeHeight(nodeRef2);
            double d2 = tree.getNodeHeight(nodeRef);
            DefaultBranchColouring defaultBranchColouring = this.sampleConditionalBranchColouring(n, d, n2, d2, colourChangeMatrix);
            defaultTreeColouring.setBranchColouring(nodeRef, defaultBranchColouring);
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
            this.sampleBranchColourings(defaultTreeColouring, tree, tree.getChild(nodeRef, i), colourChangeMatrix);
        }
    }

    private DefaultBranchColouring sampleConditionalBranchColouring(int n, double d, int n2, double d2, ColourChangeMatrix colourChangeMatrix) {
        double d3;
        double d4;
        double d5;
        double d6;
        int n3;
        DefaultBranchColouring defaultBranchColouring = new DefaultBranchColouring(n, n2);
        int n4 = 1000;
        do {
            defaultBranchColouring.clear();
            n3 = n;
            d6 = d;
            do {
                d5 = -colourChangeMatrix.getForwardRate(n3, n3);
                while ((d4 = MathUtils.nextDouble()) == 0.0) {
                }
                if (n != n2 && defaultBranchColouring.getNumEvents() == 0) {
                    double d7 = Math.exp(-d5 * (d - d2));
                    d4 = d7 + d4 * (1.0 - d7);
                }
                if (!((d6 -= (d3 = -Math.log(d4) / d5)) > d2)) continue;
                n3 = 1 - n3;
                defaultBranchColouring.addEvent(n3, d6);
            } while (d6 > d2);
        } while (n3 != n2 && --n4 > 0);
        if (n3 != n2) {
            d5 = d6 + d3;
            d4 = d2 + 0.01 * (d5 - d2);
            defaultBranchColouring.addEvent(n2, d4);
            System.out.println("dr.evolution.colouring.BranchColourSampler: failed to generate sample after 1000 trials.");
            System.out.println(": parentColour=" + n);
            System.out.println(": parentHeight=" + d);
            System.out.println(": childColour=" + n2);
            System.out.println(": childHeight=" + d2);
            System.out.println(": migration rate 0->1 = " + colourChangeMatrix.getForwardRate(0, 1));
            System.out.println(": migration rate 1->0 = " + colourChangeMatrix.getForwardRate(1, 0));
        }
        return defaultBranchColouring;
    }

    private final double calculateLogProbabilityDensity(TreeColouring treeColouring, Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, double[] dArray) {
        int n;
        double d = 1.0;
        if (tree.isRoot(nodeRef)) {
            d = colourChangeMatrix.getEquilibrium(treeColouring.getNodeColour(nodeRef));
        } else {
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            BranchColouring branchColouring = treeColouring.getBranchColouring(nodeRef);
            n = treeColouring.getNodeColour(nodeRef2);
            double d2 = tree.getNodeHeight(nodeRef2);
            for (int i = 1; i <= branchColouring.getNumEvents(); ++i) {
                int n2 = branchColouring.getForwardColourBelow(i);
                double d3 = branchColouring.getForwardTime(i);
                d *= Math.exp(-(d2 - d3) * -colourChangeMatrix.getForwardRate(n, n));
                d *= colourChangeMatrix.getForwardRate(n, n2);
                d2 = d3;
                n = n2;
            }
            double d4 = tree.getNodeHeight(nodeRef);
            d *= Math.exp(-(d2 - d4) * -colourChangeMatrix.getForwardRate(n, n));
        }
        double d5 = Math.log(d);
        for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
            d5 += this.calculateLogProbabilityDensity(treeColouring, tree, tree.getChild(nodeRef, n), colourChangeMatrix, dArray);
        }
        return d5;
    }

    public static final double calculateLogNormalization(TreeColouring treeColouring, Tree tree, NodeRef nodeRef) {
        double d = 0.0;
        if (!tree.isRoot(nodeRef)) {
            double d2 = 1.0;
            double d3 = tree.getNodeHeight(tree.getParent(nodeRef)) - tree.getNodeHeight(nodeRef);
            int n = treeColouring.getBranchColouring(nodeRef).getNumEvents();
            for (int i = 1; i <= n; ++i) {
                d2 *= d3 / (double)i;
            }
            d = 1.0 * Math.log(d2);
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
            d += BasicColourSampler.calculateLogNormalization(treeColouring, tree, tree.getChild(nodeRef, i));
        }
        return d;
    }

    @Override
    public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        throw new IllegalArgumentException("Not implemented for BasicColourSampler; you can only use <ColouredOperator>s");
    }
}

