/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.ReversibleHMCProvider;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;

public class SplitHamiltonianMonteCarlo
implements ReversibleHMCProvider {
    private int dimA;
    private int dimB;
    private double stepSize;
    private double relativeScale;
    private ReversibleHMCProvider reversibleHMCProviderA;
    private ReversibleHMCProvider reversibleHMCProviderB;

    public SplitHamiltonianMonteCarlo(ReversibleHMCProvider reversibleHMCProvider, ReversibleHMCProvider reversibleHMCProvider2, double d, double d2) {
        this.reversibleHMCProviderA = reversibleHMCProvider;
        this.reversibleHMCProviderB = reversibleHMCProvider2;
        this.dimA = reversibleHMCProvider.getInitialPosition().length;
        this.dimB = reversibleHMCProvider2.getInitialPosition().length;
        this.stepSize = d;
        this.relativeScale = d2;
    }

    @Override
    public void reversiblePositionMomentumUpdate(WrappedVector wrappedVector, WrappedVector wrappedVector2, int n, double d) {
        double[] dArray = new double[this.dimA];
        double[] dArray2 = new double[this.dimB];
        double[] dArray3 = new double[this.dimA];
        double[] dArray4 = new double[this.dimB];
        this.splitWrappedVector(wrappedVector, dArray, dArray2);
        this.splitWrappedVector(wrappedVector2, dArray3, dArray4);
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray2);
        WrappedVector.Raw raw3 = new WrappedVector.Raw(dArray3);
        WrappedVector.Raw raw4 = new WrappedVector.Raw(dArray4);
        this.reversibleHMCProviderB.reversiblePositionMomentumUpdate(raw2, raw4, n, d);
        this.reversibleHMCProviderA.reversiblePositionMomentumUpdate(raw, raw3, n, this.relativeScale * d);
        this.reversibleHMCProviderB.reversiblePositionMomentumUpdate(raw2, raw4, n, d);
        this.updateMergedVector(raw, raw2, wrappedVector);
        this.updateMergedVector(raw3, raw4, wrappedVector2);
    }

    @Override
    public double[] getInitialPosition() {
        double[] dArray = new double[this.dimA + this.dimB];
        System.arraycopy(this.reversibleHMCProviderA.getInitialPosition(), 0, dArray, 0, this.dimA);
        System.arraycopy(this.reversibleHMCProviderB.getInitialPosition(), 0, dArray, this.dimA, this.dimB);
        return dArray;
    }

    @Override
    public double getParameterLogJacobian() {
        return this.reversibleHMCProviderA.getParameterLogJacobian() + this.reversibleHMCProviderB.getParameterLogJacobian();
    }

    private WrappedVector mergeWrappedVector(WrappedVector wrappedVector, WrappedVector wrappedVector2) {
        double[] dArray = new double[this.dimA + this.dimB];
        System.arraycopy(wrappedVector.getBuffer(), 0, dArray, 0, this.dimA);
        System.arraycopy(wrappedVector2.getBuffer(), 0, dArray, this.dimA, this.dimB);
        return new WrappedVector.Raw(dArray);
    }

    private void splitWrappedVector(WrappedVector wrappedVector, double[] dArray, double[] dArray2) {
        System.arraycopy(wrappedVector.getBuffer(), 0, dArray, 0, this.dimA);
        System.arraycopy(wrappedVector.getBuffer(), this.dimA, dArray2, 0, this.dimB);
    }

    private void updateMergedVector(WrappedVector wrappedVector, WrappedVector wrappedVector2, WrappedVector wrappedVector3) {
        for (int i = 0; i < this.dimA + this.dimB; ++i) {
            if (i < this.dimA) {
                wrappedVector3.set(i, wrappedVector.get(i));
                continue;
            }
            wrappedVector3.set(i, wrappedVector2.get(i - this.dimA));
        }
    }

    @Override
    public void setParameter(double[] dArray) {
        double[] dArray2 = new double[this.dimA];
        double[] dArray3 = new double[this.dimB];
        System.arraycopy(dArray, 0, dArray2, 0, this.dimA);
        System.arraycopy(dArray, this.dimA, dArray3, 0, this.dimB);
        this.reversibleHMCProviderA.setParameter(dArray2);
        this.reversibleHMCProviderB.setParameter(dArray3);
    }

    @Override
    public WrappedVector drawMomentum() {
        return this.mergeWrappedVector(this.reversibleHMCProviderA.drawMomentum(), this.reversibleHMCProviderB.drawMomentum());
    }

    @Override
    public double getJointProbability(WrappedVector wrappedVector) {
        double[] dArray = new double[this.dimA];
        double[] dArray2 = new double[this.dimB];
        this.splitWrappedVector(wrappedVector, dArray, dArray2);
        return this.reversibleHMCProviderA.getJointProbability(new WrappedVector.Raw(dArray)) + this.reversibleHMCProviderB.getJointProbability(new WrappedVector.Raw(dArray2)) - this.reversibleHMCProviderA.getLogLikelihood();
    }

    @Override
    public double getLogLikelihood() {
        return this.reversibleHMCProviderA.getLogLikelihood();
    }

    @Override
    public double getKineticEnergy(ReadableVector readableVector) {
        double[] dArray = new double[this.dimA];
        double[] dArray2 = new double[this.dimB];
        System.arraycopy(((WrappedVector)readableVector).getBuffer(), 0, dArray, 0, this.dimA);
        System.arraycopy(((WrappedVector)readableVector).getBuffer(), this.dimA, dArray2, 0, this.dimB);
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray2);
        return this.reversibleHMCProviderA.getKineticEnergy(raw) + this.reversibleHMCProviderB.getKineticEnergy(raw2);
    }

    @Override
    public double getStepSize() {
        return 0.001;
    }
}

