/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.basta;

import dr.evolution.alignment.PatternList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evolution.util.Units;
import dr.evomodel.bigfasttree.BestSignalsFromBigFastTreeIntervals;
import dr.evomodel.bigfasttree.IntervalChangedEvent;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.substmodel.GeneralSubstitutionModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;

public class FasterStructuredCoalescentLikelihood
extends AbstractModelLikelihood
implements Units,
Citable {
    private static final boolean DEBUG = false;
    private static final boolean MATRIX_DEBUG = false;
    private static final boolean UPDATE_DEBUG = false;
    private static final boolean ASSOC_MULTIPLICATION = true;
    private static final boolean USE_TRANSPOSE = false;
    private static final boolean MINIMUM_EVALUATION = false;
    final double[] temp;
    public static Citation CITATION = new Citation(new Author[]{new Author("Nicola", "De Maio"), new Author("Chieh-Hsi", "Wu"), new Author("Kathleen", "O'Reilly"), new Author("Daniel", "Wilson")}, "New routes to phylogeography: a Bayesian structured coalescent approximation", 2015, "PLOS Genetics", 11, "e1005421", "10.1371/journal.pgen.1005421");
    protected double logLikelihood;
    protected double storedLogLikelihood;
    protected boolean likelihoodKnown = false;
    protected boolean storedLikelihoodKnown = false;
    private TreeModel treeModel;
    private BranchRateModel branchRateModel;
    private Parameter popSizes;
    private PatternList patternList;
    private final DataType dataType;
    private boolean useMAP;
    private int[][] reconstructedStates;
    private int[][] storedReconstructedStates;
    protected boolean areStatesRedrawn = false;
    protected boolean storedAreStatesRedrawn = false;
    private BestSignalsFromBigFastTreeIntervals intervals;
    private double[][] intervalStartProbs;
    private double[][] intervalEndProbs;
    private double[][] intervalStartSquareProbs;
    private double[][] intervalEndSquareProbs;
    private double[][] coalescentLeftProbs;
    private double[][] coalescentRightProbs;
    private double[] activeLineages;
    private HashSet<Integer> activeNodeNumbers;
    private GeneralSubstitutionModel generalSubstitutionModel;
    private int demes;
    private final int intervalCount;
    private double[][] migrationMatrices;
    private double[][] storedMigrationMatrices;
    private boolean[] matricesKnown;
    private boolean[] storedMatricesKnown;

    public FasterStructuredCoalescentLikelihood(Tree tree, BranchRateModel branchRateModel, Parameter parameter, PatternList patternList, DataType dataType, String string, GeneralSubstitutionModel generalSubstitutionModel, int n, TaxonList taxonList, List<TaxonList> list, boolean bl) throws TreeUtils.MissingTaxonException {
        super("structuredCoalescent");
        this.treeModel = (TreeModel)tree;
        this.patternList = patternList;
        this.dataType = dataType;
        this.useMAP = bl;
        if (!(tree instanceof TreeModel)) {
            throw new IllegalArgumentException("Please provide a TreeModel for the structured coalescent model.");
        }
        this.intervals = new BestSignalsFromBigFastTreeIntervals((TreeModel)tree);
        this.addModel(this.intervals);
        this.popSizes = parameter;
        this.addVariable(this.popSizes);
        this.branchRateModel = branchRateModel != null ? branchRateModel : new DefaultBranchRateModel();
        this.addModel(this.branchRateModel);
        this.generalSubstitutionModel = generalSubstitutionModel;
        this.addModel(this.generalSubstitutionModel);
        this.demes = generalSubstitutionModel.getDataType().getStateCount();
        int n2 = this.treeModel.getNodeCount();
        this.intervalCount = this.intervals.getIntervalCount();
        this.matricesKnown = new boolean[this.intervalCount];
        this.storedMatricesKnown = new boolean[this.intervalCount];
        this.migrationMatrices = new double[this.intervalCount][this.demes * this.demes];
        this.storedMigrationMatrices = new double[this.intervalCount][this.demes * this.demes];
        this.intervalStartProbs = new double[this.intervalCount][];
        this.intervalEndProbs = new double[this.intervalCount][];
        this.intervalStartSquareProbs = new double[this.intervalCount][];
        this.intervalEndSquareProbs = new double[this.intervalCount][];
        this.coalescentLeftProbs = new double[this.intervalCount][];
        this.coalescentRightProbs = new double[this.intervalCount][];
        for (int i = 0; i < this.intervalCount; ++i) {
            this.intervalStartProbs[i] = new double[this.demes];
            this.intervalEndProbs[i] = new double[this.demes];
            this.intervalStartSquareProbs[i] = new double[this.demes];
            this.intervalEndSquareProbs[i] = new double[this.demes];
            this.coalescentLeftProbs[i] = new double[n2 * this.demes];
            this.coalescentRightProbs[i] = new double[n2 * this.demes];
        }
        this.activeLineages = new double[n2 * this.demes];
        this.activeNodeNumbers = new HashSet(n2);
        this.likelihoodKnown = false;
        this.temp = new double[this.demes];
    }

    @Override
    public final Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    public double calculateLogLikelihood() {
        this.computeProbabilityDistributions(0);
        return this.calculateLogLikelihood(0);
    }

    public double calculateLogLikelihood(int n) {
        double d = 0.0;
        for (int i = 0; i < this.intervalCount; ++i) {
            double d2;
            double d3 = this.intervals.getInterval(i);
            if (d3 == 0.0) continue;
            double[] dArray = this.intervalStartProbs[i];
            double[] dArray2 = this.intervalStartSquareProbs[i];
            double[] dArray3 = this.intervalEndProbs[i];
            double[] dArray4 = this.intervalEndSquareProbs[i];
            double d4 = d3 / 2.0;
            if (d4 != 0.0) {
                d2 = 0.0;
                double d5 = 0.0;
                for (int j = 0; j < this.demes; ++j) {
                    d2 += (dArray[j] * dArray[j] - dArray2[j]) / (2.0 * this.popSizes.getParameterValue(j));
                    d2 += (dArray3[j] * dArray3[j] - dArray4[j]) / (2.0 * this.popSizes.getParameterValue(j));
                }
                d += -d4 * d2;
                d += -d4 * d5;
            }
            if (this.intervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
            d2 = 0.0;
            d2 = this.coalescentLeftProbs[i][0];
            d += Math.log(d2);
        }
        return d;
    }

    private void newHardWork(double[] dArray, double[] dArray2) {
        for (int i = 0; i < this.demes; ++i) {
            dArray[i] = 0.0;
            dArray2[i] = 0.0;
        }
        for (int n : this.activeNodeNumbers) {
            for (int i = 0; i < this.demes; ++i) {
                int n2 = i;
                dArray[n2] = dArray[n2] + this.activeLineages[n * this.demes + i];
                int n3 = i;
                dArray2[n3] = dArray2[n3] + this.activeLineages[n * this.demes + i] * this.activeLineages[n * this.demes + i];
            }
        }
    }

    private void handleCoalescense(NodeRef nodeRef, int n, int n2) {
        int n3;
        double d = this.intervals.getInterval(n2);
        this.newHardWork(this.intervalStartProbs[n2], this.intervalStartSquareProbs[n2]);
        this.incrementActiveLineages(this.activeLineages, d, n2);
        nodeRef = this.intervals.getCoalescentNode(n2);
        n = nodeRef.getNumber() * this.demes;
        NodeRef nodeRef2 = this.treeModel.getChild(nodeRef, 0);
        NodeRef nodeRef3 = this.treeModel.getChild(nodeRef, 1);
        int n4 = nodeRef2.getNumber() * this.demes;
        int n5 = nodeRef3.getNumber() * this.demes;
        double d2 = 0.0;
        for (n3 = 0; n3 < this.demes; ++n3) {
            this.temp[n3] = this.activeLineages[n4 + n3] * this.activeLineages[n5 + n3] / this.popSizes.getParameterValue(n3);
            d2 += this.temp[n3];
        }
        this.coalescentLeftProbs[n2][0] = d2;
        for (n3 = 0; n3 < this.demes; ++n3) {
            this.activeLineages[n + n3] = this.temp[n3] / d2;
        }
        this.newHardWork(this.intervalEndProbs[n2], this.intervalEndSquareProbs[n2]);
        this.doShit(nodeRef, nodeRef2, nodeRef3);
    }

    private void doShit(NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3) {
        this.activeNodeNumbers.remove(nodeRef2.getNumber());
        this.activeNodeNumbers.remove(nodeRef3.getNumber());
        this.activeNodeNumbers.add(nodeRef.getNumber());
    }

    private void computeProbabilityDistributions(int n) {
        int n2;
        this.activeNodeNumbers.clear();
        NodeRef nodeRef = this.intervals.getSamplingNode(-1);
        int n3 = nodeRef.getNumber() * this.demes;
        for (n2 = 0; n2 < this.demes; ++n2) {
            this.activeLineages[n3 + n2] = 0.0;
        }
        this.activeLineages[n3 + this.patternList.getPattern((int)0)[this.patternList.getTaxonIndex((String)this.treeModel.getNodeTaxon((NodeRef)nodeRef).getId())]] = 1.0;
        this.activeNodeNumbers.add(nodeRef.getNumber());
        this.intervalStartProbs[0] = Arrays.copyOfRange(this.activeLineages, n3, n3 + this.demes);
        this.intervalStartSquareProbs[0] = Arrays.copyOfRange(this.activeLineages, n3, n3 + this.demes);
        for (n2 = 0; n2 < this.intervalCount; ++n2) {
            if (this.intervals.getIntervalType(n2) == IntervalType.COALESCENT) {
                this.handleCoalescense(nodeRef, n3, n2);
                continue;
            }
            if (this.intervals.getIntervalType(n2) != IntervalType.SAMPLE) continue;
            this.handleSampling(nodeRef, n3, n2);
        }
    }

    private void handleSampling(NodeRef nodeRef, int n, int n2) {
        if (this.intervals.getInterval(n2) == 0.0) {
            nodeRef = this.intervals.getSamplingNode(n2);
            for (int i = 0; i < this.demes; ++i) {
                this.activeLineages[nodeRef.getNumber() * this.demes + i] = 0.0;
            }
            this.activeLineages[nodeRef.getNumber() * this.demes + this.patternList.getPattern((int)0)[this.patternList.getTaxonIndex((String)this.treeModel.getNodeTaxon((NodeRef)nodeRef).getId())]] = 1.0;
            this.activeNodeNumbers.add(nodeRef.getNumber());
        } else {
            double d = this.intervals.getInterval(n2);
            this.newHardWork(this.intervalStartProbs[n2], this.intervalStartSquareProbs[n2]);
            this.incrementActiveLineages(this.activeLineages, d, n2);
            this.newHardWork(this.intervalEndProbs[n2], this.intervalEndSquareProbs[n2]);
            nodeRef = this.intervals.getSamplingNode(n2);
            for (int i = 0; i < this.demes; ++i) {
                this.activeLineages[nodeRef.getNumber() * this.demes + i] = 0.0;
            }
            this.activeLineages[nodeRef.getNumber() * this.demes + this.patternList.getPattern((int)0)[this.patternList.getTaxonIndex((String)this.treeModel.getNodeTaxon((NodeRef)nodeRef).getId())]] = 1.0;
            this.activeNodeNumbers.add(nodeRef.getNumber());
        }
    }

    private void printIntervalContributions(int n) {
        int n2;
        System.out.print("  starting lineage count: ");
        for (n2 = 0; n2 < this.demes; ++n2) {
            System.out.print(this.intervalStartProbs[n][n2] + " ");
        }
        System.out.println();
        System.out.print("  starting lineage count squared: ");
        for (n2 = 0; n2 < this.demes; ++n2) {
            System.out.print(this.intervalStartSquareProbs[n][n2] + " ");
        }
        System.out.println();
        System.out.print("  ending lineage count: ");
        for (n2 = 0; n2 < this.demes; ++n2) {
            System.out.print(this.intervalEndProbs[n][n2] + " ");
        }
        System.out.println();
        System.out.print("  ending lineage count squared: ");
        for (n2 = 0; n2 < this.demes; ++n2) {
            System.out.print(this.intervalEndSquareProbs[n][n2] + " ");
        }
        System.out.println();
    }

    private void printActiveLineages() {
        System.out.println("printActiveLineages");
        for (int n : this.activeNodeNumbers) {
            System.out.print("  active lineage: ");
            for (int i = 0; i < this.demes; ++i) {
                System.out.print(this.activeLineages[n * this.demes + i] + " ");
            }
            System.out.println();
        }
    }

    private static void transpose(double[] dArray, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = i + 1; j < n; ++j) {
                int n2 = i * n + j;
                int n3 = j * n + i;
                double d = dArray[n2];
                dArray[n2] = dArray[n3];
                dArray[n3] = d;
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void incrementActiveLineages(double[] dArray, double d, int n) {
        double d2;
        Object object = this.branchRateModel;
        synchronized (object) {
            d2 = this.branchRateModel.getBranchRate(this.treeModel, this.treeModel.getRoot());
        }
        if (!this.matricesKnown[n]) {
            this.generalSubstitutionModel.getTransitionProbabilities(d2 * d, this.migrationMatrices[n]);
            this.matricesKnown[n] = true;
        }
        object = this.activeNodeNumbers.iterator();
        while (object.hasNext()) {
            int n2;
            int n3 = (Integer)object.next();
            for (n2 = 0; n2 < this.demes; ++n2) {
                this.temp[n2] = FasterStructuredCoalescentLikelihood.rdot(this.demes, dArray, n3 * this.demes, 1, this.migrationMatrices[n], n2, this.demes);
            }
            for (n2 = 0; n2 < this.demes; ++n2) {
                dArray[n3 * this.demes + n2] = this.temp[n2];
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void incrementActiveLineages(ArrayList<double[]> arrayList, double d, int n) {
        double d2;
        BranchRateModel branchRateModel = this.branchRateModel;
        synchronized (branchRateModel) {
            d2 = this.branchRateModel.getBranchRate(this.treeModel, this.treeModel.getRoot());
        }
        if (!this.matricesKnown[n]) {
            this.generalSubstitutionModel.getTransitionProbabilities(d2 * d, this.migrationMatrices[n]);
            this.matricesKnown[n] = true;
        }
        for (double[] dArray : arrayList) {
            int n2;
            for (n2 = 0; n2 < this.demes; ++n2) {
                this.temp[n2] = FasterStructuredCoalescentLikelihood.rdot(this.demes, dArray, 0, 1, this.migrationMatrices[n], n2, this.demes);
            }
            for (n2 = 0; n2 < this.demes; ++n2) {
                dArray[n2] = this.temp[n2];
            }
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.treeModel) {
            if (object instanceof TreeChangedEvent) {
                System.out.println("TreeChangedEvent");
                if (((TreeChangedEvent)object).isNodeChanged()) {
                    System.out.println("current tree = " + this.treeModel);
                    System.out.println("isNodeChanged: " + ((TreeChangedEvent)object).getNode().getNumber());
                    System.out.println("root node number: " + this.treeModel.getRoot().getNumber());
                } else if (((TreeChangedEvent)object).isHeightChanged()) {
                    System.out.println("isHeightChanged: " + ((TreeChangedEvent)object).getNode().getNumber());
                } else if (((TreeChangedEvent)object).isTreeChanged()) {
                    System.out.println("isTreeChanged");
                    System.err.println("Full tree update event - these events currently aren't used\nso either this is in error or a new feature is using them so remove this message.");
                } else {
                    System.err.println("Another tree event has occurred (possibly a trait change).");
                }
            } else if (object instanceof IntervalChangedEvent) {
                // empty if block
            }
        } else if (model == this.branchRateModel) {
            for (int i = 0; i < this.intervalCount; ++i) {
                this.matricesKnown[i] = false;
            }
            this.likelihoodKnown = false;
            this.areStatesRedrawn = false;
        } else if (model == this.generalSubstitutionModel) {
            for (int i = 0; i < this.intervalCount; ++i) {
                this.matricesKnown[i] = false;
            }
            this.likelihoodKnown = false;
            this.areStatesRedrawn = false;
        } else if (model == this.intervals) {
            for (int i = 0; i < this.intervalCount; ++i) {
                this.matricesKnown[i] = false;
            }
            this.likelihoodKnown = false;
            this.areStatesRedrawn = false;
        } else {
            throw new RuntimeException("Unknown handleModelChangedEvent source, exiting.");
        }
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
        this.areStatesRedrawn = false;
        for (int i = 0; i < this.intervalCount; ++i) {
            this.matricesKnown[i] = true;
        }
    }

    @Override
    protected void storeState() {
        int n;
        for (n = 0; n < this.intervalCount; ++n) {
            System.arraycopy(this.migrationMatrices[n], 0, this.storedMigrationMatrices[n], 0, this.demes * this.demes);
            this.storedMatricesKnown[n] = this.matricesKnown[n];
        }
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedLogLikelihood = this.logLikelihood;
        if (this.areStatesRedrawn) {
            for (n = 0; n < this.reconstructedStates.length; ++n) {
                System.arraycopy(this.reconstructedStates[n], 0, this.storedReconstructedStates[n], 0, this.reconstructedStates[n].length);
            }
        }
        this.storedAreStatesRedrawn = this.areStatesRedrawn;
    }

    @Override
    protected void restoreState() {
        for (int i = 0; i < this.intervalCount; ++i) {
            double[] dArray = this.migrationMatrices[i];
            this.migrationMatrices[i] = this.storedMigrationMatrices[i];
            this.storedMigrationMatrices[i] = dArray;
            this.matricesKnown[i] = this.storedMatricesKnown[i];
        }
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.logLikelihood = this.storedLogLikelihood;
        int[][] nArray = this.reconstructedStates;
        this.reconstructedStates = this.storedReconstructedStates;
        this.storedReconstructedStates = nArray;
        this.areStatesRedrawn = this.storedAreStatesRedrawn;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public void makeDirty() {
        for (int i = 0; i < this.intervalCount; ++i) {
            this.matricesKnown[i] = false;
        }
        this.likelihoodKnown = false;
        this.areStatesRedrawn = false;
    }

    public TreeModel getTreeModel() {
        return this.treeModel;
    }

    public static double rdot(int n, double[] dArray, int n2, int n3, double[] dArray2, int n4, int n5) {
        double d = 0.0;
        if (n3 == 1 && n5 == 1 && n2 == 0 && n4 == 0) {
            for (int i = 0; i < n; ++i) {
                d += dArray[i] * dArray2[i];
            }
        } else if (n3 == 1 && n5 == 1) {
            int n6 = 0;
            int n7 = n2;
            int n8 = n4;
            while (n6 < n) {
                d += dArray[n7] * dArray2[n8];
                ++n6;
                ++n7;
                ++n8;
            }
        } else {
            int n9 = 0;
            int n10 = n2;
            int n11 = n4;
            while (n9 < n) {
                d += dArray[n10] * dArray2[n11];
                ++n9;
                n10 += n3;
                n11 += n5;
            }
        }
        return d;
    }

    @Override
    public final void setUnits(Units.Type type) {
        this.treeModel.setUnits(type);
    }

    @Override
    public final Units.Type getUnits() {
        return this.treeModel.getUnits();
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override
    public String getDescription() {
        return "Bayesian structured coalescent approximation";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }
}

