/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchmodel.BranchSpecificSubstitutionParameterBranchModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.ProcessSimulation;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.BranchDifferentialMassProvider;
import dr.evomodel.treedatalikelihood.discrete.BranchSubstitutionParameterDelegate;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MachineAccuracy;
import dr.xml.Reportable;
import java.io.Serializable;
import java.util.ArrayList;

public class BranchSubstitutionParameterGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    protected final TreeDataLikelihood treeDataLikelihood;
    protected final TreeTrait treeTraitProvider;
    protected final Tree tree;
    protected final boolean useHessian;
    protected final CompoundParameter branchParameter;
    protected final DifferentiableBranchRates branchRateModel;
    protected final Double nullableTolerance;
    private static final boolean DEBUG = true;
    protected static final boolean COUNT_TOTAL_OPERATIONS = true;
    protected long getGradientLogDensityCount = 0L;
    private final double smallGradientThreshold = 0.5;
    private BranchDifferentialMassProvider save;

    public BranchSubstitutionParameterGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, CompoundParameter compoundParameter, DifferentiableBranchRates differentiableBranchRates, Double d, boolean bl, int n, DifferentialMassProvider.Mode mode) {
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.branchParameter = compoundParameter;
        this.branchRateModel = differentiableBranchRates;
        this.useHessian = bl;
        this.nullableTolerance = d;
        String string2 = BranchSubstitutionParameterDelegate.getName(string);
        TreeTrait treeTrait = treeDataLikelihood.getTreeTrait(string2);
        if (treeTrait == null) {
            BranchDifferentialMassProvider branchDifferentialMassProvider;
            Serializable serializable;
            Serializable serializable2;
            BranchSpecificSubstitutionParameterBranchModel branchSpecificSubstitutionParameterBranchModel = (BranchSpecificSubstitutionParameterBranchModel)beagleDataLikelihoodDelegate.getBranchModel();
            ArrayList<DifferentialMassProvider> arrayList = new ArrayList<DifferentialMassProvider>();
            for (int i = 0; i < this.tree.getNodeCount(); ++i) {
                serializable2 = this.tree.getNode(i);
                if (this.tree.isRoot((NodeRef)serializable2)) continue;
                serializable = (DifferentiableSubstitutionModel)branchSpecificSubstitutionParameterBranchModel.getSubstitutionModel((NodeRef)serializable2);
                Parameter parameter = compoundParameter.getParameter(serializable2.getNumber());
                DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter = serializable.factory(parameter, n);
                arrayList.add(new DifferentialMassProvider.DifferentialWrapper((DifferentiableSubstitutionModel)serializable, wrtParameter, mode));
            }
            this.save = branchDifferentialMassProvider = new BranchDifferentialMassProvider(differentiableBranchRates, arrayList);
            serializable2 = new BranchSubstitutionParameterDelegate(string, treeDataLikelihood.getTree(), beagleDataLikelihoodDelegate, treeDataLikelihood.getBranchRateModel(), branchDifferentialMassProvider);
            serializable = new ProcessSimulation(treeDataLikelihood, (ProcessSimulationDelegate)serializable2);
            treeDataLikelihood.addTraits(serializable.getTreeTraits());
        }
        this.treeTraitProvider = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        int n2 = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (n2 != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.branchRateModel.getRateParameter();
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        ++this.getGradientLogDensityCount;
        double[] dArray = new double[this.getDimension()];
        double[] dArray2 = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            int n = this.branchRateModel.getParameterIndexFromNode(nodeRef);
            dArray[n] = dArray2[n] * this.branchRateModel.getBranchRateDifferential(this.tree, nodeRef);
        }
        return this.branchRateModel.updateGradientLogDensity(dArray, null, 0, dArray2.length);
    }

    protected double getChainGradient(Tree tree, NodeRef nodeRef) {
        if (this.branchRateModel instanceof ArbitraryBranchRates) {
            double d = this.getParameter().getParameterValue(this.branchRateModel.getParameterIndexFromNode(nodeRef));
            return this.branchRateModel.getTransform().differential(d, tree, nodeRef);
        }
        return 1.0;
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "BranchSubstitutionParameterGradientReport");
    }

    protected boolean valuesAreSufficientlyLarge(double[] dArray) {
        for (double d : dArray) {
            if (!(Math.abs(d) < MachineAccuracy.SQRT_EPSILON * 1.2)) continue;
            return false;
        }
        return true;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, null);
    }
}

