/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evomodel.substmodel.GlmSubstitutionModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.AbstractGlmSubstitutionModelGradient;
import dr.evomodel.treedatalikelihood.discrete.AbstractLogAdditiveSubstitutionModelGradient;
import dr.inference.distribution.GeneralizedLinearModel;
import dr.inference.model.DesignMatrix;
import dr.inference.model.MaskedParameter;
import dr.inference.model.Parameter;
import dr.util.Transform;

public class DesignMatrixSubstitutionModelGradient
extends AbstractGlmSubstitutionModelGradient {
    private final int whichEffect;
    private final MaskedParameter parameter;
    private final int[][] mapEffectToIndices;

    public DesignMatrixSubstitutionModelGradient(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, GlmSubstitutionModel glmSubstitutionModel, DesignMatrix designMatrix, MaskedParameter maskedParameter, AbstractLogAdditiveSubstitutionModelGradient.ApproximationMode approximationMode) {
        super(string, treeDataLikelihood, beagleDataLikelihoodDelegate, glmSubstitutionModel, approximationMode);
        this.parameter = maskedParameter;
        Parameter parameter = maskedParameter.getUnmaskedParameter();
        int n = DesignMatrixSubstitutionModelGradient.getEffectDimension(parameter);
        this.whichEffect = this.findDesignMatrix(designMatrix);
        int n2 = this.stateCount * (this.stateCount - 1);
        if (n != n2) {
            if (this.getDimension() == n2 / 2) {
                throw new RuntimeException("Not yet implemented");
            }
            throw new IllegalArgumentException("Unable to determine random design matrix count");
        }
        this.mapEffectToIndices = this.makeDesignMap();
    }

    @Override
    String getType() {
        return "design";
    }

    private int findDesignMatrix(DesignMatrix designMatrix) {
        for (int i = 0; i < this.glm.getNumberOfFixedEffects(); ++i) {
            if (this.glm.getDesignMatrix(i) != designMatrix) continue;
            return i;
        }
        throw new IllegalArgumentException("Unable to find design matrix in GLM model");
    }

    @Override
    Double getReportTolerance() {
        return 0.01;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    AbstractGlmSubstitutionModelGradient.ParameterMap makeParameterMap(GeneralizedLinearModel generalizedLinearModel) {
        return null;
    }

    @Override
    protected double preProcessNormalization(double[] dArray, double[] dArray2, boolean bl) {
        double d = 0.0;
        if (bl) {
            for (int i = 0; i < this.stateCount; ++i) {
                for (int j = 0; j < this.stateCount; ++j) {
                    d += dArray[this.index(i, j)] * dArray2[this.index(i, j)];
                }
            }
        }
        return d;
    }

    @Override
    double processSingleGradientDimension(int n, double[] dArray, double[] dArray2, double[] dArray3, boolean bl, double d, double d2, Transform transform, boolean bl2) {
        int n2 = this.indexK(n);
        double d3 = this.glm.getFixedEffect(this.whichEffect).getParameterValue(n2);
        Parameter parameter = this.glm.getFixedEffectIndicator(this.whichEffect);
        if (parameter != null) {
            d3 *= parameter.getParameterValue(n2);
        }
        double d4 = dArray2[this.indexIJ(n)] * d3;
        double d5 = (dArray[this.indexIJ(n)] - dArray[this.indexII(n)]) * d4;
        if (bl) {
            d5 -= d4 * dArray3[this.indexI(n)] * d;
        }
        return d5;
    }

    private static int getEffectDimension(Parameter parameter) {
        return parameter instanceof DesignMatrix ? ((DesignMatrix)parameter).getRowDimension() : parameter.getDimension();
    }

    private int[][] makeDesignMap() {
        int[][] nArrayArray = new int[this.parameter.getDimension()][];
        int n = 0;
        int n2 = 0;
        Parameter parameter = this.glm.getFixedEffect(this.whichEffect);
        for (int i = 0; i < parameter.getDimension(); ++i) {
            int n3;
            int n4;
            for (n4 = 0; n4 < this.stateCount; ++n4) {
                for (n3 = n4 + 1; n3 < this.stateCount; ++n3) {
                    if (this.parameter.getParameterMaskValue(n2++) != 1.0) continue;
                    nArrayArray[n++] = new int[]{n4, n3, i};
                }
            }
            for (n4 = 0; n4 < this.stateCount; ++n4) {
                for (n3 = n4 + 1; n3 < this.stateCount; ++n3) {
                    if (this.parameter.getParameterMaskValue(n2++) != 1.0) continue;
                    nArrayArray[n++] = new int[]{n3, n4, i};
                }
            }
        }
        return nArrayArray;
    }

    private int indexIJ(int n) {
        int[] nArray = this.mapEffectToIndices[n];
        return nArray[0] * this.stateCount + nArray[1];
    }

    private int indexII(int n) {
        int[] nArray = this.mapEffectToIndices[n];
        return nArray[0] * this.stateCount + nArray[0];
    }

    private int indexI(int n) {
        int[] nArray = this.mapEffectToIndices[n];
        return nArray[0];
    }

    private int indexK(int n) {
        int[] nArray = this.mapEffectToIndices[n];
        return nArray[2];
    }
}

