/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.dynamic.learning.parametric.bayesian;

import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.exponentialfamily.EF_LearningBayesianNetwork;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.variables.Variable;
import eu.amidst.dynamic.datastream.DynamicDataInstance;
import eu.amidst.dynamic.learning.parametric.bayesian.BayesianLearningAlgorithm;
import eu.amidst.dynamic.learning.parametric.bayesian.PlateauStructure;
import eu.amidst.dynamic.models.DynamicBayesianNetwork;
import eu.amidst.dynamic.models.DynamicDAG;
import java.util.List;
import java.util.stream.Stream;

public class SVB
implements BayesianLearningAlgorithm {
    EF_LearningBayesianNetwork ef_extendedBNTime0;
    EF_LearningBayesianNetwork ef_extendedBNTimeT;
    PlateauStructure plateauStructure = new PlateauStructure();
    DynamicDAG dag;
    DataStream<DynamicDataInstance> dataStream;
    double elbo;
    boolean parallelMode = false;
    int windowsSize = 100;
    int seed = 0;

    public PlateauStructure getPlateauStructure() {
        return this.plateauStructure;
    }

    public SVB() {
        this.plateauStructure.setNRepetitions(this.windowsSize);
    }

    public int getSeed() {
        return this.seed;
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
    }

    @Override
    public double getLogMarginalProbability() {
        return this.elbo;
    }

    @Override
    public int getWindowsSize() {
        return this.windowsSize;
    }

    @Override
    public void setWindowsSize(int windowsSize) {
        this.windowsSize = windowsSize;
        this.plateauStructure.setNRepetitions(windowsSize);
    }

    public void setMaxIter(int maxIter) {
        this.plateauStructure.getVMPTime0().setMaxIter(maxIter);
        this.plateauStructure.getVMPTimeT().setMaxIter(maxIter);
    }

    public void setThreshold(double threshold) {
        this.plateauStructure.getVMPTime0().setThreshold(threshold);
        this.plateauStructure.getVMPTimeT().setThreshold(threshold);
    }

    @Override
    public void setOutput(boolean output) {
        this.plateauStructure.getVMPTime0().setOutput(output);
        this.plateauStructure.getVMPTimeT().setOutput(output);
    }

    @Override
    public void runLearning() {
        this.initLearning();
        if (!this.parallelMode) {
            this.elbo = ((Stream)this.dataStream.streamOfBatches(this.windowsSize).sequential()).mapToDouble(this::updateModel).sum();
        }
    }

    @Override
    public void setParallelMode(boolean parallelMode) {
        this.parallelMode = parallelMode;
    }

    @Override
    public double updateModel(DataOnMemory<DynamicDataInstance> batch) {
        List<DynamicDataInstance> data = batch.getList();
        double logprob = 0.0;
        if (batch.getDataInstance(0).getTimeID() == 0L) {
            logprob += this.updateModelTime0(batch.getDataInstance(0));
            data.remove(0);
            if (data.size() == 0) {
                return logprob;
            }
        }
        return logprob += this.updateModelTimeT(data);
    }

    private double updateModelTime0(DynamicDataInstance dataInstance) {
        this.plateauStructure.setEvidenceTime0(dataInstance);
        this.plateauStructure.runInferenceTime0();
        for (Variable var : this.plateauStructure.getEFLearningBNTime0().getParametersVariables()) {
            EF_UnivariateDistribution uni = ((EF_UnivariateDistribution)this.plateauStructure.getEFParameterPosteriorTime0(var)).deepCopy();
            this.plateauStructure.getEFLearningBNTime0().setDistribution(var, uni);
            this.plateauStructure.getNodeOfVarTime0(var).setPDist(uni);
        }
        return this.plateauStructure.getLogProbabilityOfEvidenceTime0();
    }

    private double updateModelTimeT(List<DynamicDataInstance> batch) {
        this.plateauStructure.setEvidenceTimeT(batch);
        this.plateauStructure.runInferenceTimeT();
        for (Variable var : this.plateauStructure.getEFLearningBNTimeT().getParametersVariables()) {
            EF_UnivariateDistribution uni = ((EF_UnivariateDistribution)this.plateauStructure.getEFParameterPosteriorTimeT(var)).deepCopy();
            this.plateauStructure.getEFLearningBNTimeT().setDistribution(var, uni);
            this.plateauStructure.getNodeOfVarTimeT(var, 0).setPDist(uni);
        }
        return this.plateauStructure.getLogProbabilityOfEvidenceTimeT();
    }

    @Override
    public void setDynamicDAG(DynamicDAG dag) {
        this.dag = dag;
    }

    @Override
    public void initLearning() {
        this.plateauStructure.setSeed(this.seed);
        this.plateauStructure.setDBNModel(this.dag);
        this.plateauStructure.resetQs();
        this.ef_extendedBNTime0 = this.plateauStructure.getEFLearningBNTime0();
        this.ef_extendedBNTimeT = this.plateauStructure.getEFLearningBNTimeT();
    }

    @Override
    public void setDataStream(DataStream<DynamicDataInstance> data) {
        this.dataStream = data;
    }

    @Override
    public DynamicBayesianNetwork getLearntDBN() {
        return new DynamicBayesianNetwork(this.dag, this.ef_extendedBNTime0.toConditionalDistribution(), this.ef_extendedBNTimeT.toConditionalDistribution());
    }
}

