/*
 * Decompiled with CFR 0.152.
 */
package gov.lanl.yadas;

import gov.lanl.yadas.TunablePerturber;
import java.util.Random;

public class GroupedSumToOnePerturber
implements TunablePerturber {
    double[] mss;
    double adj = 1.0;
    double scale;
    private int[] labels;
    private int N;
    private int whichlabel;
    private int[][] labelmat;
    static Random rand = new Random(System.currentTimeMillis());

    public GroupedSumToOnePerturber(double[] mss, int[] labels) {
        this.mss = mss;
        this.labels = labels;
        this.N = 1 + GroupedSumToOnePerturber.max(labels);
        this.labelmat = new int[this.N][];
        for (int i = 0; i < this.N; ++i) {
            int j;
            int ct = 0;
            for (j = 0; j < labels.length; ++j) {
                if (labels[j] != i) continue;
                ++ct;
            }
            this.labelmat[i] = new int[ct];
            for (j = 0; j < labels.length; ++j) {
                if (labels[j] != i) continue;
                this.labelmat[i][this.labelmat[i].length - ct--] = j;
            }
        }
    }

    public void perturb(double[][] candarray, int whoseTurn) {
        this.whichlabel = this.labels[whoseTurn];
        double oldone = candarray[0][whoseTurn];
        double newone = 2.0;
        this.scale = Math.exp(this.mss[whoseTurn] * rand.nextGaussian());
        candarray[0][whoseTurn] = newone = 1.0 / (1.0 + 1.0 / this.scale * (1.0 - candarray[0][whoseTurn]) / candarray[0][whoseTurn]);
        for (int k = 0; k < this.labelmat[this.whichlabel].length; ++k) {
            int kk = this.labelmat[this.whichlabel][k];
            if (kk == whoseTurn) continue;
            candarray[0][kk] = candarray[0][kk] * (1.0 - newone) / (1.0 - oldone);
        }
        double summ = GroupedSumToOnePerturber.somesum(candarray[0], this.labelmat[this.whichlabel]);
        for (int j = 0; j < this.labelmat[this.whichlabel].length; ++j) {
            double[] dArray = candarray[0];
            int n = this.labelmat[this.whichlabel][j];
            dArray[n] = dArray[n] / summ;
        }
        int jj = whoseTurn;
        this.adj = candarray[0][jj] / oldone * Math.pow((1.0 - candarray[0][jj]) / (1.0 - oldone), (double)this.labelmat[this.whichlabel].length - 1.0);
    }

    public static double sum(double[] arr) {
        double out = 0.0;
        for (int i = 0; i < arr.length; ++i) {
            out += arr[i];
        }
        return out;
    }

    public static double somesum(double[] arr, int[] subset) {
        double out = 0.0;
        for (int i = 0; i < subset.length; ++i) {
            out += arr[subset[i]];
        }
        return out;
    }

    public static int max(int[] vec) {
        int out = vec[0];
        for (int i = 1; i < vec.length; ++i) {
            out = Math.max(out, vec[i]);
        }
        return out;
    }

    public int numTurns() {
        return this.mss.length;
    }

    public double jacobian() {
        return this.adj;
    }

    public double[] getStepSizes() {
        double[] tem = new double[this.mss.length];
        System.arraycopy(this.mss, 0, tem, 0, this.mss.length);
        return tem;
    }

    public void setStepSize(double s, int i) {
        this.mss[i] = s;
    }

    public void setStepSizes(double[] s) {
        System.arraycopy(s, 0, this.mss, 0, s.length);
    }
}

