/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class Logistic
extends Classifier
implements OptionHandler,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 3932117032546553727L;
    protected double[][] m_Par;
    protected double[][] m_Data;
    protected int m_NumPredictors;
    protected int m_ClassIndex;
    protected int m_NumClasses;
    protected double m_Ridge = 1.0E-8;
    private RemoveUseless m_AttFilter;
    private NominalToBinary m_NominalToBinary;
    private ReplaceMissingValues m_ReplaceMissingValues;
    protected boolean m_Debug;
    protected double m_LL;
    private int m_MaxIts = -1;
    private Instances m_structure;

    public String globalInfo() {
        return "Class for building and using a multinomial logistic regression model with a ridge estimator.\n\nThere are some modifications, however, compared to the paper of leCessie and van Houwelingen(1992): \n\nIf there are k classes for n instances with m attributes, the parameter matrix B to be calculated will be an m*(k-1) matrix.\n\nThe probability for class j with the exception of the last class is\n\nPj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) \n\nThe last class has probability\n\n1-(sum[j=1..(k-1)]Pj(Xi)) \n\t= 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)\n\nThe (negative) multinomial log-likelihood is thus: \n\nL = -sum[i=1..n]{\n\tsum[j=1..(k-1)](Yij * ln(Pj(Xi)))\n\t+(1 - (sum[j=1..(k-1)]Yij)) \n\t* ln(1 - sum[j=1..(k-1)]Pj(Xi))\n\t} + ridge * (B^2)\n\nIn order to find the matrix B for which L is minimised, a Quasi-Newton Method is used to search for the optimized values of the m*(k-1) variables.  Note that before we use the optimization procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For details of the optimization procedure, please check weka.core.Optimization class.\n\nAlthough original Logistic Regression does not deal with instance weights, we modify the algorithm a little bit to handle the instance weights.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString() + "\n\n" + "Note: Missing values are replaced using a ReplaceMissingValuesFilter, and " + "nominal attributes are transformed into numeric attributes using a " + "NominalToBinaryFilter.";
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "le Cessie, S. and van Houwelingen, J.C.");
        result.setValue(TechnicalInformation.Field.YEAR, "1992");
        result.setValue(TechnicalInformation.Field.TITLE, "Ridge Estimators in Logistic Regression");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Applied Statistics");
        result.setValue(TechnicalInformation.Field.VOLUME, "41");
        result.setValue(TechnicalInformation.Field.NUMBER, "1");
        result.setValue(TechnicalInformation.Field.PAGES, "191-201");
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(3);
        newVector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        newVector.addElement(new Option("\tSet the ridge in the log-likelihood.", "R", 1, "-R <ridge>"));
        newVector.addElement(new Option("\tSet the maximum number of iterations (default -1, until convergence).", "M", 1, "-M <number>"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setDebug(Utils.getFlag('D', options));
        String ridgeString = Utils.getOption('R', options);
        this.m_Ridge = ridgeString.length() != 0 ? Double.parseDouble(ridgeString) : 1.0E-8;
        String maxItsString = Utils.getOption('M', options);
        this.m_MaxIts = maxItsString.length() != 0 ? Integer.parseInt(maxItsString) : -1;
    }

    @Override
    public String[] getOptions() {
        String[] options = new String[5];
        int current = 0;
        if (this.getDebug()) {
            options[current++] = "-D";
        }
        options[current++] = "-R";
        options[current++] = "" + this.m_Ridge;
        options[current++] = "-M";
        options[current++] = "" + this.m_MaxIts;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    @Override
    public String debugTipText() {
        return "Output debug information to the console.";
    }

    @Override
    public void setDebug(boolean debug) {
        this.m_Debug = debug;
    }

    @Override
    public boolean getDebug() {
        return this.m_Debug;
    }

    public String ridgeTipText() {
        return "Set the Ridge value in the log-likelihood.";
    }

    public void setRidge(double ridge) {
        this.m_Ridge = ridge;
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public String maxItsTipText() {
        return "Maximum number of iterations to perform.";
    }

    public int getMaxIts() {
        return this.m_MaxIts;
    }

    public void setMaxIts(int newMaxIts) {
        this.m_MaxIts = newMaxIts;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        this.getCapabilities().testWithFail(train);
        train = new Instances(train);
        train.deleteWithMissingClass();
        this.m_ReplaceMissingValues = new ReplaceMissingValues();
        this.m_ReplaceMissingValues.setInputFormat(train);
        train = Filter.useFilter(train, this.m_ReplaceMissingValues);
        this.m_AttFilter = new RemoveUseless();
        this.m_AttFilter.setInputFormat(train);
        train = Filter.useFilter(train, this.m_AttFilter);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(train);
        train = Filter.useFilter(train, this.m_NominalToBinary);
        this.m_structure = new Instances(train, 0);
        this.m_ClassIndex = train.classIndex();
        this.m_NumClasses = train.numClasses();
        int nK = this.m_NumClasses - 1;
        int nR = this.m_NumPredictors = train.numAttributes() - 1;
        int nC = train.numInstances();
        this.m_Data = new double[nC][nR + 1];
        int[] Y = new int[nC];
        double[] xMean = new double[nR + 1];
        double[] xSD = new double[nR + 1];
        double[] sY = new double[nK + 1];
        double[] weights = new double[nC];
        double totWeights = 0.0;
        this.m_Par = new double[nR + 1][nK];
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        int i = 0;
        while (i < nC) {
            Instance current = train.instance(i);
            Y[i] = (int)current.classValue();
            weights[i] = current.weight();
            totWeights += weights[i];
            this.m_Data[i][0] = 1.0;
            int j = 1;
            int k = 0;
            while (k <= nR) {
                if (k != this.m_ClassIndex) {
                    double x;
                    this.m_Data[i][j] = x = current.value(k);
                    int n = j;
                    xMean[n] = xMean[n] + weights[i] * x;
                    int n2 = j++;
                    xSD[n2] = xSD[n2] + weights[i] * x * x;
                }
                ++k;
            }
            int n = Y[i];
            sY[n] = sY[n] + 1.0;
            ++i;
        }
        if (totWeights <= 1.0 && nC > 1) {
            throw new Exception("Sum of weights of instances less than 1, please reweight!");
        }
        xMean[0] = 0.0;
        xSD[0] = 1.0;
        int j = 1;
        while (j <= nR) {
            xMean[j] = xMean[j] / totWeights;
            xSD[j] = totWeights > 1.0 ? Math.sqrt(Math.abs(xSD[j] - totWeights * xMean[j] * xMean[j]) / (totWeights - 1.0)) : 0.0;
            ++j;
        }
        if (this.m_Debug) {
            System.out.println("Descriptives...");
            int m = 0;
            while (m <= nK) {
                System.out.println(String.valueOf(sY[m]) + " cases have class " + m);
                ++m;
            }
            System.out.println("\n Variable     Avg       SD    ");
            j = 1;
            while (j <= nR) {
                System.out.println(String.valueOf(Utils.doubleToString(j, 8, 4)) + Utils.doubleToString(xMean[j], 10, 4) + Utils.doubleToString(xSD[j], 10, 4));
                ++j;
            }
        }
        i = 0;
        while (i < nC) {
            int j2 = 0;
            while (j2 <= nR) {
                if (xSD[j2] != 0.0) {
                    this.m_Data[i][j2] = (this.m_Data[i][j2] - xMean[j2]) / xSD[j2];
                }
                ++j2;
            }
            ++i;
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] x = new double[(nR + 1) * nK];
        double[][] b = new double[2][x.length];
        int p = 0;
        while (p < nK) {
            int offset = p * (nR + 1);
            x[offset] = Math.log(sY[p] + 1.0) - Math.log(sY[nK] + 1.0);
            b[0][offset] = Double.NaN;
            b[1][offset] = Double.NaN;
            int q = 1;
            while (q <= nR) {
                x[offset + q] = 0.0;
                b[0][offset + q] = Double.NaN;
                b[1][offset + q] = Double.NaN;
                ++q;
            }
            ++p;
        }
        OptEng opt = new OptEng();
        opt.setDebug(this.m_Debug);
        opt.setWeights(weights);
        opt.setClassLabels(Y);
        if (this.m_MaxIts == -1) {
            x = opt.findArgmin(x, b);
            while (x == null) {
                x = opt.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                x = opt.findArgmin(x, b);
            }
            if (this.m_Debug) {
                System.out.println(" -------------<Converged>--------------");
            }
        } else {
            opt.setMaxIteration(this.m_MaxIts);
            x = opt.findArgmin(x, b);
            if (x == null) {
                x = opt.getVarbValues();
            }
        }
        this.m_LL = -opt.getMinFunction();
        this.m_Data = null;
        int i2 = 0;
        while (i2 < nK) {
            this.m_Par[0][i2] = x[i2 * (nR + 1)];
            int j3 = 1;
            while (j3 <= nR) {
                this.m_Par[j3][i2] = x[i2 * (nR + 1) + j3];
                if (xSD[j3] != 0.0) {
                    double[] dArray = this.m_Par[j3];
                    int n = i2;
                    dArray[n] = dArray[n] / xSD[j3];
                    double[] dArray2 = this.m_Par[0];
                    int n3 = i2;
                    dArray2[n3] = dArray2[n3] - this.m_Par[j3][i2] * xMean[j3];
                }
                ++j3;
            }
            ++i2;
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_ReplaceMissingValues.input(instance);
        instance = this.m_ReplaceMissingValues.output();
        this.m_AttFilter.input(instance);
        instance = this.m_AttFilter.output();
        this.m_NominalToBinary.input(instance);
        instance = this.m_NominalToBinary.output();
        double[] instDat = new double[this.m_NumPredictors + 1];
        int j = 1;
        instDat[0] = 1.0;
        int k = 0;
        while (k <= this.m_NumPredictors) {
            if (k != this.m_ClassIndex) {
                instDat[j++] = instance.value(k);
            }
            ++k;
        }
        double[] distribution = this.evaluateProbability(instDat);
        return distribution;
    }

    private double[] evaluateProbability(double[] data) {
        double[] prob = new double[this.m_NumClasses];
        double[] v = new double[this.m_NumClasses];
        int j = 0;
        while (j < this.m_NumClasses - 1) {
            int k = 0;
            while (k <= this.m_NumPredictors) {
                int n = j;
                v[n] = v[n] + this.m_Par[k][j] * data[k];
                ++k;
            }
            ++j;
        }
        v[this.m_NumClasses - 1] = 0.0;
        int m = 0;
        while (m < this.m_NumClasses) {
            double sum = 0.0;
            int n = 0;
            while (n < this.m_NumClasses - 1) {
                sum += Math.exp(v[n] - v[m]);
                ++n;
            }
            prob[m] = 1.0 / (sum + Math.exp(-v[m]));
            ++m;
        }
        return prob;
    }

    public double[][] coefficients() {
        return this.m_Par;
    }

    public String toString() {
        StringBuffer temp = new StringBuffer();
        String result = "";
        temp.append("Logistic Regression with ridge parameter of " + this.m_Ridge);
        if (this.m_Par == null) {
            return String.valueOf(result) + ": No model built yet.";
        }
        int attLength = 0;
        int i = 0;
        while (i < this.m_structure.numAttributes()) {
            if (i != this.m_structure.classIndex() && this.m_structure.attribute(i).name().length() > attLength) {
                attLength = this.m_structure.attribute(i).name().length();
            }
            ++i;
        }
        if ("Intercept".length() > attLength) {
            attLength = "Intercept".length();
        }
        if ("Variable".length() > attLength) {
            attLength = "Variable".length();
        }
        attLength += 2;
        int colWidth = 0;
        int i2 = 0;
        while (i2 < this.m_structure.classAttribute().numValues() - 1) {
            if (this.m_structure.classAttribute().value(i2).length() > colWidth) {
                colWidth = this.m_structure.classAttribute().value(i2).length();
            }
            ++i2;
        }
        int j = 1;
        while (j <= this.m_NumPredictors) {
            int k = 0;
            while (k < this.m_NumClasses - 1) {
                double ORc;
                String t;
                if (Utils.doubleToString(this.m_Par[j][k], 12, 4).trim().length() > colWidth) {
                    colWidth = Utils.doubleToString(this.m_Par[j][k], 12, 4).trim().length();
                }
                if ((t = " " + ((ORc = Math.exp(this.m_Par[j][k])) > 1.0E10 ? "" + ORc : Utils.doubleToString(ORc, 12, 4))).trim().length() > colWidth) {
                    colWidth = t.trim().length();
                }
                ++k;
            }
            ++j;
        }
        if ("Class".length() > colWidth) {
            colWidth = "Class".length();
        }
        temp.append("\nCoefficients...\n");
        temp.append(String.valueOf(Utils.padLeft(" ", attLength)) + Utils.padLeft("Class", colWidth += 2) + "\n");
        temp.append(Utils.padRight("Variable", attLength));
        i2 = 0;
        while (i2 < this.m_NumClasses - 1) {
            String className = this.m_structure.classAttribute().value(i2);
            temp.append(Utils.padLeft(className, colWidth));
            ++i2;
        }
        temp.append("\n");
        int separatorL = attLength + (this.m_NumClasses - 1) * colWidth;
        int i3 = 0;
        while (i3 < separatorL) {
            temp.append("=");
            ++i3;
        }
        temp.append("\n");
        int j2 = 1;
        int i4 = 0;
        while (i4 < this.m_structure.numAttributes()) {
            if (i4 != this.m_structure.classIndex()) {
                temp.append(Utils.padRight(this.m_structure.attribute(i4).name(), attLength));
                int k = 0;
                while (k < this.m_NumClasses - 1) {
                    temp.append(Utils.padLeft(Utils.doubleToString(this.m_Par[j2][k], 12, 4).trim(), colWidth));
                    ++k;
                }
                temp.append("\n");
                ++j2;
            }
            ++i4;
        }
        temp.append(Utils.padRight("Intercept", attLength));
        int k = 0;
        while (k < this.m_NumClasses - 1) {
            temp.append(Utils.padLeft(Utils.doubleToString(this.m_Par[0][k], 10, 4).trim(), colWidth));
            ++k;
        }
        temp.append("\n");
        temp.append("\n\nOdds Ratios...\n");
        temp.append(String.valueOf(Utils.padLeft(" ", attLength)) + Utils.padLeft("Class", colWidth) + "\n");
        temp.append(Utils.padRight("Variable", attLength));
        i4 = 0;
        while (i4 < this.m_NumClasses - 1) {
            String className = this.m_structure.classAttribute().value(i4);
            temp.append(Utils.padLeft(className, colWidth));
            ++i4;
        }
        temp.append("\n");
        i4 = 0;
        while (i4 < separatorL) {
            temp.append("=");
            ++i4;
        }
        temp.append("\n");
        j2 = 1;
        i4 = 0;
        while (i4 < this.m_structure.numAttributes()) {
            if (i4 != this.m_structure.classIndex()) {
                temp.append(Utils.padRight(this.m_structure.attribute(i4).name(), attLength));
                int k2 = 0;
                while (k2 < this.m_NumClasses - 1) {
                    double ORc = Math.exp(this.m_Par[j2][k2]);
                    String ORs = " " + (ORc > 1.0E10 ? "" + ORc : Utils.doubleToString(ORc, 12, 4));
                    temp.append(Utils.padLeft(ORs.trim(), colWidth));
                    ++k2;
                }
                temp.append("\n");
                ++j2;
            }
            ++i4;
        }
        return temp.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5523 $");
    }

    public static void main(String[] argv) {
        Logistic.runClassifier(new Logistic(), argv);
    }

    private class OptEng
    extends Optimization {
        private double[] weights;
        private int[] cls;

        private OptEng() {
        }

        public void setWeights(double[] w) {
            this.weights = w;
        }

        public void setClassLabels(int[] c) {
            this.cls = c;
        }

        @Override
        protected double objectiveFunction(double[] x) {
            double nll = 0.0;
            int dim = Logistic.this.m_NumPredictors + 1;
            int i = 0;
            while (i < this.cls.length) {
                double[] exp = new double[Logistic.this.m_NumClasses - 1];
                int offset = 0;
                while (offset < Logistic.this.m_NumClasses - 1) {
                    int index = offset * dim;
                    int j = 0;
                    while (j < dim) {
                        int n = offset;
                        exp[n] = exp[n] + Logistic.this.m_Data[i][j] * x[index + j];
                        ++j;
                    }
                    ++offset;
                }
                double max = exp[Utils.maxIndex(exp)];
                double denom = Math.exp(-max);
                double num = this.cls[i] == Logistic.this.m_NumClasses - 1 ? -max : exp[this.cls[i]] - max;
                int offset2 = 0;
                while (offset2 < Logistic.this.m_NumClasses - 1) {
                    denom += Math.exp(exp[offset2] - max);
                    ++offset2;
                }
                nll -= this.weights[i] * (num - Math.log(denom));
                ++i;
            }
            int offset = 0;
            while (offset < Logistic.this.m_NumClasses - 1) {
                int r = 1;
                while (r < dim) {
                    nll += Logistic.this.m_Ridge * x[offset * dim + r] * x[offset * dim + r];
                    ++r;
                }
                ++offset;
            }
            return nll;
        }

        @Override
        protected double[] evaluateGradient(double[] x) {
            double[] grad = new double[x.length];
            int dim = Logistic.this.m_NumPredictors + 1;
            int i = 0;
            while (i < this.cls.length) {
                int index;
                double[] num = new double[Logistic.this.m_NumClasses - 1];
                int offset = 0;
                while (offset < Logistic.this.m_NumClasses - 1) {
                    double exp = 0.0;
                    index = offset * dim;
                    int j = 0;
                    while (j < dim) {
                        exp += Logistic.this.m_Data[i][j] * x[index + j];
                        ++j;
                    }
                    num[offset] = exp;
                    ++offset;
                }
                double max = num[Utils.maxIndex(num)];
                double denom = Math.exp(-max);
                int offset2 = 0;
                while (offset2 < Logistic.this.m_NumClasses - 1) {
                    num[offset2] = Math.exp(num[offset2] - max);
                    denom += num[offset2];
                    ++offset2;
                }
                Utils.normalize(num, denom);
                int offset3 = 0;
                while (offset3 < Logistic.this.m_NumClasses - 1) {
                    index = offset3 * dim;
                    double firstTerm = this.weights[i] * num[offset3];
                    int q = 0;
                    while (q < dim) {
                        int n = index + q;
                        grad[n] = grad[n] + firstTerm * Logistic.this.m_Data[i][q];
                        ++q;
                    }
                    ++offset3;
                }
                if (this.cls[i] != Logistic.this.m_NumClasses - 1) {
                    int p = 0;
                    while (p < dim) {
                        int n = this.cls[i] * dim + p;
                        grad[n] = grad[n] - this.weights[i] * Logistic.this.m_Data[i][p];
                        ++p;
                    }
                }
                ++i;
            }
            int offset = 0;
            while (offset < Logistic.this.m_NumClasses - 1) {
                int r = 1;
                while (r < dim) {
                    int n = offset * dim + r;
                    grad[n] = grad[n] + 2.0 * Logistic.this.m_Ridge * x[offset * dim + r];
                    ++r;
                }
                ++offset;
            }
            return grad;
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 5523 $");
        }
    }
}

