/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch;
import dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient;
import dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider;
import dr.inference.model.CachedMatrixInverse;
import dr.inference.model.CompoundSymmetricMatrix;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;

public abstract class AbstractPrecisionGradient
extends AbstractDiffusionGradient
implements Reportable {
    private final GradientWrtPrecisionProvider gradientWrtPrecisionProvider;
    final CompoundSymmetricMatrix compoundSymmetricMatrix;
    private final int dim;
    private Parametrization parametrization;
    private final MatrixParameterInterface precision;
    private final MatrixParameterInterface variance;
    private static final boolean CHECK_GRADIENT = false;
    private static final boolean DEBUG = false;

    AbstractPrecisionGradient(GradientWrtPrecisionProvider gradientWrtPrecisionProvider, Likelihood likelihood, MatrixParameterInterface matrixParameterInterface, double d, double d2) {
        super(likelihood, d, d2);
        this.precision = matrixParameterInterface;
        if (matrixParameterInterface instanceof CachedMatrixInverse) {
            this.compoundSymmetricMatrix = (CompoundSymmetricMatrix)((CachedMatrixInverse)matrixParameterInterface).getBaseParameter();
            this.variance = this.compoundSymmetricMatrix;
            this.parametrization = Parametrization.AS_VARIANCE;
        } else if (matrixParameterInterface instanceof CompoundSymmetricMatrix) {
            this.compoundSymmetricMatrix = (CompoundSymmetricMatrix)matrixParameterInterface;
            this.variance = new CachedMatrixInverse("", this.precision);
            this.parametrization = Parametrization.AS_PRECISION;
        } else {
            throw new IllegalArgumentException("Unimplemented type");
        }
        assert (this.compoundSymmetricMatrix.asCorrelation()) : "PrecisionGradient can only be applied to a CompoundSymmetricMatrix with off-diagonal as correlation.";
        this.gradientWrtPrecisionProvider = gradientWrtPrecisionProvider;
        this.dim = matrixParameterInterface.getColumnDimension();
    }

    @Override
    public Parameter getRawParameter() {
        return this.precision;
    }

    @Override
    public ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter getDerivationParameter() {
        return ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.WRT_VARIANCE;
    }

    int getDimensionCorrelation() {
        return this.dim * (this.dim - 1) / 2;
    }

    int getDimensionDiagonal() {
        return this.dim;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.gradientWrtPrecisionProvider.getBranchSpecificGradient() == null ? null : this.gradientWrtPrecisionProvider.getBranchSpecificGradient().getGradientLogDensity();
        return this.getGradientLogDensity(dArray);
    }

    @Override
    public double[] getGradientLogDensity(double[] dArray) {
        double[] dArray2 = new double[this.dim * this.dim];
        if (dArray != null) {
            System.arraycopy(dArray, this.offset, dArray2, 0, this.dim * this.dim);
        }
        this.parametrization.updateParameters(this.variance);
        double[] dArray3 = AbstractPrecisionGradient.flatten(this.variance.getParameterAsMatrix());
        double[] dArray4 = AbstractPrecisionGradient.flatten(this.precision.getParameterAsMatrix());
        dArray2 = this.parametrization.getGradientWrtParameter(dArray2, dArray4, dArray3, this.gradientWrtPrecisionProvider);
        dArray2 = this.getGradientParameter(dArray2);
        return dArray2;
    }

    @Override
    String getReportString(double[] dArray, double[] dArray2) {
        return this.getClass().getCanonicalName() + "\nanalytic: " + new Vector(dArray) + "\nnumeric: " + new Vector(dArray2) + "\n";
    }

    @Override
    String getReportString(double[] dArray, double[] dArray2, double[] dArray3) {
        return this.getClass().getCanonicalName() + "\nanalytic: " + new Vector(dArray) + "\nnumeric (no Cholesky): " + new Vector(dArray2) + "\nnumeric (with Cholesky): " + new Vector(dArray3) + "\n";
    }

    abstract double[] getGradientParameter(double[] var1);

    double[] getGradientCorrelation(double[] dArray) {
        return this.compoundSymmetricMatrix.updateGradientOffDiagonal(dArray);
    }

    double[] getGradientDiagonal(double[] dArray) {
        return this.compoundSymmetricMatrix.updateGradientDiagonal(dArray);
    }

    public static double[] flatten(double[][] dArray) {
        int n = 0;
        for (double[] dArray2 : dArray) {
            n += dArray2.length;
        }
        double[] dArray3 = new double[n];
        int n2 = 0;
        for (double[] dArray4 : dArray) {
            System.arraycopy(dArray4, 0, dArray3, n2, dArray4.length);
            n2 += dArray4.length;
        }
        return dArray3;
    }

    static enum Parametrization {
        AS_PRECISION{

            @Override
            public double[] getGradientWrtParameter(double[] dArray, double[] dArray2, double[] dArray3, GradientWrtPrecisionProvider gradientWrtPrecisionProvider) {
                return gradientWrtPrecisionProvider.getGradientWrtPrecision(dArray3, dArray);
            }

            @Override
            void updateParameters(MatrixParameterInterface matrixParameterInterface) {
                ((CachedMatrixInverse)matrixParameterInterface).forceComputeInverse();
            }
        }
        ,
        AS_VARIANCE{

            @Override
            public double[] getGradientWrtParameter(double[] dArray, double[] dArray2, double[] dArray3, GradientWrtPrecisionProvider gradientWrtPrecisionProvider) {
                return gradientWrtPrecisionProvider.getGradientWrtVariance(dArray2, dArray3, dArray);
            }

            @Override
            void updateParameters(MatrixParameterInterface matrixParameterInterface) {
            }
        };


        abstract double[] getGradientWrtParameter(double[] var1, double[] var2, double[] var3, GradientWrtPrecisionProvider var4);

        abstract void updateParameters(MatrixParameterInterface var1);
    }
}

