/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.ReversibleHMCProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.GeneralOperator;
import dr.inference.operators.hmc.SplitHMCtravelTimeMultiplier;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;

public class SplitHamiltonianMonteCarloOperator
extends AbstractAdaptableOperator
implements GeneralOperator,
ReversibleHMCProvider {
    private double stepSize;
    public double relativeScale;
    public final SplitHMCtravelTimeMultiplier travelTimeMultipler;
    public ReversibleHMCProvider inner;
    public ReversibleHMCProvider outer;
    protected final Parameter parameter;
    int dimInner;
    int dimOuter;
    private int nSteps;
    private int nSplitOuter;
    private int gradientCheckCount;
    private double gradientCheckTolerance;

    public SplitHamiltonianMonteCarloOperator(double d, ReversibleHMCProvider reversibleHMCProvider, ReversibleHMCProvider reversibleHMCProvider2, Parameter parameter, double d2, double d3, int n, int n2, int n3, double d4, SplitHMCtravelTimeMultiplier splitHMCtravelTimeMultiplier) {
        this.setWeight(d);
        this.inner = reversibleHMCProvider;
        this.outer = reversibleHMCProvider2;
        this.dimInner = reversibleHMCProvider.getInitialPosition().length;
        this.dimOuter = reversibleHMCProvider2.getInitialPosition().length;
        this.parameter = parameter;
        this.stepSize = d2;
        this.relativeScale = d3;
        this.nSteps = n;
        this.nSplitOuter = n2;
        this.gradientCheckCount = n3;
        this.gradientCheckTolerance = d4;
        this.travelTimeMultipler = splitHMCtravelTimeMultiplier;
    }

    @Override
    public double doOperation(Likelihood likelihood) {
        if (this.getCount() < (long)this.gradientCheckCount) {
            this.checkGradient(likelihood);
        }
        this.updateRS();
        if (this.travelTimeMultipler.shouldGetMultiplier(this.getCount())) {
            this.relativeScale = this.travelTimeMultipler.getMultiplier();
        }
        return this.mergedUpdate();
    }

    void checkGradient(final Likelihood likelihood) {
        if (this.parameter.getDimension() != this.dimInner + this.dimOuter) {
            throw new RuntimeException("Unequal dimensions");
        }
        MultivariateFunction multivariateFunction = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                if (!SplitHamiltonianMonteCarloOperator.this.anyTransform()) {
                    ReadableVector.Utils.setParameter(dArray, SplitHamiltonianMonteCarloOperator.this.parameter);
                    return likelihood.getLogLikelihood();
                }
                double[] dArray2 = SplitHamiltonianMonteCarloOperator.this.jointTransformInverse(dArray);
                ReadableVector.Utils.setParameter(dArray2, SplitHamiltonianMonteCarloOperator.this.parameter);
                return likelihood.getLogLikelihood() - SplitHamiltonianMonteCarloOperator.this.transformGetLogJacobian(dArray2);
            }

            @Override
            public int getNumArguments() {
                return SplitHamiltonianMonteCarloOperator.this.parameter.getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return SplitHamiltonianMonteCarloOperator.this.parameter.getBounds().getLowerLimit(n);
            }

            @Override
            public double getUpperBound(int n) {
                return SplitHamiltonianMonteCarloOperator.this.parameter.getBounds().getUpperLimit(n);
            }
        };
        double[] dArray = this.mergeGradient();
        double[] dArray2 = this.parameter.getParameterValues();
        if (!this.anyTransform()) {
            double[] dArray3 = NumericalDerivative.gradient(multivariateFunction, this.parameter.getParameterValues());
            if (!MathUtils.isClose(dArray, dArray3, this.gradientCheckTolerance)) {
                String string = "Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(dArray) + "\n\tNumeric : " + new WrappedVector.Raw(dArray3) + "\n";
                throw new RuntimeException(string);
            }
        } else {
            double[] dArray4 = this.getInitialPosition();
            double[] dArray5 = NumericalDerivative.gradient(multivariateFunction, dArray4);
            double[] dArray6 = this.transformupdateGradientLogDensity(dArray, this.parameter);
            if (!MathUtils.isClose(dArray6, dArray5, this.gradientCheckTolerance)) {
                String string = "Transformed Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(dArray6) + "\n\tNumeric : " + new WrappedVector.Raw(dArray5) + "\n\tParameter : " + new WrappedVector.Raw(this.parameter.getParameterValues()) + "\n\tTransformed Parameter : " + new WrappedVector.Raw(dArray4) + "\n";
                throw new RuntimeException(string);
            }
        }
        ReadableVector.Utils.setParameter(dArray2, this.parameter);
    }

    private boolean anyTransform() {
        return this.inner.getParameterLogJacobian() != 0.0 || this.outer.getParameterLogJacobian() != 0.0;
    }

    private double mergedUpdate() {
        double[] dArray = this.inner.getInitialPosition();
        double[] dArray2 = this.outer.getInitialPosition();
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray2);
        WrappedVector wrappedVector = this.inner.drawMomentum();
        WrappedVector wrappedVector2 = this.outer.drawMomentum();
        WrappedVector.Raw raw3 = new WrappedVector.Raw(this.inner.getGradientProvider().getGradientLogDensity());
        WrappedVector.Raw raw4 = new WrappedVector.Raw(this.outer.getGradientProvider().getGradientLogDensity());
        double d = this.inner.getKineticEnergy(wrappedVector) + this.outer.getKineticEnergy(wrappedVector2) + this.inner.getParameterLogJacobian() + this.outer.getParameterLogJacobian();
        for (int i = 0; i < this.nSteps; ++i) {
            int n;
            for (n = 0; n < this.nSplitOuter; ++n) {
                this.outer.reversiblePositionMomentumUpdate(raw2, wrappedVector2, raw4, 1, 0.5 * this.stepSize / (double)this.nSplitOuter);
            }
            this.inner.reversiblePositionMomentumUpdate(raw, wrappedVector, raw3, 1, this.relativeScale * this.stepSize);
            this.updateOuterGradient(raw4);
            for (n = 0; n < this.nSplitOuter; ++n) {
                this.outer.reversiblePositionMomentumUpdate(raw2, wrappedVector2, raw4, 1, 0.5 * this.stepSize / (double)this.nSplitOuter);
            }
        }
        double d2 = this.inner.getKineticEnergy(wrappedVector) + this.outer.getKineticEnergy(wrappedVector2) + this.inner.getParameterLogJacobian() + this.outer.getParameterLogJacobian();
        return d - d2;
    }

    public void updateOuterGradient(WrappedVector wrappedVector) {
        double[] dArray = this.outer.getGradientProvider().getGradientLogDensity();
        for (int i = 0; i < dArray.length; ++i) {
            wrappedVector.set(i, dArray[i]);
        }
    }

    @Override
    public double doOperation() {
        throw new RuntimeException("Should not be executed");
    }

    @Override
    public String getOperatorName() {
        return "Split HMC operator";
    }

    @Override
    protected void setAdaptableParameterValue(double d) {
    }

    @Override
    protected double getAdaptableParameterValue() {
        return 1.0;
    }

    @Override
    public double getRawParameter() {
        return 1.0;
    }

    @Override
    public String getAdaptableParameterName() {
        return null;
    }

    @Override
    public void reversiblePositionMomentumUpdate(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3, int n, double d) {
        int n2;
        this.updateRS();
        double[] dArray = new double[this.dimInner];
        double[] dArray2 = new double[this.dimOuter];
        double[] dArray3 = new double[this.dimInner];
        double[] dArray4 = new double[this.dimOuter];
        this.splitWrappedVector(wrappedVector, dArray, dArray2);
        this.splitWrappedVector(wrappedVector2, dArray3, dArray4);
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray2);
        WrappedVector.Raw raw3 = new WrappedVector.Raw(dArray3);
        WrappedVector.Raw raw4 = new WrappedVector.Raw(dArray4);
        for (n2 = 0; n2 < this.nSplitOuter; ++n2) {
            this.outer.reversiblePositionMomentumUpdate(raw2, raw4, wrappedVector3, n, 0.5 * d / (double)this.nSplitOuter);
        }
        this.inner.reversiblePositionMomentumUpdate(raw, raw3, wrappedVector3, n, this.relativeScale * d);
        this.updateOuterGradient(wrappedVector3);
        for (n2 = 0; n2 < this.nSplitOuter; ++n2) {
            this.outer.reversiblePositionMomentumUpdate(raw2, raw4, wrappedVector3, n, 0.5 * d / (double)this.nSplitOuter);
        }
        this.updateMergedVector(raw, raw2, wrappedVector);
        this.updateMergedVector(raw3, raw4, wrappedVector2);
    }

    @Override
    public void providerUpdatePreconditioning() {
        this.inner.providerUpdatePreconditioning();
        this.outer.providerUpdatePreconditioning();
    }

    public double[] jointTransformInverse(double[] dArray) {
        double[] dArray2 = new double[this.dimInner + this.dimOuter];
        double[] dArray3 = new double[this.dimOuter];
        System.arraycopy(dArray, this.dimInner, dArray3, 0, this.dimOuter);
        double[] dArray4 = this.outer.getTransform().inverse(dArray3, 0, this.dimOuter);
        System.arraycopy(dArray, 0, dArray2, 0, this.dimInner);
        System.arraycopy(dArray4, 0, dArray2, this.dimInner, this.dimOuter);
        return dArray2;
    }

    public double transformGetLogJacobian(double[] dArray) {
        double[] dArray2 = new double[this.dimOuter];
        System.arraycopy(dArray, this.dimInner, dArray2, 0, this.dimOuter);
        return this.outer.getTransform().logJacobian(dArray2, 0, this.dimOuter);
    }

    public double[] transformupdateGradientLogDensity(double[] dArray, Parameter parameter) {
        double[] dArray2 = new double[this.dimInner];
        double[] dArray3 = new double[this.dimOuter];
        double[] dArray4 = new double[this.dimOuter];
        double[] dArray5 = new double[this.dimInner + this.dimOuter];
        System.arraycopy(dArray, 0, dArray2, 0, this.dimInner);
        System.arraycopy(dArray, this.dimInner, dArray3, 0, this.dimOuter);
        System.arraycopy(parameter.getParameterValues(), this.dimInner, dArray4, 0, this.dimOuter);
        double[] dArray6 = this.outer.getTransform().updateGradientLogDensity(dArray3, dArray4, 0, this.dimOuter);
        System.arraycopy(dArray2, 0, dArray5, 0, this.dimInner);
        System.arraycopy(dArray6, 0, dArray5, this.dimInner, this.dimOuter);
        return dArray5;
    }

    @Override
    public double[] getInitialPosition() {
        double[] dArray = new double[this.dimInner + this.dimOuter];
        System.arraycopy(this.inner.getInitialPosition(), 0, dArray, 0, this.dimInner);
        System.arraycopy(this.outer.getInitialPosition(), 0, dArray, this.dimInner, this.dimOuter);
        return dArray;
    }

    @Override
    public double getParameterLogJacobian() {
        return this.inner.getParameterLogJacobian() + this.outer.getParameterLogJacobian();
    }

    @Override
    public int getNumGradientEvent() {
        return this.inner.getNumGradientEvent() + this.outer.getNumGradientEvent();
    }

    @Override
    public int getNumBoundaryEvent() {
        return this.inner.getNumBoundaryEvent() + this.outer.getNumBoundaryEvent();
    }

    @Override
    public double[] getMask() {
        return new double[0];
    }

    @Override
    public Transform getTransform() {
        return null;
    }

    @Override
    public GradientWrtParameterProvider getGradientProvider() {
        return this.outer.getGradientProvider();
    }

    private double[] mergeGradient() {
        double[] dArray = new double[this.dimInner + this.dimOuter];
        System.arraycopy(this.inner.getGradientProvider().getGradientLogDensity(), 0, dArray, 0, this.dimInner);
        System.arraycopy(this.outer.getGradientProvider().getGradientLogDensity(), 0, dArray, this.dimInner, this.dimOuter);
        return dArray;
    }

    @Override
    public void setParameter(double[] dArray) {
        double[] dArray2 = new double[this.dimInner];
        double[] dArray3 = new double[this.dimOuter];
        System.arraycopy(dArray, 0, dArray2, 0, this.dimInner);
        System.arraycopy(dArray, this.dimInner, dArray3, 0, this.dimOuter);
        this.inner.setParameter(dArray2);
        this.outer.setParameter(dArray3);
    }

    @Override
    public WrappedVector drawMomentum() {
        return this.mergeWrappedVector(this.inner.drawMomentum(), this.outer.drawMomentum());
    }

    @Override
    public double getJointProbability(WrappedVector wrappedVector) {
        double[] dArray = new double[this.dimInner];
        double[] dArray2 = new double[this.dimOuter];
        this.splitWrappedVector(wrappedVector, dArray, dArray2);
        return this.inner.getJointProbability(new WrappedVector.Raw(dArray)) + this.outer.getJointProbability(new WrappedVector.Raw(dArray2)) - this.inner.getLogLikelihood();
    }

    @Override
    public double getLogLikelihood() {
        return this.inner.getLogLikelihood();
    }

    @Override
    public double getKineticEnergy(ReadableVector readableVector) {
        double[] dArray = new double[this.dimInner];
        double[] dArray2 = new double[this.dimOuter];
        System.arraycopy(((WrappedVector)readableVector).getBuffer(), 0, dArray, 0, this.dimInner);
        System.arraycopy(((WrappedVector)readableVector).getBuffer(), this.dimInner, dArray2, 0, this.dimOuter);
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray2);
        return this.inner.getKineticEnergy(raw) + this.outer.getKineticEnergy(raw2);
    }

    @Override
    public double getStepSize() {
        return this.stepSize;
    }

    private WrappedVector mergeWrappedVector(WrappedVector wrappedVector, WrappedVector wrappedVector2) {
        double[] dArray = new double[wrappedVector.getDim() + wrappedVector2.getDim()];
        System.arraycopy(wrappedVector.getBuffer(), 0, dArray, 0, wrappedVector.getDim());
        System.arraycopy(wrappedVector2.getBuffer(), 0, dArray, wrappedVector.getDim(), wrappedVector2.getDim());
        return new WrappedVector.Raw(dArray);
    }

    private void splitWrappedVector(WrappedVector wrappedVector, double[] dArray, double[] dArray2) {
        System.arraycopy(wrappedVector.getBuffer(), 0, dArray, 0, this.dimInner);
        System.arraycopy(wrappedVector.getBuffer(), this.dimInner, dArray2, 0, this.dimOuter);
    }

    private void updateMergedVector(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3) {
        for (int i = 0; i < this.dimInner + this.dimOuter; ++i) {
            if (i < this.dimInner) {
                wrappedVector3.set(i, wrappedVector.get(i));
                continue;
            }
            wrappedVector3.set(i, wrappedVector2.get(i - this.dimInner));
        }
    }

    private void updateRS() {
        if (this.travelTimeMultipler != null && this.travelTimeMultipler.shouldUpdateSCM(this.getCount())) {
            this.travelTimeMultipler.updateSCM(this.travelTimeMultipler.getInnerCov(), this.inner.getInitialPosition(), this.getCount());
            this.travelTimeMultipler.updateSCM(this.travelTimeMultipler.getOuterCov(), this.outer.getInitialPosition(), this.getCount());
        }
    }
}

