/*
 * Decompiled with CFR 0.152.
 */
package opennlp.maxent;

import opennlp.maxent.DataIndexer;
import opennlp.maxent.EvalParameters;
import opennlp.maxent.EventStream;
import opennlp.maxent.GISModel;
import opennlp.maxent.MutableContext;
import opennlp.maxent.OnePassDataIndexer;
import opennlp.maxent.Prior;
import opennlp.maxent.UniformPrior;

class GISTrainer {
    private boolean useSimpleSmoothing = false;
    private boolean useSlackParameter = false;
    private boolean useGaussianSmoothing = false;
    private double sigma = 2.0;
    private double _smoothingObservation = 0.1;
    private boolean printMessages = false;
    private int numUniqueEvents;
    private int numPreds;
    private int numOutcomes;
    private int[][] contexts;
    private float[][] values;
    private int[] outcomeList;
    private int[] numTimesEventsSeen;
    private int[] predicateCounts;
    private int cutoff;
    private String[] outcomeLabels;
    private String[] predLabels;
    private MutableContext[] observedExpects;
    private MutableContext[] params;
    private MutableContext[] modelExpects;
    private Prior prior;
    private double cfObservedExpect;
    private double CFMOD;
    private final double NEAR_ZERO = 0.01;
    private final double LLThreshold = 1.0E-4;
    double[] modelDistribution;
    int[] numfeats;
    EvalParameters evalParams;

    GISTrainer() {
    }

    GISTrainer(boolean printMessages) {
        this();
        this.printMessages = printMessages;
    }

    public void setSmoothing(boolean smooth) {
        this.useSimpleSmoothing = smooth;
    }

    public void setSmoothingObservation(double timesSeen) {
        this._smoothingObservation = timesSeen;
    }

    public GISModel trainModel(EventStream eventStream, int iterations, int cutoff) {
        return this.trainModel(iterations, new OnePassDataIndexer(eventStream, cutoff), cutoff);
    }

    public GISModel trainModel(int iterations, DataIndexer di, int cutoff) {
        return this.trainModel(iterations, di, new UniformPrior(), cutoff);
    }

    public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff) {
        this.display("Incorporating indexed data for training...  \n");
        this.contexts = di.getContexts();
        this.values = di.getValues();
        this.cutoff = cutoff;
        this.predicateCounts = di.getPredCounts();
        this.numTimesEventsSeen = di.getNumTimesEventsSeen();
        this.numUniqueEvents = this.contexts.length;
        this.prior = modelPrior;
        int correctionConstant = 1;
        int ci = 0;
        while (ci < this.contexts.length) {
            if (this.values == null || this.values[ci] == null) {
                if (this.contexts[ci].length > correctionConstant) {
                    correctionConstant = this.contexts[ci].length;
                }
            } else {
                float cl = this.values[ci][0];
                int vi = 1;
                while (vi < this.values[ci].length) {
                    cl += this.values[ci][vi];
                    ++vi;
                }
                if (cl > (float)correctionConstant) {
                    correctionConstant = (int)Math.ceil(cl);
                }
            }
            ++ci;
        }
        this.display("done.\n");
        this.outcomeLabels = di.getOutcomeLabels();
        this.outcomeList = di.getOutcomeList();
        this.numOutcomes = this.outcomeLabels.length;
        this.predLabels = di.getPredLabels();
        this.prior.setLabels(this.outcomeLabels, this.predLabels);
        this.numPreds = this.predLabels.length;
        this.display("\tNumber of Event Tokens: " + this.numUniqueEvents + "\n");
        this.display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        this.display("\t  Number of Predicates: " + this.numPreds + "\n");
        float[][] predCount = new float[this.numPreds][this.numOutcomes];
        int ti = 0;
        while (ti < this.numUniqueEvents) {
            int j = 0;
            while (j < this.contexts[ti].length) {
                if (this.values != null && this.values[ti] != null) {
                    float[] fArray = predCount[this.contexts[ti][j]];
                    int n = this.outcomeList[ti];
                    fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti] * this.values[ti][j];
                } else {
                    float[] fArray = predCount[this.contexts[ti][j]];
                    int n = this.outcomeList[ti];
                    fArray[n] = fArray[n] + (float)this.numTimesEventsSeen[ti];
                }
                ++j;
            }
            ++ti;
        }
        di = null;
        double smoothingObservation = this._smoothingObservation;
        this.params = new MutableContext[this.numPreds];
        this.modelExpects = new MutableContext[this.numPreds];
        this.observedExpects = new MutableContext[this.numPreds];
        this.evalParams = new EvalParameters(this.params, 0.0, correctionConstant, this.numOutcomes);
        int[] activeOutcomes = new int[this.numOutcomes];
        int[] allOutcomesPattern = new int[this.numOutcomes];
        int oi = 0;
        while (oi < this.numOutcomes) {
            allOutcomesPattern[oi] = oi;
            ++oi;
        }
        int numActiveOutcomes = 0;
        int pi = 0;
        while (pi < this.numPreds) {
            int aoi;
            int[] outcomePattern;
            numActiveOutcomes = 0;
            if (this.useSimpleSmoothing) {
                numActiveOutcomes = this.numOutcomes;
                outcomePattern = allOutcomesPattern;
            } else {
                int oi2 = 0;
                while (oi2 < this.numOutcomes) {
                    if (predCount[pi][oi2] > 0.0f && this.predicateCounts[pi] > cutoff) {
                        activeOutcomes[numActiveOutcomes] = oi2;
                        ++numActiveOutcomes;
                    }
                    ++oi2;
                }
                if (numActiveOutcomes == this.numOutcomes) {
                    outcomePattern = allOutcomesPattern;
                } else {
                    outcomePattern = new int[numActiveOutcomes];
                    aoi = 0;
                    while (aoi < numActiveOutcomes) {
                        outcomePattern[aoi] = activeOutcomes[aoi];
                        ++aoi;
                    }
                }
            }
            this.params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            this.modelExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            this.observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
            aoi = 0;
            while (aoi < numActiveOutcomes) {
                int oi3 = outcomePattern[aoi];
                this.params[pi].setParameter(aoi, 0.0);
                this.modelExpects[pi].setParameter(aoi, 0.0);
                if (predCount[pi][oi3] > 0.0f) {
                    this.observedExpects[pi].setParameter(aoi, predCount[pi][oi3]);
                } else if (this.useSimpleSmoothing) {
                    this.observedExpects[pi].setParameter(aoi, smoothingObservation);
                }
                ++aoi;
            }
            ++pi;
        }
        if (this.useSlackParameter) {
            int cfvalSum = 0;
            int ti2 = 0;
            while (ti2 < this.numUniqueEvents) {
                int j = 0;
                while (j < this.contexts[ti2].length) {
                    int pi2 = this.contexts[ti2][j];
                    if (!this.modelExpects[pi2].contains(this.outcomeList[ti2])) {
                        cfvalSum += this.numTimesEventsSeen[ti2];
                    }
                    ++j;
                }
                cfvalSum += (correctionConstant - this.contexts[ti2].length) * this.numTimesEventsSeen[ti2];
                ++ti2;
            }
            this.cfObservedExpect = cfvalSum == 0 ? Math.log(0.01) : Math.log(cfvalSum);
        }
        predCount = null;
        this.display("...done.\n");
        this.modelDistribution = new double[this.numOutcomes];
        this.numfeats = new int[this.numOutcomes];
        this.display("Computing model parameters...\n");
        this.findParameters(iterations);
        return new GISModel(this.params, this.predLabels, this.outcomeLabels, correctionConstant, this.evalParams.correctionParam);
    }

    private void findParameters(int iterations) {
        double prevLL = 0.0;
        double currLL = 0.0;
        this.display("Performing " + iterations + " iterations.\n");
        int i = 1;
        while (i <= iterations) {
            if (i < 10) {
                this.display("  " + i + ":  ");
            } else if (i < 100) {
                this.display(" " + i + ":  ");
            } else {
                this.display(String.valueOf(i) + ":  ");
            }
            currLL = this.nextIteration();
            if (i > 1) {
                if (prevLL > currLL) {
                    System.err.println("Model Diverging: loglikelihood decreased");
                    break;
                }
                if (currLL - prevLL < 1.0E-4) break;
            }
            prevLL = currLL;
            ++i;
        }
        this.observedExpects = null;
        this.modelExpects = null;
        this.numTimesEventsSeen = null;
        this.contexts = null;
    }

    private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) {
        double param = this.params[predicate].getParameters()[oid];
        double x = 0.0;
        double x0 = 0.0;
        double modelValue = this.modelExpects[predicate].getParameters()[oid];
        double observedValue = this.observedExpects[predicate].getParameters()[oid];
        int i = 0;
        while (i < 50) {
            double tmp = modelValue * Math.exp(correctionConstant * x0);
            double f = tmp + (param + x0) / this.sigma - observedValue;
            double fp = tmp * correctionConstant + 1.0 / this.sigma;
            if (fp == 0.0) break;
            x = x0 - f / fp;
            if (Math.abs(x - x0) < 1.0E-6) {
                x0 = x;
                break;
            }
            x0 = x;
            ++i;
        }
        return x0;
    }

    private double nextIteration() {
        int aoi;
        int[] activeOutcomes;
        double loglikelihood = 0.0;
        this.CFMOD = 0.0;
        int numEvents = 0;
        int numCorrect = 0;
        int ei = 0;
        while (ei < this.numUniqueEvents) {
            if (this.values != null) {
                this.prior.logPrior(this.modelDistribution, this.contexts[ei], this.values[ei]);
                GISModel.eval(this.contexts[ei], this.values[ei], this.modelDistribution, this.evalParams);
            } else {
                this.prior.logPrior(this.modelDistribution, this.contexts[ei]);
                GISModel.eval(this.contexts[ei], this.modelDistribution, this.evalParams);
            }
            int j = 0;
            while (j < this.contexts[ei].length) {
                int pi = this.contexts[ei][j];
                if (this.predicateCounts[pi] >= this.cutoff) {
                    activeOutcomes = this.modelExpects[pi].getOutcomes();
                    aoi = 0;
                    while (aoi < activeOutcomes.length) {
                        int oi = activeOutcomes[aoi];
                        if (this.values != null && this.values[ei] != null) {
                            this.modelExpects[pi].updateParameter(aoi, this.modelDistribution[oi] * (double)this.values[ei][j] * (double)this.numTimesEventsSeen[ei]);
                        } else {
                            this.modelExpects[pi].updateParameter(aoi, this.modelDistribution[oi] * (double)this.numTimesEventsSeen[ei]);
                        }
                        ++aoi;
                    }
                    if (this.useSlackParameter) {
                        int oi = 0;
                        while (oi < this.numOutcomes) {
                            if (!this.modelExpects[pi].contains(oi)) {
                                this.CFMOD += this.modelDistribution[oi] * (double)this.numTimesEventsSeen[ei];
                            }
                            ++oi;
                        }
                    }
                }
                ++j;
            }
            if (this.useSlackParameter) {
                this.CFMOD += (this.evalParams.correctionConstant - (double)this.contexts[ei].length) * (double)this.numTimesEventsSeen[ei];
            }
            loglikelihood += Math.log(this.modelDistribution[this.outcomeList[ei]]) * (double)this.numTimesEventsSeen[ei];
            numEvents += this.numTimesEventsSeen[ei];
            if (this.printMessages) {
                int max = 0;
                int oi = 1;
                while (oi < this.numOutcomes) {
                    if (this.modelDistribution[oi] > this.modelDistribution[max]) {
                        max = oi;
                    }
                    ++oi;
                }
                if (max == this.outcomeList[ei]) {
                    numCorrect += this.numTimesEventsSeen[ei];
                }
            }
            ++ei;
        }
        this.display(".");
        int pi = 0;
        while (pi < this.numPreds) {
            double[] observed = this.observedExpects[pi].getParameters();
            double[] model = this.modelExpects[pi].getParameters();
            activeOutcomes = this.params[pi].getOutcomes();
            aoi = 0;
            while (aoi < activeOutcomes.length) {
                if (this.useGaussianSmoothing) {
                    this.params[pi].updateParameter(aoi, this.gaussianUpdate(pi, aoi, numEvents, this.evalParams.correctionConstant));
                } else {
                    if (model[aoi] == 0.0) {
                        System.err.println("Model expects == 0 for " + this.predLabels[pi] + " " + this.outcomeLabels[aoi]);
                    }
                    this.params[pi].updateParameter(aoi, Math.log(observed[aoi]) - Math.log(model[aoi]));
                }
                this.modelExpects[pi].setParameter(aoi, 0.0);
                ++aoi;
            }
            ++pi;
        }
        if (this.CFMOD > 0.0 && this.useSlackParameter) {
            this.evalParams.correctionParam += this.cfObservedExpect - Math.log(this.CFMOD);
        }
        this.display(". loglikelihood=" + loglikelihood + "\t" + (double)numCorrect / (double)numEvents + "\n");
        return loglikelihood;
    }

    private void display(String s) {
        if (this.printMessages) {
            System.out.print(s);
        }
    }
}

