/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions.gp;

import dr.inference.distribution.RandomField;
import dr.inference.model.AbstractModel;
import dr.inference.model.DesignMatrix;
import dr.inference.model.GradientProvider;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.RandomFieldDistribution;
import dr.math.distributions.gp.GaussianProcessKernel;
import dr.xml.Reportable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.linsol.LinearSolver;

public class AdditiveGaussianProcessDistribution
extends RandomFieldDistribution
implements Reportable {
    public static final String TYPE = "GaussianProcess";
    private final int order;
    private final int dim;
    private final Parameter orderVariance;
    private final Parameter meanParameter;
    private final Parameter nuggetParameter;
    private final List<BasisDimension> bases;
    private double[] mean;
    private final double[] tmp;
    private double[] storedPrecision;
    private double[] storedMean;
    private double storedLogDeterminant = 0.0;
    private final DenseMatrix64F gramian;
    private final DenseMatrix64F precision;
    private final DenseMatrix64F variance;
    private final double[] tmpMatrix;
    private double[] diff;
    private double[] precisionDiff;
    private double logDeterminant;
    private boolean meanKnown = false;
    private boolean precisionAndDeterminantKnown = false;
    private boolean gramianAndVarianceKnown = false;
    private boolean precisionDiffKnown = false;
    private static final boolean USE_CHOLESKY = true;
    private Parameter field = null;
    private boolean fieldUpdated = true;
    private final boolean DEBUG = false;
    double[] storedPrecisionDiff;
    double[] storedDiff;

    public void passParameter(Parameter parameter) {
        this.field = parameter;
        if (parameter != null) {
            this.addVariable(parameter);
        }
    }

    public AdditiveGaussianProcessDistribution(String string, int n, Parameter parameter, Parameter parameter2, Parameter parameter3, List<BasisDimension> list) {
        super(string);
        this.order = parameter.getDimension();
        if (this.order != 1) {
            throw new RuntimeException("Not yet implemented");
        }
        this.dim = n;
        this.orderVariance = parameter;
        this.meanParameter = parameter2;
        this.nuggetParameter = parameter3;
        this.bases = list;
        this.mean = new double[n];
        this.diff = new double[n];
        this.gramian = new DenseMatrix64F(n, n);
        this.variance = new DenseMatrix64F(n, n);
        this.precision = new DenseMatrix64F(n, n);
        this.storedPrecision = new double[n * n];
        this.storedMean = new double[n];
        this.precisionDiff = new double[n];
        this.tmpMatrix = new double[n * n];
        this.tmp = new double[n];
        this.addVariable(parameter);
        if (parameter2 != null) {
            this.addVariable(parameter2);
        }
        if (parameter3 != null) {
            this.addVariable(parameter3);
        }
        for (BasisDimension basisDimension : list) {
            GaussianProcessKernel gaussianProcessKernel = basisDimension.getKernel();
            if (gaussianProcessKernel instanceof AbstractModel) {
                this.addModel((AbstractModel)((Object)gaussianProcessKernel));
            }
            this.addVariable(basisDimension.getDesignMatrix1());
            this.addVariable(basisDimension.getDesignMatrix2());
        }
    }

    public int getOrder() {
        return this.order;
    }

    public Parameter getOrderVariance() {
        return this.orderVariance;
    }

    public List<BasisDimension> getBases() {
        return this.bases;
    }

    private void computeGramianAndVariance() {
        AdditiveGaussianProcessDistribution.computeAdditiveGramian(this.gramian, this.bases, this.orderVariance);
        this.variance.set(this.gramian);
        if (this.nuggetParameter != null) {
            for (int i = 0; i < this.dim; ++i) {
                this.variance.add(i, i, this.getNugget(i));
            }
        }
    }

    private void computePrecisionAndDeterminant() {
        DenseMatrix64F denseMatrix64F = this.getVariance();
        LinearSolver<DenseMatrix64F> linearSolver = LinearSolverFactory.symmPosDef(this.dim);
        if (!linearSolver.setA(denseMatrix64F)) {
            throw new RuntimeException("Unable to decompose matrix");
        }
        linearSolver.invert(this.precision);
        this.logDeterminant = 2.0 * AdditiveGaussianProcessDistribution.computeLogDeterminantFromTriangularMatrix(((CholeskyDecompositionCommon_D64)linearSolver.getDecomposition()).getT());
    }

    private static double computeLogDeterminantFromTriangularMatrix(DenseMatrix64F denseMatrix64F) {
        int n = denseMatrix64F.numCols;
        double[] dArray = denseMatrix64F.getData();
        double d = 0.0;
        int n2 = n * n;
        for (int i = 0; i < n2; i += n + 1) {
            d += Math.log(dArray[i]);
        }
        return d;
    }

    protected double[] getPrecision() {
        return this.getPrecisionAsMatrix().getData();
    }

    protected DenseMatrix64F getPrecisionAsMatrix() {
        if (!this.precisionAndDeterminantKnown) {
            this.computePrecisionAndDeterminant();
            this.precisionAndDeterminantKnown = true;
        }
        return this.precision;
    }

    private double getLogDeterminant() {
        if (!this.precisionAndDeterminantKnown) {
            this.computePrecisionAndDeterminant();
            this.precisionAndDeterminantKnown = true;
        }
        return this.logDeterminant;
    }

    private DenseMatrix64F getGramian() {
        if (!this.gramianAndVarianceKnown) {
            this.computeGramianAndVariance();
            this.gramianAndVarianceKnown = true;
        }
        return this.gramian;
    }

    private DenseMatrix64F getVariance() {
        if (!this.gramianAndVarianceKnown) {
            this.computeGramianAndVariance();
            this.gramianAndVarianceKnown = true;
        }
        return this.variance;
    }

    protected double getNugget(int n) {
        return this.nuggetParameter.getDimension() == 1 ? this.nuggetParameter.getParameterValue(0) : this.nuggetParameter.getParameterValue(n);
    }

    @Override
    public double[] getMean() {
        if (!this.meanKnown) {
            if (this.meanParameter == null) {
                Arrays.fill(this.mean, 0.0);
            } else if (this.meanParameter.getDimension() == 1) {
                Arrays.fill(this.mean, this.meanParameter.getParameterValue(0));
            } else {
                for (int i = 0; i < this.mean.length; ++i) {
                    this.mean[i] = this.meanParameter.getParameterValue(i);
                }
            }
            this.meanKnown = true;
        }
        return this.mean;
    }

    protected double[] getPrecisionDiff(double[] dArray) {
        this.computingDelegate(dArray);
        return this.precisionDiff;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    @Override
    public double[][] getScaleMatrix() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public Variable<Double> getLocationVariable() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    @Override
    public double logPdf(double[] dArray) {
        double d = 0.0;
        if (this.field == null) {
            this.computeDiff(dArray);
            double[] dArray2 = this.getPrecision();
            for (int i = 0; i < this.dim; ++i) {
                for (int j = 0; j < this.dim; ++j) {
                    d += this.diff[i] * dArray2[i * this.dim + j] * this.diff[j];
                }
            }
        } else {
            this.computingDelegate(dArray);
            for (int i = 0; i < this.dim; ++i) {
                d += this.diff[i] * this.precisionDiff[i];
            }
            d *= -1.0;
        }
        return -0.5 * ((double)this.dim * Math.log(Math.PI * 2) + this.getLogDeterminant()) - 0.5 * d;
    }

    @Override
    public double[] getGradientLogDensity(Object object) {
        if (this.field == null) {
            return MultivariateNormalDistribution.gradLogPdf((double[])object, this.getMean(), this.getPrecision());
        }
        this.computingDelegate((double[])object);
        return this.precisionDiff;
    }

    public void computingDelegate(double[] dArray) {
        if (this.field == null) {
            this.fieldUpdated = true;
        }
        if (this.fieldUpdated || !this.meanKnown) {
            this.computeDiff(dArray);
            this.precisionDiffKnown = false;
            this.fieldUpdated = false;
        }
        if (!this.precisionDiffKnown) {
            this.computePrecisionDiff();
            this.precisionDiffKnown = true;
        }
    }

    public void computeDiff(double[] dArray) {
        double[] dArray2 = this.getMean();
        for (int i = 0; i < this.dim; ++i) {
            this.diff[i] = dArray[i] - dArray2[i];
        }
    }

    public void computePrecisionDiff() {
        double[] dArray = this.getPrecision();
        for (int i = 0; i < this.dim; ++i) {
            this.precisionDiff[i] = 0.0;
            for (int j = 0; j < this.dim; ++j) {
                int n = i;
                this.precisionDiff[n] = this.precisionDiff[n] - dArray[i * this.dim + j] * this.diff[j];
            }
        }
    }

    private boolean containsKernel(Model model) {
        for (BasisDimension basisDimension : this.bases) {
            if (model != basisDimension.getKernel()) continue;
            return true;
        }
        return false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (!this.containsKernel(model)) {
            throw new IllegalArgumentException("Unknown model");
        }
        this.precisionAndDeterminantKnown = false;
        this.gramianAndVarianceKnown = false;
        this.precisionDiffKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.meanParameter) {
            this.meanKnown = false;
            this.fireModelChanged();
        } else if (variable == this.nuggetParameter) {
            this.precisionAndDeterminantKnown = false;
            this.gramianAndVarianceKnown = false;
            this.precisionDiffKnown = false;
            this.fireModelChanged();
        } else if (variable == this.field) {
            this.fieldUpdated = true;
            this.precisionDiffKnown = false;
        }
    }

    @Override
    protected void storeState() {
        this.storedPrecision = Arrays.copyOf(this.getPrecision(), this.dim * this.dim);
        this.storedLogDeterminant = this.getLogDeterminant();
        this.storedMean = Arrays.copyOf(this.getMean(), this.dim);
        if (this.field != null) {
            this.storedDiff = Arrays.copyOf(this.diff, this.dim);
            this.storedPrecisionDiff = Arrays.copyOf(this.precisionDiff, this.dim);
        }
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.storedPrecision;
        this.storedPrecision = this.precision.getData();
        this.precision.setData(dArray);
        this.logDeterminant = this.storedLogDeterminant;
        dArray = this.storedMean;
        this.storedMean = this.mean;
        this.mean = dArray;
        if (this.field != null) {
            dArray = this.storedDiff;
            this.storedDiff = this.diff;
            this.diff = dArray;
            dArray = this.storedPrecisionDiff;
            this.storedPrecisionDiff = this.precisionDiff;
            this.precisionDiff = dArray;
        }
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public GradientProvider getGradientWrt(final Parameter parameter) {
        if (parameter == this.meanParameter) {
            return new GradientProvider(){

                @Override
                public int getDimension() {
                    return AdditiveGaussianProcessDistribution.this.meanParameter.getDimension();
                }

                @Override
                public double[] getGradientLogDensity(Object object) {
                    AdditiveGaussianProcessDistribution.this.computingDelegate((double[])object);
                    if (AdditiveGaussianProcessDistribution.this.meanParameter.getDimension() == AdditiveGaussianProcessDistribution.this.dim) {
                        return AdditiveGaussianProcessDistribution.this.precisionDiff;
                    }
                    if (AdditiveGaussianProcessDistribution.this.meanParameter.getDimension() == 1) {
                        double d = 0.0;
                        for (int i = 0; i < AdditiveGaussianProcessDistribution.this.dim; ++i) {
                            d += AdditiveGaussianProcessDistribution.this.precisionDiff[i];
                        }
                        return new double[]{d};
                    }
                    throw new IllegalArgumentException("Unknown mean parameter structure");
                }
            };
        }
        if (parameter == this.nuggetParameter) {
            throw new RuntimeException("Not implemented");
        }
        for (BasisDimension basisDimension : this.bases) {
            if (parameter == basisDimension.getDesignMatrix1() || parameter == basisDimension.getDesignMatrix2()) {
                return new GradientProvider(){

                    @Override
                    public int getDimension() {
                        return parameter.getDimension();
                    }

                    @Override
                    public double[] getGradientLogDensity(Object object) {
                        throw new RuntimeException("Not yet implemented (DesignMatrix Gradient)");
                    }
                };
            }
            for (Parameter parameter2 : basisDimension.getKernel().getParameters()) {
                if (parameter != parameter2) continue;
                throw new RuntimeException("Use GaussianProcessKernelGradient");
            }
        }
        throw new IllegalArgumentException("Unknown parameter");
    }

    @Override
    public double[] getDiagonalHessianLogDensity(Object object) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[][] getHessianLogDensity(Object object) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] nextRandom() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public String getReport() {
        double[] dArray = this.getMean();
        double[] dArray2 = this.getPrecision();
        double[] dArray3 = this.getPrecisionDiff(new double[this.dim]);
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("predictionDiff:");
        for (double d : dArray3) {
            stringBuilder.append(" ").append(d);
        }
        stringBuilder.append("\n");
        stringBuilder.append("mean:");
        for (double d : dArray) {
            stringBuilder.append(" ").append(d);
        }
        stringBuilder.append("\n");
        stringBuilder.append("precision:");
        for (double d : dArray2) {
            stringBuilder.append(" ").append(d);
        }
        return stringBuilder.toString();
    }

    public static void computeAdditiveGramian(DenseMatrix64F denseMatrix64F, List<BasisDimension> list, Parameter parameter) {
        denseMatrix64F.zero();
        int n = denseMatrix64F.getNumRows();
        int n2 = denseMatrix64F.getNumCols();
        for (BasisDimension basisDimension : list) {
            GaussianProcessKernel gaussianProcessKernel = basisDimension.getKernel();
            DesignMatrix designMatrix = basisDimension.getDesignMatrix1();
            DesignMatrix designMatrix2 = basisDimension.getDesignMatrix2();
            WeightFunction weightFunction = basisDimension.getWeightFunction();
            double d = gaussianProcessKernel.getScale();
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n2; ++j) {
                    double d2 = designMatrix.getParameterValue(i, 0);
                    double d3 = designMatrix2.getParameterValue(j, 0);
                    double d4 = d * gaussianProcessKernel.getUnscaledCovariance(d2, d3);
                    if (weightFunction != null) {
                        double d5 = weightFunction.getWeight(d2);
                        double d6 = weightFunction.getWeight(d3);
                        d4 *= d5 * d6;
                    }
                    denseMatrix64F.add(i, j, d4);
                }
            }
        }
        int n3 = parameter.getDimension();
    }

    public static class BasisDimension {
        private final GaussianProcessKernel kernel;
        private final DesignMatrix design1;
        private final DesignMatrix design2;
        private final WeightFunction weightFunction;

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, DesignMatrix designMatrix, DesignMatrix designMatrix2) {
            this(gaussianProcessKernel, designMatrix, designMatrix2, null);
        }

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, DesignMatrix designMatrix, DesignMatrix designMatrix2, WeightFunction weightFunction) {
            this.kernel = gaussianProcessKernel;
            this.design1 = designMatrix;
            this.design2 = designMatrix2;
            this.weightFunction = weightFunction;
        }

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, DesignMatrix designMatrix) {
            this(gaussianProcessKernel, designMatrix, designMatrix);
        }

        public BasisDimension(GaussianProcessKernel gaussianProcessKernel, RandomField.WeightProvider weightProvider) {
            this(gaussianProcessKernel, BasisDimension.makeDesignMatrixFromWeights(weightProvider));
        }

        public GaussianProcessKernel getKernel() {
            return this.kernel;
        }

        DesignMatrix getDesignMatrix1() {
            return this.design1;
        }

        DesignMatrix getDesignMatrix2() {
            return this.design2;
        }

        WeightFunction getWeightFunction() {
            return this.weightFunction;
        }

        private static DesignMatrix makeDesignMatrixFromWeights(final RandomField.WeightProvider weightProvider) {
            return new DesignMatrix("weights", false){

                @Override
                public double getParameterValue(int n, int n2) {
                    throw new RuntimeException("Not yet implemented");
                }

                @Override
                public int getDimension() {
                    return weightProvider.getDimension();
                }

                @Override
                public int getRowDimension() {
                    return weightProvider.getDimension();
                }

                @Override
                public int getColumnDimension() {
                    return 1;
                }

                @Override
                public Parameter getParameter(int n) {
                    throw new IllegalArgumentException("Not allowed");
                }
            };
        }
    }

    public static interface WeightFunction {
        public double getWeight(double var1);

        public void configure(Map<String, Double> var1);

        public static class WeightFunctionFactory {
            private static final Map<String, Supplier<WeightFunction>> registry = new HashMap<String, Supplier<WeightFunction>>();

            public static void register(String string, Supplier<WeightFunction> supplier) {
                registry.put(string.toLowerCase(), supplier);
            }

            public static WeightFunction create(String string, Map<String, Double> map) {
                Supplier<WeightFunction> supplier = registry.get(string.toLowerCase());
                if (supplier == null) {
                    throw new IllegalArgumentException("Unknown weight function: " + string);
                }
                WeightFunction weightFunction = supplier.get();
                weightFunction.configure(map);
                return weightFunction;
            }

            static {
                WeightFunctionFactory.register("sigmoid", SigmoidWeightFunction::new);
                WeightFunctionFactory.register("sigmoidComplement", SigmoidComplementWeightFunction::new);
                WeightFunctionFactory.register("identity", IdentityWeightFunction::new);
                WeightFunctionFactory.register("linear", LinearWeightFunction::new);
            }
        }

        public static class LinearWeightFunction
        implements WeightFunction {
            private double slope = 1.0;
            private double intercept = 0.0;

            @Override
            public void configure(Map<String, Double> map) {
                if (map.containsKey("slope")) {
                    this.slope = map.get("slope");
                }
                if (map.containsKey("intercept")) {
                    this.intercept = map.get("intercept");
                }
            }

            @Override
            public double getWeight(double d) {
                return this.slope * d + this.intercept;
            }
        }

        public static class IdentityWeightFunction
        implements WeightFunction {
            @Override
            public void configure(Map<String, Double> map) {
            }

            @Override
            public double getWeight(double d) {
                return 1.0;
            }
        }

        public static class SigmoidComplementWeightFunction
        implements WeightFunction {
            private double location = 0.0;
            private double scale = 1.0;

            @Override
            public void configure(Map<String, Double> map) {
                if (map.containsKey("scale")) {
                    this.scale = map.get("scale");
                }
                if (map.containsKey("location")) {
                    this.location = map.get("location");
                }
            }

            @Override
            public double getWeight(double d) {
                return 1.0 - 1.0 / (1.0 + Math.exp(-this.scale * (d - this.location)));
            }
        }

        public static class SigmoidWeightFunction
        implements WeightFunction {
            private double location = 0.0;
            private double scale = 1.0;

            @Override
            public void configure(Map<String, Double> map) {
                if (map.containsKey("scale")) {
                    this.scale = map.get("scale");
                }
                if (map.containsKey("location")) {
                    this.location = map.get("location");
                }
            }

            @Override
            public double getWeight(double d) {
                return 1.0 / (1.0 + Math.exp(-this.scale * (d - this.location)));
            }
        }
    }
}

