/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.SecantHessian;
import dr.math.AdaptableCovariance;
import dr.math.AdaptableVector;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.RobustEigenDecomposition;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import java.util.Arrays;

public interface MassPreconditioner {
    public WrappedVector drawInitialMomentum();

    public double getVelocity(int var1, ReadableVector var2);

    public void storeSecant(ReadableVector var1, ReadableVector var2);

    public void updateMass();

    public ReadableVector doCollision(int[] var1, ReadableVector var2);

    public static class AdaptiveFullHessianPreconditioning
    extends FullHessianPreconditioning {
        private final AdaptableCovariance adaptableCovariance;
        private final GradientWrtParameterProvider gradientProvider;
        protected MultivariateFunction numeric1 = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                for (int i = 0; i < dArray.length; ++i) {
                    gradientProvider.getParameter().setParameterValue(i, dArray[i]);
                }
                return gradientProvider.getLikelihood().getLogLikelihood();
            }

            @Override
            public int getNumArguments() {
                return gradientProvider.getParameter().getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return 0.0;
            }

            @Override
            public double getUpperBound(int n) {
                return Double.POSITIVE_INFINITY;
            }
        };

        AdaptiveFullHessianPreconditioning(GradientWrtParameterProvider gradientWrtParameterProvider, AdaptableCovariance adaptableCovariance, Transform transform, int n) {
            super(null, transform, n);
            this.adaptableCovariance = adaptableCovariance;
            this.gradientProvider = gradientWrtParameterProvider;
        }

        @Override
        protected double[] computeInverseMass() {
            WrappedMatrix.ArrayOfArray arrayOfArray = (WrappedMatrix.ArrayOfArray)this.adaptableCovariance.getCovariance();
            return ((FullHessianPreconditioning)this).computeInverseMass(arrayOfArray, this.gradientProvider, FullHessianPreconditioning.PDTransformMatrix.Negate);
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.adaptableCovariance.update(readableVector2);
        }
    }

    public static class Secant
    extends FullHessianPreconditioning {
        private final SecantHessian secantHessian;

        Secant(SecantHessian secantHessian, Transform transform) {
            super(secantHessian, transform);
            this.secantHessian = secantHessian;
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.secantHessian.storeSecant(readableVector, readableVector2);
        }
    }

    public static class FullHessianPreconditioning
    extends HessianBased {
        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            super(hessianWrtParameterProvider, transform);
        }

        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n) {
            super(hessianWrtParameterProvider, transform, n);
        }

        @Override
        protected void initializeMass() {
            double[] dArray = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i * this.dim + i] = 1.0;
            }
            this.inverseMass = dArray;
        }

        private double[] computeInverseMass(WrappedMatrix.ArrayOfArray arrayOfArray, GradientWrtParameterProvider gradientWrtParameterProvider, PDTransformMatrix pDTransformMatrix) {
            double[][] dArray = arrayOfArray.getArrays();
            if (this.transform != null) {
                dArray = this.transform.updateHessianLogDensity(dArray, new double[this.dim][this.dim], gradientWrtParameterProvider.getGradientLogDensity(), gradientWrtParameterProvider.getParameter().getParameterValues(), 0, this.dim);
            }
            return pDTransformMatrix.transformMatrix(dArray, this.dim);
        }

        @Override
        protected double[] computeInverseMass() {
            WrappedMatrix.ArrayOfArray arrayOfArray = new WrappedMatrix.ArrayOfArray(this.hessian.getHessianLogDensity());
            return this.computeInverseMass(arrayOfArray, this.hessian, PDTransformMatrix.Invert);
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(new double[this.dim], FullHessianPreconditioning.toArray(this.inverseMass, this.dim, this.dim));
            return new WrappedVector.Raw(multivariateNormalDistribution.nextMultivariateNormal());
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            double d = 0.0;
            for (int i = 0; i < this.dim; ++i) {
                d += this.inverseMass[n * this.dim + i] * readableVector.get(i);
            }
            return d;
        }

        private static double[][] toArray(double[] dArray, int n, int n2) {
            double[][] dArrayArray = new double[n][];
            for (int i = 0; i < n; ++i) {
                dArrayArray[i] = new double[n2];
                System.arraycopy(dArray, n2 * i, dArrayArray[i], 0, n2);
            }
            return dArrayArray;
        }

        static enum PDTransformMatrix {
            Invert("Transform inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.inverseNegateEigenvalues(doubleMatrix1D);
                }
            }
            ,
            Default("Transform matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                }
            }
            ,
            Negate("Transform negative matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                    this.boundEigenvalues(doubleMatrix1D);
                    this.scaleEigenvalues(doubleMatrix1D);
                }
            }
            ,
            NegateInvert("Transform negative inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.inverseNegateEigenvalues(doubleMatrix1D);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                    this.boundEigenvalues(doubleMatrix1D);
                    this.scaleEigenvalues(doubleMatrix1D);
                }
            };

            String desc;
            private static final double MIN_EIGENVALUE = -10.0;
            private static final double MAX_EIGENVALUE = -0.5;

            private PDTransformMatrix(String string2) {
                this.desc = string2;
            }

            public String toString() {
                return this.desc;
            }

            protected void boundEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    if (doubleMatrix1D.get(i) > -0.5) {
                        doubleMatrix1D.set(i, -0.5);
                        continue;
                    }
                    if (!(doubleMatrix1D.get(i) < -10.0)) continue;
                    doubleMatrix1D.set(i, -10.0);
                }
            }

            protected void scaleEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                double d = 0.0;
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    d += doubleMatrix1D.get(i);
                }
                double d2 = -d / (double)doubleMatrix1D.cardinality();
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, doubleMatrix1D.get(i) / d2);
                }
            }

            protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                this.boundEigenvalues(doubleMatrix1D);
                this.scaleEigenvalues(doubleMatrix1D);
            }

            protected void inverseNegateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, -1.0 / doubleMatrix1D.get(i));
                }
            }

            protected void negateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, -doubleMatrix1D.get(i));
                }
            }

            public double[] transformMatrix(double[][] dArray, int n) {
                Algebra algebra = new Algebra();
                DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(dArray);
                RobustEigenDecomposition robustEigenDecomposition = new RobustEigenDecomposition(denseDoubleMatrix2D);
                DoubleMatrix1D doubleMatrix1D = robustEigenDecomposition.getRealEigenvalues();
                this.normalizeEigenvalues(doubleMatrix1D);
                DoubleMatrix2D doubleMatrix2D = robustEigenDecomposition.getV();
                this.transformEigenvalues(doubleMatrix1D);
                double[][] dArray2 = algebra.mult(algebra.mult(doubleMatrix2D, DoubleFactory2D.dense.diagonal(doubleMatrix1D)), algebra.inverse(doubleMatrix2D)).toArray();
                double[] dArray3 = new double[n * n];
                for (int i = 0; i < n; ++i) {
                    System.arraycopy(dArray2[i], 0, dArray3, i * n, n);
                }
                return dArray3;
            }

            protected abstract void transformEigenvalues(DoubleMatrix1D var1);
        }
    }

    public static class AdaptiveDiagonalPreconditioning
    extends DiagonalPreconditioning {
        private AdaptableVector.AdaptableVariance variance;
        private final int minimumUpdates = 100;

        AdaptiveDiagonalPreconditioning(int n, Transform transform) {
            super(n, transform);
            this.variance = new AdaptableVector.AdaptableVariance(n);
        }

        @Override
        protected void initializeMass() {
            super.initializeMass();
            this.adaptiveDiagonal.update(new WrappedVector.Raw(this.inverseMass));
        }

        @Override
        protected double[] computeInverseMass() {
            if (this.variance.getUpdateCount() > 100) {
                double[] dArray = this.variance.getVariance();
                this.adaptiveDiagonal.update(new WrappedVector.Raw(dArray));
            }
            return this.normalizeVector(this.adaptiveDiagonal.getMean(), this.dim);
        }

        private double[] normalizeVector(ReadableVector readableVector, double d) {
            double d2 = 0.0;
            for (int i = 0; i < readableVector.getDim(); ++i) {
                d2 += readableVector.get(i);
            }
            double d3 = d / d2;
            double[] dArray = new double[readableVector.getDim()];
            for (int i = 0; i < readableVector.getDim(); ++i) {
                dArray[i] = readableVector.get(i) * d3;
            }
            return dArray;
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.variance.update(readableVector2);
        }
    }

    public static class DiagonalHessianPreconditioning
    extends DiagonalPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        DiagonalHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n) {
            super(hessianWrtParameterProvider.getDimension(), transform);
            this.hessian = hessianWrtParameterProvider;
            this.adaptiveDiagonal = n > 0 ? new AdaptableVector.LimitedMemory(hessianWrtParameterProvider.getDimension(), n) : new AdaptableVector.Default(hessianWrtParameterProvider.getDimension());
        }

        @Override
        protected double[] computeInverseMass() {
            double[] dArray = this.hessian.getDiagonalHessianLogDensity();
            if (this.transform != null) {
                double[] dArray2 = this.hessian.getParameter().getParameterValues();
                double[] dArray3 = this.hessian.getGradientLogDensity();
                dArray = this.transform.updateDiagonalHessianLogDensity(dArray, dArray3, dArray2, 0, this.dim);
            }
            this.adaptiveDiagonal.update(new WrappedVector.Raw(dArray));
            return this.boundMassInverse(((WrappedVector)this.adaptiveDiagonal.getMean()).getBuffer());
        }

        private double[] boundMassInverse(double[] dArray) {
            double d = 0.0;
            double[] dArray2 = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray2[i] = -1.0 / dArray[i];
                if (dArray2[i] < 0.01) {
                    dArray2[i] = 0.01;
                } else if (dArray2[i] > 100.0) {
                    dArray2[i] = 100.0;
                }
                d += 1.0 / dArray2[i];
            }
            double d2 = d / (double)this.dim;
            for (int i = 0; i < this.dim; ++i) {
                dArray2[i] = dArray2[i] * d2;
            }
            return dArray2;
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }
    }

    public static abstract class DiagonalPreconditioning
    extends AbstractMassPreconditioning {
        protected AdaptableVector adaptiveDiagonal;

        protected DiagonalPreconditioning(int n, Transform transform) {
            super(n, transform);
            this.adaptiveDiagonal = new AdaptableVector.Default(n);
            this.initializeMass();
        }

        @Override
        protected void initializeMass() {
            double[] dArray = new double[this.dim];
            Arrays.fill(dArray, 1.0);
            this.inverseMass = dArray;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] dArray = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = MathUtils.nextGaussian() * Math.sqrt(1.0 / this.inverseMass[i]);
            }
            return new WrappedVector.Raw(dArray);
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            return readableVector.get(n) * this.inverseMass[n];
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            int n;
            if (nArray.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (n = 0; n < readableVector.getDim(); ++n) {
                raw.set(n, readableVector.get(n));
            }
            n = nArray[0];
            int n2 = nArray[1];
            double d = ((this.inverseMass[n2] - this.inverseMass[n]) * readableVector.get(n) + 2.0 * this.inverseMass[n2] * readableVector.get(n2)) / (this.inverseMass[n] + this.inverseMass[n2]);
            double d2 = ((this.inverseMass[n] - this.inverseMass[n2]) * readableVector.get(n2) + 2.0 * this.inverseMass[n] * readableVector.get(n)) / (this.inverseMass[n] + this.inverseMass[n2]);
            raw.set(n, d);
            raw.set(n2, d2);
            return raw;
        }
    }

    public static abstract class HessianBased
    extends AbstractMassPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            this(hessianWrtParameterProvider, transform, hessianWrtParameterProvider.getDimension());
        }

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n) {
            super(n, transform);
            this.hessian = hessianWrtParameterProvider;
            this.initializeMass();
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            throw new RuntimeException("Not yet implemented.");
        }
    }

    public static abstract class AbstractMassPreconditioning
    implements MassPreconditioner {
        protected final int dim;
        protected final Transform transform;
        double[] inverseMass;

        protected AbstractMassPreconditioning(int n, Transform transform) {
            this.dim = n;
            this.transform = transform;
        }

        protected abstract void initializeMass();

        protected abstract double[] computeInverseMass();

        @Override
        public void updateMass() {
            this.inverseMass = this.computeInverseMass();
        }

        @Override
        public abstract void storeSecant(ReadableVector var1, ReadableVector var2);
    }

    public static class NoPreconditioning
    implements MassPreconditioner {
        final int dim;

        NoPreconditioning(int n) {
            this.dim = n;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] dArray = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = MathUtils.nextGaussian();
            }
            return new WrappedVector.Raw(dArray);
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            return readableVector.get(n);
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public void updateMass() {
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            if (nArray.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (int i = 0; i < readableVector.getDim(); ++i) {
                raw.set(i, readableVector.get(i));
            }
            raw.set(nArray[0], readableVector.get(nArray[1]));
            raw.set(nArray[1], readableVector.get(nArray[0]));
            return raw;
        }
    }

    public static enum Type {
        NONE("none"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                Parameter parameter = gradientWrtParameterProvider.getParameter();
                int n = parameter.getDimension();
                if (transform != null && transform instanceof Transform.MultivariableTransform) {
                    n = ((Transform.MultivariableTransform)transform).getDimension();
                }
                return new NoPreconditioning(n);
            }
        }
        ,
        DIAGONAL("diagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new DiagonalHessianPreconditioning((HessianWrtParameterProvider)gradientWrtParameterProvider, transform, options.preconditioningMemory);
            }
        }
        ,
        ADAPTIVE_DIAGONAL("adaptiveDiagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new AdaptiveDiagonalPreconditioning(gradientWrtParameterProvider.getDimension(), transform);
            }
        }
        ,
        FULL("full"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                return new FullHessianPreconditioning((HessianWrtParameterProvider)gradientWrtParameterProvider, transform);
            }
        }
        ,
        SECANT("secant"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                SecantHessian secantHessian = new SecantHessian(gradientWrtParameterProvider, options.preconditioningMemory);
                return new Secant(secantHessian, transform);
            }
        }
        ,
        ADAPTIVE("adaptive"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, HamiltonianMonteCarloOperator.Options options) {
                AdaptableCovariance adaptableCovariance = new AdaptableCovariance(gradientWrtParameterProvider.getDimension());
                return new AdaptiveFullHessianPreconditioning(gradientWrtParameterProvider, adaptableCovariance, transform, gradientWrtParameterProvider.getDimension());
            }
        };

        private final String name;

        private Type(String string2) {
            this.name = string2;
        }

        public abstract MassPreconditioner factory(GradientWrtParameterProvider var1, Transform var2, HamiltonianMonteCarloOperator.Options var3);

        public String getName() {
            return this.name;
        }

        public static Type parseFromString(String string) {
            for (Type type : Type.values()) {
                if (type.name.toLowerCase().compareToIgnoreCase(string) != 0) continue;
                return type;
            }
            return NONE;
        }
    }
}

