/*
 * Decompiled with CFR 0.152.
 */
package umontreal.iro.lecuyer.probdistmulti;

import umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti;
import umontreal.iro.lecuyer.util.Num;

public class MultinomialDist
extends DiscreteDistributionIntMulti {
    protected int n;
    protected double[] p;

    public MultinomialDist(int n, double[] p) {
        this.setParams(n, p);
    }

    public double prob(int[] x) {
        return MultinomialDist.prob_(this.n, this.p, x);
    }

    public double cdf(int[] x) {
        return MultinomialDist.cdf_(this.n, this.p, x);
    }

    public double[] getMean() {
        return MultinomialDist.getMean_(this.n, this.p);
    }

    public double[][] getCovariance() {
        return MultinomialDist.getCovariance_(this.n, this.p);
    }

    public double[][] getCorrelation() {
        return MultinomialDist.getCorrelation_(this.n, this.p);
    }

    private static void verifParam(int n, double[] p) {
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        double sumPi = 0.0;
        for (int i = 0; i < p.length; ++i) {
            if (p[i] < 0.0 || p[i] > 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            sumPi += p[i];
        }
        if (sumPi != 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static double prob_(int n, double[] p, int[] x) {
        if (x.length != p.length) {
            throw new IllegalArgumentException("x and p must have the same dimension");
        }
        double sumXFact = 0.0;
        int sumX = 0;
        double sumPX = 0.0;
        for (int i = 0; i < p.length; ++i) {
            sumX += x[i];
            sumXFact += Num.lnFactorial(x[i]);
            sumPX += (double)x[i] * Math.log(p[i]);
        }
        if (sumX != n) {
            return 0.0;
        }
        return Math.exp(Num.lnFactorial(n) - sumXFact + sumPX);
    }

    public static double prob(int n, double[] p, int[] x) {
        MultinomialDist.verifParam(n, p);
        return MultinomialDist.prob_(n, p, x);
    }

    private static double cdf_(int n, double[] p, int[] x) {
        boolean end = false;
        double sum = 0.0;
        if (x.length != p.length) {
            throw new IllegalArgumentException("x and p must have the same dimension");
        }
        int[] is = new int[x.length];
        for (int i = 0; i < is.length; ++i) {
            is[i] = 0;
        }
        sum = 0.0;
        while (!end) {
            sum += MultinomialDist.prob(n, p, is);
            is[0] = is[0] + 1;
            if (is[0] <= x[0]) continue;
            is[0] = 0;
            int j = 1;
            while (j < x.length && is[j] == x[j]) {
                is[j++] = 0;
            }
            if (j == x.length) {
                end = true;
                continue;
            }
            int n2 = j;
            is[n2] = is[n2] + 1;
        }
        return sum;
    }

    public static double cdf(int n, double[] p, int[] x) {
        MultinomialDist.verifParam(n, p);
        return MultinomialDist.cdf_(n, p, x);
    }

    private static double[] getMean_(int n, double[] p) {
        double[] mean = new double[p.length];
        for (int i = 0; i < p.length; ++i) {
            mean[i] = (double)n * p[i];
        }
        return mean;
    }

    public static double[] getMean(int n, double[] p) {
        MultinomialDist.verifParam(n, p);
        return MultinomialDist.getMean_(n, p);
    }

    private static double[][] getCovariance_(int n, double[] p) {
        double[][] cov = new double[p.length][p.length];
        for (int i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                cov[i][j] = (double)(-n) * p[i] * p[j];
            }
            cov[i][i] = (double)n * p[i] * (1.0 - p[i]);
        }
        return cov;
    }

    public static double[][] getCovariance(int n, double[] p) {
        MultinomialDist.verifParam(n, p);
        return MultinomialDist.getCovariance_(n, p);
    }

    private static double[][] getCorrelation_(int n, double[] p) {
        double[][] corr = new double[p.length][p.length];
        for (int i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                corr[i][j] = -Math.sqrt(p[i] * p[j] / ((1.0 - p[i]) * (1.0 - p[j])));
            }
            corr[i][i] = 1.0;
        }
        return corr;
    }

    public static double[][] getCorrelation(int n, double[] p) {
        MultinomialDist.verifParam(n, p);
        return MultinomialDist.getCorrelation_(n, p);
    }

    public static double[] getMLE(int[][] x, int m, int d, int n) {
        int i;
        double[] parameters = new double[d];
        double[] xBar = new double[d];
        double N = 0.0;
        if (m <= 0) {
            throw new IllegalArgumentException("m <= 0");
        }
        if (d <= 0) {
            throw new IllegalArgumentException("d <= 0");
        }
        for (i = 0; i < d; ++i) {
            xBar[i] = 0.0;
        }
        for (int v = 0; v < m; ++v) {
            for (int c = 0; c < d; ++c) {
                int n2 = c;
                xBar[n2] = xBar[n2] + (double)x[v][c];
            }
        }
        for (i = 0; i < d; ++i) {
            xBar[i] = xBar[i] / (double)n;
            N += xBar[i];
        }
        if (N != (double)n) {
            throw new IllegalArgumentException("n is not correct");
        }
        for (i = 0; i < d; ++i) {
            parameters[i] = xBar[i] / (double)n;
        }
        return parameters;
    }

    public int getN() {
        return this.n;
    }

    public double[] getP() {
        return this.p;
    }

    public void setParams(int n, double[] p) {
        double sumP = 0.0;
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        if (p.length < 2) {
            throw new IllegalArgumentException("p.length < 2");
        }
        this.n = n;
        this.dimension = p.length;
        this.p = new double[this.dimension];
        for (int i = 0; i < this.dimension; ++i) {
            if (p[i] < 0.0 || p[i] > 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            this.p[i] = p[i];
            sumP += p[i];
        }
        if (sumP != 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }
}

