/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import com.google.common.math.DoubleMath;
import com.google.common.primitives.UnsignedLong;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureImportanceMap;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.HasFeatureImportances;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RFactorVector;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVectorUtil;
import org.jpmml.rexp.TreeModelConverter;
import org.jpmml.rexp.XLevelsFormulaContext;
import org.jpmml.rexp.visitors.RandomForestCompactor;

public class RandomForestConverter
extends TreeModelConverter<RGenericVector>
implements HasFeatureImportances {
    private boolean compact = this.getOption("compact", Boolean.TRUE);
    private static final UnsignedLong TWO = UnsignedLong.valueOf((long)2L);

    public RandomForestConverter(RGenericVector randomForest) {
        super(randomForest);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector randomForest = (RGenericVector)this.getObject();
        if (randomForest.hasElement("terms")) {
            this.encodeFormula(encoder);
        } else {
            this.encodeNonFormula(encoder);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector randomForest = (RGenericVector)this.getObject();
        RStringVector type = randomForest.getStringElement("type");
        RGenericVector forest = randomForest.getGenericElement("forest");
        switch ((String)type.asScalar()) {
            case "regression": {
                return this.encodeRegression(forest, schema);
            }
            case "classification": {
                return this.encodeClassification(forest, schema);
            }
        }
        throw new IllegalArgumentException();
    }

    @Override
    public FeatureImportanceMap getFeatureImportances(Schema schema) {
        RGenericVector randomForest = (RGenericVector)this.getObject();
        RDoubleVector importance = randomForest.getDoubleElement("importance", false);
        if (importance == null) {
            return null;
        }
        RStringVector importanceRows = importance.dimnames(0);
        RStringVector importanceColumns = importance.dimnames(1);
        RIntegerVector importanceDim = importance.dim();
        int rows = importanceDim.getValue(0);
        int columns = importanceDim.getValue(1);
        List features = schema.getFeatures();
        FeatureImportanceMap result = new FeatureImportanceMap(importanceColumns.getDequotedValue(columns - 1));
        List defaultImportances = CMatrixUtil.getColumn(importance.getValues(), (int)rows, (int)columns, (int)(columns - 1));
        for (int i = 0; i < features.size(); ++i) {
            result.put(features.get(i), defaultImportances.get(i));
        }
        return result;
    }

    private void encodeFormula(RExpEncoder encoder) {
        RGenericVector randomForest = (RGenericVector)this.getObject();
        RGenericVector forest = randomForest.getGenericElement("forest");
        RNumberVector<?> y = randomForest.getNumericElement("y", false);
        RExp terms = (RExp)randomForest.getElement("terms");
        final RNumberVector<?> ncat = forest.getNumericElement("ncat");
        RGenericVector xlevels = forest.getGenericElement("xlevels");
        XLevelsFormulaContext context = new XLevelsFormulaContext(xlevels){

            @Override
            public List<String> getCategories(String variable) {
                if (ncat != null && ncat.hasElement(variable) && ((Number)ncat.getElement(variable)).doubleValue() > 1.0) {
                    return super.getCategories(variable);
                }
                return null;
            }
        };
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        if (y instanceof RIntegerVector) {
            FormulaUtil.setLabel(formula, terms, y, encoder);
        } else {
            FormulaUtil.setLabel(formula, terms, null, encoder);
        }
        FormulaUtil.addFeatures(formula, xlevels.names(), false, encoder);
    }

    private void encodeNonFormula(RExpEncoder encoder) {
        DataField dataField;
        RGenericVector randomForest = (RGenericVector)this.getObject();
        RGenericVector forest = randomForest.getGenericElement("forest");
        RFactorVector y = randomForest.getNumericElement("y", false);
        RStringVector xNames = randomForest.getStringElement("xNames", false);
        RNumberVector<?> ncat = forest.getNumericElement("ncat");
        RGenericVector xlevels = forest.getGenericElement("xlevels");
        if (xNames == null) {
            xNames = xlevels.names();
        }
        String name = "_target";
        if (y instanceof RIntegerVector) {
            y = randomForest.getFactorElement("y");
            dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING, y.getLevelValues());
        } else {
            dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
        }
        encoder.setLabel(dataField);
        RVectorUtil.checkSize(ncat, xNames);
        for (int i = 0; i < ncat.size(); ++i) {
            DataField dataField2;
            boolean categorical;
            String name2 = xNames.getValue(i);
            boolean bl = categorical = ((Number)ncat.getValue(i)).doubleValue() > 1.0;
            if (categorical) {
                RStringVector levels = xlevels.getStringValue(i);
                dataField2 = encoder.createDataField(name2, OpType.CATEGORICAL, null, levels.getValues());
            } else {
                dataField2 = encoder.createDataField(name2, OpType.CONTINUOUS, DataType.DOUBLE);
            }
            encoder.addFeature((Field<?>)dataField2);
        }
    }

    private MiningModel encodeRegression(RGenericVector forest, Schema schema) {
        RNumberVector<?> leftDaughter = forest.getNumericElement("leftDaughter");
        RNumberVector<?> rightDaughter = forest.getNumericElement("rightDaughter");
        RDoubleVector nodepred = forest.getDoubleElement("nodepred");
        RNumberVector<?> bestvar = forest.getNumericElement("bestvar");
        RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit");
        RIntegerVector nrnodes = forest.getIntegerElement("nrnodes");
        RNumberVector<?> ntree = forest.getNumericElement("ntree");
        ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>(){

            public Double encode(Double value) {
                return value;
            }
        };
        int rows = (Integer)nrnodes.asScalar();
        int columns = ValueUtil.asInt((Number)((Number)ntree.asScalar()));
        Schema segmentSchema = schema.toAnonymousSchema();
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (int i = 0; i < columns; ++i) {
            TreeModel treeModel = this.encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(leftDaughter.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(rightDaughter.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(nodepred.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(bestvar.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), (int)rows, (int)columns, (int)i), segmentSchema);
            treeModels.add(treeModel);
        }
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels));
        return miningModel;
    }

    private MiningModel encodeClassification(RGenericVector forest, Schema schema) {
        RNumberVector<?> bestvar = forest.getNumericElement("bestvar");
        RNumberVector<?> treemap = forest.getNumericElement("treemap");
        RIntegerVector nodepred = forest.getIntegerElement("nodepred");
        RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit");
        RIntegerVector nrnodes = forest.getIntegerElement("nrnodes");
        RDoubleVector ntree = forest.getDoubleElement("ntree");
        int rows = (Integer)nrnodes.asScalar();
        int columns = ValueUtil.asInt((Number)((Number)ntree.asScalar()));
        final CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>(){

            @Override
            public Object encode(Integer value) {
                return categoricalLabel.getValue(value - 1);
            }
        };
        Schema segmentSchema = schema.toAnonymousSchema();
        ArrayList<TreeModel> treeModels = new ArrayList<TreeModel>();
        for (int i = 0; i < columns; ++i) {
            List daughters = FortranMatrixUtil.getColumn(treemap.getValues(), (int)(2 * rows), (int)columns, (int)i);
            TreeModel treeModel = this.encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn((List)daughters, (int)rows, (int)2, (int)0), FortranMatrixUtil.getColumn((List)daughters, (int)rows, (int)2, (int)1), FortranMatrixUtil.getColumn(nodepred.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(bestvar.getValues(), (int)rows, (int)columns, (int)i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), (int)rows, (int)columns, (int)i), segmentSchema);
            treeModels.add(treeModel);
        }
        MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.MAJORITY_VOTE, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        return miningModel;
    }

    private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
        RGenericVector randomForest = (RGenericVector)this.getObject();
        Node root = this.encodeNode((Predicate)True.INSTANCE, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema);
        TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        if (this.compact) {
            RandomForestCompactor visitor = new RandomForestCompactor();
            visitor.applyTo((Visitable)treeModel);
        }
        return treeModel;
    }

    private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
        int right;
        Predicate rightPredicate;
        Predicate leftPredicate;
        Integer id = i + 1;
        int var = ValueUtil.asInt((Number)bestvar.get(i));
        if (var == 0) {
            Number prediction = (Number)nodepred.get(i);
            SimpleNode result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId((Object)id);
            return result;
        }
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Feature feature = schema.getFeature(var - 1);
        Double split = xbestsplit.get(i);
        if (feature instanceof BooleanFeature) {
            BooleanFeature booleanFeature = (BooleanFeature)feature;
            if (split != 0.5) {
                throw new IllegalArgumentException();
            }
            leftPredicate = this.createSimplePredicate((Feature)booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
            rightPredicate = this.createSimplePredicate((Feature)booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            String name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            List<Object> leftValues = RandomForestConverter.selectValues(values, valueFilter, split, true);
            List<Object> rightValues = RandomForestConverter.selectValues(values, valueFilter, split, false);
            leftCategoryManager = categoryManager.fork(name, leftValues);
            rightCategoryManager = categoryManager.fork(name, rightValues);
            leftPredicate = this.createPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = this.createPredicate((Feature)categoricalFeature, rightValues);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
            rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
        }
        SimpleNode result = new BranchNode(null, predicate).setId((Object)id);
        List nodes = result.getNodes();
        int left = ValueUtil.asInt((Number)leftDaughter.get(i));
        if (left != 0) {
            Node leftChild = this.encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
            nodes.add(leftChild);
        }
        if ((right = ValueUtil.asInt((Number)rightDaughter.get(i))) != 0) {
            Node rightChild = this.encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
            nodes.add(rightChild);
        }
        return result;
    }

    public static List<Object> selectValues(List<?> values, java.util.function.Predicate<Object> valueFilter, Double split, boolean left) {
        UnsignedLong bits = RandomForestConverter.toUnsignedLong(split);
        ArrayList<Object> result = new ArrayList<Object>();
        for (int i = 0; i < values.size(); ++i) {
            Object value = values.get(i);
            boolean append = left ? bits.mod(TWO).equals((Object)UnsignedLong.ONE) : bits.mod(TWO).equals((Object)UnsignedLong.ZERO);
            if (append && valueFilter.test(value)) {
                result.add(value);
            }
            bits = bits.dividedBy(TWO);
        }
        return result;
    }

    public static UnsignedLong toUnsignedLong(double value) {
        if (!DoubleMath.isMathematicalInteger((double)value)) {
            throw new IllegalArgumentException();
        }
        return UnsignedLong.fromLongBits((long)((long)value));
    }

    private static interface ScoreEncoder<V extends Number> {
        public Object encode(V var1);
    }
}

