/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.exponentialfamily;

import eu.amidst.core.distribution.Multinomial;
import eu.amidst.core.exponentialfamily.EF_ConditionalDistribution;
import eu.amidst.core.exponentialfamily.EF_Dirichlet;
import eu.amidst.core.exponentialfamily.EF_Multinomial_Dirichlet;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.exponentialfamily.ParameterVariables;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.utils.ArrayVector;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.DistributionType;
import eu.amidst.core.variables.Variable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class EF_Multinomial
extends EF_UnivariateDistribution {
    static double THRESHOLD = 1.0E-10;

    public EF_Multinomial(Variable var) {
        if (!var.isMultinomial() && !var.isIndicator()) {
            throw new UnsupportedOperationException("Creating a Multinomial EF distribution for a non-multinomial variable.");
        }
        this.parents = new ArrayList();
        this.var = var;
        int nstates = var.getNumberOfStates();
        this.naturalParameters = this.createZeroNaturalParameters();
        this.momentParameters = this.createZeroMomentParameters();
        for (int i = 0; i < nstates; ++i) {
            this.naturalParameters.set(i, -Math.log(nstates));
            this.momentParameters.set(i, 1.0 / (double)nstates);
        }
    }

    @Override
    public double computeLogBaseMeasure(double val) {
        return 0.0;
    }

    @Override
    public double computeLogNormalizer() {
        double sum = 0.0;
        for (int i = 0; i < this.naturalParameters.size(); ++i) {
            sum += Math.exp(this.naturalParameters.get(i));
        }
        return Math.log(sum);
    }

    @Override
    public Vector createZeroVector() {
        return new ArrayVector(this.var.getNumberOfStates());
    }

    @Override
    public SufficientStatistics getSufficientStatistics(double val) {
        SufficientStatistics vec = this.createZeroSufficientStatistics();
        vec.set((int)val, 1.0);
        return vec;
    }

    @Override
    public Vector getExpectedParameters() {
        return this.momentParameters;
    }

    @Override
    public double computeLogProbabilityOf(double val) {
        return this.naturalParameters.dotProduct(this.getSufficientStatistics(val)) + this.computeLogBaseMeasure(val) - this.computeLogNormalizer();
    }

    @Override
    public void updateNaturalFromMomentParameters() {
        int nstates = this.var.getNumberOfStates();
        for (int i = 0; i < nstates; ++i) {
            if (this.momentParameters.get(i) == 0.0) {
                this.naturalParameters.set(i, Math.log(THRESHOLD));
                continue;
            }
            if (this.momentParameters.get(i) == 1.0) {
                this.naturalParameters.set(i, Math.log(1.0 - THRESHOLD));
                continue;
            }
            this.naturalParameters.set(i, Math.log(this.momentParameters.get(i)));
        }
    }

    @Override
    public void fixNumericalInstability() {
        this.naturalParameters = Utils.logNormalize(this.naturalParameters);
    }

    @Override
    public void updateMomentFromNaturalParameters() {
        int nstates = this.var.getNumberOfStates();
        for (int i = 0; i < nstates; ++i) {
            this.momentParameters.set(i, Math.exp(this.naturalParameters.get(i)));
        }
        Utils.normalize(this.momentParameters);
    }

    @Override
    public int sizeOfSufficientStatistics() {
        return this.var.getNumberOfStates();
    }

    @Override
    public EF_UnivariateDistribution deepCopy(Variable var) {
        EF_Multinomial copy = new EF_Multinomial(var);
        copy.getNaturalParameters().copy(this.getNaturalParameters());
        copy.getMomentParameters().copy(this.getMomentParameters());
        return copy;
    }

    @Override
    public EF_UnivariateDistribution randomInitialization(Random random) {
        int i;
        double[] probabilities = new double[this.var.getNumberOfStates()];
        for (i = 0; i < probabilities.length; ++i) {
            probabilities[i] = random.nextDouble();
        }
        probabilities = Utils.normalize(probabilities);
        for (i = 0; i < probabilities.length; ++i) {
            this.getMomentParameters().set(i, probabilities[i]);
        }
        this.updateNaturalFromMomentParameters();
        return this;
    }

    public Multinomial toUnivariateDistribution() {
        Multinomial multinomial = new Multinomial(this.getVariable());
        for (int i = 0; i < multinomial.getVariable().getNumberOfStates(); ++i) {
            multinomial.setProbabilityOfState(i, this.getMomentParameters().get(i));
        }
        return multinomial;
    }

    @Override
    public List<EF_ConditionalDistribution> toExtendedLearningDistribution(ParameterVariables variables, String nameSuffix) {
        Variable varDirichlet = variables.newDirichletParameter(this.var.getName() + "_DirichletParameter_" + nameSuffix + "_" + variables.getNumberOfVars(), this.var.getNumberOfStates());
        EF_Dirichlet uni = (EF_Dirichlet)((DistributionType)varDirichlet.getDistributionType()).newEFUnivariateDistribution();
        return Arrays.asList(new EF_Multinomial_Dirichlet(this.var, varDirichlet), uni);
    }

    @Override
    public SufficientStatistics createInitSufficientStatistics() {
        ArrayVector vector = new ArrayVector(this.sizeOfSufficientStatistics());
        double nstates = this.sizeOfSufficientStatistics();
        int i = 0;
        while ((double)i < nstates) {
            vector.set(i, 1.0 / nstates);
            ++i;
        }
        return vector;
    }
}

