/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.tree;

import dr.evolution.distance.DistanceMatrix;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class RzhetskyNeiBranchLengthsTree
extends SimpleTree {
    private final DistanceMatrix distanceMatrix;
    private final Set<Integer> allTaxonSet;

    public RzhetskyNeiBranchLengthsTree(Tree tree, DistanceMatrix distanceMatrix) {
        super(tree);
        this.distanceMatrix = distanceMatrix;
        HashMap<NodeRef, Set<Integer>> hashMap = new HashMap<NodeRef, Set<Integer>>();
        this.allTaxonSet = new HashSet<Integer>(this.getTaxonSets(this, this.getRoot(), hashMap));
    }

    private Set<Integer> getTaxonSets(Tree tree, NodeRef nodeRef, Map<NodeRef, Set<Integer>> map) {
        HashSet<Integer> hashSet = new HashSet<Integer>();
        if (tree.isExternal(nodeRef)) {
            hashSet.add(nodeRef.getNumber());
        } else {
            assert (tree.getChildCount(nodeRef) == 2) : "Must be a strictly bifurcating tree";
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                hashSet.addAll(this.getTaxonSets(tree, this.getChild(nodeRef, i), map));
            }
        }
        map.put(nodeRef, hashSet);
        return hashSet;
    }

    private void calculateBranchLengths(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, Map<NodeRef, Set<Integer>> map) {
        double d;
        if (tree.isExternal(nodeRef)) {
            Set<Integer> set = map.get(nodeRef);
            Set<Integer> set2 = map.get(nodeRef2);
            HashSet<Integer> hashSet = new HashSet<Integer>(this.allTaxonSet);
            hashSet.removeAll(set);
            hashSet.removeAll(set2);
            double d2 = hashSet.size();
            double d3 = set2.size();
            double d4 = this.getSumOfDistances(set, hashSet);
            double d5 = this.getSumOfDistances(set, set2);
            double d6 = this.getSumOfDistances(hashSet, set2);
            d = 0.5 * (d4 / d2 + d5 / d3 - d6 / (d2 * d3));
        } else {
            NodeRef nodeRef3 = this.getChild(nodeRef, 0);
            NodeRef nodeRef4 = this.getChild(nodeRef, 1);
            this.calculateBranchLengths(tree, nodeRef3, nodeRef4, map);
            this.calculateBranchLengths(tree, nodeRef4, nodeRef3, map);
            Set<Integer> set = map.get(nodeRef3);
            Set<Integer> set3 = map.get(nodeRef4);
            Set<Integer> set4 = map.get(nodeRef2);
            HashSet<Integer> hashSet = new HashSet<Integer>(this.allTaxonSet);
            hashSet.removeAll(set);
            hashSet.removeAll(set3);
            hashSet.removeAll(set4);
            double d7 = hashSet.size();
            double d8 = set4.size();
            double d9 = set.size();
            double d10 = set3.size();
            double d11 = (d8 * d9 + d7 * d10) / ((d7 + d8) * (d9 + d10));
            double d12 = this.getSumOfDistances(hashSet, set);
            double d13 = this.getSumOfDistances(set4, set3);
            double d14 = this.getSumOfDistances(set4, set);
            double d15 = this.getSumOfDistances(hashSet, set3);
            double d16 = this.getSumOfDistances(hashSet, set4);
            double d17 = this.getSumOfDistances(set, set3);
            d = 0.5 * (d11 * (d12 / d7 * d9 + d13 / d8 * d10) + (1.0 - d11) * (d14 / d8 * d9 + d15 / d7 * d10) - d16 / d7 * d8 - d17 / d9 * d10);
        }
        this.setBranchLength(nodeRef, d);
    }

    private double getSumOfDistances(Set<Integer> set, Set<Integer> set2) {
        double d = 0.0;
        for (int n : set) {
            for (int n2 : set2) {
                d += this.distanceMatrix.getElement(n, n2);
            }
        }
        return d;
    }
}

