/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.utils;

import com.google.common.base.Stopwatch;
import eu.amidst.core.datastream.Attribute;
import eu.amidst.core.datastream.Attributes;
import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.io.BayesianNetworkLoader;
import eu.amidst.core.io.DataStreamWriter;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.AmidstOptionsHandler;
import eu.amidst.core.utils.LocalRandomGenerator;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.StateSpaceType;
import eu.amidst.core.variables.Variable;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class BayesianNetworkSampler
implements AmidstOptionsHandler,
Serializable {
    private static final long serialVersionUID = 4107783324901370839L;
    private BayesianNetwork network;
    private List<Variable> causalOrder;
    private int seed = 0;
    private Random random = new Random(this.seed);
    private Map<Variable, Boolean> hiddenVars = new HashMap<Variable, Boolean>();
    private Map<Variable, Double> marNoise = new HashMap<Variable, Double>();

    public BayesianNetworkSampler(BayesianNetwork network1) {
        this.network = network1;
        this.causalOrder = Utils.getTopologicalOrder(this.network.getDAG());
    }

    private Stream<Assignment> getSampleStream(int nSamples) {
        LocalRandomGenerator randomGenerator = new LocalRandomGenerator(this.seed);
        return IntStream.range(0, nSamples).mapToObj(i -> BayesianNetworkSampler.sample(this.network, this.causalOrder, randomGenerator.current())).map(this::filter);
    }

    public void setHiddenVar(Variable var) {
        this.hiddenVars.put(var, true);
    }

    public void setMARVar(Variable var, double noiseProb) {
        this.marNoise.put(var, noiseProb);
    }

    private Assignment filter(Assignment assignment) {
        this.hiddenVars.keySet().stream().forEach(var -> assignment.setValue((Variable)var, Utils.missingValue()));
        this.marNoise.entrySet().forEach(e -> {
            if (this.random.nextDouble() < (Double)e.getValue()) {
                assignment.setValue((Variable)e.getKey(), Utils.missingValue());
            }
        });
        return assignment;
    }

    private List<Assignment> getSampleList(int nSamples) {
        return this.getSampleStream(nSamples).collect(Collectors.toList());
    }

    private Iterable<Assignment> getSampleIterator(final int nSamples) {
        class I
        implements Iterable<Assignment> {
            I() {
            }

            @Override
            public Iterator<Assignment> iterator() {
                return BayesianNetworkSampler.this.getSampleStream(nSamples).iterator();
            }
        }
        return new I();
    }

    public void setSeed(int seed) {
        this.seed = seed;
        this.random = new Random(seed);
    }

    public DataStream<DataInstance> sampleToDataStream(int nSamples) {
        this.random = new Random(this.seed);
        class TemporalDataStream
        implements DataStream<DataInstance>,
        Serializable {
            private static final long serialVersionUID = -3436599636425587512L;
            Attributes atts;
            BayesianNetworkSampler sampler;
            int nSamples;

            TemporalDataStream(BayesianNetworkSampler sampler1, int nSamples1) {
                this.sampler = sampler1;
                this.nSamples = nSamples1;
                List<Attribute> list = this.sampler.network.getVariables().getListOfVariables().stream().map((? super T var) -> new Attribute(var.getVarID(), var.getName(), (StateSpaceType)var.getStateSpaceType())).collect(Collectors.toList());
                this.atts = new Attributes(list);
            }

            @Override
            public Attributes getAttributes() {
                return this.atts;
            }

            @Override
            public Stream<DataInstance> stream() {
                return this.sampler.getSampleStream(this.nSamples).map((? super T a) -> {
                    class TemporalDataInstance
                    implements DataInstance,
                    Serializable {
                        private static final long serialVersionUID = -3436599636425587512L;
                        Assignment assignment;
                        Attributes attributes;
                        List<Variable> variables;

                        TemporalDataInstance(Assignment assignment1, Attributes atts) {
                            this.assignment = assignment1;
                            this.attributes = atts;
                            this.variables = sampler.network.getVariables().getListOfVariables();
                        }

                        @Override
                        public double getValue(Variable var) {
                            return this.assignment.getValue(var);
                        }

                        @Override
                        public void setValue(Variable var, double value) {
                            this.assignment.setValue(var, value);
                        }

                        @Override
                        public Attributes getAttributes() {
                            return this.attributes;
                        }

                        @Override
                        public Set<Variable> getVariables() {
                            return this.assignment.getVariables();
                        }

                        @Override
                        public double getValue(Attribute att) {
                            return this.assignment.getValue(this.variables.get(att.getIndex()));
                        }

                        @Override
                        public void setValue(Attribute att, double value) {
                            if (!att.isSpecialAttribute()) {
                                this.assignment.setValue(this.variables.get(att.getIndex()), value);
                            }
                        }

                        @Override
                        public double[] toArray() {
                            int numAtts = this.attributes.getNumberOfAttributes();
                            double[] values = new double[numAtts];
                            for (int att = 0; att < numAtts; ++att) {
                                values[att] = this.getValue(this.attributes.getFullListOfAttributes().get(att));
                            }
                            return values;
                        }

                        public String toString() {
                            return this.outputString();
                        }
                    }
                    return new TemporalDataInstance((Assignment)a, this.atts);
                });
            }

            @Override
            public void close() {
            }

            @Override
            public boolean isRestartable() {
                return false;
            }

            @Override
            public void restart() {
            }
        }
        return new TemporalDataStream(this, nSamples);
    }

    private static Assignment sample(BayesianNetwork network, List<Variable> causalOrder, Random random) {
        HashMapAssignment assignment = new HashMapAssignment(network.getNumberOfVars());
        for (Variable var : causalOrder) {
            double sampledValue = ((ConditionalDistribution)network.getConditionalDistribution(var)).getUnivariateDistribution(assignment).sample(random);
            assignment.setValue(var, sampledValue);
        }
        return assignment;
    }

    @Override
    public String listOptions() {
        return this.classNameID() + ",\\" + "-seed, 0, seed for random number generator\\";
    }

    @Override
    public String listOptionsRecursively() {
        BayesianNetworkSampler bayesianNetworkSampler = this;
        return this.listOptions() + "\n" + bayesianNetworkSampler.network.listOptionsRecursively();
    }

    @Override
    public void loadOptions() {
        this.seed = this.getIntOption("-seed");
    }

    @Override
    public String classNameID() {
        return "BayesianNetworkSampler";
    }

    public static void main(String[] args) throws Exception {
        Stopwatch watch = Stopwatch.createStarted();
        BayesianNetwork network = BayesianNetworkLoader.loadFromFile("./networks/dataWeka/asia.bn");
        BayesianNetworkSampler sampler = new BayesianNetworkSampler(network);
        sampler.setSeed(0);
        DataStream<DataInstance> dataStream = sampler.sampleToDataStream(100);
        DataStreamWriter.writeDataToFile(dataStream, "./datasets/simulated/asisa-samples.arff");
        System.out.println(watch.stop());
        for (Assignment assignment : sampler.getSampleIterator(10)) {
            System.out.println(assignment.outputString());
        }
        System.out.println();
        for (Assignment assignment : sampler.getSampleList(2)) {
            System.out.println(assignment.outputString());
        }
        System.out.println();
        sampler.getSampleStream(2).forEach(e -> System.out.println(e.outputString()));
    }
}

