/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.BranchSpecificFixedEffects;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.branchratemodel.NodeRateMap;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.DoubleBinaryOperator;

public class ArbitraryBranchRates
extends AbstractBranchRateModel
implements DifferentiableBranchRates,
Citable {
    private final TreeParameterModel rates;
    private final Parameter rateParameter;
    private final BranchRateTransform transform;
    List<Citation> citations = new ArrayList<Citation>();
    public static Citation CITATION = new Citation(new Author[]{new Author("X", "Ji"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, Citation.Status.IN_PREPARATION);

    public ArbitraryBranchRates(TreeModel treeModel, Parameter parameter, BranchRateTransform branchRateTransform, boolean bl) {
        this("arbitraryBranchRates", treeModel, parameter, branchRateTransform, bl);
    }

    public ArbitraryBranchRates(String string, TreeModel treeModel, Parameter parameter, BranchRateTransform branchRateTransform, boolean bl) {
        this(string, treeModel, parameter, branchRateTransform, bl, TreeParameterModel.Type.WITHOUT_ROOT);
    }

    public ArbitraryBranchRates(TreeModel treeModel, Parameter parameter, BranchRateTransform branchRateTransform, boolean bl, TreeParameterModel.Type type) {
        this("arbitraryBranchRates", treeModel, parameter, branchRateTransform, bl, type);
    }

    public ArbitraryBranchRates(String string, TreeModel treeModel, Parameter parameter, BranchRateTransform branchRateTransform, boolean bl, TreeParameterModel.Type type) {
        super(string);
        double d;
        this.transform = branchRateTransform;
        if (branchRateTransform instanceof Model) {
            this.addModel((Model)((Object)branchRateTransform));
        }
        if (bl) {
            d = branchRateTransform.center();
            for (int i = 0; i < parameter.getDimension(); ++i) {
                parameter.setValue(i, d);
            }
        }
        d = branchRateTransform.lower();
        double d2 = branchRateTransform.upper();
        Parameter.DefaultBounds defaultBounds = new Parameter.DefaultBounds(d2, d, parameter.getDimension());
        parameter.addBounds(defaultBounds);
        this.rates = new TreeParameterModel((MutableTreeModel)treeModel, parameter, type);
        this.rateParameter = parameter;
        this.addModel(this.rates);
    }

    public void setBranchRate(Tree tree, NodeRef nodeRef, double d) {
        this.rates.setNodeValue(tree, nodeRef, d);
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        double d = this.rates.getNodeValue(tree, nodeRef);
        return this.transform.differential(d, tree, nodeRef);
    }

    @Override
    public BranchRateTransform getTransform() {
        return this.transform;
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        return dArray;
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        return dArray;
    }

    @Override
    public void forEachOverRates(NodeRateMap nodeRateMap) {
        this.rates.forEach(nodeRateMap);
    }

    @Override
    public double mapReduceOverRates(NodeRateMap nodeRateMap, DoubleBinaryOperator doubleBinaryOperator, double d) {
        return this.rates.mapReduce(nodeRateMap, doubleBinaryOperator, d);
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        return this.transform.transform(this.getUntransformedBranchRate(tree, nodeRef), tree, nodeRef);
    }

    @Override
    public double getUntransformedBranchRate(Tree tree, NodeRef nodeRef) {
        return this.rates.getNodeValue(tree, nodeRef);
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.rates.getParameterIndexFromNodeNumber(nodeRef.getNumber());
    }

    public int getNodeNumberFromParameterIndex(int n) {
        return this.rates.getNodeNumberFromParameterIndex(n);
    }

    @Override
    public Parameter getRateParameter() {
        return this.rateParameter;
    }

    public boolean usingReciprocal() {
        return this.transform instanceof BranchRateTransform.Reciprocal;
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.rates) {
            this.fireModelChanged(object, n);
        } else if (model == this.transform) {
            this.fireModelChanged();
        } else {
            throw new RuntimeException("Unknown model");
        }
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
    }

    @Override
    protected void acceptState() {
    }

    public static BranchRateTransform make(boolean bl, boolean bl2, boolean bl3) {
        return ArbitraryBranchRates.make(bl, bl2, bl3, null, null);
    }

    public static BranchRateTransform make(boolean bl, boolean bl2, boolean bl3, BranchSpecificFixedEffects branchSpecificFixedEffects, Parameter parameter) {
        if ((bl || bl2) && (branchSpecificFixedEffects != null || parameter != null)) {
            throw new RuntimeException("Not yet implemented");
        }
        BranchRateTransform branchRateTransform = bl2 ? new BranchRateTransform.Exponentiate() : (bl ? new BranchRateTransform.Reciprocal() : (bl3 ? new BranchRateTransform.MultiplyByLocation("arbitraryBranchRates", branchSpecificFixedEffects) : (branchSpecificFixedEffects != null || parameter != null ? new BranchRateTransform.LocationScaleLogNormal("arbitraryBranchRates", branchSpecificFixedEffects, parameter) : new BranchRateTransform.None())));
        return branchRateTransform;
    }

    @Override
    public Tree getTree() {
        return this.rates.getTreeModel();
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        double d = this.rates.getNodeValue(tree, nodeRef);
        return this.transform.secondDifferential(d, tree, nodeRef);
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override
    public String getDescription() {
        return "Location-scale relaxed clock";
    }

    @Override
    public List<Citation> getCitations() {
        if (this.citations.isEmpty()) {
            return Collections.singletonList(CITATION);
        }
        return this.citations;
    }

    public static interface BranchRateTransform {
        public double differential(double var1, Tree var3, NodeRef var4);

        public double secondDifferential(double var1, Tree var3, NodeRef var4);

        public double transform(double var1, Tree var3, NodeRef var4);

        public double center();

        public double lower();

        public double upper();

        public double randomize(double var1);

        public static class LocationScaleLogNormal
        extends AbstractModel
        implements BranchRateTransform {
            private final BranchSpecificFixedEffects location;
            private final Parameter scale;
            private final double baseMeasureMu;
            private final double baseMeasureSigma;
            private double transformMu;
            private double transformSigma;
            private boolean transformKnown;

            LocationScaleLogNormal(String string, BranchSpecificFixedEffects branchSpecificFixedEffects, Parameter parameter) {
                this(string, branchSpecificFixedEffects, parameter, LocationScaleLogNormal.getMuPhi(1.0), LocationScaleLogNormal.getSigmaPhi(1.0));
            }

            LocationScaleLogNormal(String string, BranchSpecificFixedEffects branchSpecificFixedEffects, Parameter parameter, double d, double d2) {
                super(string);
                this.baseMeasureMu = d;
                this.baseMeasureSigma = d2;
                this.location = branchSpecificFixedEffects;
                this.scale = parameter;
                if (branchSpecificFixedEffects instanceof Model) {
                    this.addModel((Model)((Object)branchSpecificFixedEffects));
                }
                if (parameter != null) {
                    this.addVariable(parameter);
                }
                this.transformKnown = false;
            }

            public BranchSpecificFixedEffects getLocationObject() {
                return this.location;
            }

            public double getTransformMu() {
                return this.transformMu;
            }

            public double getTransformSigma() {
                return this.transformSigma;
            }

            public double getLocation(Tree tree, NodeRef nodeRef) {
                return this.location != null ? this.location.getEffect(tree, nodeRef) : 1.0;
            }

            public double getScale(Tree tree, NodeRef nodeRef) {
                return this.scale.getParameterValue(0);
            }

            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                double d2 = this.transform(d, tree, nodeRef);
                return d > 0.0 ? d2 * this.transformSigma / (d * this.baseMeasureSigma) : Double.POSITIVE_INFINITY;
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                double d2 = this.transform(d, tree, nodeRef);
                if (d > 0.0) {
                    return d2 * this.transformSigma / (d * d * this.baseMeasureSigma) * (this.transformSigma / this.baseMeasureSigma - 1.0);
                }
                if (this.transformSigma > this.baseMeasureSigma) {
                    return Double.POSITIVE_INFINITY;
                }
                if (this.transformSigma < this.baseMeasureSigma) {
                    return Double.NEGATIVE_INFINITY;
                }
                return 0.0;
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                if (!this.transformKnown) {
                    this.setupTransform();
                    this.transformKnown = true;
                }
                double d2 = LocationScaleLogNormal.logNormalTransform(d, this.baseMeasureMu, this.baseMeasureSigma, this.transformMu, this.transformSigma);
                if (this.location != null) {
                    d2 *= this.location.getEffect(tree, nodeRef);
                }
                return d2;
            }

            @Override
            public double center() {
                return 1.0;
            }

            @Override
            public double lower() {
                return 0.0;
            }

            @Override
            public double upper() {
                return Double.POSITIVE_INFINITY;
            }

            @Override
            public double randomize(double d) {
                return Math.exp(d);
            }

            @Override
            protected void handleModelChangedEvent(Model model, Object object, int n) {
                if (model != this.location) {
                    throw new RuntimeException("Unknown model");
                }
                this.transformKnown = false;
                this.fireModelChanged();
            }

            @Override
            protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
                if (variable != this.scale) {
                    throw new RuntimeException("Unknown variable");
                }
                this.transformKnown = false;
                this.fireModelChanged();
            }

            @Override
            protected void storeState() {
            }

            @Override
            protected void restoreState() {
                this.transformKnown = false;
            }

            @Override
            protected void acceptState() {
            }

            private void setupTransform() {
                double d = this.scale.getParameterValue(0) * this.scale.getParameterValue(0);
                this.transformMu = LocationScaleLogNormal.getMuPhi(d);
                this.transformSigma = LocationScaleLogNormal.getSigmaPhi(d);
            }

            private static double logNormalTransform(double d, double d2, double d3, double d4, double d5) {
                return Math.exp(d5 / d3 * (Math.log(d) - d2) + d4);
            }

            private static double getMuPhi(double d) {
                return -0.5 * Math.log(1.0 + d);
            }

            private static double getSigmaPhi(double d) {
                return Math.sqrt(Math.log(1.0 + d));
            }
        }

        public static class LocationShrinkage
        extends AbstractModel
        implements BranchRateTransform {
            private final BranchSpecificFixedEffects location;

            public LocationShrinkage(String string, BranchSpecificFixedEffects branchSpecificFixedEffects) {
                super(string);
                this.location = branchSpecificFixedEffects;
                if (branchSpecificFixedEffects instanceof Model) {
                    this.addModel((Model)((Object)branchSpecificFixedEffects));
                }
            }

            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                return this.transform(d, tree, nodeRef);
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                return this.transform(d, tree, nodeRef);
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                return this.location.getEffect(tree, nodeRef) * Math.exp(d);
            }

            @Override
            public double center() {
                return 0.0;
            }

            @Override
            public double lower() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public double upper() {
                return Double.POSITIVE_INFINITY;
            }

            @Override
            public double randomize(double d) {
                return d;
            }

            @Override
            protected void handleModelChangedEvent(Model model, Object object, int n) {
                this.fireModelChanged();
            }

            @Override
            protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
                this.fireModelChanged();
            }

            @Override
            protected void storeState() {
            }

            @Override
            protected void restoreState() {
            }

            @Override
            protected void acceptState() {
            }
        }

        public static class MultiplyByLocation
        extends AbstractModel
        implements BranchRateTransform {
            private final BranchSpecificFixedEffects location;

            MultiplyByLocation(String string, BranchSpecificFixedEffects branchSpecificFixedEffects) {
                super(string);
                this.location = branchSpecificFixedEffects;
                if (branchSpecificFixedEffects instanceof Model) {
                    this.addModel((Model)((Object)branchSpecificFixedEffects));
                }
            }

            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                return this.location.getEffect(tree, nodeRef);
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                return 0.0;
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                return this.location.getEffect(tree, nodeRef) * d;
            }

            @Override
            public double center() {
                return 1.0;
            }

            @Override
            public double lower() {
                return 0.0;
            }

            @Override
            public double upper() {
                return Double.POSITIVE_INFINITY;
            }

            @Override
            public double randomize(double d) {
                return Math.exp(d);
            }

            @Override
            protected void handleModelChangedEvent(Model model, Object object, int n) {
                this.fireModelChanged();
            }

            @Override
            protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
                throw new RuntimeException("Not yet implemented");
            }

            @Override
            protected void storeState() {
            }

            @Override
            protected void restoreState() {
            }

            @Override
            protected void acceptState() {
            }
        }

        public static class Exponentiate
        implements BranchRateTransform {
            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                return this.transform(d, null, null);
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                return this.transform(d, null, null);
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                return Math.exp(d);
            }

            @Override
            public double randomize(double d) {
                return d;
            }

            @Override
            public double center() {
                return 0.0;
            }

            @Override
            public double lower() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public double upper() {
                return Double.POSITIVE_INFINITY;
            }
        }

        public static class Reciprocal
        extends Base {
            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                return -1.0 / (d * d);
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                return 2.0 / (d * d * d);
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                return 1.0 / d;
            }

            @Override
            public double randomize(double d) {
                return Math.exp(-d);
            }
        }

        public static class None
        extends Base {
            @Override
            public double differential(double d, Tree tree, NodeRef nodeRef) {
                return 1.0;
            }

            @Override
            public double secondDifferential(double d, Tree tree, NodeRef nodeRef) {
                return 0.0;
            }

            @Override
            public double transform(double d, Tree tree, NodeRef nodeRef) {
                return d;
            }
        }

        public static abstract class Base
        implements BranchRateTransform {
            @Override
            public double center() {
                return 1.0;
            }

            @Override
            public double lower() {
                return 0.0;
            }

            @Override
            public double upper() {
                return Double.POSITIVE_INFINITY;
            }

            @Override
            public double randomize(double d) {
                return Math.exp(d);
            }
        }
    }
}

