/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.tree;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.tree.TreeStatistic;
import dr.stats.DiscreteStatistics;

public class MixedEffectsRateStatistic
extends TreeStatistic {
    private Tree tree;
    private ArbitraryBranchRates branchRateModel;
    private boolean internal;
    private boolean external;
    private String mode;
    private boolean logScale;
    private double[] rates;
    private double[] locations;

    public MixedEffectsRateStatistic(String string, Tree tree, ArbitraryBranchRates arbitraryBranchRates, boolean bl, boolean bl2, boolean bl3, String string2) {
        super(string);
        this.tree = tree;
        this.branchRateModel = arbitraryBranchRates;
        this.internal = bl2;
        this.external = bl;
        this.mode = string2;
        this.logScale = bl3;
        if (!(arbitraryBranchRates.getTransform() instanceof ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal)) {
            throw new RuntimeException("MixedEffectsRateStatistic currently only supports LocationScaleLogNormal models.");
        }
    }

    @Override
    public void setTree(Tree tree) {
        this.tree = tree;
    }

    @Override
    public Tree getTree() {
        return this.tree;
    }

    @Override
    public int getDimension() {
        return 1;
    }

    private void prepareForComputation() {
        int n;
        int n2 = 0;
        int n3 = 0;
        if (this.external) {
            n3 = n2 += this.tree.getExternalNodeCount();
        }
        if (this.internal) {
            n2 += this.tree.getInternalNodeCount() - 1;
        }
        this.rates = new double[n2];
        this.locations = new double[n2];
        ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal locationScaleLogNormal = (ArbitraryBranchRates.BranchRateTransform.LocationScaleLogNormal)this.branchRateModel.getTransform();
        for (n = 0; n < n3; ++n) {
            NodeRef nodeRef = this.tree.getExternalNode(n);
            this.rates[n] = this.branchRateModel.getBranchRate(this.tree, nodeRef);
            this.locations[n] = locationScaleLogNormal.getLocation(this.tree, nodeRef);
        }
        if (this.internal) {
            n = this.tree.getInternalNodeCount();
            int n4 = n3;
            for (int i = 0; i < n; ++i) {
                NodeRef nodeRef = this.tree.getInternalNode(i);
                if (this.tree.isRoot(nodeRef)) continue;
                this.rates[n4] = this.branchRateModel.getBranchRate(this.tree, nodeRef);
                this.locations[n4] = locationScaleLogNormal.getLocation(this.tree, nodeRef);
                ++n4;
            }
        }
    }

    private void takeLogs() {
        for (int i = 0; i < this.locations.length; ++i) {
            this.locations[i] = Math.log(this.locations[i]);
            this.rates[i] = Math.log(this.rates[i]);
        }
    }

    private double[] getResiduals() {
        double[] dArray = new double[this.locations.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = this.rates[i] - this.locations[i];
        }
        return dArray;
    }

    @Override
    public double getStatisticValue(int n) {
        this.prepareForComputation();
        if (this.logScale) {
            this.takeLogs();
        }
        double[] dArray = this.getResiduals();
        if (this.mode.equals("meanOfResiduals")) {
            return DiscreteStatistics.mean(dArray);
        }
        if (this.mode.equals("varianceOfResiduals")) {
            return DiscreteStatistics.variance(dArray);
        }
        if (this.mode.equals("proportionOfVarianceExplained")) {
            return 1.0 - DiscreteStatistics.variance(dArray) / DiscreteStatistics.variance(this.rates);
        }
        throw new IllegalArgumentException();
    }
}

