/*
 * Decompiled with CFR 0.152.
 */
package moa.tasks;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.StringUtils;
import moa.core.TimingUtils;
import moa.evaluation.LearningEvaluation;
import moa.evaluation.LearningPerformanceEvaluator;
import moa.evaluation.preview.LearningCurve;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.streams.CachedInstancesStream;
import moa.streams.ExampleStream;
import moa.tasks.ClassificationMainTask;
import moa.tasks.TaskMonitor;

public class EvaluatePeriodicHeldOutTest
extends ClassificationMainTask {
    private static final long serialVersionUID = 1L;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", MultiClassClassifier.class, "moa.classifiers.trees.HoeffdingTree");
    public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", ExampleStream.class, "generators.RandomTreeGenerator");
    public ClassOption evaluatorOption = new ClassOption("evaluator", 'e', "Learning performance evaluation method.", LearningPerformanceEvaluator.class, "BasicClassificationPerformanceEvaluator");
    public IntOption testSizeOption = new IntOption("testSize", 'n', "Number of testing examples.", 1000000, 0, Integer.MAX_VALUE);
    public IntOption trainSizeOption = new IntOption("trainSize", 'i', "Number of training examples, <1 = unlimited.", 0, 0, Integer.MAX_VALUE);
    public IntOption trainTimeOption = new IntOption("trainTime", 't', "Number of training seconds.", 36000, 0, Integer.MAX_VALUE);
    public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "Number of training examples between samples of learning performance.", 100000, 0, Integer.MAX_VALUE);
    public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true);
    public FlagOption cacheTestOption = new FlagOption("cacheTest", 'c', "Cache test instances in memory.");

    @Override
    public String getPurposeString() {
        return "Evaluates a classifier on a stream by periodically testing on a heldout set.";
    }

    @Override
    protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
        Learner learner = (Learner)this.getPreparedClassOption(this.learnerOption);
        ExampleStream stream = (ExampleStream)this.getPreparedClassOption(this.streamOption);
        LearningPerformanceEvaluator evaluator = (LearningPerformanceEvaluator)this.getPreparedClassOption(this.evaluatorOption);
        learner.setModelContext(stream.getHeader());
        long instancesProcessed = 0L;
        LearningCurve learningCurve = new LearningCurve("evaluation instances");
        File dumpFile = this.dumpFileOption.getFile();
        PrintStream immediateResultStream = null;
        if (dumpFile != null) {
            try {
                immediateResultStream = dumpFile.exists() ? new PrintStream(new FileOutputStream(dumpFile, true), true) : new PrintStream(new FileOutputStream(dumpFile), true);
            }
            catch (Exception ex) {
                throw new RuntimeException("Unable to open immediate result file: " + dumpFile, ex);
            }
        }
        boolean firstDump = true;
        ExampleStream testStream = null;
        int testSize = this.testSizeOption.getValue();
        if (this.cacheTestOption.isSet()) {
            monitor.setCurrentActivity("Caching test examples...", -1.0);
            Instances testInstances = new Instances(stream.getHeader(), this.testSizeOption.getValue());
            while (testInstances.numInstances() < testSize) {
                testInstances.add((Instance)stream.nextInstance().getData());
                if (testInstances.numInstances() % 10 != 0) continue;
                if (monitor.taskShouldAbort()) {
                    return null;
                }
                monitor.setCurrentActivityFractionComplete((double)testInstances.numInstances() / (double)this.testSizeOption.getValue());
            }
            testStream = new CachedInstancesStream(testInstances);
        } else {
            testStream = stream;
        }
        instancesProcessed = 0L;
        TimingUtils.enablePreciseTiming();
        double totalTrainTime = 0.0;
        while ((this.trainSizeOption.getValue() < 1 || instancesProcessed < (long)this.trainSizeOption.getValue()) && stream.hasMoreInstances()) {
            Measurement[] modelMeasurements;
            Measurement[] performanceMeasurements;
            monitor.setCurrentActivityDescription("Training...");
            long instancesTarget = instancesProcessed + (long)this.sampleFrequencyOption.getValue();
            long trainStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
            while (instancesProcessed < instancesTarget && stream.hasMoreInstances()) {
                learner.trainOnInstance(stream.nextInstance());
                if (++instancesProcessed % 10L != 0L) continue;
                if (monitor.taskShouldAbort()) {
                    return null;
                }
                monitor.setCurrentActivityFractionComplete((double)instancesProcessed / (double)this.trainSizeOption.getValue());
            }
            double lastTrainTime = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - trainStartTime);
            if ((totalTrainTime += lastTrainTime) > (double)this.trainTimeOption.getValue()) break;
            if (this.cacheTestOption.isSet()) {
                testStream.restart();
            }
            evaluator.reset();
            long testInstancesProcessed = 0L;
            monitor.setCurrentActivityDescription("Testing (after " + StringUtils.doubleToString((double)instancesProcessed / (double)this.trainSizeOption.getValue() * 100.0, 2) + "% training)...");
            long testStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
            int instCount = 0;
            for (instCount = 0; instCount < testSize && stream.hasMoreInstances(); ++instCount) {
                Object testInst = testStream.nextInstance();
                double trueClass = ((Instance)testInst.getData()).classValue();
                double[] prediction = learner.getVotesForInstance(testInst);
                evaluator.addResult(testInst, prediction);
                if (++testInstancesProcessed % 10L != 0L) continue;
                if (monitor.taskShouldAbort()) {
                    return null;
                }
                monitor.setCurrentActivityFractionComplete((double)testInstancesProcessed / (double)testSize);
            }
            if (instCount != testSize) break;
            double testTime = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - testStartTime);
            ArrayList<Measurement> measurements = new ArrayList<Measurement>();
            measurements.add(new Measurement("evaluation instances", instancesProcessed));
            measurements.add(new Measurement("total train time", totalTrainTime));
            measurements.add(new Measurement("total train speed", (double)instancesProcessed / totalTrainTime));
            measurements.add(new Measurement("last train time", lastTrainTime));
            measurements.add(new Measurement("last train speed", (double)this.sampleFrequencyOption.getValue() / lastTrainTime));
            measurements.add(new Measurement("test time", testTime));
            measurements.add(new Measurement("test speed", (double)this.testSizeOption.getValue() / testTime));
            for (Measurement measurement : performanceMeasurements = evaluator.getPerformanceMeasurements()) {
                measurements.add(measurement);
            }
            for (Measurement measurement : modelMeasurements = learner.getModelMeasurements()) {
                measurements.add(measurement);
            }
            learningCurve.insertEntry(new LearningEvaluation(measurements.toArray(new Measurement[measurements.size()])));
            if (immediateResultStream != null) {
                if (firstDump) {
                    immediateResultStream.println(learningCurve.headerToString());
                    firstDump = false;
                }
                immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
                immediateResultStream.flush();
            }
            if (!monitor.resultPreviewRequested()) continue;
            monitor.setLatestResultPreview(learningCurve.copy());
        }
        if (immediateResultStream != null) {
            immediateResultStream.close();
        }
        return learningCurve;
    }

    @Override
    public Class<?> getTaskResultType() {
        return LearningCurve.class;
    }
}

