/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.arg.coalescent;

import dr.evomodel.arg.ARGModel;
import dr.evomodel.arg.coalescent.ARGCoalescentLikelihood;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

public class ARGUniformPrior
extends ARGCoalescentLikelihood {
    public static final String ARG_UNIFORM_PRIOR = "argUniformPrior";
    public static final String INITIAL_CALCULATIONS = "initialCalculations";
    public static final int INITIAL_DEFAULT = 5;
    public static final double[][] logARGCoalescentCount = new double[][]{{1.0986122886681098, 4.276666119016055, 8.265650165580329, 12.882968485504067, 18.022948777876554, 23.61061585799083, 29.589035299467483, 35.9136593725818, 42.54885778516801, 49.465568765313016, 56.63967062533004, 64.05083469593846, 71.6817017141028, 79.51727811078219, 87.54448452157624}, {2.8903717578961645, 7.049254841255837, 11.72325121817378, 16.888557143580204, 22.489284935076803, 28.475109395164285, 34.804239049061785, 41.44233437710982, 48.360992220084476, 55.53645069653148, 62.948589014628716, 70.58017373916618, 78.41629051730933, 86.44391155847957, 94.65156231759461}, {5.192956850890211, 10.127430784020902, 15.386536070471918, 21.031563309246433, 27.041414773420907, 33.38480151281157, 40.03192406503731, 46.95658863790172, 54.13620464676437, 61.55131458837549, 69.18508096220405, 77.02283620054189, 85.0517120025277, 93.26034012092553, 101.63861177667185}, {7.90100705199242, 13.480736877978641, 19.249258325378765, 25.32054548939157, 31.698755104859107, 38.36733349428991, 45.30602732354885, 52.49523955146093, 59.917151710215464, 67.55588200276209, 75.39734855785784, 83.42905687418532, 91.63988547749693, 100.01989325102055, 108.56015364817871}, {10.945529489715843, 17.07892753271249, 23.29838718386135, 29.753407176990198, 36.467041602426036, 43.43431706668038, 50.64269194407989, 58.077999881992824, 65.7264083635128, 73.57505740705963, 81.61221669682189, 89.8272615722345, 98.21058447998666, 106.75348902699808, 115.44808566836372}, {14.277733999891046, 20.896472983408266, 27.5204956568849, 34.32463786712059, 41.34646380715327, 48.590254195514014, 56.04978852563933, 63.71543385947847, 71.57673212169905, 79.62341044573522, 87.84577046007783, 96.23481579273648, 104.78226604090887, 113.48052233027693, 122.32261419599219}, {17.861252938347157, 24.912242385415205, 31.903176098451674, 39.02769936297709, 46.334764858303096, 53.836163974100344, 61.53092519032464, 69.4132445398445, 77.47555198269822, 85.70981250210274, 94.10810078400579, 102.66285174696894, 111.36695890554311, 120.21379947792728, 129.1972243144967}, {21.667915428117478, 29.108649135506738, 36.43531279311537, 43.855945609774885, 51.42860726365353, 59.171269494486445, 67.08736948530374, 75.17436983012065, 83.4271948608094, 91.83976176721744, 100.40570480277214, 109.11872661806457, 117.97276669219758, 126.96207643311605, 136.08124559689557}, {25.67524861334995, 33.470895149684544, 41.10703544874183, 48.802991754828916, 56.62422391351782, 64.59385162827832, 72.71905241455046, 81.00010003398279, 89.43409011267886, 98.0166551565224, 106.74281045596386, 115.60738869034444, 124.6052669677749, 133.73148419256228, 142.98129861277982}, {29.864903355376374, 37.986383730127116, 45.90958634624165, 53.86285360119818, 61.91774245806648, 70.10170706136994, 78.42512661349667, 86.890712820353, 95.49746436335495, 104.24253036074631, 113.12215615381358, 122.13218720871606, 131.26834550980237, 140.52638287915215, 149.90216524590267}, {34.22161218206597, 42.644274889635966, 50.83517511899286, 59.02998615127507, 67.30534956176908, 75.69240403974487, 84.20429128794189, 92.84585184006139, 101.61776251108103, 110.51852076251443, 119.54547309663305, 128.69537701990959, 137.96471924891867, 147.34989978600277, 156.84733922685047}, {38.732471688582805, 47.43514810013058, 55.87684474986188, 64.2992772236583, 72.78337269607047, 81.36342882144557, 90.05498796388106, 98.86476107289788, 107.79491286633757, 116.84514548111035, 126.01379567653075, 135.29844575074105, 144.69627736969085, 154.20428208422834, 163.8193886161766}, {43.38643203874033, 52.35074398686485, 61.028354989069626, 69.6660218145462, 78.34831802900398, 87.11227051188368, 95.97552137406072, 104.94643415554035, 114.02849935687857, 123.22250024459373, 132.52766774855348, 141.9423328912873, 151.46431112672838, 161.0911356940919, 170.82020228928107}};
    private ArrayList<Double> argNumber;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(ARGModel.class)};

        @Override
        public String getParserDescription() {
            return "A uniform prior for an ARG model";
        }

        @Override
        public Class getReturnType() {
            return ARGUniformPrior.class;
        }

        @Override
        public String getParserName() {
            return ARGUniformPrior.ARG_UNIFORM_PRIOR;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            ARGModel aRGModel = (ARGModel)xMLObject.getChild(ARGModel.class);
            int n = Integer.MAX_VALUE;
            if (xMLObject.hasAttribute("maxReassortments")) {
                n = xMLObject.getIntegerAttribute("maxReassortments");
            }
            int n2 = 5;
            if (xMLObject.hasAttribute(ARGUniformPrior.INITIAL_CALCULATIONS)) {
                n2 = xMLObject.getIntegerAttribute(ARGUniformPrior.INITIAL_CALCULATIONS);
            }
            return new ARGUniformPrior(aRGModel, n, n2);
        }
    };

    public ARGUniformPrior(ARGModel aRGModel, int n, int n2) {
        super(ARG_UNIFORM_PRIOR, aRGModel, n);
        this.addModel(aRGModel);
        this.argNumber = new ArrayList(15);
        if (aRGModel.getExternalNodeCount() - 3 < logARGCoalescentCount.length) {
            Logger.getLogger("dr.evomodel").info("Creating ARGUniformPrior using stored arg counts");
            int n3 = aRGModel.getExternalNodeCount() - 3;
            for (int i = 0; i < logARGCoalescentCount[n3].length; ++i) {
                this.argNumber.add(logARGCoalescentCount[n3][i]);
            }
        } else {
            Logger.getLogger("dr.evomodel").info("Creating ARGUniformPrior by calculating arg counts");
            int n4 = aRGModel.getExternalNodeCount();
            for (int i = 0; i < n2; ++i) {
                this.argNumber.add(ARGUniformPrior.logNumberARGS(n4, i));
            }
        }
    }

    public double getLogARGNumber(int n) {
        if (n >= this.argNumber.size()) {
            this.argNumber.add(ARGUniformPrior.logNumberARGS(this.arg.getExternalNodeCount(), n));
        }
        return this.argNumber.get(n);
    }

    @Override
    public double getLogLikelihood() {
        if (this.likelihoodKnown) {
            return this.logLikelihood;
        }
        this.likelihoodKnown = true;
        this.logLikelihood = this.calculateLogLikelihood();
        this.logLikelihood = this.arg.getReassortmentNodeCount() > this.maxReassortments ? Double.NEGATIVE_INFINITY : this.calculateLogLikelihood();
        if (!this.currentARGValid(true)) {
            this.logLikelihood = Double.NEGATIVE_INFINITY;
        }
        return this.logLikelihood;
    }

    @Override
    public double calculateLogLikelihood() {
        double d = this.arg.getNodeHeight(this.arg.getRoot());
        int n = this.arg.getInternalNodeCount() - 1;
        double d2 = this.logFactorial(n) - (double)n * Math.log(d) - this.getLogARGNumber(this.arg.getReassortmentNodeCount());
        assert (!Double.isInfinite(d2) && !Double.isNaN(d2));
        return d2;
    }

    private double logFactorial(int n) {
        double d = 0.0;
        for (int i = n; i > 0; --i) {
            d += Math.log(i);
        }
        return d;
    }

    private int numberARGS(int n, int n2) {
        int n3 = n;
        int n4 = 2 * n2 + n - 1;
        return this.shurikoRecursion(n3, n4);
    }

    private int shurikoRecursion(int n, int n2) {
        int n3 = 0;
        n3 = n == 0 ? 0 : (n == 1 ? (n2 == 0 ? 1 : 0) : (n2 == 0 ? (n == 1 ? 1 : 0) : (n == n2 + 1 ? n * (n - 1) / 2 * this.shurikoRecursion(n - 1, n2 - 1) : n * this.shurikoRecursion(n + 1, n2 - 1) + n * (n - 1) / 2 * this.shurikoRecursion(n - 1, n2 - 1))));
        return n3;
    }

    public static double logNumberARGS(int n, int n2) {
        int[] nArray;
        int n3;
        Logger.getLogger("dr.evomodel").warning("Calculating ARG count for " + n2 + " reassortments.  This may take awhile");
        if (n2 == 0) {
            double d = 0.0;
            for (int i = n; i > 2; --i) {
                d += Math.log((double)(i * (i - 1)) / 2.0);
            }
            return d;
        }
        int[] nArray2 = new int[n - 3 + n2 * 2];
        int[] nArray3 = new int[nArray2.length];
        for (n3 = 0; n3 < n2; ++n3) {
            nArray2[n3] = 1;
            nArray3[n3] = 1;
        }
        while (n3 < nArray2.length) {
            nArray2[n3] = -1;
            nArray3[n3] = -1;
            ++n3;
        }
        double d = 100.0;
        double d2 = 0.0;
        while (nArray3[0] != -9 && !ARGUniformPrior.stopCombination(nArray3, n)) {
            if (ARGUniformPrior.testCombination(nArray3, n)) {
                d = d2;
                nArray = ARGUniformPrior.generateValues(nArray3, n);
                d = (d2 += ARGUniformPrior.reduceThenDivide(nArray, ARGUniformPrior.generateValues(nArray2, n))) - d;
            }
            ARGUniformPrior.nextCombination(nArray3);
        }
        d2 = Math.log(d2);
        nArray = new int[nArray2.length + 2];
        for (n3 = 0; n3 < nArray2.length; ++n3) {
            nArray[n3] = nArray2[n3];
        }
        nArray[nArray.length - 1] = -1;
        nArray[nArray.length - 2] = -1;
        nArray2 = ARGUniformPrior.generateValues(nArray, n);
        for (int i = 0; i < nArray.length; ++i) {
            d2 += Math.log(nArray2[i]);
        }
        return d2;
    }

    private static double reduceThenDivide(int[] nArray, int[] nArray2) {
        Arrays.sort(nArray);
        Arrays.sort(nArray2);
        double d = 1.0;
        for (int i = 0; i < nArray.length; ++i) {
            d *= (double)nArray[i] / (double)nArray2[i];
        }
        return d;
    }

    private static int[] generateValues(int[] nArray, int n) {
        int[] nArray2 = new int[nArray.length];
        for (int i = 0; i < nArray.length; ++i) {
            nArray2[i] = nArray[i] == 1 ? n : n * (n - 1) / 2;
            n += nArray[i];
        }
        return nArray2;
    }

    private static boolean testCombination(int[] nArray, int n) {
        for (int i = 0; i < nArray.length; ++i) {
            if ((n += nArray[i]) != 1) continue;
            return false;
        }
        return true;
    }

    private static boolean stopCombination(int[] nArray, int n) {
        for (int i = 0; i < nArray.length && nArray[i] == -1; ++i) {
            if (--n != 1) continue;
            return true;
        }
        return false;
    }

    private static void nextCombination(int[] nArray) {
        if (nArray[nArray.length - 1] == -1) {
            for (int i = nArray.length - 1; i > -1; --i) {
                if (nArray[i] != 1) continue;
                nArray[i] = -1;
                nArray[i + 1] = 1;
                return;
            }
        } else {
            int n = 0;
            int n2 = nArray.length - 1;
            while (nArray[n2] == 1) {
                ++n;
                --n2;
            }
            int n3 = -1;
            while (n2 > -1) {
                if (nArray[n2] == 1) {
                    n3 = n2;
                    break;
                }
                --n2;
            }
            if (n3 == -1) {
                nArray[0] = -9;
                return;
            }
            nArray[n3] = -1;
            nArray[n3 + 1] = 1;
            for (n2 = 0; n2 < n; ++n2) {
                nArray[n2 + n3 + 2] = 1;
            }
            for (n2 = n3 + 2 + n; n2 < nArray.length; ++n2) {
                nArray[n2] = -1;
            }
        }
    }
}

