/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import java.io.Serializable;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public abstract class AbstractDiffusionModelDelegate
extends AbstractModel
implements DiffusionProcessDelegate,
Serializable {
    final Tree tree;
    private final MultivariateDiffusionModel diffusionModel;
    private final BufferIndexHelper eigenBufferHelper;
    private final BufferIndexHelper matrixBufferHelper;
    protected final int dim;

    AbstractDiffusionModelDelegate(Tree tree, MultivariateDiffusionModel multivariateDiffusionModel, int n) {
        super("AbstractDiffusionModelDelegate");
        this.tree = tree;
        this.diffusionModel = multivariateDiffusionModel;
        this.addModel(multivariateDiffusionModel);
        this.dim = multivariateDiffusionModel.getPrecisionParameter().getColumnDimension();
        this.eigenBufferHelper = new BufferIndexHelper(1, 0, n);
        this.matrixBufferHelper = new BufferIndexHelper(tree.getNodeCount(), 0, n);
    }

    @Override
    public int getEigenBufferCount() {
        return this.eigenBufferHelper.getBufferCount();
    }

    @Override
    public int getEigenBufferOffsetIndex(int n) {
        return this.eigenBufferHelper.getOffsetIndex(n);
    }

    @Override
    public int getMatrixBufferCount() {
        return this.matrixBufferHelper.getBufferCount();
    }

    @Override
    public int getMatrixBufferOffsetIndex(int n) {
        return this.matrixBufferHelper.getOffsetIndex(n);
    }

    @Override
    public void flipMatrixBufferOffset(int n) {
        this.matrixBufferHelper.flipOffset(n);
    }

    @Override
    public int getDiffusionModelCount() {
        return 1;
    }

    @Override
    public MultivariateDiffusionModel getDiffusionModel(int n) {
        assert (n == 0);
        return this.diffusionModel;
    }

    @Override
    public int getMatrixIndex(int n) {
        return this.matrixBufferHelper.getOffsetIndex(n);
    }

    @Override
    public void setDiffusionModels(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, boolean bl) {
        if (bl) {
            this.eigenBufferHelper.flipOffset(0);
        }
        continuousDiffusionIntegrator.setDiffusionPrecision(this.eigenBufferHelper.getOffsetIndex(0), this.diffusionModel.getPrecisionmatrixAsVector(), Math.log(this.diffusionModel.getDeterminantPrecisionMatrix()));
    }

    @Override
    public void updateDiffusionMatrices(ContinuousDiffusionIntegrator continuousDiffusionIntegrator, int[] nArray, double[] dArray, int n, boolean bl) {
        int[] nArray2 = new int[n];
        for (int i = 0; i < n; ++i) {
            if (bl) {
                this.matrixBufferHelper.flipOffset(nArray[i]);
            }
            nArray2[i] = this.matrixBufferHelper.getOffsetIndex(nArray[i]);
        }
        continuousDiffusionIntegrator.updateBrownianDiffusionMatrices(this.eigenBufferHelper.getOffsetIndex(0), nArray2, dArray, this.getDriftRates(nArray, n), n);
    }

    protected abstract double[] getDriftRates(int[] var1, int var2);

    @Override
    public boolean hasDrift() {
        return false;
    }

    @Override
    public boolean hasActualization() {
        return false;
    }

    @Override
    public boolean hasDiagonalActualization() {
        return false;
    }

    @Override
    public boolean isIntegratedProcess() {
        return false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model != this.diffusionModel) {
            throw new RuntimeException("Unknown model");
        }
        this.fireModelChanged(model);
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    public void storeState() {
        this.eigenBufferHelper.storeState();
        this.matrixBufferHelper.storeState();
    }

    @Override
    public void restoreState() {
        this.eigenBufferHelper.restoreState();
        this.matrixBufferHelper.restoreState();
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public DenseMatrix64F getGradientVarianceWrtVariance(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, DenseMatrix64F denseMatrix64F) {
        return this.scaleGradient(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F);
    }

    DenseMatrix64F scaleGradient(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, DenseMatrix64F denseMatrix64F) {
        return this.scaleGradient(this.getScalarNode(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate), denseMatrix64F);
    }

    private DenseMatrix64F scaleGradient(double d, DenseMatrix64F denseMatrix64F) {
        DenseMatrix64F denseMatrix64F2 = denseMatrix64F.copy();
        if (d == 0.0) {
            CommonOps.fill(denseMatrix64F2, 0.0);
        } else {
            CommonOps.scale(d, denseMatrix64F2);
        }
        return denseMatrix64F2;
    }

    private double getScalarNode(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate) {
        if (this.tree.isRoot(nodeRef)) {
            return 1.0 / continuousDataLikelihoodDelegate.getRootProcessDelegate().getPseudoObservations();
        }
        return continuousDiffusionIntegrator.getBranchLength(this.getMatrixIndex(nodeRef.getNumber()));
    }

    @Override
    public double[] getGradientDisplacementWrtRoot(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, DenseMatrix64F denseMatrix64F) {
        boolean bl;
        boolean bl2 = bl = continuousDataLikelihoodDelegate.getRootProcessDelegate().getPseudoObservations() == Double.POSITIVE_INFINITY;
        if (bl && this.tree.isRoot(this.tree.getParent(nodeRef))) {
            return denseMatrix64F.getData();
        }
        if (!bl && this.tree.isRoot(nodeRef)) {
            return denseMatrix64F.getData();
        }
        return new double[denseMatrix64F.getNumRows()];
    }
}

