/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RStringVector;

public class NNConverter
extends ModelConverter<RGenericVector> {
    public NNConverter(RGenericVector nn) {
        super(nn);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector nn = (RGenericVector)this.getObject();
        RGenericVector modelList = nn.getGenericElement("model.list");
        RStringVector response = modelList.getStringElement("response");
        RStringVector variables = modelList.getStringElement("variables");
        DataField dataField = encoder.createDataField(FieldName.create((String)((String)response.asScalar())), OpType.CONTINUOUS, DataType.DOUBLE);
        encoder.setLabel(dataField);
        for (int i = 0; i < variables.size(); ++i) {
            String variable = variables.getValue(i);
            DataField dataField2 = encoder.createDataField(FieldName.create((String)variable), OpType.CONTINUOUS, DataType.DOUBLE);
            encoder.addFeature((Field<?>)dataField2);
        }
    }

    @Override
    public Model encodeModel(Schema schema) {
        RGenericVector nn = (RGenericVector)this.getObject();
        RExp actFct = (RExp)nn.getElement("act.fct");
        RBooleanVector linearOutput = nn.getBooleanElement("linear.output");
        RGenericVector weights = nn.getGenericElement("weights");
        RStringVector actFctType = actFct.getStringAttribute("type");
        weights = (RGenericVector)weights.getValue(0);
        NeuralNetwork.ActivationFunction activationFunction = NeuralNetwork.ActivationFunction.LOGISTIC;
        switch ((String)actFctType.asScalar()) {
            case "logistic": {
                activationFunction = NeuralNetwork.ActivationFunction.LOGISTIC;
                break;
            }
            case "tanh": {
                activationFunction = NeuralNetwork.ActivationFunction.TANH;
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        List features = schema.getFeatures();
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs((List)features, (DataType)DataType.DOUBLE);
        ArrayList<NeuralLayer> neuralLayers = new ArrayList<NeuralLayer>();
        List entities = neuralInputs.getNeuralInputs();
        for (int i = 0; i < weights.size(); ++i) {
            boolean hidden = i < weights.size() - 1;
            NeuralLayer neuralLayer = new NeuralLayer();
            if (hidden || linearOutput != null && !((Boolean)linearOutput.asScalar()).booleanValue()) {
                neuralLayer.setActivationFunction(activationFunction);
            }
            RDoubleVector layerWeights = (RDoubleVector)weights.getValue(i);
            RIntegerVector layerDim = layerWeights.dim();
            int layerRows = layerDim.getValue(0);
            int layerColumns = layerDim.getValue(1);
            for (int j = 0; j < layerColumns; ++j) {
                List neuronWeights = FortranMatrixUtil.getColumn(layerWeights.getValues(), (int)layerRows, (int)layerColumns, (int)j);
                String id = hidden ? "hidden/" + String.valueOf(i) + "/" + String.valueOf(j) : "output/" + String.valueOf(j);
                Neuron neuron = NeuralNetworkUtil.createNeuron((List)entities, neuronWeights.subList(1, neuronWeights.size()), (Double)((Double)neuronWeights.get(0))).setId(id);
                neuralLayer.addNeurons(new Neuron[]{neuron});
            }
            neuralLayers.add(neuralLayer);
            entities = neuralLayer.getNeurons();
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.REGRESSION, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema((Label)continuousLabel), neuralInputs, neuralLayers).setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs((List)entities, (ContinuousLabel)continuousLabel));
        return neuralNetwork;
    }
}

