/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.multidimensionalscaling.mm;

import dr.inference.model.MatrixParameterInterface;
import dr.inference.multidimensionalscaling.MultiDimensionalScalingLikelihood;
import dr.inference.multidimensionalscaling.mm.MMAlgorithm;
import dr.inference.operators.EllipticalSliceOperator;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class MultiDimensionalScalingMM
extends MMAlgorithm {
    private final MultiDimensionalScalingLikelihood likelihood;
    private final GaussianProcessRandomGenerator gp;
    private final int P;
    private final int Q;
    private double[] XtX = null;
    private double[] D = null;
    private double[] distance = null;
    private final double tolerance;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        public static final String MDS_STARTING_VALUES = "mdsModeFinder";
        public static final String TOLERANCE = "tolerance";
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(MultiDimensionalScalingLikelihood.class), new ElementRule(GaussianProcessRandomGenerator.class, true), AttributeRule.newDoubleRule("tolerance", true)};

        @Override
        public String getParserName() {
            return MDS_STARTING_VALUES;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            MultiDimensionalScalingLikelihood multiDimensionalScalingLikelihood = (MultiDimensionalScalingLikelihood)xMLObject.getChild(MultiDimensionalScalingLikelihood.class);
            GaussianProcessRandomGenerator gaussianProcessRandomGenerator = (GaussianProcessRandomGenerator)xMLObject.getChild(GaussianProcessRandomGenerator.class);
            double d = xMLObject.getAttribute(TOLERANCE, 0.001);
            MultiDimensionalScalingMM multiDimensionalScalingMM = new MultiDimensionalScalingMM(multiDimensionalScalingLikelihood, gaussianProcessRandomGenerator, d);
            multiDimensionalScalingMM.run();
            return multiDimensionalScalingMM;
        }

        @Override
        public String getParserDescription() {
            return "Provides a mode finder for a MDS model on a tree";
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public Class getReturnType() {
            return MMAlgorithm.class;
        }
    };
    private double[][] precision = null;
    private double[] precisionStatistics = null;
    private boolean ignoreGP = false;
    private double weightTree;

    public MultiDimensionalScalingLikelihood getLikelihood() {
        return this.likelihood;
    }

    public GaussianProcessRandomGenerator getGaussianProcess() {
        return this.gp;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public MultiDimensionalScalingMM(MultiDimensionalScalingLikelihood multiDimensionalScalingLikelihood, GaussianProcessRandomGenerator gaussianProcessRandomGenerator, double d) {
        this.likelihood = multiDimensionalScalingLikelihood;
        this.gp = gaussianProcessRandomGenerator;
        this.P = multiDimensionalScalingLikelihood.getMdsDimension();
        this.Q = multiDimensionalScalingLikelihood.getLocationCount();
        this.tolerance = d;
    }

    public void run() {
        this.run(100000);
    }

    public void run(int n) {
        Object object;
        if (n == 0) {
            return;
        }
        if (this.gp != null) {
            object = this.gp.getPrecisionMatrix();
            this.setPrecision((double[][])object);
        }
        this.weightTree = 1.0 / this.likelihood.getMDSPrecision();
        object = this.likelihood.getMatrixParameter().getParameterValues();
        System.err.println("Start: " + this.printArray((double[])object));
        double d = this.printLogObjective();
        double[] dArray = null;
        try {
            dArray = this.findMode(this.likelihood.getMatrixParameter().getParameterValues(), this.tolerance, n);
        }
        catch (MMAlgorithm.NotConvergedException notConvergedException) {
            notConvergedException.printStackTrace();
        }
        this.setParameterValues(this.likelihood.getMatrixParameter(), dArray);
        double d2 = this.printLogObjective();
        System.err.println("Move: " + d + " -> " + d2 + " : " + (d2 - d));
    }

    private double printLogObjective() {
        double d = this.likelihood.getLogLikelihood();
        double d2 = this.gp.getLikelihood().getLogLikelihood();
        double d3 = d;
        if (this.weightTree != 0.0) {
            d3 += d2;
        }
        System.err.println("obj: " + d3 + " = " + d + " + " + d2);
        return d3;
    }

    private void setParameterValues(MatrixParameterInterface matrixParameterInterface, double[] dArray) {
        matrixParameterInterface.setAllParameterValuesQuietly(dArray, 0);
        matrixParameterInterface.setParameterValueNotifyChangedAll(0, 0, dArray[0]);
    }

    private double[] getDistanceMatrix() {
        return this.likelihood.getObservations();
    }

    private void setPrecision(double[][] dArray) {
        if (!this.ignoreGP) {
            int n = dArray.length;
            if (n != this.Q * this.P) {
                throw new IllegalArgumentException("Invalid dimensions");
            }
            this.precision = dArray;
            this.precisionStatistics = new double[n];
            for (int i = 0; i < n; ++i) {
                double d = 0.0;
                for (int j = 0; j < n; ++j) {
                    if (i == j) continue;
                    d += Math.abs(this.precision[i][j]);
                }
                this.precisionStatistics[i] = d;
            }
        }
    }

    @Override
    protected void mmUpdate(double[] dArray, double[] dArray2) {
        double d;
        int n;
        int n2;
        if (this.XtX == null) {
            this.XtX = new double[this.Q * this.Q];
        }
        if (this.D == null) {
            this.D = new double[this.Q * this.Q];
            for (n2 = 0; n2 < this.Q; ++n2) {
                this.D[n2 * this.Q + n2] = 1.0;
            }
        }
        if (this.distance == null) {
            this.distance = this.getDistanceMatrix();
        }
        for (n2 = 0; n2 < this.Q; ++n2) {
            for (n = n2; n < this.Q; ++n) {
                d = 0.0;
                for (int i = 0; i < this.P; ++i) {
                    d += dArray[n2 * this.P + i] * dArray[n * this.P + i];
                }
                double d2 = d;
                this.XtX[n2 * this.Q + n] = d2;
                this.XtX[n * this.Q + n2] = d2;
            }
        }
        for (n2 = 0; n2 < this.Q; ++n2) {
            for (n = n2 + 1; n < this.Q; ++n) {
                d = this.XtX[n2 * this.Q + n2] + this.XtX[n * this.Q + n] - 2.0 * this.XtX[n2 * this.Q + n];
                double d3 = d > 0.0 ? Math.sqrt(d) : 0.0;
                double d4 = Math.max(d3, 1.0E-10);
                this.D[n2 * this.Q + n] = d4;
                this.D[n * this.Q + n2] = d4;
                if (!Double.isNaN(this.D[n2 * this.Q + n])) continue;
                System.err.println("D NaN");
                System.err.println(this.XtX[n2 * this.Q + n2]);
                System.err.println(this.XtX[n * this.Q + n]);
                System.err.println(2.0 * this.XtX[n2 * this.Q + n]);
                System.err.println(d);
                System.err.println(d3);
                System.exit(-1);
            }
        }
        for (n2 = 0; n2 < this.Q; ++n2) {
            for (n = 0; n < this.P; ++n) {
                int n3 = n2 * this.P + n;
                double d5 = 0.0;
                for (int i = 0; i < this.Q; ++i) {
                    double d6 = 0.0;
                    if (n2 != i) {
                        d6 += this.distance[n2 * this.Q + i] * (dArray[n2 * this.P + n] - dArray[i * this.P + n]) / this.D[n2 * this.Q + i] + (dArray[n2 * this.P + n] + dArray[i * this.P + n]);
                    }
                    if (Double.isNaN(d6)) {
                        System.err.println("Bomb at " + n2 + " " + n + " " + i);
                        System.err.println("Distance = " + this.distance[n2 * this.Q + i]);
                        System.err.println("Ci = " + dArray[n2 * this.P + n]);
                        System.err.println("Cj = " + dArray[i * this.P + n]);
                        System.err.println("D = " + this.D[n2 * this.Q + i]);
                        System.exit(-1);
                    }
                    if (this.precision != null) {
                        for (int j = 0; j < this.P; ++j) {
                            int n4 = i * this.P + j;
                            double d7 = this.precision[n3][n4];
                            int n5 = d7 > 0.0 ? 1 : -1;
                            d6 += this.weightTree * Math.abs(d7) * (dArray[n2 * this.P + n] - (double)n5 * dArray[i * this.P + j]);
                        }
                    }
                    d5 += d6;
                }
                double d8 = 2 * (this.Q - 1);
                if (this.precision != null) {
                    d8 += this.weightTree * (2.0 * this.precision[n3][n3] + this.precisionStatistics[n3]);
                }
                dArray2[n2 * this.P + n] = d5 / d8;
            }
        }
        EllipticalSliceOperator.transformPoint(dArray2, true, true, this.P);
    }
}

