/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.continuous.AbstractDriftDiffusionModelDelegate;
import dr.evomodel.treedatalikelihood.continuous.BranchRateGradient;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate;
import dr.evomodel.treedatalikelihood.continuous.OUDiffusionModelDelegate;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.evomodel.treedatalikelihood.preorder.NormalSufficientStatistics;
import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.Arrays;
import java.util.List;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public interface ContinuousTraitGradientForBranch {
    public double[] getGradientForBranch(BranchSufficientStatistics var1, NodeRef var2);

    public int getParameterIndexFromNode(NodeRef var1);

    public int getDimension();

    public static class SamplingVarianceGradient
    extends ContinuousProcessParameterGradient {
        ModelExtensionProvider.NormalExtensionProvider dataModel;

        public SamplingVarianceGradient(int n, Tree tree, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ModelExtensionProvider.NormalExtensionProvider normalExtensionProvider) {
            super(n, tree, continuousDataLikelihoodDelegate, Arrays.asList(ContinuousProcessParameterGradient.DerivationParameter.WRT_SAMPLING_VARIANCE));
            this.dataModel = normalExtensionProvider;
        }

        @Override
        public double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            if (!this.tree.isExternal(nodeRef)) {
                return new double[this.getDimension()];
            }
            double[] dArray = this.getGradientForBranch(branchSufficientStatistics, nodeRef, true, false);
            this.dataModel.chainRuleWrtVariance(dArray, nodeRef);
            return dArray;
        }

        /*
         * WARNING - void declaration
         */
        @Override
        void getSufficientStatistics(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            int n;
            void object;
            NormalSufficientStatistics normalSufficientStatistics = branchSufficientStatistics.getBelow();
            NormalSufficientStatistics normalSufficientStatistics2 = branchSufficientStatistics.getAbove();
            DenseMatrix64F denseMatrix64F = this.dataModel.getExtensionVariance(nodeRef);
            DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dim, this.dim);
            DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dim, this.dim);
            CommonOps.add((D1Matrix64F)normalSufficientStatistics2.getRawVariance(), denseMatrix64F, (D1Matrix64F)denseMatrix64F2);
            MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
            NormalSufficientStatistics normalSufficientStatistics3 = new NormalSufficientStatistics(normalSufficientStatistics2.getRawMeanCopy(), denseMatrix64F3, denseMatrix64F2);
            int[] nArray = branchSufficientStatistics.getMissing();
            DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.dim, this.dim);
            boolean i = false;
            while (object < this.dim) {
                denseMatrix64F4.unsafe_set((int)object, (int)object, Double.POSITIVE_INFINITY);
                ++object;
            }
            int[] nArray2 = nArray;
            int n2 = nArray2.length;
            for (n = 0; n < n2; ++n) {
                int n3 = nArray2[n];
                denseMatrix64F4.unsafe_set(n3, n3, 0.0);
            }
            NormalSufficientStatistics normalSufficientStatistics4 = new NormalSufficientStatistics(normalSufficientStatistics.getRawMeanCopy(), denseMatrix64F4);
            NormalSufficientStatistics normalSufficientStatistics5 = BranchRateGradient.ContinuousTraitGradientForBranch.Default.computeJointStatistics(normalSufficientStatistics4, normalSufficientStatistics3, this.dim);
            this.matrixQ = normalSufficientStatistics3.getRawPrecision();
            this.matrixW = normalSufficientStatistics3.getRawVariance();
            this.matrixV = normalSufficientStatistics5.getRawVariance();
            for (n = 0; n < this.dim; ++n) {
                this.matrixDelta.unsafe_set(n, 0, normalSufficientStatistics5.getMean(n) - normalSufficientStatistics3.getMean(n));
            }
        }

        @Override
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return 0;
        }
    }

    public static class ContinuousProcessParameterGradient
    extends Default {
        ContinuousDataLikelihoodDelegate likelihoodDelegate;
        ContinuousDiffusionIntegrator cdi;
        DiffusionProcessDelegate diffusionProcessDelegate;
        final List<DerivationParameter> derivationParameter;

        public ContinuousProcessParameterGradient(int n, Tree tree, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, List<DerivationParameter> list) {
            super(n, tree);
            this.likelihoodDelegate = continuousDataLikelihoodDelegate;
            this.cdi = continuousDataLikelihoodDelegate.getIntegrator();
            this.diffusionProcessDelegate = continuousDataLikelihoodDelegate.getDiffusionProcessDelegate();
            this.derivationParameter = list;
        }

        @Override
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return 0;
        }

        @Override
        public int getDimension() {
            int n = 0;
            for (DerivationParameter derivationParameter : this.derivationParameter) {
                n += derivationParameter.getDimension(this.dim);
            }
            return n;
        }

        @Override
        public double[] chainRule(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            ContinuousProcessParameterGradient.removeMissing(denseMatrix64F, branchSufficientStatistics.getMissing());
            double[] dArray = new double[this.getDimension()];
            int n = 0;
            for (DerivationParameter derivationParameter : this.derivationParameter) {
                int n2 = derivationParameter.getDimension(this.dim);
                System.arraycopy(derivationParameter.chainRule(this.cdi, this.diffusionProcessDelegate, this.likelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2), 0, dArray, n, n2);
                n += n2;
            }
            return dArray;
        }

        @Override
        public double[] chainRuleRoot(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            double[] dArray = new double[this.getDimension()];
            int n = 0;
            for (DerivationParameter derivationParameter : this.derivationParameter) {
                int n2 = derivationParameter.getDimension(this.dim);
                System.arraycopy(derivationParameter.chainRuleRoot(this.cdi, this.diffusionProcessDelegate, this.likelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2), 0, dArray, n, n2);
                n += n2;
            }
            return dArray;
        }

        private static void removeMissing(DenseMatrix64F denseMatrix64F, int[] nArray) {
            for (int n : nArray) {
                for (int i = 0; i < denseMatrix64F.getNumCols(); ++i) {
                    denseMatrix64F.unsafe_set(n, i, 0.0);
                    denseMatrix64F.unsafe_set(i, n, 0.0);
                }
            }
        }

        List<DerivationParameter> getDerivationParameter() {
            return this.derivationParameter;
        }

        public static enum DerivationParameter {
            WRT_VARIANCE{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    DenseMatrix64F denseMatrix64F3 = diffusionProcessDelegate.getGradientVarianceWrtVariance(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F);
                    return denseMatrix64F3.getData();
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return this.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override
                public int getDimension(int n) {
                    return n * n;
                }
            }
            ,
            WRT_CONSTANT_DRIFT{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    DenseMatrix64F denseMatrix64F3 = ((AbstractDriftDiffusionModelDelegate)diffusionProcessDelegate).getGradientDisplacementWrtDrift(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F2);
                    return denseMatrix64F3.getData();
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return new double[denseMatrix64F2.getNumRows()];
                }

                @Override
                public int getDimension(int n) {
                    return n;
                }
            }
            ,
            WRT_ROOT_MEAN{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return diffusionProcessDelegate.getGradientDisplacementWrtRoot(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F2);
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return this.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override
                public int getDimension(int n) {
                    return n;
                }
            }
            ,
            WRT_CONSTANT_DRIFT_AND_ROOT_MEAN{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    double[] dArray = WRT_CONSTANT_DRIFT.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                    double[] dArray2 = WRT_ROOT_MEAN.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                    for (int i = 0; i < dArray2.length; ++i) {
                        int n = i;
                        dArray2[n] = dArray2[n] + dArray[i];
                    }
                    return dArray2;
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return WRT_ROOT_MEAN.chainRuleRoot(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override
                public int getDimension(int n) {
                    return n;
                }
            }
            ,
            WRT_DIAGONAL_SELECTION_STRENGTH{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    DenseMatrix64F denseMatrix64F3 = ((OUDiffusionModelDelegate)diffusionProcessDelegate).getGradientVarianceWrtAttenuation(nodeRef, continuousDiffusionIntegrator, branchSufficientStatistics, denseMatrix64F);
                    DenseMatrix64F denseMatrix64F4 = ((OUDiffusionModelDelegate)diffusionProcessDelegate).getGradientDisplacementWrtAttenuation(nodeRef, continuousDiffusionIntegrator, branchSufficientStatistics, denseMatrix64F2);
                    CommonOps.addEquals(denseMatrix64F3, denseMatrix64F4);
                    return denseMatrix64F3.getData();
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return new double[continuousDataLikelihoodDelegate.getTraitDim()];
                }

                @Override
                public int getDimension(int n) {
                    return n;
                }
            }
            ,
            WRT_BRANCH_SPECIFIC_DRIFT{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    double[] dArray = WRT_CONSTANT_DRIFT.chainRule(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                    return dArray;
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return WRT_ROOT_MEAN.chainRuleRoot(continuousDiffusionIntegrator, diffusionProcessDelegate, continuousDataLikelihoodDelegate, branchSufficientStatistics, nodeRef, denseMatrix64F, denseMatrix64F2);
                }

                @Override
                public int getDimension(int n) {
                    return n;
                }
            }
            ,
            WRT_SAMPLING_VARIANCE{

                @Override
                public double[] chainRule(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    return denseMatrix64F.getData();
                }

                @Override
                public double[] chainRuleRoot(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
                    throw new RuntimeException("Should never be called.");
                }

                @Override
                public int getDimension(int n) {
                    return n * n;
                }
            };


            abstract double[] chainRule(ContinuousDiffusionIntegrator var1, DiffusionProcessDelegate var2, ContinuousDataLikelihoodDelegate var3, BranchSufficientStatistics var4, NodeRef var5, DenseMatrix64F var6, DenseMatrix64F var7);

            abstract double[] chainRuleRoot(ContinuousDiffusionIntegrator var1, DiffusionProcessDelegate var2, ContinuousDataLikelihoodDelegate var3, BranchSufficientStatistics var4, NodeRef var5, DenseMatrix64F var6, DenseMatrix64F var7);

            public abstract int getDimension(int var1);
        }
    }

    public static class RateGradient
    extends Default {
        private final DenseMatrix64F matrixJacobianQInv;
        private final DenseMatrix64F matrixJacobianN;
        private final DenseMatrix64F matrix0;
        private final ArbitraryBranchRates branchRateModel;

        public RateGradient(int n, Tree tree, BranchRateModel branchRateModel) {
            super(n, tree);
            this.branchRateModel = branchRateModel instanceof ArbitraryBranchRates ? (ArbitraryBranchRates)branchRateModel : null;
            this.matrixJacobianQInv = new DenseMatrix64F(n, n);
            this.matrixJacobianN = new DenseMatrix64F(n, 1);
            this.matrix0 = new DenseMatrix64F(n, n);
        }

        @Override
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            if (this.tree.isRoot(nodeRef)) {
                return 0;
            }
            return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
        }

        @Override
        public int getDimension() {
            return 1;
        }

        @Override
        public double[] chainRule(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            double d = this.branchRateModel.getBranchRate(this.tree, nodeRef);
            double d2 = this.branchRateModel.getBranchRateDifferential(this.tree, nodeRef);
            double d3 = d2 / d;
            DenseMatrix64F denseMatrix64F3 = this.matrixJacobianQInv;
            CommonOps.scale(d3, branchSufficientStatistics.getBranch().getRawVariance(), denseMatrix64F3);
            double[] dArray = new double[1];
            for (int i = 0; i < denseMatrix64F3.getNumElements(); ++i) {
                dArray[0] = dArray[0] + denseMatrix64F3.get(i) * denseMatrix64F.get(i);
            }
            DenseMatrix64F denseMatrix64F4 = this.matrixJacobianN;
            CommonOps.scale(d3, branchSufficientStatistics.getBranch().getRawDisplacement(), denseMatrix64F4);
            for (int i = 0; i < denseMatrix64F4.numRows; ++i) {
                dArray[0] = dArray[0] + denseMatrix64F4.get(i) * denseMatrix64F2.get(i);
            }
            return dArray;
        }

        @Override
        public double[] chainRuleRoot(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
            return new double[1];
        }
    }

    public static abstract class Default
    implements ContinuousTraitGradientForBranch {
        private final DenseMatrix64F matrixGradientQInv;
        private final DenseMatrix64F matrixGradientN;
        final DenseMatrix64F matrixDelta;
        DenseMatrix64F matrixQ;
        DenseMatrix64F matrixW;
        DenseMatrix64F matrixV;
        final int dim;
        final Tree tree;
        static final boolean DEBUG = false;

        public Default(int n, Tree tree) {
            this.dim = n;
            this.tree = tree;
            this.matrixGradientQInv = new DenseMatrix64F(n, n);
            this.matrixGradientN = new DenseMatrix64F(n, 1);
            this.matrixDelta = new DenseMatrix64F(n, 1);
            this.matrixQ = new DenseMatrix64F(n, n);
            this.matrixW = new DenseMatrix64F(n, n);
            this.matrixV = new DenseMatrix64F(n, n);
        }

        @Override
        public int getParameterIndexFromNode(NodeRef nodeRef) {
            return nodeRef.getNumber();
        }

        @Override
        public double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            return this.getGradientForBranch(branchSufficientStatistics, nodeRef, true, true);
        }

        double[] getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef, boolean bl, boolean bl2) {
            this.getSufficientStatistics(branchSufficientStatistics, nodeRef);
            DenseMatrix64F denseMatrix64F = this.matrixQ;
            DenseMatrix64F denseMatrix64F2 = this.matrixW;
            DenseMatrix64F denseMatrix64F3 = this.matrixV;
            DenseMatrix64F denseMatrix64F4 = this.matrixDelta;
            DenseMatrix64F denseMatrix64F5 = this.matrixGradientQInv;
            DenseMatrix64F denseMatrix64F6 = this.matrixGradientN;
            if (bl) {
                Default.getGradientQInvForBranch(denseMatrix64F, denseMatrix64F2, denseMatrix64F3, denseMatrix64F4, denseMatrix64F5);
            }
            if (bl2) {
                this.getGradientNForBranch(denseMatrix64F, denseMatrix64F4, denseMatrix64F6);
            }
            if (this.tree.isRoot(nodeRef)) {
                return this.chainRuleRoot(branchSufficientStatistics, nodeRef, denseMatrix64F5, denseMatrix64F6);
            }
            return this.chainRule(branchSufficientStatistics, nodeRef, denseMatrix64F5, denseMatrix64F6);
        }

        void getSufficientStatistics(BranchSufficientStatistics branchSufficientStatistics, NodeRef nodeRef) {
            NormalSufficientStatistics normalSufficientStatistics = branchSufficientStatistics.getBelow();
            NormalSufficientStatistics normalSufficientStatistics2 = branchSufficientStatistics.getAbove();
            NormalSufficientStatistics normalSufficientStatistics3 = BranchRateGradient.ContinuousTraitGradientForBranch.Default.computeJointStatistics(normalSufficientStatistics, normalSufficientStatistics2, this.dim);
            this.matrixQ = normalSufficientStatistics2.getRawPrecision();
            this.matrixW = normalSufficientStatistics2.getRawVariance();
            this.matrixV = normalSufficientStatistics3.getRawVariance();
            for (int i = 0; i < this.dim; ++i) {
                this.matrixDelta.unsafe_set(i, 0, normalSufficientStatistics3.getMean(i) - normalSufficientStatistics2.getMean(i));
            }
        }

        abstract double[] chainRule(BranchSufficientStatistics var1, NodeRef var2, DenseMatrix64F var3, DenseMatrix64F var4);

        abstract double[] chainRuleRoot(BranchSufficientStatistics var1, NodeRef var2, DenseMatrix64F var3, DenseMatrix64F var4);

        static void getGradientQInvForBranch(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3, DenseMatrix64F denseMatrix64F4, DenseMatrix64F denseMatrix64F5) {
            CommonOps.scale(0.5, denseMatrix64F2, denseMatrix64F5);
            CommonOps.multAddTransB(-0.5, denseMatrix64F4, denseMatrix64F4, denseMatrix64F5);
            CommonOps.addEquals(denseMatrix64F5, -0.5, denseMatrix64F3);
            MultivariateChainRule.InverseGeneral inverseGeneral = new MultivariateChainRule.InverseGeneral(denseMatrix64F);
            inverseGeneral.chainGradient(denseMatrix64F5);
        }

        private void getGradientNForBranch(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
            CommonOps.multTransA(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
        }
    }
}

