options(stremr.verbose = TRUE)
require("data.table")
set_all_stremr_options(fit.package = "speedglm", fit.algorithm = "glm")

# ----------------------------------------------------------------------
# Simulated Data
# ----------------------------------------------------------------------
data(OdataNoCENS)
OdataDT <- as.data.table(OdataNoCENS, key=c(ID, t))

# define lagged N, first value is always 1 (always monitored at the first time point):
OdataDT[, ("N.tminus1") := shift(get("N"), n = 1L, type = "lag", fill = 1L), by = ID]
OdataDT[, ("TI.tminus1") := shift(get("TI"), n = 1L, type = "lag", fill = 1L), by = ID]

# ----------------------------------------------------------------------
# Define intervention (always treated):
# ----------------------------------------------------------------------
OdataDT[, ("TI.set1") := 1L]
OdataDT[, ("TI.set0") := 0L]

# ----------------------------------------------------------------------
# Import Data
# ----------------------------------------------------------------------
OData <- importData(OdataDT, ID = "ID", t = "t", covars = c("highA1c", "lastNat1", "N.tminus1"),
                    CENS = "C", TRT = "TI", MONITOR = "N", OUTCOME = "Y.tplus1")

# ----------------------------------------------------------------------
# Model the Propensity Scores
# ----------------------------------------------------------------------
gform_CENS <- "C ~ highA1c + lastNat1"
gform_TRT = "TI ~ CVD + highA1c + N.tminus1"
gform_MONITOR <- "N ~ 1"
stratify_CENS <- list(C=c("t < 16", "t == 16"))

# ----------------------------------------------------------------------
# Fit Propensity Scores
# ----------------------------------------------------------------------
OData <- fitPropensity(OData, gform_CENS = gform_CENS,
                        gform_TRT = gform_TRT,
                        gform_MONITOR = gform_MONITOR,
                        stratify_CENS = stratify_CENS)

# ----------------------------------------------------------------------
# IPW Ajusted KM or Saturated MSM
# ----------------------------------------------------------------------
require("magrittr")
AKME.St.1 <- getIPWeights(OData, intervened_TRT = "TI.set1") %>%
             survNPMSM(OData) %$%
             IPW_estimates
AKME.St.1

# ----------------------------------------------------------------------
# Bounded IPW
# ----------------------------------------------------------------------
IPW.St.1 <- getIPWeights(OData, intervened_TRT = "TI.set1") %>%
             survDirectIPW(OData)
IPW.St.1[]

# ----------------------------------------------------------------------
# IPW-MSM for hazard
# ----------------------------------------------------------------------
wts.DT.1 <- getIPWeights(OData = OData, intervened_TRT = "TI.set1", rule_name = "TI1")
wts.DT.0 <- getIPWeights(OData = OData, intervened_TRT = "TI.set0", rule_name = "TI0")
survMSM_res <- survMSM(list(wts.DT.1, wts.DT.0), OData, t_breaks = c(1:8,12,16)-1,)
survMSM_res$St

# ----------------------------------------------------------------------
# Sequential G-COMP
# ----------------------------------------------------------------------
t.surv <- c(0:15)
Qforms <- rep.int("Q.kplus1 ~ CVD + highA1c + N + lastNat1 + TI + TI.tminus1", (max(t.surv)+1))
params = list(fit.package = "speedglm", fit.algorithm = "glm")

\dontrun{
gcomp_est <- fitSeqGcomp(OData, t_periods = t.surv, intervened_TRT = "TI.set1",
                          Qforms = Qforms, params_Q = params, stratifyQ_by_rule = FALSE)
gcomp_est[]
}
# ----------------------------------------------------------------------
# TMLE
# ----------------------------------------------------------------------
\dontrun{
tmle_est <- fitTMLE(OData, t_periods = t.surv, intervened_TRT = "TI.set1",
                    Qforms = Qforms, params_Q = params, stratifyQ_by_rule = TRUE)
tmle_est[]
}

# ----------------------------------------------------------------------
# Running IPW-Adjusted KM with optional user-specified weights:
# ----------------------------------------------------------------------
addedWts_DT <- OdataDT[, c("ID", "t"), with = FALSE]
addedWts_DT[, new.wts := sample.int(10, nrow(OdataDT), replace = TRUE)/10]
survNP_res_addedWts <- survNPMSM(wts.DT.1, OData, weights = addedWts_DT)

# ----------------------------------------------------------------------
# Multivariate Propensity Score Regressions
# ----------------------------------------------------------------------
gform_CENS <- "C + TI + N ~ highA1c + lastNat1"
OData <- fitPropensity(OData, gform_CENS = gform_CENS, gform_TRT = gform_TRT,
                        gform_MONITOR = gform_MONITOR)

# ----------------------------------------------------------------------
# Fitting Propensity scores with Random Forests:
# ----------------------------------------------------------------------
\dontrun{
set_all_stremr_options(fit.package = "h2o", fit.algorithm = "randomForest")
require("h2o")
h2o::h2o.init(nthreads = -1)
gform_CENS <- "C ~ highA1c + lastNat1"
OData <- fitPropensity(OData, gform_CENS = gform_CENS,
                        gform_TRT = gform_TRT,
                        gform_MONITOR = gform_MONITOR,
                        stratify_CENS = stratify_CENS)

# For Gradient Boosting machines:
set_all_stremr_options(fit.package = "h2o", fit.algorithm = "gbm")
# Use `H2O-3` distributed implementation of GLM
set_all_stremr_options(fit.package = "h2o", fit.algorithm = "glm")
# Use Deep Neural Nets:
set_all_stremr_options(fit.package = "h2o", fit.algorithm = "deeplearning")
}

# ----------------------------------------------------------------------
# Fitting different models with different algorithms
# Fine tuning modeling with optional tuning parameters.
# ----------------------------------------------------------------------
\dontrun{
params_TRT = list(fit.package = "h2o", fit.algorithm = "gbm", ntrees = 50,
    learn_rate = 0.05, sample_rate = 0.8, col_sample_rate = 0.8,
    balance_classes = TRUE)
params_CENS = list(fit.package = "speedglm", fit.algorithm = "glm")
params_MONITOR = list(fit.package = "speedglm", fit.algorithm = "glm")
OData <- fitPropensity(OData,
            gform_CENS = gform_CENS, stratify_CENS = stratify_CENS, params_CENS = params_CENS,
            gform_TRT = gform_TRT, params_TRT = params_TRT,
            gform_MONITOR = gform_MONITOR, params_MONITOR = params_MONITOR)
}

# ----------------------------------------------------------------------
# Running TMLE based on the previous fit of the propensity scores.
# Also applying Random Forest to estimate the sequential outcome model
# ----------------------------------------------------------------------
\dontrun{
t.surv <- c(0:5)
Qforms <- rep.int("Q.kplus1 ~ CVD + highA1c + N + lastNat1 + TI + TI.tminus1", (max(t.surv)+1))
params_Q = list(fit.package = "h2o", fit.algorithm = "randomForest",
                ntrees = 100, learn_rate = 0.05, sample_rate = 0.8,
                col_sample_rate = 0.8, balance_classes = TRUE)
tmle_est <- fitTMLE(OData, t_periods = t.surv, intervened_TRT = "TI.set1",
            Qforms = Qforms, params_Q = params_Q,
            stratifyQ_by_rule = TRUE)
}

\dontrun{
t.surv <- c(0:5)
Qforms <- rep.int("Q.kplus1 ~ CVD + highA1c + N + lastNat1 + TI + TI.tminus1", (max(t.surv)+1))
params_Q = list(fit.package = "h2o", fit.algorithm = "randomForest",
                ntrees = 100, learn_rate = 0.05, sample_rate = 0.8,
                col_sample_rate = 0.8, balance_classes = TRUE)
tmle_est <- fitTMLE(OData, t_periods = t.surv, intervened_TRT = "TI.set1",
            Qforms = Qforms, params_Q = params_Q,
            stratifyQ_by_rule = FALSE)
}

# ----------------------------------------------------------------------
# Ensemble Learning with SuperLearner (using h2oEnsemble R package):
# ----------------------------------------------------------------------
\dontrun{
require('h2oEnsemble')
# ----------------------------------------------------------------------
# Define many learners at once via grid search over tuning parameter spaces (hyperparameters)
# ----------------------------------------------------------------------
# 1. Runs h2o.grid in the background for each individual learner and saves cross-validated risks.
# 2. Calls h2o.stack from h2oEnsemble package to evaluate the final SuperLearner.
# ----------------------------------------------------------------------
# Search criteria and grid search space for glm tuning parameters:
GLM_hyper_params <- list(search_criteria = list(strategy = "RandomDiscrete", max_models = 5),
                         alpha = c(0,1,seq(0.1,0.9,0.1)),
                         lambda = c(0,1e-7,1e-5,1e-3,1e-1))

# Search criteria and grid search space for Random Forests:
search_criteria <- list(strategy = "RandomDiscrete", max_models = 5, max_runtime_secs = 60*60)
RF_hyper_params <- list(search_criteria = search_criteria,
                        ntrees = c(100, 200, 300, 500),
                        mtries = 1:4,
                        max_depth = c(5, 10, 15, 20, 25),
                        sample_rate = c(0.7, 0.8, 0.9, 1.0),
                        col_sample_rate_per_tree = c(0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8),
                        balance_classes = c(TRUE, FALSE))

# Search criteria and grid search space for Gradient Boosting Machines (gbm):
GBM_hyper_params <- list(search_criteria = search_criteria,
                         ntrees = c(100, 200, 300, 500),
                         learn_rate = c(0.005, 0.01, 0.03, 0.06),
                         max_depth = c(3, 4, 5, 6, 9),
                         sample_rate = c(0.7, 0.8, 0.9, 1.0),
                         col_sample_rate = c(0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8),
                         balance_classes = c(TRUE, FALSE))

# Specify individual learners to include in the SuperLearner library (in addition to grid models):
h2o.glm.1 <- function(..., alpha = 0.0) h2o.glm.wrapper(..., alpha = alpha)
h2o.glm.2 <- function(..., x = "highA1c", alpha = 0.0) h2o.glm.wrapper(..., x = x, alpha = alpha)
h2o.glm.3 <- function(..., alpha = 1.0) h2o.glm.wrapper(..., alpha = alpha)

# ----------------------------------------------------------------------
# Put all different algorithms together to define the final model ensemble:
# ----------------------------------------------------------------------
SLparams = list(fit.package = "h2o", fit.algorithm = "SuperLearner",
                 grid.algorithm = c("glm", "randomForest", "gbm"),
                 learner = c("h2o.glm.1", "h2o.glm.2", "h2o.glm.3"),
                 metalearner = "h2o.glm_nn",
                 nfolds = 10,
                 seed = 23,
                 glm = GLM_hyper_params,
                 randomForest = RF_hyper_params,
                 gbm = GBM_hyper_params)

# ----------------------------------------------------------------------
# Use save.ensemble and ensemble.dir.path arguments to save the final SuperLearner fit
# ----------------------------------------------------------------------
# This will also save the individual learner fits in the same directory.
# Only one directory per single SuperLearner fit is allowed.
# Can be loaded later on to save time and avoid refitting SuperLearner (load.ensemble = TRUE).

params_TRT = c(SLparams, save.ensemble = TRUE, ensemble.dir.path = "./h2o-ensemble-model-TRT")
params_CENS = list(fit.package = "speedglm", fit.algorithm = "glm")
params_MONITOR = list(fit.package = "speedglm", fit.algorithm = "glm")

OData <- fitPropensity(OData,
            gform_CENS = gform_CENS, stratify_CENS = stratify_CENS, params_CENS = params_CENS,
            gform_TRT = gform_TRT, params_TRT = params_TRT,
            gform_MONITOR = gform_MONITOR, params_MONITOR = params_MONITOR)

# ----------------------------------------------------------------------
# Re-using previously saved SuperLearner fit
# ----------------------------------------------------------------------
params_TRT = c(SLparams, load.ensemble = TRUE, ensemble.dir.path = "./h2o-ensemble-model-TRT")
params_CENS = list(fit.package = "speedglm", fit.algorithm = "glm")
params_MONITOR = list(fit.package = "speedglm", fit.algorithm = "glm")

OData <- fitPropensity(OData,
            gform_CENS = gform_CENS, stratify_CENS = stratify_CENS, params_CENS = params_CENS,
            gform_TRT = gform_TRT, params_TRT = params_TRT,
            gform_MONITOR = gform_MONITOR, params_MONITOR = params_MONITOR)
}
