/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.exponentialfamily;

import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.utils.ArrayVector;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.Variable;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.math3.special.Gamma;

public class EF_Dirichlet
extends EF_UnivariateDistribution {
    int nOfStates;

    public EF_Dirichlet(Variable var1) {
        if (!var1.isDirichletParameter()) {
            throw new IllegalArgumentException("Non Dirichlet var");
        }
        this.var = var1;
        this.nOfStates = this.var.getNumberOfStates();
        this.naturalParameters = this.createZeroNaturalParameters();
        this.momentParameters = this.createZeroMomentParameters();
        this.parents = new ArrayList();
        for (int i = 0; i < this.nOfStates; ++i) {
            this.naturalParameters.set(i, 1.0);
        }
        this.fixNumericalInstability();
        this.updateMomentFromNaturalParameters();
    }

    public EF_Dirichlet(Variable var1, double scale) {
        if (!var1.isDirichletParameter()) {
            throw new IllegalArgumentException("The variable is not a Dirichlet parameter!");
        }
        this.var = var1;
        this.nOfStates = this.var.getNumberOfStates();
        this.naturalParameters = this.createZeroNaturalParameters();
        this.momentParameters = this.createZeroMomentParameters();
        this.parents = new ArrayList();
        for (int i = 0; i < this.nOfStates; ++i) {
            this.naturalParameters.set(i, scale - 1.0);
        }
        this.fixNumericalInstability();
        this.updateMomentFromNaturalParameters();
    }

    @Override
    public double computeLogBaseMeasure(double val) {
        return 0.0;
    }

    @Override
    public SufficientStatistics getSufficientStatistics(double val) {
        SufficientStatistics vec = this.createZeroSufficientStatistics();
        vec.set((int)val, Math.log(val));
        return vec;
    }

    @Override
    public Vector getExpectedParameters() {
        double sum = 0.0;
        for (int i = 0; i < this.nOfStates; ++i) {
            sum += this.naturalParameters.get(i) + 1.0;
        }
        ArrayVector vector = new ArrayVector(this.nOfStates);
        for (int i = 0; i < this.nOfStates; ++i) {
            vector.set(i, (this.naturalParameters.get(i) + 1.0) / sum);
        }
        return vector;
    }

    @Override
    public void fixNumericalInstability() {
    }

    @Override
    public EF_UnivariateDistribution deepCopy(Variable var) {
        EF_Dirichlet copy = new EF_Dirichlet(var);
        copy.getNaturalParameters().copy(this.getNaturalParameters());
        copy.getMomentParameters().copy(this.getMomentParameters());
        return copy;
    }

    @Override
    public EF_UnivariateDistribution randomInitialization(Random random) {
        for (int i = 0; i < this.nOfStates; ++i) {
            this.getNaturalParameters().set(i, 5.0 * random.nextDouble() + 1.0 + 1.0E-5);
        }
        this.fixNumericalInstability();
        this.updateMomentFromNaturalParameters();
        return this;
    }

    @Override
    public <E extends UnivariateDistribution> E toUnivariateDistribution() {
        throw new UnsupportedOperationException("Dirichlet is not included yet in the Distributions package.");
    }

    @Override
    public void updateNaturalFromMomentParameters() {
        throw new UnsupportedOperationException("No Implemented. EF_Dirichlet distribution should (right now) only be used for learning.");
    }

    @Override
    public void updateMomentFromNaturalParameters() {
        int i;
        double sumOfU_i = 0.0;
        for (i = 0; i < this.nOfStates; ++i) {
            sumOfU_i += this.naturalParameters.get(i);
        }
        for (i = 0; i < this.nOfStates; ++i) {
            this.momentParameters.set(i, Gamma.digamma(this.naturalParameters.get(i)) - Gamma.digamma(sumOfU_i));
        }
    }

    @Override
    public int sizeOfSufficientStatistics() {
        return this.nOfStates;
    }

    @Override
    public double computeLogNormalizer() {
        double sumOfU_i = 0.0;
        double sumLogGammaOfU_i = 0.0;
        for (int i = 0; i < this.nOfStates; ++i) {
            sumOfU_i += this.naturalParameters.get(i);
            sumLogGammaOfU_i += Gamma.logGamma(this.naturalParameters.get(i));
        }
        return sumLogGammaOfU_i - Gamma.logGamma(sumOfU_i);
    }

    @Override
    public Vector createZeroVector() {
        return new ArrayVector(this.nOfStates);
    }

    @Override
    public SufficientStatistics createInitSufficientStatistics() {
        ArrayVector vector = new ArrayVector(this.sizeOfSufficientStatistics());
        for (int i = 0; i < this.sizeOfSufficientStatistics(); ++i) {
            vector.set(i, Gamma.digamma(1.0) - Gamma.digamma(this.sizeOfSufficientStatistics()));
        }
        return vector;
    }
}

