/*
 * Decompiled with CFR 0.152.
 */
package org.ohdsi.metaAnalysis;

import dr.inference.loggers.ArrayLogFormatter;
import dr.inference.loggers.LogFormatter;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.Logger;
import dr.inference.loggers.MCLogger;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCOptions;
import dr.inference.model.Likelihood;
import dr.inference.operators.OperatorSchedule;
import dr.inference.trace.Trace;
import dr.inference.trace.TraceCorrelation;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.List;
import org.ohdsi.metaAnalysis.Analysis;

public class Runner {
    private final Likelihood joint;
    private final OperatorSchedule schedule;
    private final Logger[] logger;
    final long chainLength;
    final int burnIn;
    final int subSampleFrequency;
    int consoleWidth = 175;

    public Runner(Analysis analysis, int chainLength, int burnIn, int subSampleFrequency, double seed) {
        MathUtils.setSeed((long)Math.round(seed));
        this.joint = analysis.getJoint();
        this.schedule = analysis.getSchedule();
        this.logger = this.getLogger(analysis.getLoggerColumns(), subSampleFrequency);
        this.chainLength = chainLength;
        this.burnIn = burnIn;
        this.subSampleFrequency = subSampleFrequency;
    }

    private Logger[] getLogger(List<Loggable> columns, int subSampleFrequency) {
        ArrayLogFormatter formatter = new ArrayLogFormatter(false);
        MCLogger memory = new MCLogger((LogFormatter)formatter, (long)subSampleFrequency, false);
        for (Loggable column : columns) {
            memory.add(column);
        }
        Logger callback = new Logger(){

            public void startLogging() {
            }

            public void log(long iteration) {
                if (iteration % 10000L == 0L) {
                    this.progressPercentage(iteration, Runner.this.chainLength);
                }
            }

            private void progressPercentage(long done, long total) {
                if (done > total) {
                    throw new IllegalArgumentException();
                }
                int barSize = Runner.this.consoleWidth - 10;
                int donePercent = (int)(100L * done / total);
                int doneChars = (int)((long)barSize * done / total);
                String bar = "|" + new String(new char[doneChars]).replace('\u0000', '=') + new String(new char[barSize - doneChars]).replace('\u0000', ' ') + "|";
                System.out.print("\r" + bar + " " + donePercent + "%");
                if (done == total) {
                    System.out.print("\n");
                }
            }

            public void stopLogging() {
            }
        };
        return new Logger[]{memory, callback};
    }

    public void run() {
        MCMC mcmc = new MCMC("mcmc1");
        mcmc.setShowOperatorAnalysis(true);
        mcmc.init(Runner.getOptions(this.chainLength), this.joint, this.schedule, this.logger);
        mcmc.run();
    }

    public void setConsoleWidth(int consoleWidth) {
        this.consoleWidth = consoleWidth;
    }

    public String[] getParameterNames() {
        List<Trace> traces = this.getTraces(this.logger[0]);
        String[] parameterNames = new String[traces.size() - 1];
        int i = 1;
        while (i < traces.size()) {
            Trace trace = traces.get(i);
            parameterNames[i - 1] = trace.getName();
            ++i;
        }
        return parameterNames;
    }

    public double[] getTrace(int parameterIndex) {
        List<Trace> traces = this.getTraces(this.logger[0]);
        if (parameterIndex < 0 || parameterIndex >= traces.size()) {
            throw new RuntimeException("Parameter index out of range. Maximum index = " + traces.size());
        }
        Trace trace = traces.get(parameterIndex);
        int start = this.burnIn / this.subSampleFrequency + 1;
        List values = trace.getValues(start, trace.getValueCount());
        double[] primitives = new double[values.size()];
        int i = 0;
        while (i < values.size()) {
            primitives[i] = (Double)values.get(i);
            ++i;
        }
        return primitives;
    }

    public void processSamples() {
        List<Trace> traces = this.getTraces(this.logger[0]);
        int i = 1;
        while (i < traces.size()) {
            Trace trace = traces.get(i);
            int start = this.burnIn / this.subSampleFrequency + 1;
            List values = trace.getValues(start, trace.getValueCount());
            TraceCorrelation statistics = new TraceCorrelation(values, trace.getTraceType(), (long)this.subSampleFrequency);
            System.out.println(String.valueOf(trace.getName()) + " " + statistics.getMean() + " " + statistics.getStdError() + " " + statistics.getESS() + " " + statistics.getSize());
            ++i;
        }
    }

    private static MCMCOptions getOptions(long chainLength) {
        boolean useAdaptation = true;
        long adaptationDelay = chainLength / 100L;
        double adaptationTarget = 0.234;
        boolean useSmoothAcceptanceRatio = false;
        double temperature = 1.0;
        long fullEvaluationCount = 1000L;
        double evaluationTestThreshold = 0.1;
        int minOperatorCountForFullEvaluation = 1;
        return new MCMCOptions(chainLength, fullEvaluationCount, minOperatorCountForFullEvaluation, evaluationTestThreshold, useAdaptation, adaptationDelay, adaptationTarget, useSmoothAcceptanceRatio, temperature);
    }

    private List<Trace> getTraces(Logger logger) {
        ArrayList<Trace> traceList = new ArrayList<Trace>();
        if (logger instanceof MCLogger) {
            for (LogFormatter f : ((MCLogger)logger).getFormatters()) {
                if (!(f instanceof ArrayLogFormatter)) continue;
                traceList.addAll(((ArrayLogFormatter)f).getTraces());
            }
        }
        return traceList;
    }
}

