/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.ht;

import java.io.Serializable;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import weka.classifiers.trees.ht.ConditionalSufficientStats;
import weka.classifiers.trees.ht.SplitCandidate;
import weka.classifiers.trees.ht.SplitMetric;
import weka.classifiers.trees.ht.UnivariateNominalMultiwaySplit;
import weka.classifiers.trees.ht.WeightMass;
import weka.core.Utils;

public class NominalConditionalSufficientStats
extends ConditionalSufficientStats
implements Serializable {
    private static final long serialVersionUID = -669902060601313488L;
    protected double m_totalWeight;
    protected double m_missingWeight;

    @Override
    public void update(double attVal, String classVal, double weight) {
        if (Utils.isMissingValue(attVal)) {
            this.m_missingWeight += weight;
        } else {
            new Integer((int)attVal);
            ValueDistribution valDist = (ValueDistribution)this.m_classLookup.get(classVal);
            if (valDist == null) {
                valDist = new ValueDistribution();
                valDist.add((int)attVal, weight);
                this.m_classLookup.put(classVal, valDist);
            } else {
                valDist.add((int)attVal, weight);
            }
        }
        this.m_totalWeight += weight;
    }

    @Override
    public double probabilityOfAttValConditionedOnClass(double attVal, String classVal) {
        ValueDistribution valDist = (ValueDistribution)this.m_classLookup.get(classVal);
        if (valDist != null) {
            double prob = valDist.getWeight((int)attVal) / valDist.sum();
            return prob;
        }
        return 0.0;
    }

    protected List<Map<String, WeightMass>> classDistsAfterSplit() {
        HashMap splitDists = new HashMap();
        for (Map.Entry cls : this.m_classLookup.entrySet()) {
            String classVal = (String)cls.getKey();
            ValueDistribution attDist = (ValueDistribution)cls.getValue();
            for (Map.Entry<Integer, WeightMass> att : attDist.m_dist.entrySet()) {
                WeightMass clsCount;
                Integer attVal = att.getKey();
                WeightMass attCount = att.getValue();
                HashMap<String, WeightMass> clsDist = (HashMap<String, WeightMass>)splitDists.get(attVal);
                if (clsDist == null) {
                    clsDist = new HashMap<String, WeightMass>();
                    splitDists.put(attVal, clsDist);
                }
                if ((clsCount = (WeightMass)clsDist.get(classVal)) == null) {
                    clsCount = new WeightMass();
                    clsDist.put(classVal, clsCount);
                }
                clsCount.m_weight += attCount.m_weight;
            }
        }
        LinkedList<Map<String, WeightMass>> result = new LinkedList<Map<String, WeightMass>>();
        for (Map.Entry v : splitDists.entrySet()) {
            result.add((Map<String, WeightMass>)v.getValue());
        }
        return result;
    }

    @Override
    public SplitCandidate bestSplit(SplitMetric splitMetric, Map<String, WeightMass> preSplitDist, String attName) {
        List<Map<String, WeightMass>> postSplitDists = this.classDistsAfterSplit();
        double merit = splitMetric.evaluateSplit(preSplitDist, postSplitDists);
        SplitCandidate candidate = new SplitCandidate(new UnivariateNominalMultiwaySplit(attName), postSplitDists, merit);
        return candidate;
    }

    protected class ValueDistribution
    implements Serializable {
        private static final long serialVersionUID = -61711544350888154L;
        protected final Map<Integer, WeightMass> m_dist = new LinkedHashMap<Integer, WeightMass>();
        private double m_sum;

        protected ValueDistribution() {
        }

        public void add(int val, double weight) {
            WeightMass count = this.m_dist.get(val);
            if (count == null) {
                count = new WeightMass();
                count.m_weight = 1.0;
                this.m_sum += 1.0;
                this.m_dist.put(val, count);
            }
            count.m_weight += weight;
            this.m_sum += weight;
        }

        public void delete(int val, double weight) {
            WeightMass count = this.m_dist.get(val);
            if (count != null) {
                count.m_weight -= weight;
                this.m_sum -= weight;
            }
        }

        public double getWeight(int val) {
            WeightMass count = this.m_dist.get(val);
            if (count != null) {
                return count.m_weight;
            }
            return 0.0;
        }

        public double sum() {
            return this.m_sum;
        }
    }
}

