/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.ArrayList;
import java.util.List;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.options.ClassOption;

public class LearnNSE
extends AbstractClassifier
implements MultiClassClassifier {
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "bayes.NaiveBayes");
    public IntOption periodOption = new IntOption("period", 'p', "Size of the environments.", 250, 1, Integer.MAX_VALUE);
    public FloatOption sigmoidSlopeOption = new FloatOption("sigmoidSlope", 'a', "Slope of the sigmoid function controlling the number of previous periods taken into account during weighting.", 0.5, 0.0, 3.4028234663852886E38);
    public FloatOption sigmoidCrossingPointOption = new FloatOption("sigmoidCrossingPoint", 'b', "Halfway crossing point of the sigmoid function controlling the number of previous periods taken into account during weighting.", 10.0, 0.0, 3.4028234663852886E38);
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 'e', "Ensemble size.", 15, 1, Integer.MAX_VALUE);
    public MultiChoiceOption pruningStrategyOption = new MultiChoiceOption("pruningStrategy", 's', "Classifiers pruning strategy to be used.", new String[]{"NO", "AGE", "ERROR"}, new String[]{"Don't prune classifiers", "Age-based", "Error-based"}, 0);
    protected List<Classifier> ensemble;
    protected List<Double> ensembleWeights;
    protected List<ArrayList<Double>> bkts;
    protected List<ArrayList<Double>> wkts;
    protected Instances buffer;
    protected long index;
    protected double slope;
    protected double crossingPoint;
    protected int pruning;
    protected int ensembleSize;

    @Override
    public void resetLearningImpl() {
        this.ensemble = new ArrayList<Classifier>();
        this.ensembleWeights = new ArrayList<Double>();
        this.bkts = new ArrayList<ArrayList<Double>>();
        this.wkts = new ArrayList<ArrayList<Double>>();
        this.index = 0L;
        this.buffer = null;
        this.slope = this.sigmoidSlopeOption.getValue();
        this.crossingPoint = this.sigmoidCrossingPointOption.getValue();
        this.pruning = this.pruningStrategyOption.getChosenIndex();
        this.ensembleSize = this.ensembleSizeOption.getValue();
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
        ++this.index;
        if (this.buffer == null) {
            this.buffer = new Instances(inst.dataset());
        }
        this.buffer.add(inst);
        if (this.index % (long)this.periodOption.getValue() == 0L) {
            this.index = 0L;
            double mt = this.buffer.numInstances();
            Classifier classifier = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
            classifier.resetLearning();
            if (this.ensemble.size() > 0) {
                Instance instance;
                double et = 0.0;
                int i = 0;
                while ((double)i < mt) {
                    boolean vote = this.correctlyClassifies(this.buffer.instance(i));
                    if (!vote) {
                        et += 1.0 / mt;
                    }
                    ++i;
                }
                double weightSum = 0.0;
                int i2 = 0;
                while ((double)i2 < mt) {
                    instance = this.buffer.instance(i2);
                    boolean vote = this.correctlyClassifies(instance);
                    double error = 1.0 / mt * (vote ? et : 1.0);
                    instance.setWeight(error);
                    weightSum += error;
                    ++i2;
                }
                i2 = 0;
                while ((double)i2 < mt) {
                    instance = this.buffer.instance(i2);
                    instance.setWeight(instance.weight() / weightSum);
                    Instance trainingInstance = instance.copy();
                    trainingInstance.setWeight(1.0);
                    classifier.trainOnInstance(trainingInstance);
                    ++i2;
                }
            } else {
                int i = 0;
                while ((double)i < mt) {
                    Instance instance = this.buffer.instance(i);
                    instance.setWeight(1.0 / mt);
                    Instance trainingInstance = instance.copy();
                    trainingInstance.setWeight(1.0);
                    classifier.trainOnInstance(trainingInstance);
                    ++i;
                }
            }
            this.ensemble.add(classifier);
            this.bkts.add(new ArrayList());
            this.wkts.add(new ArrayList());
            this.ensembleWeights.clear();
            int t = this.ensemble.size();
            double maxError = Double.NEGATIVE_INFINITY;
            int errorIndex = Integer.MIN_VALUE;
            for (int k = 1; k <= t; ++k) {
                double ekt = 0.0;
                int i = 0;
                while ((double)i < mt) {
                    Instance instance = this.buffer.instance(i);
                    if (!this.ensemble.get(k - 1).correctlyClassifies(instance)) {
                        ekt += instance.weight();
                    }
                    ++i;
                }
                if (k == t && ekt > 0.5) {
                    Classifier c = (Classifier)this.getPreparedClassOption(this.baseLearnerOption);
                    c.resetLearning();
                    this.ensemble.set(k - 1, c);
                } else if (ekt > 0.5) {
                    ekt = 0.5;
                }
                if (ekt > maxError) {
                    maxError = ekt;
                    errorIndex = k;
                }
                double bkt = ekt / (1.0 - ekt);
                ArrayList<Double> nbkt = this.bkts.get(k - 1);
                nbkt.add(bkt);
                double wkt = 1.0 / (1.0 + Math.exp(-this.slope * ((double)(t - k) - this.crossingPoint)));
                List weights = this.wkts.get(k - 1);
                double sum = 0.0;
                for (Double weight : weights) {
                    sum += weight.doubleValue();
                }
                weights.add(wkt / (sum + wkt));
                double sbkt = 0.0;
                for (int j = 0; j < weights.size(); ++j) {
                    sbkt += (Double)weights.get(j) * nbkt.get(j);
                }
                this.ensembleWeights.add(Math.log(1.0 / sbkt));
            }
            if (this.pruning == 1 && t > this.ensembleSize) {
                this.ensemble.remove(0);
                this.ensembleWeights.remove(0);
                this.bkts.remove(0);
                this.wkts.remove(0);
            } else if (this.pruning == 2 && t > this.ensembleSize) {
                this.ensemble.remove(errorIndex - 1);
                this.ensembleWeights.remove(errorIndex - 1);
                this.bkts.remove(errorIndex - 1);
                this.wkts.remove(errorIndex - 1);
            }
            this.buffer = new Instances(this.getModelContext());
        }
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        DoubleVector combinedVote = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0) {
            for (int i = 0; i < this.ensemble.size(); ++i) {
                DoubleVector vote;
                if (!(this.ensembleWeights.get(i) > 0.0) || !((vote = new DoubleVector(this.ensemble.get(i).getVotesForInstance(inst))).sumOfValues() > 0.0)) continue;
                vote.normalize();
                vote.scaleValues(this.ensembleWeights.get(i));
                combinedVote.addValues(vote);
            }
        }
        return combinedVote.getArrayRef();
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurements = null;
        if (this.ensembleWeights != null) {
            measurements = new Measurement[this.ensembleWeights.size()];
            for (int i = 0; i < this.ensembleWeights.size(); ++i) {
                measurements[i] = new Measurement("member weight " + (i + 1), this.ensembleWeights.get(i));
            }
        }
        return measurements;
    }
}

