/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.AutoCorrelatedBranchRatesDistribution;
import dr.evomodel.branchratemodel.AutoCorrelatedGradientWrtIncrements;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;

public class BranchRateGradientWrtIncrements
implements GradientWrtParameterProvider,
Reportable {
    private final GradientWrtParameterProvider rateGradientProvider;
    private final AutoCorrelatedGradientWrtIncrements priorGradientProvider;
    private final ArbitraryBranchRates branchRates;
    private final Tree tree;
    private final AutoCorrelatedBranchRatesDistribution.BranchVarianceScaling scaling;
    private final AutoCorrelatedBranchRatesDistribution.BranchRateUnits units;

    public BranchRateGradientWrtIncrements(GradientWrtParameterProvider gradientWrtParameterProvider, AutoCorrelatedGradientWrtIncrements autoCorrelatedGradientWrtIncrements) {
        this.rateGradientProvider = gradientWrtParameterProvider;
        this.priorGradientProvider = autoCorrelatedGradientWrtIncrements;
        AutoCorrelatedBranchRatesDistribution autoCorrelatedBranchRatesDistribution = autoCorrelatedGradientWrtIncrements.getDistribution();
        this.branchRates = autoCorrelatedBranchRatesDistribution.getBranchRateModel();
        this.tree = autoCorrelatedBranchRatesDistribution.getTree();
        this.scaling = autoCorrelatedBranchRatesDistribution.getScaling();
        this.units = autoCorrelatedBranchRatesDistribution.getUnits();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.rateGradientProvider.getLikelihood();
    }

    @Override
    public Parameter getParameter() {
        return this.priorGradientProvider.getParameter();
    }

    @Override
    public int getDimension() {
        return this.priorGradientProvider.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.rateGradientProvider.getGradientLogDensity();
        double[] dArray2 = new double[dArray.length];
        this.recursePostOrderToAccumulateGradient(this.tree.getRoot(), dArray, dArray2);
        return dArray2;
    }

    private double recursePostOrderToAccumulateGradient(NodeRef nodeRef, double[] dArray, double[] dArray2) {
        double d = 0.0;
        if (!this.tree.isExternal(nodeRef)) {
            d += this.recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 0), dArray, dArray2);
            d += this.recursePostOrderToAccumulateGradient(this.tree.getChild(nodeRef, 1), dArray, dArray2);
        }
        if (!this.tree.isRoot(nodeRef)) {
            int n = this.branchRates.getParameterIndexFromNode(nodeRef);
            dArray2[n] = this.scaling.inverseRescaleIncrement(d += this.units.inverseTransformGradient(dArray[n], this.branchRates.getUntransformedBranchRate(this.tree, nodeRef)), this.tree.getBranchLength(nodeRef));
        }
        return d;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
    }
}

