/*
 * Decompiled with CFR 0.152.
 */
package keel.Algorithms.Lazy_Learning.CW;

import java.util.StringTokenizer;
import keel.Algorithms.Lazy_Learning.LazyAlgorithm;
import org.core.Files;

public class CW
extends LazyAlgorithm {
    double BETA;
    double MU;
    double epsilon;
    double[][] cWeights;

    public CW(String script) {
        this.readDataFiles(script);
        this.name = "CW";
        this.cWeights = new double[this.nClasses][this.inputAtt];
        for (int i = 0; i < this.cWeights.length; ++i) {
            for (int j = 0; j < this.inputAtt; ++j) {
                this.cWeights[i][j] = 1.0;
            }
        }
        this.setInitialTime();
    }

    @Override
    protected void readParameters(String script) {
        String file = Files.readFile(script);
        StringTokenizer fileLines = new StringTokenizer(file, "\n\r");
        fileLines.nextToken();
        fileLines.nextToken();
        fileLines.nextToken();
        String line = fileLines.nextToken();
        StringTokenizer tokens = new StringTokenizer(line, "=");
        tokens.nextToken();
        this.BETA = Double.parseDouble(tokens.nextToken().substring(1));
        line = fileLines.nextToken();
        tokens = new StringTokenizer(line, "=");
        tokens.nextToken();
        this.MU = Double.parseDouble(tokens.nextToken().substring(1));
        line = fileLines.nextToken();
        tokens = new StringTokenizer(line, "=");
        tokens.nextToken();
        this.epsilon = Double.parseDouble(tokens.nextToken().substring(1));
    }

    public void calculateWeights() {
        double errorAnterior = Double.MAX_VALUE;
        double error = this.errorEstimation();
        while (Math.abs(error - errorAnterior) > this.epsilon) {
            errorAnterior = error;
            for (int i = 0; i < this.trainData.length; ++i) {
                int same = this.findEqual(i);
                int diff = this.findNoEqual(i);
                int classSame = this.trainOutput[same];
                int classDiff = this.trainOutput[diff];
                double distSame = this.weightedDistance(this.trainData[i], same);
                double distDiff = this.weightedDistance(this.trainData[i], diff);
                double ratio = distSame / distDiff;
                double Q = this.derivativeSigmoid(ratio) * ratio;
                for (int j = 0; j < this.inputAtt; ++j) {
                    this.cWeights[classSame][j] = this.cWeights[classSame][j] - this.MU * Q * this.ratio(i, same, j, distSame) * this.cWeights[classSame][j];
                    this.cWeights[classDiff][j] = this.cWeights[classDiff][j] + this.MU * Q * this.ratio(i, diff, j, distDiff) * this.cWeights[classDiff][j];
                }
            }
            error = this.errorEstimation();
        }
    }

    private double errorEstimation() {
        double total = 0.0;
        for (int i = 0; i < this.trainData.length; ++i) {
            int equal = this.findEqual(i);
            int notEqual = this.findNoEqual(i);
            double distance1 = this.weightedDistance(this.trainData[i], equal);
            double distance2 = this.weightedDistance(this.trainData[i], notEqual);
            total += this.sigmoid(distance1 / distance2);
        }
        return total /= (double)this.trainData.length;
    }

    private double sigmoid(double z) {
        return 1.0 / (1.0 + Math.exp(this.BETA * (1.0 - z)));
    }

    private double derivativeSigmoid(double z) {
        double up = this.BETA * Math.exp(this.BETA * (1.0 - z));
        double down = 1.0 + Math.exp(this.BETA * (1.0 - z));
        down *= down;
        double total = up / down;
        return total;
    }

    private double ratio(int instance1, int instance2, int feature, double dist) {
        double up = this.trainData[instance1][feature] - this.trainData[instance2][feature];
        return up * up - dist * dist;
    }

    private int findEqual(int instance) {
        double distance = Double.MAX_VALUE;
        int insClass = this.trainOutput[instance];
        int selected = 0;
        for (int i = 0; i < this.trainData.length; ++i) {
            double aux;
            if (this.trainOutput[i] != insClass || i == instance || !(distance > (aux = this.weightedDistance(this.trainData[i], instance)))) continue;
            distance = aux;
            selected = i;
        }
        return selected;
    }

    private int findNoEqual(int instance) {
        double distance = Double.MAX_VALUE;
        int insClass = this.trainOutput[instance];
        int selected = 0;
        for (int i = 0; i < this.trainData.length; ++i) {
            double aux;
            if (this.trainOutput[i] == insClass || !(distance > (aux = this.weightedDistance(this.trainData[i], instance)))) continue;
            distance = aux;
            selected = i;
        }
        return selected;
    }

    @Override
    protected int evaluate(double[] example) {
        int output = 0;
        double min = Double.MAX_VALUE;
        for (int i = 0; i < this.trainData.length; ++i) {
            double aux = this.weightedDistance(example, i);
            if (!(aux < min) || aux == 0.0) continue;
            min = aux;
            output = i;
        }
        output = this.trainOutput[output];
        return output;
    }

    private double weightedDistance(double[] example, int reference) {
        double dist = 0.0;
        int referenceClass = this.trainOutput[reference];
        for (int i = 0; i < this.inputAtt; ++i) {
            double a = example[i];
            double b = this.trainData[reference][i];
            double c = a - b;
            double aux = this.cWeights[referenceClass][i] * c * c;
            dist += aux;
        }
        dist = Math.sqrt(dist);
        return dist;
    }
}

