/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.GraphicalParameterBound;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;

public class ReflectiveHamiltonianMonteCarloOperator
extends HamiltonianMonteCarloOperator {
    private final GraphicalParameterBound treeParameterBound;

    public ReflectiveHamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, HamiltonianMonteCarloOperator.Options options, MassPreconditioner.Type type, GraphicalParameterBound graphicalParameterBound) {
        super(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, type);
        this.treeParameterBound = graphicalParameterBound;
        this.leapFrogEngine = this.constructLeapFrogEngine(transform);
    }

    @Override
    protected HamiltonianMonteCarloOperator.LeapFrogEngine constructLeapFrogEngine(Transform transform) {
        return new WithGraphBounds(this.parameter, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask, this.treeParameterBound);
    }

    static enum ReflectionType {
        Reflection{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double d, int[] nArray, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d2);
                wrappedVector.set(nArray[0], -wrappedVector.get(nArray[0]));
                dArray[nArray[0]] = d;
            }
        }
        ,
        Collision{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double d, int[] nArray, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d2);
                ReadableVector readableVector = massPreconditioner.doCollision(nArray, wrappedVector);
                for (int n : nArray) {
                    wrappedVector.set(n, readableVector.get(n));
                    dArray[n] = d;
                }
            }
        }
        ,
        None{

            @Override
            void doReflection(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double d, int[] nArray, double d2) {
                this.updatePosition(dArray, massPreconditioner, wrappedVector, d2);
            }
        };


        void updatePosition(double[] dArray, MassPreconditioner massPreconditioner, WrappedVector wrappedVector, double d) {
            int n = dArray.length;
            for (int i = 0; i < n; ++i) {
                int n2 = i;
                dArray[n2] = dArray[n2] + massPreconditioner.getVelocity(i, wrappedVector) * d;
            }
        }

        abstract void doReflection(double[] var1, MassPreconditioner var2, WrappedVector var3, double var4, int[] var6, double var7);
    }

    class ReflectionEvent {
        private final ReflectionType type;
        private final double eventTime;
        private final double eventLocation;
        private final double intervalLength;
        private final int[] indices;

        ReflectionEvent(ReflectionType reflectionType, double d, double d2, double d3, int[] nArray) {
            this.type = reflectionType;
            this.eventTime = d;
            this.intervalLength = d3;
            this.indices = nArray;
            this.eventLocation = d2;
        }

        public double getEventTime() {
            return this.eventTime;
        }

        public ReflectionType getType() {
            return this.type;
        }

        public void doReflection(double[] dArray, WrappedVector wrappedVector) {
            this.type.doReflection(dArray, ReflectiveHamiltonianMonteCarloOperator.this.preconditioning, wrappedVector, this.eventLocation, this.indices, this.eventTime);
        }
    }

    class WithGraphBounds
    extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default {
        private final GraphicalParameterBound graphicalParameterBound;

        protected WithGraphBounds(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray, GraphicalParameterBound graphicalParameterBound) {
            super(parameter, instabilityHandler, massPreconditioner, dArray);
            this.graphicalParameterBound = graphicalParameterBound;
        }

        @Override
        public void updatePosition(double[] dArray, WrappedVector wrappedVector, double d) {
            ReflectionEvent reflectionEvent;
            for (double d2 = 0.0; d2 < d; d2 += reflectionEvent.getEventTime()) {
                reflectionEvent = this.nextEvent(dArray, wrappedVector, d - d2);
                reflectionEvent.doReflection(dArray, wrappedVector);
            }
            this.setParameter(dArray);
        }

        private ReflectionEvent nextEvent(double[] dArray, WrappedVector wrappedVector, double d) {
            ReflectionEvent reflectionEvent = this.firstReflectionAtFixedBounds(dArray, wrappedVector, d);
            ReflectionEvent reflectionEvent2 = this.firstCollision(dArray, wrappedVector, d);
            return reflectionEvent.getEventTime() < reflectionEvent2.getEventTime() ? reflectionEvent : reflectionEvent2;
        }

        private boolean isReflected(double d, double d2, double d3) {
            if (d > d3) {
                return d2 <= d3;
            }
            if (d < d3) {
                return d2 >= d3;
            }
            return false;
        }

        private boolean isCollision(double d, double d2, double d3, double d4) {
            if (d > d3) {
                return d2 <= d4;
            }
            if (d < d3) {
                return d2 >= d4;
            }
            return false;
        }

        private ReflectionEvent firstCollision(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = this.getIntendedPosition(dArray, readableVector, d);
            double d2 = d;
            double d3 = -1.0;
            ReflectionType reflectionType = ReflectionType.None;
            int n2 = -1;
            int n3 = -1;
            for (int i = 0; i < n; ++i) {
                double d4 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                int[] nArray = this.graphicalParameterBound.getConnectedParameterIndices(i);
                if (nArray == null) continue;
                for (int n4 : this.graphicalParameterBound.getConnectedParameterIndices(i)) {
                    double d5;
                    if (n4 <= i) continue;
                    double d6 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(n4, readableVector);
                    if (!this.isCollision(dArray[i], dArray2[i], dArray[n4], dArray2[n4]) || !((d5 = (dArray[n4] - dArray[i]) / (d4 - d6)) < d2)) continue;
                    d2 = d5;
                    d3 = d5 * d4 + dArray[i];
                    n2 = i;
                    n3 = n4;
                    reflectionType = ReflectionType.Collision;
                }
            }
            return new ReflectionEvent(reflectionType, d2, d3, d, new int[]{n2, n3});
        }

        private double[] getIntendedPosition(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = new double[n];
            for (int i = 0; i < n; ++i) {
                double d2 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                dArray2[i] = dArray[i] + d * d2;
            }
            return dArray2;
        }

        private ReflectionEvent firstReflectionAtFixedBounds(double[] dArray, ReadableVector readableVector, double d) {
            int n = dArray.length;
            double[] dArray2 = this.getIntendedPosition(dArray, readableVector, d);
            double d2 = d;
            double d3 = -1.0;
            ReflectionType reflectionType = ReflectionType.None;
            int n2 = -1;
            for (int i = 0; i < n; ++i) {
                double d4;
                double d5 = ReflectiveHamiltonianMonteCarloOperator.this.preconditioning.getVelocity(i, readableVector);
                double d6 = this.graphicalParameterBound.getFixedUpperBound(i);
                double d7 = this.graphicalParameterBound.getFixedLowerBound(i);
                if (this.isReflected(dArray[i], dArray2[i], d6)) {
                    d4 = (d6 - dArray[i]) / d5;
                    if (d4 < 0.0) {
                        throw new RuntimeException("Check isReflected() function plz.");
                    }
                    if (!(d4 < d2)) continue;
                    d2 = d4;
                    reflectionType = ReflectionType.Reflection;
                    n2 = i;
                    d3 = d6;
                    continue;
                }
                if (!this.isReflected(dArray[i], dArray2[i], d7)) continue;
                d4 = (d7 - dArray[i]) / d5;
                if (d4 < 0.0) {
                    throw new RuntimeException("Check isReflected() function plz.");
                }
                if (!(d4 < d2)) continue;
                d2 = d4;
                reflectionType = ReflectionType.Reflection;
                n2 = i;
                d3 = d7;
            }
            return new ReflectionEvent(reflectionType, d2, d3, d, new int[]{n2});
        }
    }
}

