/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.inference;

import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.distribution.Distribution;
import eu.amidst.core.distribution.Multinomial;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_Distribution;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.inference.InferenceAlgorithm;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.ArrayVector;
import eu.amidst.core.utils.BayesianNetworkGenerator;
import eu.amidst.core.utils.LocalRandomGenerator;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class ImportanceSamplingRobust
implements InferenceAlgorithm,
Serializable {
    private static final long serialVersionUID = 8587756877237341367L;
    private BayesianNetwork model;
    private BayesianNetwork samplingModel;
    private boolean sameSamplingModel;
    private List<Variable> causalOrder;
    private int seed = 0;
    private int sampleSize = 1000;
    private Stream<WeightedAssignment> weightedSampleStream;
    private List<Variable> variablesAPosteriori;
    private List<SufficientStatistics> SSvariablesAPosteriori;
    private Assignment evidence;
    private double logProbOfEvidence;
    private boolean parallelMode = true;

    @Override
    public void setParallelMode(boolean parallelMode_) {
        this.parallelMode = parallelMode_;
    }

    @Override
    public void setSeed(int seed) {
        this.seed = seed;
    }

    @Override
    public void setModel(BayesianNetwork model_) {
        this.samplingModel = this.model = Serialization.deepCopy(model_);
        this.causalOrder = Utils.getTopologicalOrder(this.model.getDAG());
        this.sameSamplingModel = true;
        this.evidence = null;
        this.weightedSampleStream = null;
        this.variablesAPosteriori = this.model.getVariables().getListOfVariables();
        this.SSvariablesAPosteriori = new ArrayList<SufficientStatistics>(this.variablesAPosteriori.size());
        this.variablesAPosteriori.stream().forEachOrdered(variable -> {
            Object ef_univariateDistribution = ((UnivariateDistribution)variable.newUnivariateDistribution()).toEFUnivariateDistribution();
            ArrayVector arrayVector = new ArrayVector(((EF_Distribution)ef_univariateDistribution).sizeOfSufficientStatistics());
            this.SSvariablesAPosteriori.add(arrayVector);
        });
    }

    @Override
    public void setEvidence(Assignment evidence_) {
        this.evidence = evidence_;
        this.weightedSampleStream = null;
    }

    @Override
    public BayesianNetwork getOriginalModel() {
        return this.model;
    }

    public BayesianNetwork getSamplingModel() {
        return this.samplingModel;
    }

    @Override
    public double getLogProbabilityOfEvidence() {
        return this.logProbOfEvidence;
    }

    public void setSamplingModel(BayesianNetwork samplingModel_) {
        this.samplingModel = new BayesianNetwork(samplingModel_.getDAG(), Serialization.deepCopy(samplingModel_.getConditionalDistributions()));
        this.causalOrder = Utils.getTopologicalOrder(this.samplingModel.getDAG());
        this.sameSamplingModel = this.samplingModel.equalBNs(this.model, 1.0E-10);
    }

    public void setSampleSize(int sampleSize) {
        this.sampleSize = sampleSize;
    }

    public void setVariablesAPosteriori(List<Variable> variablesAPosterior) {
        this.variablesAPosteriori = variablesAPosterior;
        this.SSvariablesAPosteriori = new ArrayList<SufficientStatistics>();
        variablesAPosterior.stream().forEachOrdered(variable -> {
            Object ef_univariateDistribution = ((UnivariateDistribution)variable.newUnivariateDistribution()).toEFUnivariateDistribution();
            ArrayVector arrayVector = new ArrayVector(((EF_Distribution)ef_univariateDistribution).sizeOfSufficientStatistics());
            this.SSvariablesAPosteriori.add(arrayVector);
        });
    }

    private void updatePosteriorDistributions(Assignment sample, double logWeight) {
        int nVarsAPosteriori = this.variablesAPosteriori.size();
        IntStream.range(0, nVarsAPosteriori).forEach(i -> {
            ArrayVector SSsample;
            Variable variable = this.variablesAPosteriori.get(i);
            Object ef_univariateDistribution = ((UnivariateDistribution)variable.newUnivariateDistribution()).toEFUnivariateDistribution();
            ArrayVector SSposterior = new ArrayVector(((EF_Distribution)ef_univariateDistribution).sizeOfSufficientStatistics());
            SSposterior.copy(this.SSvariablesAPosteriori.get(i));
            if (variable.isMultinomial()) {
                SSsample = new ArrayVector(((EF_Distribution)ef_univariateDistribution).sizeOfSufficientStatistics());
                SSsample.copy(((EF_UnivariateDistribution)ef_univariateDistribution).getSufficientStatistics(sample));
                if (this.evidence != null) {
                    SSsample.multiplyBy(logWeight);
                }
            } else {
                throw new UnsupportedOperationException("ImportanceSamplingRobust.updatePosteriorDistributions() works only for multinomials");
            }
            ArrayVector newSSposterior = this.robustSumOfMultinomialSufficientStatistics(SSposterior, SSsample);
            this.SSvariablesAPosteriori.set(i, newSSposterior);
        });
    }

    private WeightedAssignment generateSampleSameModel(Random random) {
        HashMapAssignment sample = new HashMapAssignment(this.model.getNumberOfVars());
        double logWeight = 0.0;
        for (Variable samplingVar : this.causalOrder) {
            double simulatedValue;
            Object samplingDistribution = this.model.getConditionalDistribution(samplingVar);
            UnivariateDistribution univariateSamplingDistribution = ((ConditionalDistribution)samplingDistribution).getUnivariateDistribution(sample);
            if (this.evidence != null && !Double.isNaN(this.evidence.getValue(samplingVar))) {
                simulatedValue = this.evidence.getValue(samplingVar);
                logWeight += univariateSamplingDistribution.getLogProbability(simulatedValue);
            } else {
                simulatedValue = univariateSamplingDistribution.sample(random);
            }
            sample.setValue(samplingVar, simulatedValue);
        }
        return new WeightedAssignment(sample, logWeight);
    }

    private WeightedAssignment generateSample(Random random) {
        if (this.sameSamplingModel) {
            return this.generateSampleSameModel(random);
        }
        HashMapAssignment samplingAssignment = new HashMapAssignment(1);
        HashMapAssignment modelAssignment = new HashMapAssignment(1);
        double numerator = 0.0;
        double denominator = 0.0;
        for (Variable samplingVar : this.causalOrder) {
            double simulatedValue;
            Variable modelVar = this.model.getVariables().getVariableById(samplingVar.getVarID());
            if (this.evidence != null && !Double.isNaN(this.evidence.getValue(samplingVar))) {
                simulatedValue = this.evidence.getValue(samplingVar);
                UnivariateDistribution univariateModelDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(modelVar)).getUnivariateDistribution(modelAssignment);
                numerator += univariateModelDistribution.getLogProbability(simulatedValue);
            } else {
                Object samplingDistribution = this.samplingModel.getConditionalDistribution(samplingVar);
                UnivariateDistribution univariateSamplingDistribution = ((ConditionalDistribution)samplingDistribution).getUnivariateDistribution(samplingAssignment);
                simulatedValue = univariateSamplingDistribution.sample(random);
                denominator += univariateSamplingDistribution.getLogProbability(simulatedValue);
                UnivariateDistribution univariateModelDistribution = ((ConditionalDistribution)this.model.getConditionalDistribution(modelVar)).getUnivariateDistribution(modelAssignment);
                numerator += univariateModelDistribution.getLogProbability(simulatedValue);
            }
            modelAssignment.setValue(modelVar, simulatedValue);
            samplingAssignment.setValue(samplingVar, simulatedValue);
        }
        double logWeight = numerator - denominator;
        return new WeightedAssignment(samplingAssignment, logWeight);
    }

    @Override
    public double getExpectedValue(Variable var, Function<Double, Double> function) {
        if (this.parallelMode) {
            this.weightedSampleStream.parallel();
        }
        List sum = this.weightedSampleStream.map(ws -> Arrays.asList(Math.exp(((WeightedAssignment)ws).logWeight), Math.exp(((WeightedAssignment)ws).logWeight) * (Double)function.apply(((WeightedAssignment)ws).assignment.getValue(var)))).filter(array -> Double.isFinite((Double)array.get(0)) && Double.isFinite((Double)array.get(1))).reduce(Arrays.asList(new Double(0.0), new Double(0.0)), (e1, e2) -> Arrays.asList((Double)e1.get(0) + (Double)e2.get(0), (Double)e1.get(1) + (Double)e2.get(1)));
        return (Double)sum.get(1) / (Double)sum.get(0);
    }

    private double robustSumOfLogarithms(double log_x1, double log_x2) {
        double result;
        if (log_x1 != 0.0 && log_x2 != 0.0) {
            double aux_max = Math.max(log_x1, log_x2);
            double aux_min = Math.min(log_x1, log_x2);
            double aux = Math.exp(aux_min - aux_max);
            double tail = aux < 0.5 ? Math.log1p(aux) : Math.log(1.0 + aux);
            result = aux_max + (Double.isFinite(tail) ? tail : 0.0);
        } else {
            result = log_x1 == 0.0 ? log_x2 : log_x1;
        }
        return result;
    }

    private ArrayVector robustSumOfMultinomialSufficientStatistics(ArrayVector ss1, ArrayVector ss2) {
        double[] ss1_values = ss1.toArray();
        double[] ss2_values = ss2.toArray();
        double[] ss_result = new double[ss1_values.length];
        for (int i = 0; i < ss_result.length; ++i) {
            double log_a = ss1_values[i];
            double log_b = ss2_values[i];
            ss_result[i] = this.robustSumOfLogarithms(log_a, log_b);
        }
        return new ArrayVector(ss_result);
    }

    private ArrayVector robustSumOfNormalSufficientStatistics(ArrayVector ss1, ArrayVector ss2) {
        return new ArrayVector(new double[]{ss1.get(0)});
    }

    private ArrayVector robustNormalizationOfLogProbabilitiesVector(ArrayVector ss) {
        double log_sumProbabilities = 0.0;
        double[] logProbabilities = ss.toArray();
        for (int i = 0; i < logProbabilities.length; ++i) {
            double log_b = logProbabilities[i];
            log_sumProbabilities = this.robustSumOfLogarithms(log_sumProbabilities, log_b);
        }
        double[] normalizedProbabilities = new double[logProbabilities.length];
        for (int i = 0; i < logProbabilities.length; ++i) {
            double result = Math.exp(logProbabilities[i] - log_sumProbabilities);
            normalizedProbabilities[i] = 0.0 <= result && result <= 1.0 ? result : 0.0;
        }
        return new ArrayVector(normalizedProbabilities);
    }

    @Override
    public <E extends UnivariateDistribution> E getPosterior(Variable variable) {
        Object ef_univariateDistribution = ((UnivariateDistribution)variable.newUnivariateDistribution()).toEFUnivariateDistribution();
        if (!variable.isMultinomial()) {
            throw new UnsupportedOperationException("ImportanceSamplingRobust.getPosterior() not supported yet for non-multinomial distributions");
        }
        ArrayVector sumSS = new ArrayVector(((EF_Distribution)ef_univariateDistribution).sizeOfSufficientStatistics());
        sumSS.copy(this.SSvariablesAPosteriori.get(this.variablesAPosteriori.indexOf(variable)));
        sumSS = this.robustNormalizationOfLogProbabilitiesVector(sumSS);
        ((EF_Distribution)ef_univariateDistribution).setMomentParameters(sumSS);
        Multinomial posteriorDistribution = (Multinomial)((EF_UnivariateDistribution)ef_univariateDistribution).toUnivariateDistribution();
        posteriorDistribution.setProbabilities(Utils.normalize(posteriorDistribution.getParameters()));
        return (E)posteriorDistribution;
    }

    @Override
    public void runInference() {
        LocalRandomGenerator randomGenerator = new LocalRandomGenerator(this.seed);
        IntStream weightedSampleStream = IntStream.range(0, this.sampleSize).parallel();
        if (!this.parallelMode) {
            weightedSampleStream = weightedSampleStream.sequential();
        }
        double logSumWeights = weightedSampleStream.mapToDouble(i -> {
            WeightedAssignment weightedSample = this.generateSample(randomGenerator.current());
            this.updatePosteriorDistributions(weightedSample.assignment, weightedSample.logWeight);
            return weightedSample.logWeight;
        }).reduce(this::robustSumOfLogarithms).getAsDouble();
        this.logProbOfEvidence = this.evidence != null ? logSumWeights - Math.log(this.sampleSize) : 0.0;
    }

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        BayesianNetworkGenerator.setNumberOfGaussianVars(0);
        BayesianNetworkGenerator.setNumberOfMultinomialVars(60, 2);
        BayesianNetworkGenerator.setNumberOfLinks(100);
        BayesianNetworkGenerator.setSeed(1);
        BayesianNetwork bn = BayesianNetworkGenerator.generateBayesianNetwork();
        System.out.println(bn);
        ImportanceSamplingRobust importanceSampling = new ImportanceSamplingRobust();
        importanceSampling.setModel(bn);
        importanceSampling.setParallelMode(true);
        importanceSampling.setSampleSize(5000);
        importanceSampling.setSeed(57457);
        List<Variable> causalOrder = importanceSampling.causalOrder;
        Variable varPosterior = causalOrder.get(0);
        ArrayList<Variable> variablesPosteriori = new ArrayList<Variable>(1);
        variablesPosteriori.add(varPosterior);
        importanceSampling.setVariablesAPosteriori(variablesPosteriori);
        importanceSampling.runInference();
        System.out.println("Posterior (IS) of " + varPosterior.getName() + ":" + ((Distribution)importanceSampling.getPosterior(varPosterior)).toString());
        System.out.println(((Distribution)bn.getConditionalDistribution(varPosterior)).toString());
        System.out.println("Log-Prob. of Evidence: " + importanceSampling.getLogProbabilityOfEvidence());
        Variable variableEvidence = causalOrder.get(1);
        int varEvidenceValue = 0;
        System.out.println("Evidence: Variable " + variableEvidence.getName() + " = " + varEvidenceValue);
        System.out.println();
        HashMapAssignment assignment = new HashMapAssignment(1);
        assignment.setValue(variableEvidence, varEvidenceValue);
        importanceSampling.setEvidence(assignment);
        long time_start = System.nanoTime();
        importanceSampling.runInference();
        long time_end = System.nanoTime();
        double execution_time = ((double)time_end - (double)time_start) / 1.0E9;
        System.out.println("Execution time: " + execution_time + " s");
        System.out.println("Posterior of " + varPosterior.getName() + " (IS with Evidence) :" + ((Distribution)importanceSampling.getPosterior(varPosterior)).toString());
        System.out.println("Log-Prob. of Evidence: " + importanceSampling.getLogProbabilityOfEvidence());
        System.out.println("Prob of Evidence: " + Math.exp(importanceSampling.getLogProbabilityOfEvidence()));
    }

    private class WeightedAssignment {
        private HashMapAssignment assignment;
        private double logWeight;

        public WeightedAssignment(HashMapAssignment assignment_, double weight_) {
            this.assignment = assignment_;
            this.logWeight = weight_;
        }

        public String toString() {
            StringBuilder str = new StringBuilder();
            str.append("[ ");
            for (Map.Entry<Variable, Double> entry : this.assignment.entrySet()) {
                str.append(entry.getKey().getName() + " = " + entry.getValue());
                str.append(", ");
            }
            str.append("Weight = " + this.logWeight + " ]");
            return str.toString();
        }
    }
}

