/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.OldAbstractCoalescentLikelihood;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.List;
import no.uib.cipr.matrix.BandCholesky;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.NotConvergedException;
import no.uib.cipr.matrix.SymmTridiagEVD;
import no.uib.cipr.matrix.SymmTridiagMatrix;
import no.uib.cipr.matrix.UpperSPDBandMatrix;
import no.uib.cipr.matrix.UpperTriangBandMatrix;

public class GaussianProcessSkytrackLikelihood
extends OldAbstractCoalescentLikelihood {
    public static final double LOG_TWO_TIMES_PI = 1.837877;
    protected Parameter precisionParameter;
    protected Parameter lambda_boundParameter;
    protected Parameter lambdaParameter;
    protected Parameter betaParameter;
    protected Parameter alphaParameter;
    protected Parameter GPtype;
    protected Parameter GPcounts;
    protected Parameter coalfactor;
    protected Parameter popSizeParameter;
    protected Parameter changePoints;
    protected Parameter Tmrca;
    protected Parameter CoalCounts;
    protected Parameter numPoints;
    protected double[] GPcoalfactor;
    protected double[] storedGPcoalfactor;
    protected double[] GPCoalInterval;
    protected double[] storedGPCoalInterval;
    protected double[] backupIntervals;
    protected int[] CoalPosIndicator;
    protected int[] storedCoalPosIndicator;
    protected double[] CoalTime;
    protected double[] storedCoalTime;
    protected int numintervals;
    protected int numcoalpoints;
    protected double constlik;
    protected double storedconstlik;
    protected double logGPLikelihood;
    protected SymmTridiagMatrix weightMatrix;
    protected boolean rescaleByRootHeight;
    private boolean flagForJulia = false;

    private static List<Tree> wrapTree(Tree tree) {
        ArrayList<Tree> arrayList = new ArrayList<Tree>();
        arrayList.add(tree);
        return arrayList;
    }

    public GaussianProcessSkytrackLikelihood(Tree tree, Parameter parameter, boolean bl, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, Parameter parameter7, Parameter parameter8, Parameter parameter9, Parameter parameter10, Parameter parameter11, Parameter parameter12, Parameter parameter13) {
        this(GaussianProcessSkytrackLikelihood.wrapTree(tree), parameter, bl, parameter2, parameter3, parameter4, parameter5, parameter6, parameter7, parameter8, parameter9, parameter10, parameter11, parameter12, parameter13);
    }

    public GaussianProcessSkytrackLikelihood(String string) {
        super(string);
    }

    public GaussianProcessSkytrackLikelihood(List<Tree> list, Parameter parameter, boolean bl, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, Parameter parameter7, Parameter parameter8, Parameter parameter9, Parameter parameter10, Parameter parameter11, Parameter parameter12, Parameter parameter13) {
        super("gpSkytrackLikelihood");
        this.popSizeParameter = parameter4;
        this.Tmrca = parameter13;
        this.changePoints = parameter7;
        this.numPoints = parameter12;
        this.precisionParameter = parameter;
        this.lambdaParameter = parameter3;
        this.betaParameter = parameter6;
        this.alphaParameter = parameter5;
        this.rescaleByRootHeight = bl;
        this.lambda_boundParameter = parameter2;
        this.GPcounts = parameter9;
        this.GPtype = parameter8;
        this.coalfactor = parameter10;
        this.CoalCounts = parameter11;
        this.addVariable(this.precisionParameter);
        this.addVariable(this.popSizeParameter);
        this.addVariable(this.changePoints);
        this.addVariable(parameter12);
        this.addVariable(parameter9);
        this.addVariable(parameter8);
        this.addVariable(parameter10);
        this.addVariable(this.lambda_boundParameter);
        this.addVariable(parameter11);
        this.setTree(list);
        this.wrapSetupIntervals();
        this.numintervals = this.getIntervalCount();
        this.numcoalpoints = this.getCorrectFieldLength();
        this.GPcoalfactor = new double[this.numintervals];
        this.backupIntervals = new double[this.numintervals];
        this.GPCoalInterval = new double[this.numcoalpoints];
        this.storedGPCoalInterval = new double[this.numcoalpoints];
        this.CoalPosIndicator = new int[this.numcoalpoints];
        this.storedCoalPosIndicator = new int[this.numcoalpoints];
        this.CoalTime = new double[this.numcoalpoints];
        this.storedCoalTime = new double[this.numcoalpoints];
        this.storedGPcoalfactor = new double[this.numintervals];
        parameter9.setDimension(this.numintervals);
        parameter11.setDimension(this.numcoalpoints);
        parameter8.setDimension(this.numcoalpoints);
        parameter12.setParameterValue(0, this.numcoalpoints);
        this.popSizeParameter.setDimension(this.numcoalpoints);
        this.changePoints.setDimension(this.numcoalpoints);
        parameter10.setDimension(this.numcoalpoints);
        this.initializationReport();
        this.setupSufficientStatistics();
        this.setupGPvalues();
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        super.handleModelChangedEvent(model, object, n);
        if (model == this.tree) {
            if (object instanceof TreeChangedEvent) {
                TreeChangedEvent treeChangedEvent = (TreeChangedEvent)object;
                this.flagForJulia = true;
            } else if (object instanceof Parameter) {
                this.flagForJulia = true;
            } else {
                throw new IllegalArgumentException("Not sure what type of model changed event occurred: " + object.getClass().toString());
            }
        }
    }

    @Override
    public LogColumn[] getColumns() {
        return new LogColumn[]{new VariableLengthColumn("changePoints", this.changePoints), new VariableLengthColumn("Gvalues", this.popSizeParameter)};
    }

    protected void setTree(List<Tree> list) {
        if (list.size() != 1) {
            throw new RuntimeException("GP-based method only implemented for one tree");
        }
        this.tree = list.get(0);
        this.treesSet = null;
        if (this.tree instanceof TreeModel) {
            this.addModel((TreeModel)this.tree);
        }
    }

    protected void wrapSetupIntervals() {
        this.setupIntervals();
        this.intervalsKnown = true;
    }

    public double calculateLogLikelihood(Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, double[] dArray) {
        double d = parameter4.getParameterValue(0);
        this.logGPLikelihood = -d * this.getConstlik();
        for (int i = 0; i < parameter2.getSize(); ++i) {
            if (!(dArray[i] > 0.0)) continue;
            if (parameter2.getParameterValue(i) < 0.0) {
                System.err.println("WARNING");
            }
            this.logGPLikelihood += parameter2.getParameterValue(i) * Math.log(d * dArray[i]);
        }
        double[] dArray2 = parameter.getParameterValues();
        for (int i = 0; i < parameter.getSize(); ++i) {
            this.logGPLikelihood += -Math.log(1.0 + Math.exp(-parameter3.getParameterValue(i) * dArray2[i]));
        }
        return this.logGPLikelihood;
    }

    public double getConstlik() {
        return this.constlik;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            if (this.flagForJulia) {
                System.err.println("recalculating intervals and counts");
                this.wrapSetupIntervals();
                this.recomputeValues();
                this.flagForJulia = false;
            }
            this.logLikelihood = this.calculateLogLikelihood(this.popSizeParameter, this.GPcounts, this.GPtype, this.lambda_boundParameter, this.GPcoalfactor) + this.calculateLogGP() + this.getLogPriorLambda(this.lambdaParameter.getParameterValue(0), 0.01, this.lambda_boundParameter.getParameterValue(0));
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    protected SymmTridiagMatrix getQmatrix(double d, double[] dArray) {
        double d2 = 1.0E-11;
        double[] dArray2 = new double[dArray.length - 1];
        double[] dArray3 = new double[dArray.length];
        for (int i = 0; i < dArray.length - 1; ++i) {
            dArray2[i] = d * (-1.0 / (dArray[i + 1] - dArray[i]));
            if (i >= dArray.length - 2) continue;
            dArray3[i + 1] = -dArray2[i] + d * (1.0 / (dArray[i + 2] - dArray[i + 1]) + d2);
        }
        dArray3[0] = -dArray2[0] + d * d2;
        dArray3[dArray.length - 1] = -dArray2[dArray.length - 2] + d * d2;
        SymmTridiagMatrix symmTridiagMatrix = new SymmTridiagMatrix(dArray3, dArray2);
        return symmTridiagMatrix;
    }

    protected double calculateLogGP() {
        SymmTridiagMatrix symmTridiagMatrix = this.getQmatrix(this.precisionParameter.getParameterValue(0), this.changePoints.getParameterValues());
        DenseVector denseVector = new DenseVector(this.popSizeParameter.getSize());
        DenseVector denseVector2 = new DenseVector(this.popSizeParameter.getParameterValues());
        symmTridiagMatrix.mult(denseVector2, denseVector);
        double d = -0.5 * GaussianProcessSkytrackLikelihood.logGeneralizedDeterminant(symmTridiagMatrix) - 0.5 * denseVector2.dot(denseVector) - 0.5 * (double)(this.popSizeParameter.getSize() - 1) * 1.837877;
        return d;
    }

    private double getLogPriorLambda(double d, double d2, double d3) {
        double d4 = d3 < d ? d2 * (1.0 / d) : Math.log(1.0 - d2) * (1.0 / d) * Math.exp(-(1.0 / d) * (d3 - d));
        return d4;
    }

    public static double logGeneralizedDeterminant(SymmTridiagMatrix symmTridiagMatrix) {
        SymmTridiagEVD symmTridiagEVD = new SymmTridiagEVD(symmTridiagMatrix.numRows(), false);
        try {
            symmTridiagEVD.factor(symmTridiagMatrix);
        }
        catch (NotConvergedException notConvergedException) {
            throw new RuntimeException("Not converged error in generalized determinate calculation.\n" + notConvergedException.getMessage());
        }
        double[] dArray = symmTridiagEVD.getEigenvalues();
        double d = 0.0;
        for (double d2 : dArray) {
            if (!(d2 > 1.0E-5)) continue;
            d += Math.log(d2);
        }
        return d;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void restoreState() {
        super.restoreState();
        System.arraycopy(this.storedGPcoalfactor, 0, this.GPcoalfactor, 0, this.storedGPcoalfactor.length);
        System.arraycopy(this.storedCoalTime, 0, this.CoalTime, 0, this.storedCoalTime.length);
        System.arraycopy(this.storedGPCoalInterval, 0, this.GPCoalInterval, 0, this.storedGPCoalInterval.length);
        System.arraycopy(this.storedCoalPosIndicator, 0, this.CoalPosIndicator, 0, this.storedCoalPosIndicator.length);
        this.constlik = this.storedconstlik;
    }

    @Override
    protected void storeState() {
        super.storeState();
        System.arraycopy(this.GPcoalfactor, 0, this.storedGPcoalfactor, 0, this.GPcoalfactor.length);
        System.arraycopy(this.CoalTime, 0, this.storedCoalTime, 0, this.CoalTime.length);
        System.arraycopy(this.GPCoalInterval, 0, this.storedGPCoalInterval, 0, this.GPCoalInterval.length);
        System.arraycopy(this.CoalPosIndicator, 0, this.storedCoalPosIndicator, 0, this.CoalPosIndicator.length);
        this.storedconstlik = this.constlik;
    }

    @Override
    public String toString() {
        return this.getId() + "(" + Double.toString(this.getLogLikelihood()) + ")";
    }

    public void initializationReport() {
        System.out.println("Creating a GP based estimation of effective population trajectories:");
        System.out.println("\tIf you publish results using this model, please reference: Minin, Palacios, Suchard (XXXX), AAA");
    }

    public static void checkTree(TreeModel treeModel) {
        for (int i = 0; i < treeModel.getInternalNodeCount(); ++i) {
            double d;
            NodeRef nodeRef = treeModel.getInternalNode(i);
            if (nodeRef == treeModel.getRoot()) continue;
            double d2 = treeModel.getNodeHeight(treeModel.getParent(nodeRef));
            double d3 = treeModel.getNodeHeight(treeModel.getChild(nodeRef, 0));
            double d4 = treeModel.getNodeHeight(treeModel.getChild(nodeRef, 1));
            if (d4 > (d = d3)) {
                d = d4;
            }
            double d5 = d + MathUtils.nextDouble() * (d2 - d);
            treeModel.setNodeHeight(nodeRef, d5);
        }
        treeModel.pushTreeChangedEvent();
    }

    protected void recomputeValues() {
        int n;
        int n2;
        int n3;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        int n4 = 0;
        int n5 = 0;
        int n6 = 0;
        this.constlik = 0.0;
        for (n3 = 0; n3 < this.getIntervalCount(); ++n3) {
            d += this.getInterval(n3);
            d3 = 0.0;
            for (n2 = n5; n2 < this.changePoints.getSize(); ++n2) {
                if (!(this.changePoints.getParameterValue(n2) <= d)) continue;
                ++n5;
                d3 += 1.0;
            }
            this.GPcounts.setParameterValue(n3, d3);
            this.GPcoalfactor[n3] = (double)this.getLineageCount(n3) * ((double)this.getLineageCount(n3) - 1.0) / 2.0;
            this.constlik += this.GPcoalfactor[n3] * this.getInterval(n3);
            if (this.getIntervalType(n3) != OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) continue;
            this.CoalPosIndicator[n4] = n3;
            d3 = 0.0;
            for (n2 = n6; n2 < this.changePoints.getSize(); ++n2) {
                if (this.changePoints.getParameterValue(n2) <= d) {
                    ++n6;
                    d3 += 1.0;
                    continue;
                }
                n2 = this.changePoints.getSize();
            }
            this.CoalCounts.setParameterValue(n4, d3 - 1.0);
            this.CoalTime[n4] = d;
            this.GPCoalInterval[n4] = d - d2;
            this.coalfactor.setParameterValue(n4, (double)(this.getLineageCount(n3) * (this.getLineageCount(n3) - 1)) / 2.0);
            ++n4;
            d2 = d;
        }
        n3 = 0;
        n2 = 0;
        for (n = 0; n < this.changePoints.getSize(); ++n) {
            if (this.GPtype.getParameterValue(n) != 1.0) continue;
            ++n3;
        }
        for (n = 0; n < this.CoalCounts.getSize(); ++n) {
            n2 = (int)((double)n2 + this.CoalCounts.getParameterValue(n));
        }
        if (n3 != this.CoalCounts.getSize()) {
            System.err.println("WARNING CONSISTENCY 1");
        }
        if (n2 != this.changePoints.getSize() - this.CoalCounts.getSize()) {
            System.err.println("WARNING CONSISTENCY 2:" + n2 + "and changePts size" + this.changePoints.getSize());
        }
        this.Tmrca.setParameterValue(0, this.CoalTime[n4 - 1]);
    }

    protected void setupSufficientStatistics() {
        double d = 0.0;
        double d2 = 0.0;
        int n = 0;
        this.constlik = 0.0;
        for (int i = 0; i < this.getIntervalCount(); ++i) {
            d += this.getInterval(i);
            this.GPcounts.setParameterValue(i, 0.0);
            this.GPcoalfactor[i] = (double)this.getLineageCount(i) * ((double)this.getLineageCount(i) - 1.0) / 2.0;
            this.constlik += this.GPcoalfactor[i] * this.getInterval(i);
            if (this.getIntervalType(i) != OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) continue;
            this.GPcounts.setParameterValue(i, 1.0);
            this.GPtype.setParameterValue(n, 1.0);
            this.CoalPosIndicator[n] = i;
            this.changePoints.setParameterValue(n, d);
            this.CoalCounts.setParameterValue(n, 0.0);
            this.CoalTime[n] = d;
            this.GPCoalInterval[n] = d - d2;
            this.coalfactor.setParameterValue(n, (double)(this.getLineageCount(i) * (this.getLineageCount(i) - 1)) / 2.0);
            ++n;
            d2 = d;
        }
        this.Tmrca.setParameterValue(0, this.CoalTime[n - 1]);
    }

    protected int getCorrectFieldLength() {
        return this.tree.getExternalNodeCount() - 1;
    }

    protected void setupQmatrix(double d) {
        double d2 = 1.0E-6;
        double[] dArray = new double[this.changePoints.getSize() - 1];
        double[] dArray2 = new double[this.changePoints.getSize()];
        for (int i = 0; i < this.changePoints.getSize() - 1; ++i) {
            dArray[i] = d * (-1.0 / (this.changePoints.getParameterValue(i + 1) - this.changePoints.getParameterValue(i)));
            if (i >= this.getCorrectFieldLength() - 2) continue;
            dArray2[i + 1] = -dArray[i] + d * (1.0 / (this.changePoints.getParameterValue(i + 2) - this.changePoints.getParameterValue(i + 1)) + d2);
        }
        dArray2[0] = -dArray[0] + d * d2;
        dArray2[this.getCorrectFieldLength() - 1] = -dArray[this.getCorrectFieldLength() - 2] + d * d2;
        this.weightMatrix = new SymmTridiagMatrix(dArray2, dArray);
    }

    protected void setupGPvalues() {
        this.setupQmatrix(this.precisionParameter.getParameterValue(0));
        int n = this.getCorrectFieldLength();
        DenseVector denseVector = new DenseVector(n);
        DenseVector denseVector2 = new DenseVector(n);
        for (int i = 0; i < n; ++i) {
            denseVector.set(i, MathUtils.nextGaussian());
        }
        UpperSPDBandMatrix upperSPDBandMatrix = new UpperSPDBandMatrix(this.weightMatrix, 1);
        BandCholesky bandCholesky = new BandCholesky(n, 1, true);
        bandCholesky.factor(upperSPDBandMatrix);
        UpperTriangBandMatrix upperTriangBandMatrix = bandCholesky.getU();
        upperTriangBandMatrix.solve(denseVector, denseVector2);
        for (int i = 0; i < n; ++i) {
            this.popSizeParameter.setParameterValue(i, 1.0);
        }
    }

    public Parameter getPrecisionParameter() {
        return this.precisionParameter;
    }

    public Parameter getPopSizeParameter() {
        return this.popSizeParameter;
    }

    public Parameter getNumPoints() {
        return this.numPoints;
    }

    public Parameter getLambdaParameter() {
        return this.lambdaParameter;
    }

    public Parameter getLambdaBoundParameter() {
        return this.lambda_boundParameter;
    }

    public Parameter getChangePoints() {
        return this.changePoints;
    }

    public double getAlphaParameter() {
        return this.alphaParameter.getParameterValue(0);
    }

    public double getBetaParameter() {
        return this.betaParameter.getParameterValue(0);
    }

    public double[] getGPcoalfactor() {
        return this.GPcoalfactor;
    }

    public Parameter getcoalfactor() {
        return this.coalfactor;
    }

    public Parameter getCoalCounts() {
        return this.CoalCounts;
    }

    public Parameter getGPtype() {
        return this.GPtype;
    }

    public Parameter getGPcounts() {
        return this.GPcounts;
    }

    public SymmTridiagMatrix getWeightMatrix() {
        return this.weightMatrix.copy();
    }

    public double[] getGPCoalInterval() {
        return this.GPCoalInterval;
    }

    public double[] getCoalTime() {
        return this.CoalTime;
    }

    public double getGPCoalInterval(int n) {
        return this.GPCoalInterval[n];
    }

    public int[] getCoalPosIndicator() {
        return this.CoalPosIndicator;
    }

    private class VariableLengthColumn
    extends LogColumn.Abstract {
        private final Parameter param;
        private static final String OPEN = "{";
        private static final String CLOSE = "}";
        private static final String DELIMIT = ",";

        public VariableLengthColumn(String string, Parameter parameter) {
            super(string);
            this.param = parameter;
        }

        @Override
        protected String getFormattedValue() {
            return this.convertToDelimited(this.param.getParameterValues());
        }

        private String convertToDelimited(double[] dArray) {
            StringBuilder stringBuilder = new StringBuilder(OPEN);
            int n = dArray.length;
            for (int i = 0; i < n; ++i) {
                stringBuilder.append(Double.toString(dArray[i]));
                if (i >= n - 1) continue;
                stringBuilder.append(DELIMIT);
            }
            stringBuilder.append(CLOSE);
            return stringBuilder.toString();
        }
    }
}

