/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateGradient;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.Arrays;

public class NodeHeightGradientForDiscreteTrait
extends DiscreteTraitBranchRateGradient
implements GradientWrtParameterProvider,
Reportable,
Loggable {
    private final double[] nodeHeights;
    private final TreeModel treeModel;
    protected TreeParameterModel indexHelper;
    private final NodeHeightProxyParameter nodeHeightProxyParameter;
    private MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                NodeRef nodeRef = NodeHeightGradientForDiscreteTrait.this.tree.getInternalNode(i);
                NodeHeightGradientForDiscreteTrait.this.treeModel.setNodeHeight(nodeRef, dArray[i]);
            }
            NodeHeightGradientForDiscreteTrait.this.treeDataLikelihood.makeDirty();
            return NodeHeightGradientForDiscreteTrait.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return NodeHeightGradientForDiscreteTrait.this.tree.getInternalNodeCount();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = true;

    public NodeHeightGradientForDiscreteTrait(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter) {
        super(string, treeDataLikelihood, beagleDataLikelihoodDelegate, parameter, false);
        if (!(treeDataLikelihood.getTree() instanceof TreeModel)) {
            throw new IllegalArgumentException("Must provide a TreeModel");
        }
        this.treeModel = (TreeModel)treeDataLikelihood.getTree();
        this.nodeHeights = new double[this.tree.getInternalNodeCount()];
        this.indexHelper = new TreeParameterModel(this.treeModel, new Parameter.Default(this.tree.getNodeCount() - 1), false);
        this.nodeHeightProxyParameter = new NodeHeightProxyParameter("internalNodeHeights", this.treeModel, true);
    }

    @Override
    public Parameter getParameter() {
        return this.nodeHeightProxyParameter;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = new double[this.tree.getInternalNodeCount()];
        Arrays.fill(dArray, 0.0);
        double[] dArray2 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = this.tree.getInternalNode(i);
            if (!this.tree.isRoot(nodeRef)) {
                n = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef.getNumber());
                int n2 = i;
                dArray[n2] = dArray[n2] - dArray2[n] * this.branchRateModel.getBranchRate(this.tree, nodeRef);
            }
            for (n = 0; n < this.tree.getChildCount(nodeRef); ++n) {
                NodeRef nodeRef2 = this.tree.getChild(nodeRef, n);
                int n3 = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef2.getNumber());
                int n4 = i;
                dArray[n4] = dArray[n4] + dArray2[n3] * this.branchRateModel.getBranchRate(this.tree, nodeRef2);
            }
        }
        return dArray;
    }

    private double[] getNodeHeights() {
        for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getInternalNode(i);
            this.nodeHeights[i] = this.tree.getNodeHeight(nodeRef);
        }
        return this.nodeHeights;
    }

    @Override
    public String getReport() {
        this.treeDataLikelihood.makeDirty();
        double[] dArray = this.getNodeHeights();
        double[] dArray2 = null;
        boolean bl = this.valuesAreSufficientlyLarge(this.getNodeHeights());
        if (bl) {
            dArray2 = NumericalDerivative.gradient(this.numeric1, this.getNodeHeights());
        }
        for (int i = 0; i < dArray.length; ++i) {
            NodeRef nodeRef = this.tree.getInternalNode(i);
            this.treeModel.setNodeHeight(nodeRef, dArray[i]);
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Peeling: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        if (dArray2 != null && bl) {
            stringBuilder.append("numeric: ").append(new Vector(dArray2));
        } else {
            stringBuilder.append("mumeric: too close to 0");
        }
        stringBuilder.append("\n");
        this.treeDataLikelihood.makeDirty();
        return stringBuilder.toString();
    }
}

