/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.ProcessSimulation;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateDelegate;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.math.MultivariateFunction;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.xml.Reportable;
import java.util.Collections;
import java.util.List;

public class DiscreteTraitBranchRateGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable,
Citable {
    protected final TreeDataLikelihood treeDataLikelihood;
    protected final TreeTrait treeTraitProvider;
    protected final Tree tree;
    protected final boolean useHessian;
    protected final Parameter rateParameter;
    protected final DifferentiableBranchRates branchRateModel;
    MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                DiscreteTraitBranchRateGradient.this.rateParameter.setParameterValue(i, dArray[i]);
            }
            return DiscreteTraitBranchRateGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return DiscreteTraitBranchRateGradient.this.rateParameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean CHECK_GRADIENT_IN_REPORT = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    private long getGradientLogDensityCount = 0L;
    private long totalGradientTime = 0L;
    private static final Citation CITATION = new Citation(new Author[]{new Author("X", "Ji"), new Author("Z", "Zhang"), new Author("A", "Holbrook"), new Author("A", "Nishimura"), new Author("G", "Beale"), new Author("A", "Rambaut"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, "Gradients do grow on trees: a linear-time O(N)-dimensional gradient for statistical phylogenetics", 2020, "Molecular Biology and Evolution", 37, 3047, 3060, Citation.Status.PUBLISHED);

    public DiscreteTraitBranchRateGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter, boolean bl) {
        assert (treeDataLikelihood != null);
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.rateParameter = parameter;
        this.useHessian = bl;
        BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
        this.branchRateModel = branchRateModel instanceof DifferentiableBranchRates ? (DifferentiableBranchRates)((Object)branchRateModel) : null;
        String string2 = this.getTraitName(string);
        TreeTrait treeTrait = treeDataLikelihood.getTreeTrait(string2);
        if (treeTrait == null) {
            ProcessSimulationDelegate processSimulationDelegate = this.makeGradientDelegate(string, this.tree, beagleDataLikelihoodDelegate);
            ProcessSimulation processSimulation = new ProcessSimulation(treeDataLikelihood, processSimulationDelegate);
            treeDataLikelihood.addTraits(processSimulation.getTreeTraits());
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        int n = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (n != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
    }

    protected String getTraitName(String string) {
        return DiscreteTraitBranchRateDelegate.getName(null);
    }

    protected ProcessSimulationDelegate makeGradientDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate) {
        return new DiscreteTraitBranchRateDelegate(string, tree, beagleDataLikelihoodDelegate);
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.rateParameter;
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        double[] dArray = new double[this.tree.getNodeCount() - 1];
        double[] dArray2 = (double[])this.treeDataLikelihood.getTreeTrait("BranchRateHessian").getTrait(this.tree, null);
        double[] dArray3 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n2 = this.getParameterIndexFromNode(nodeRef);
            double d = this.getChainGradient(this.tree, nodeRef);
            double d2 = this.getChainSecondDerivative(this.tree, nodeRef);
            dArray[n2] = dArray2[n] * d * d + dArray3[n] * d2;
            ++n;
        }
        dArray = this.branchRateModel.updateDiagonalHessianLogDensity(dArray, dArray3, null, 0, dArray3.length);
        return dArray;
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] getGradientLogDensity() {
        long l = System.nanoTime();
        double[] dArray = new double[this.tree.getNodeCount() - 1];
        double[] dArray2 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            double d;
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n2 = this.getParameterIndexFromNode(nodeRef);
            dArray[n2] = d = dArray2[n] * this.getChainGradient(this.tree, nodeRef);
            dArray[n] = d;
            ++n;
        }
        dArray = this.updateBranchRateGradientLogDensity(dArray);
        ++this.getGradientLogDensityCount;
        long l2 = System.nanoTime();
        this.totalGradientTime += (l2 - l) / 1000000L;
        return dArray;
    }

    double[] updateBranchRateGradientLogDensity(double[] dArray) {
        return dArray;
    }

    protected double getChainGradient(Tree tree, NodeRef nodeRef) {
        return tree.getBranchLength(nodeRef);
    }

    protected double getChainSecondDerivative(Tree tree, NodeRef nodeRef) {
        return 0.0;
    }

    protected int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
    }

    protected boolean valuesAreSufficientlyLarge(double[] dArray) {
        for (double d : dArray) {
            if (!(Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2)) continue;
            return false;
        }
        return true;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("\n\tgetGradientLogDensityCount = ").append(this.getGradientLogDensityCount);
        stringBuilder.append("\n\taverageGradientTime = ");
        if (this.getGradientLogDensityCount > 0L) {
            stringBuilder.append(this.totalGradientTime / this.getGradientLogDensityCount);
        } else {
            stringBuilder.append("NA");
        }
        stringBuilder.append("\n");
        String string = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, null);
        if (this.useHessian) {
            string = string + HessianWrtParameterProvider.getReportAndCheckForError(this, null);
        }
        stringBuilder.append(string);
        return stringBuilder.toString();
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "gradient report");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.FRAMEWORK;
    }

    @Override
    public String getDescription() {
        return "Using linear-time branch-specific parameter differential calculations";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }
}

