/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.hmc.ReversibleHMCProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.AbstractZigZagOperator;
import dr.inference.operators.hmc.MassPreconditionScheduler;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.inference.operators.hmc.MinimumTravelInformationBinary;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.TaskPool;
import dr.util.Transform;
import dr.xml.Reportable;
import java.util.function.BinaryOperator;

public class ReversibleZigZagOperator
extends AbstractZigZagOperator
implements Reportable,
ReversibleHMCProvider {
    public ReversibleZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, AbstractParticleOperator.NativeCodeOptions nativeCodeOptions, boolean bl, Parameter parameter, Parameter parameter2, int n, MassPreconditioner massPreconditioner, MassPreconditionScheduler.Type type) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, nativeCodeOptions, bl, parameter, parameter2, n, massPreconditioner, type);
    }

    @Override
    public String getOperatorName() {
        return "Zig-zag particle operator";
    }

    @Override
    MinimumTravelInformation getNextBounce(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        MinimumTravelInformationBinary minimumTravelInformationBinary;
        MinimumTravelInformation minimumTravelInformation;
        this.timer.startTimer("getNext");
        if (this.taskPool != null) {
            minimumTravelInformation = this.getNextBounceParallel(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
        } else if (this.nativeCodeOptions.useNativeFindNextBounce) {
            minimumTravelInformationBinary = this.getNextBounceNative(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
            minimumTravelInformation = new MinimumTravelInformation(minimumTravelInformationBinary.time, minimumTravelInformationBinary.index, minimumTravelInformationBinary.type);
        } else {
            minimumTravelInformation = this.getNextBounceSerial(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
        }
        this.timer.stopTimer("getNext");
        if (this.nativeCodeOptions.testNativeFindNextBounce) {
            minimumTravelInformation = this.getNextBounceSerial(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
            minimumTravelInformationBinary = new MinimumTravelInformationBinary(minimumTravelInformation.time, minimumTravelInformation.index[0], minimumTravelInformation.type);
            this.testNative(minimumTravelInformationBinary, wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4, wrappedVector5);
        }
        return minimumTravelInformation;
    }

    private MinimumTravelInformation getNextBounceSerial(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        return this.getNextBounceImpl(0, wrappedVector.getDim(), wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
    }

    private MinimumTravelInformation getNextBounceParallel(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        double[] dArray3 = wrappedVector3.getBuffer();
        double[] dArray4 = wrappedVector4.getBuffer();
        double[] dArray5 = wrappedVector5.getBuffer();
        TaskPool.RangeCallable<MinimumTravelInformation> rangeCallable = (n, n2, n3) -> this.getNextBounceImpl(n, n2, dArray, dArray2, dArray3, dArray4, dArray5);
        BinaryOperator binaryOperator = (minimumTravelInformation, minimumTravelInformation2) -> minimumTravelInformation.time < minimumTravelInformation2.time ? minimumTravelInformation : minimumTravelInformation2;
        return this.taskPool.mapReduce(rangeCallable, binaryOperator);
    }

    private MinimumTravelInformation getNextBounceImpl(int n, int n2, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5) {
        double d = Double.POSITIVE_INFINITY;
        int[] nArray = new int[]{-1, -1};
        AbstractParticleOperator.Type type = AbstractParticleOperator.Type.NONE;
        MinimumTravelInformation minimumTravelInformation = this.findCategoricalBoundaryTime(dArray, dArray2);
        if (minimumTravelInformation.time < d) {
            d = minimumTravelInformation.time;
            nArray = (int[])minimumTravelInformation.index.clone();
            type = AbstractParticleOperator.Type.CATE_BOUNDARY;
        }
        for (int i = n; i < n2; ++i) {
            double d2;
            double d3 = this.findBinaryBoundaryTime(i, dArray[i], dArray2[i]);
            if (d3 < d) {
                d = d3;
                nArray[0] = i;
                nArray[1] = -1;
                type = AbstractParticleOperator.Type.BINARY_BOUNDARY;
            }
            if (!((d2 = ReversibleZigZagOperator.findGradientRoot(dArray3[i], dArray4[i], dArray5[i])) < d)) continue;
            d = d2;
            nArray[0] = i;
            nArray[1] = -1;
            type = AbstractParticleOperator.Type.GRADIENT;
        }
        return new MinimumTravelInformation(d, nArray, type);
    }

    @Override
    final WrappedVector drawInitialMomentum() {
        WrappedVector wrappedVector = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector.getDim()];
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = MathUtils.nextDouble() > 0.5 ? 1 : -1;
            dArray[i] = (double)n2 * MathUtils.nextExponential(1.0) * Math.sqrt(wrappedVector.get(i));
        }
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    @Override
    final WrappedVector drawInitialVelocity(WrappedVector wrappedVector) {
        if (!this.refreshVelocity && this.storedVelocity != null) {
            return this.storedVelocity;
        }
        WrappedVector wrappedVector2 = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector.getDim()];
        int n = wrappedVector.getDim();
        for (int i = 0; i < n; ++i) {
            dArray[i] = (double)ReversibleZigZagOperator.sign(wrappedVector.get(i)) / Math.sqrt(wrappedVector2.get(i));
        }
        return new WrappedVector.Raw(dArray);
    }

    private void testNative(MinimumTravelInformationBinary minimumTravelInformationBinary, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformationBinary minimumTravelInformationBinary2 = this.nativeZigZag.getNextReversibleEvent(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
        this.timer.stopTimer("getNextC++");
        if (!minimumTravelInformationBinary.equals(minimumTravelInformationBinary2)) {
            System.err.println(minimumTravelInformationBinary2 + " ?= " + minimumTravelInformationBinary + "\n");
            System.exit(-1);
        }
    }

    private MinimumTravelInformationBinary getNextBounceNative(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformationBinary minimumTravelInformationBinary = this.nativeZigZag.getNextReversibleEvent(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer());
        this.timer.stopTimer("getNextC++");
        return minimumTravelInformationBinary;
    }

    private void updateDynamics(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5, double[] dArray6, double d, int[] nArray) {
        double d2 = d * d / 2.0;
        double d3 = 2.0 * dArray2[nArray[0]];
        double d4 = 0.0;
        double[] dArray7 = new double[dArray6.length];
        if (nArray[1] > 0) {
            dArray7 = this.getPrecisionColumn(nArray[1]).getBuffer();
            d4 = 2.0 * dArray2[nArray[1]];
        }
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            double d5 = dArray4[i];
            double d6 = dArray3[i];
            dArray[i] = dArray[i] + d * dArray2[i];
            dArray5[i] = dArray5[i] + d * d5 - d2 * d6;
            dArray4[i] = d5 - d * d6;
            dArray3[i] = d6 - d3 * dArray6[i] - d4 * dArray7[i];
        }
    }

    @Override
    void updateDynamics(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, WrappedVector wrappedVector6, double d, int[] nArray, AbstractParticleOperator.Type type) {
        if (!this.nativeCodeOptions.useNativeUpdateDynamics) {
            this.updateDynamics(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer(), wrappedVector6.getBuffer(), d, nArray);
        } else {
            this.nativeZigZag.updateReversibleDynamics(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer(), wrappedVector6.getBuffer(), d, nArray[0], type.ordinal());
        }
        if (type == AbstractParticleOperator.Type.BINARY_BOUNDARY) {
            ReversibleZigZagOperator.reflectMomentum(wrappedVector5, nArray[0]);
            ReversibleZigZagOperator.setZeroPosition(wrappedVector, nArray[0]);
        } else if (type == AbstractParticleOperator.Type.CATE_BOUNDARY) {
            ReversibleZigZagOperator.reflectMomentum(wrappedVector5, nArray[0]);
            ReversibleZigZagOperator.reflectMomentum(wrappedVector5, nArray[1]);
            ReversibleZigZagOperator.setEqualPosition(wrappedVector, nArray[0], nArray[1]);
        } else {
            ReversibleZigZagOperator.setZeroMomentum(wrappedVector5, nArray[0]);
        }
    }

    @Override
    void updatePositionAndMomentum(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, double d) {
        ReversibleZigZagOperator.updatePosition(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), d);
        ReversibleZigZagOperator.updateMomentum(wrappedVector3.getBuffer(), wrappedVector4.getBuffer(), wrappedVector5.getBuffer(), d);
    }

    @Override
    public void reversiblePositionMomentumUpdate(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, int n, double d) {
        this.preconditioning.totalTravelTime = d;
        if (n == -1) {
            this.negateVector(wrappedVector2);
        }
        this.integrateTrajectory(wrappedVector, wrappedVector2);
        if (n == -1) {
            this.negateVector(wrappedVector2);
        }
        ReadableVector.Utils.setParameter((ReadableVector)wrappedVector, this.parameter);
    }

    @Override
    public void providerUpdatePreconditioning() {
        this.updatePreconditioning(new WrappedVector.Raw(this.getInitialPosition()));
    }

    @Override
    public double[] getInitialPosition() {
        return this.parameter.getParameterValues();
    }

    @Override
    public double getParameterLogJacobian() {
        return 0.0;
    }

    @Override
    public int getNumGradientEvent() {
        return this.numGradientEvents;
    }

    @Override
    public int getNumBoundaryEvent() {
        return this.numBoundaryEvents;
    }

    @Override
    public double[] getMask() {
        return this.maskVector;
    }

    @Override
    public Transform getTransform() {
        return null;
    }

    @Override
    public GradientWrtParameterProvider getGradientProvider() {
        return this.gradientProvider;
    }

    @Override
    public void setParameter(double[] dArray) {
        ReadableVector.Utils.setParameter(dArray, this.parameter);
    }

    @Override
    public WrappedVector drawMomentum() {
        return this.drawInitialMomentum();
    }

    @Override
    public double getJointProbability(WrappedVector wrappedVector) {
        return this.gradientProvider.getLikelihood().getLogLikelihood() - this.getKineticEnergy(wrappedVector) - this.getParameterLogJacobian();
    }

    @Override
    public double getLogLikelihood() {
        return this.gradientProvider.getLikelihood().getLogLikelihood();
    }

    @Override
    public double getKineticEnergy(ReadableVector readableVector) {
        int n = readableVector.getDim();
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            d += Math.abs(readableVector.get(i));
        }
        return d;
    }

    @Override
    public double getStepSize() {
        return this.preconditioning.totalTravelTime;
    }

    private void negateVector(WrappedVector wrappedVector) {
        for (int i = 0; i < wrappedVector.getDim(); ++i) {
            wrappedVector.set(i, -wrappedVector.get(i));
        }
    }
}

