/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.AbstractParticleOperator;
import dr.inference.operators.hmc.AbstractZigZagOperator;
import dr.inference.operators.hmc.MassPreconditionScheduler;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.MinimumTravelInformation;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.WrappedVector;
import dr.xml.Reportable;

public class IrreversibleZigZagOperator
extends AbstractZigZagOperator
implements Reportable {
    static final boolean CPP_NEXT_BOUNCE = false;
    private static final boolean NEW_WAY = false;
    private static final boolean NOT_YET_IMPLEMENTED = false;

    public IrreversibleZigZagOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, AbstractParticleOperator.Options options, AbstractParticleOperator.NativeCodeOptions nativeCodeOptions, boolean bl, Parameter parameter, Parameter parameter2, int n, MassPreconditioner massPreconditioner, MassPreconditionScheduler.Type type) {
        super(gradientWrtParameterProvider, precisionMatrixVectorProductProvider, precisionColumnProvider, d, options, nativeCodeOptions, bl, parameter, parameter2, n, massPreconditioner, type);
    }

    @Override
    WrappedVector drawInitialVelocity(WrappedVector wrappedVector) {
        if (!this.refreshVelocity && this.storedVelocity != null) {
            return this.storedVelocity;
        }
        WrappedVector wrappedVector2 = this.preconditioning.mass;
        double[] dArray = new double[wrappedVector2.getDim()];
        int n = wrappedVector2.getDim();
        for (int i = 0; i < n; ++i) {
            dArray[i] = MathUtils.nextDouble() > 0.5 ? 1.0 : -1.0;
        }
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    @Override
    void updatePositionAndMomentum(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, double d) {
        IrreversibleZigZagOperator.updatePosition(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), d);
    }

    private MinimumTravelInformation getNextBounceNative(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformation minimumTravelInformation = this.nativeZigZag.getNextIrreversibleEvent(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer());
        this.timer.stopTimer("getNextC++");
        return minimumTravelInformation;
    }

    private void testNative(MinimumTravelInformation minimumTravelInformation, WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4) {
        this.timer.startTimer("getNextC++");
        MinimumTravelInformation minimumTravelInformation2 = this.nativeZigZag.getNextIrreversibleEvent(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer());
        this.timer.stopTimer("getNextC++");
        if (!minimumTravelInformation.equals(minimumTravelInformation2)) {
            System.err.println(minimumTravelInformation2 + " ?= " + minimumTravelInformation + "\n");
            System.exit(-1);
        }
    }

    @Override
    MinimumTravelInformation getNextBounce(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5) {
        this.timer.startTimer("getNext");
        MinimumTravelInformation minimumTravelInformation = this.nativeCodeOptions.useNativeFindNextBounce ? this.getNextBounceNative(wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4) : this.getNextBounceImpl(wrappedVector.getBuffer(), wrappedVector2.getBuffer(), wrappedVector3.getBuffer(), wrappedVector4.getBuffer());
        this.timer.stopTimer("getNext");
        if (this.nativeCodeOptions.testNativeFindNextBounce) {
            this.testNative(minimumTravelInformation, wrappedVector, wrappedVector2, wrappedVector3, wrappedVector4);
        }
        return minimumTravelInformation;
    }

    private MinimumTravelInformation getNextBounceImpl(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        double d = Double.POSITIVE_INFINITY;
        int n = -1;
        AbstractParticleOperator.Type type = AbstractParticleOperator.Type.NONE;
        int n2 = dArray.length;
        for (int i = 0; i < n2; ++i) {
            double d2;
            double d3;
            double d4 = this.findBinaryBoundaryTime(i, dArray[i], dArray2[i]);
            if (d4 < d) {
                d = d4;
                n = i;
                type = AbstractParticleOperator.Type.BINARY_BOUNDARY;
            }
            if (!((d3 = this.getSwitchTime(-dArray2[i] * dArray4[i], dArray2[i] * dArray3[i], d2 = MathUtils.nextExponential(1.0))) < d)) continue;
            d = d3;
            n = i;
            type = AbstractParticleOperator.Type.GRADIENT;
        }
        return new MinimumTravelInformation(d, n, type);
    }

    private double[] getRoots(double[] dArray, double[] dArray2) {
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            double d = dArray2[i] / dArray[i];
            dArray3[i] = d >= 0.0 ? d : 0.0;
        }
        return dArray3;
    }

    private double getSwitchTimeByMergedProcesses(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5) {
        double d = MathUtils.nextExponential(1.0);
        double d2 = -1.0;
        double d3 = 0.0;
        if (dArray5[dArray5.length - 1] == 0.0) {
            PiecewiseLinearEndpoints piecewiseLinearEndpoints = this.getEndpointInfo(0.0, 0.0, dArray3, dArray2, dArray, dArray4);
            d2 = this.integrateLinearFunctionToArea(piecewiseLinearEndpoints, d, false);
        } else {
            int n = 1;
            double d4 = dArray5[0];
            while (n < dArray5.length) {
                if (dArray5[n] > 0.0) {
                    double d5 = dArray5[n];
                    PiecewiseLinearEndpoints piecewiseLinearEndpoints = this.getEndpointInfo(d4, d5, dArray3, dArray2, dArray, dArray4);
                    double d6 = this.getTrapezoidArea(piecewiseLinearEndpoints);
                    if ((d3 += d6) > d) {
                        double d7 = d - (d3 - d6);
                        d2 = this.integrateLinearFunctionToArea(piecewiseLinearEndpoints, d7, true);
                        break;
                    }
                    if (n == dArray5.length - 1) {
                        double d8 = d - d3;
                        d2 = this.integrateLinearFunctionToArea(piecewiseLinearEndpoints, d8, false);
                        break;
                    }
                    d4 = d5;
                    ++n;
                    continue;
                }
                ++n;
            }
        }
        return d2;
    }

    private double integrateLinearFunctionToArea(PiecewiseLinearEndpoints piecewiseLinearEndpoints, double d, boolean bl) {
        if (bl) {
            double d2 = piecewiseLinearEndpoints.slope0;
            double d3 = piecewiseLinearEndpoints.f0 - d2 * piecewiseLinearEndpoints.c0;
            return this.onlyPositiveRoot(d2 * 0.5, d3, -(d2 * 0.5 * piecewiseLinearEndpoints.c0 * piecewiseLinearEndpoints.c0 + d3 * piecewiseLinearEndpoints.c0 + d));
        }
        double d4 = piecewiseLinearEndpoints.slope1;
        double d5 = piecewiseLinearEndpoints.f1 - d4 * piecewiseLinearEndpoints.c1;
        return this.onlyPositiveRoot(d4 * 0.5, d5, -(d4 * 0.5 * piecewiseLinearEndpoints.c1 * piecewiseLinearEndpoints.c1 + d5 * piecewiseLinearEndpoints.c1 + d));
    }

    private double onlyPositiveRoot(double d, double d2, double d3) {
        return (-d2 + Math.sqrt(d2 * d2 - 4.0 * d * d3)) / (2.0 * d);
    }

    private PiecewiseLinearEndpoints getEndpointInfo(double d, double d2, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        double[] dArray5 = new double[2];
        double[] dArray6 = new double[2];
        for (int i = 0; i < dArray4.length; ++i) {
            this.accumulateCoefficients(d, dArray, dArray2, dArray3, dArray4, dArray5, i);
            this.accumulateCoefficients(d2, dArray, dArray2, dArray3, dArray4, dArray6, i);
        }
        return new PiecewiseLinearEndpoints(d, d2, dArray5[0] * d + dArray5[1], dArray6[0] * d2 + dArray6[1], dArray5[0], dArray6[0]);
    }

    private void accumulateCoefficients(double d, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, double[] dArray5, int n) {
        if (dArray4[n] >= d && dArray[n] * dArray3[n] <= 0.0 || dArray4[n] <= d && dArray[n] * dArray3[n] >= 0.0) {
            dArray5[0] = dArray5[0] + dArray[n] * dArray3[n];
            dArray5[1] = dArray5[1] + -dArray[n] * dArray2[n];
        }
    }

    private double getTrapezoidArea(PiecewiseLinearEndpoints piecewiseLinearEndpoints) {
        return (piecewiseLinearEndpoints.f0 + piecewiseLinearEndpoints.f1) * (piecewiseLinearEndpoints.c1 - piecewiseLinearEndpoints.c0) / 2.0;
    }

    private int getEventDimension(double[] dArray, double[] dArray2, double[] dArray3, double d) {
        double[] dArray4 = new double[dArray.length];
        double d2 = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            double d3 = d * dArray[i] * dArray3[i] - dArray[i] * dArray2[i];
            dArray4[i] = d3 > 0.0 ? d3 : 0.0;
            d2 += dArray4[i];
        }
        double d4 = MathUtils.nextDouble();
        double d5 = 0.0;
        int n = -1;
        for (int i = 0; i < dArray4.length; ++i) {
            if (!(d4 <= (d5 += dArray4[i] / d2))) continue;
            n = i;
            break;
        }
        return n;
    }

    private double getSwitchTime(double d, double d2, double d3) {
        if (d2 > 0.0) {
            if (d < 0.0) {
                return -d / d2 + Math.sqrt(2.0 * d3 / d2);
            }
            return -d / d2 + Math.sqrt(d * d / (d2 * d2) + 2.0 * d3 / d2);
        }
        if (d2 == 0.0) {
            if (d > 0.0) {
                return d3 / d;
            }
            return Double.POSITIVE_INFINITY;
        }
        if (d <= 0.0) {
            return Double.POSITIVE_INFINITY;
        }
        double d4 = -d / d2;
        if (d3 <= d * d4 + d2 * d4 * d4 / 2.0) {
            return d4 - Math.sqrt(d4 * d4 + 2.0 * d3 / d2);
        }
        return Double.POSITIVE_INFINITY;
    }

    @Override
    void updateDynamics(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, WrappedVector wrappedVector4, WrappedVector wrappedVector5, WrappedVector wrappedVector6, double d, int[] nArray, AbstractParticleOperator.Type type) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        double[] dArray3 = wrappedVector3.getBuffer();
        double[] dArray4 = wrappedVector4.getBuffer();
        double[] dArray5 = wrappedVector6.getBuffer();
        double d2 = 2.0 * dArray2[nArray[0]];
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            double d3 = dArray3[i];
            dArray[i] = dArray[i] + d * dArray2[i];
            dArray4[i] = dArray4[i] - d * d3;
            dArray3[i] = d3 - d2 * dArray5[i];
        }
    }

    @Override
    public String getOperatorName() {
        return "Irreversible zig-zag operator";
    }

    private class PiecewiseLinearEndpoints {
        final double c0;
        final double c1;
        final double f0;
        final double f1;
        final double slope0;
        final double slope1;

        private PiecewiseLinearEndpoints(double d, double d2, double d3, double d4, double d5, double d6) {
            this.c0 = d;
            this.c1 = d2;
            this.f0 = d3;
            this.f1 = d4;
            this.slope0 = d5;
            this.slope1 = d6;
        }
    }
}

