/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

public class KLGain
extends RankedFeatureVector {
    private static Logger logger = MalletLogger.getLogger(KLGain.class.getName());

    private static double[] calcKLGains(InstanceList ilist, LabelVector[] classifications) {
        int fli;
        int fl;
        double modelLabelWeight;
        int i;
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        assert (ilist.size() > 0);
        double[][] p = new double[numClasses][numFeatures];
        double[][] q = new double[numClasses][numFeatures];
        double[][] alphas = new double[numClasses][numFeatures];
        logger.info("Starting klgains, #instances=" + numInstances);
        double trueLabelWeightSum = 0.0;
        double modelLabelWeightSum = 0.0;
        boolean doingSmoothing = true;
        double numInExpectation = (double)numInstances + 1.0;
        for (i = 0; i < numClasses; ++i) {
            for (int j = 0; j < numFeatures; ++j) {
                double d = 1.0 / (numInExpectation * (double)numFeatures * (double)numClasses);
                q[i][j] = d;
                p[i][j] = d;
                trueLabelWeightSum += p[i][j];
                modelLabelWeightSum += q[i][j];
            }
        }
        for (i = 0; i < numInstances; ++i) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            for (int li = 0; li < numClasses; ++li) {
                double trueLabelWeight = labeling.value(li) / numInExpectation;
                modelLabelWeight = classifications[i].value(li) / numInExpectation;
                trueLabelWeightSum += trueLabelWeight;
                modelLabelWeightSum += modelLabelWeight;
                if (trueLabelWeight == 0.0 && modelLabelWeight == 0.0) continue;
                for (fl = 0; fl < fv.numLocations(); ++fl) {
                    fli = fv.indexAtLocation(fl);
                    assert (fv.valueAtLocation(fl) == 1.0);
                    double[] dArray = p[li];
                    int n = fli;
                    dArray[n] = dArray[n] + trueLabelWeight;
                    double[] dArray2 = q[li];
                    int n2 = fli;
                    dArray2[n2] = dArray2[n2] + modelLabelWeight;
                }
            }
        }
        assert (Math.abs(trueLabelWeightSum - 1.0) < 0.001) : "trueLabelWeightSum should be 1.0, it was " + trueLabelWeightSum;
        assert (Math.abs(modelLabelWeightSum - 1.0) < 0.001) : "modelLabelWeightSum should be 1.0, it was " + modelLabelWeightSum;
        for (i = 0; i < numClasses; ++i) {
            for (int j = 0; j < numFeatures; ++j) {
                alphas[i][j] = Math.log(p[i][j] * (1.0 - q[i][j]) / (q[i][j] * (1.0 - p[i][j])));
            }
        }
        double[][] qeag = new double[numClasses][numFeatures];
        modelLabelWeightSum = 0.0;
        for (int i2 = 0; i2 < ilist.size(); ++i2) {
            assert (classifications[i2].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i2);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            int fvMaxLocation = fv.numLocations() - 1;
            for (int li = 0; li < numClasses; ++li) {
                modelLabelWeight = classifications[i2].value(li) / (double)numInstances;
                modelLabelWeightSum += modelLabelWeight;
                for (fl = 0; fl < fv.numLocations(); ++fl) {
                    fli = fv.indexAtLocation(fl);
                    double[] dArray = qeag[li];
                    int n = fli;
                    dArray[n] = dArray[n] + (modelLabelWeight * Math.exp(alphas[li][fli]) - modelLabelWeight);
                }
            }
        }
        for (int li = 0; li < numClasses; ++li) {
            int fi = 0;
            while (fi < numFeatures) {
                double[] dArray = qeag[li];
                int n = fi++;
                dArray[n] = dArray[n] + modelLabelWeightSum;
            }
        }
        double[] klgains = new double[numFeatures];
        for (int i3 = 0; i3 < numClasses; ++i3) {
            for (int j = 0; j < numFeatures; ++j) {
                if (!(alphas[i3][j] > 0.0) || Double.isInfinite(alphas[i3][j])) continue;
                int n = j;
                klgains[n] = klgains[n] + (alphas[i3][j] * p[i3][j] - Math.log(qeag[i3][j]));
            }
        }
        logger.info("klgains.length=" + klgains.length);
        for (int j = 0; j < numFeatures; ++j) {
            if (j % (numFeatures / 100) != 0) continue;
            for (int i4 = 0; i4 < numClasses; ++i4) {
                logger.info("c=" + i4 + " p[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + p[i4][j]);
                logger.info("c=" + i4 + " q[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + q[i4][j]);
                logger.info("c=" + i4 + " alphas[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + alphas[i4][j]);
                logger.info("c=" + i4 + " qeag[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + qeag[i4][j]);
            }
            logger.info("klgains[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + klgains[j]);
        }
        return klgains;
    }

    public KLGain(InstanceList ilist, LabelVector[] classifications) {
        super(ilist.getDataAlphabet(), KLGain.calcKLGains(ilist, classifications));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] c) {
        LabelVector[] ret = new LabelVector[c.length];
        for (int i = 0; i < c.length; ++i) {
            ret[i] = c[i].getLabelVector();
        }
        return ret;
    }

    public KLGain(InstanceList ilist, Classification[] classifications) {
        super(ilist.getDataAlphabet(), KLGain.calcKLGains(ilist, KLGain.getLabelVectorsFromClassifications(classifications)));
    }
}

