#' @importFrom survival survfit Surv
#' @importFrom stats var cov pnorm
#' @import survival
#' @importFrom dplyr %>% select all_of rename left_join arrange

auc.func <- function(table, start.time, tau){
  if (!nrow(table)) {
    return(0)
  }
  if (tau <= start.time || max(table[,1], na.rm = TRUE) < start.time) {
    return(0)
  }
  loc2_candidates <- which(table[,1] <= tau)
  if (!length(loc2_candidates)) {
    return(0)
  }
  loc2 <- max(loc2_candidates)
  loc1_candidates <- which(table[,1] >= start.time)
  if (!length(loc1_candidates)) {
    return(0)
  }
  loc1 <- loc1_candidates[1]
  if (loc1 > loc2) {
    return(0)
  }
  length_int <- loc2 - loc1 + 1
  auc <- 0
  if(length_int == 1){
    delta <- max(0, tau - table[loc2,1])
    auc <- auc + table[loc2,2] * delta
  } else {
    for(i in 1:(length_int - 1)){
      delta <- table[loc1 + i, 1] - table[loc1 + i - 1, 1]
      if (delta > 0) {
        auc <- auc + table[loc1 + i - 1,2] * delta
      }
    }
    delta_tail <- max(0, tau - table[loc2,1])
    auc <- auc + table[loc2,2] * delta_tail
  }
  return(auc)
}

auc.var.joint <- function(data.w, tau, a, b, c){
  fit <- survival::survfit(Surv(etime, estatus) ~ 1, etype=etype2, data=data.w)
  n.risk <- fit$n.risk
  if (is.matrix(n.risk)) {
    n.risk <- n.risk[,1]
  }
  n.event <- fit$n.event
  if (!is.matrix(n.event)) {
    n.event <- cbind(n.event)
  }
  if (ncol(n.event) < 3) {
    n.event <- cbind(n.event, matrix(0, nrow(n.event), 3 - ncol(n.event)))
  }
  # Combine events 2 & 3 into 1st column => total events
  n.event[,1] <- rowSums(n.event[,-1, drop = FALSE])
  survival.matrix <- fit$pstate
  time <- fit$time
  idx <- which(time <= tau)
  var <- 0
  if (!length(idx)) {
    est <- a + b*auc.func(cbind(time, survival.matrix[,2]), min(time, na.rm = TRUE), tau) +
      c*auc.func(cbind(time, survival.matrix[,3]), min(time, na.rm = TRUE), tau)
    return(c(est, 0))
  }
  for(i in idx){
    if (n.risk[i] == 0) {
      next
    }
    auc.cif1 <- auc.func(cbind(time, survival.matrix[,2]), time[i], tau)
    cif1     <- survival.matrix[i,2]
    auc.cif2 <- auc.func(cbind(time, survival.matrix[,3]), time[i], tau)
    cif2     <- survival.matrix[i,3]
    surv     <- survival.matrix[i,1]

    var <- var + (
      b^2*auc.cif1^2 - 2*b^2*(tau - time[i])*cif1*auc.cif1 + b^2*(tau - time[i])^2*cif1^2 +
        2*b*c*auc.cif1*auc.cif2 - 2*b*c*(tau - time[i])*(cif2*auc.cif1 + cif1*auc.cif2) +
        2*b*c*(tau - time[i])^2*cif1*cif2 +
        c^2*auc.cif2^2 - 2*c^2*(tau - time[i])*cif2*auc.cif2 + c^2*(tau - time[i])^2*cif2^2
    ) * n.event[i,1]/n.risk[i]^2

    var <- var + (
      b^2*(tau - time[i])^2*(surv^2 + 2*surv*cif1) -
        2*b^2*(tau - time[i])*surv*auc.cif1 -
        2*b*c*(tau - time[i])*surv*auc.cif2 +
        2*b*c*(tau - time[i])^2*surv*cif2
    ) * n.event[i,2]/n.risk[i]^2

    var <- var + (
      c^2*(tau - time[i])^2*(surv^2 + 2*surv*cif2) -
        2*c^2*(tau - time[i])*surv*auc.cif2 -
        2*b*c*(tau - time[i])*surv*auc.cif1 +
        2*b*c*(tau - time[i])^2*surv*cif1
    ) * n.event[i,3]/n.risk[i]^2
  }
  est_start <- if (length(time)) min(time) else 0
  est <- a + b*auc.func(cbind(time, survival.matrix[,2]), est_start, tau) +
    c*auc.func(cbind(time, survival.matrix[,3]), est_start, tau)
  return(c(est, sqrt(var)))
}

table1_cif <- function(data1, data2, tau, a, b, c){
  diff1 <- auc.var.joint(data1, tau, a, b, c)
  diff2 <- auc.var.joint(data2, tau, a, b, c)

  line1 <- c(diff1[1], diff1[1] - 1.96*diff1[2], diff1[1] + 1.96*diff1[2], 0)
  line2 <- c(diff2[1], diff2[1] - 1.96*diff2[2], diff2[1] + 1.96*diff2[2], 0)
  psi_diff <- diff1[1] - diff2[1]
  psi_se   <- sqrt(diff1[2]^2 + diff2[2]^2)
  if (psi_se < .Machine$double.eps) {
    psi_lower <- psi_diff
    psi_upper <- psi_diff
    p_value <- 0
  } else {
    psi_lower <- psi_diff - 1.96*psi_se
    psi_upper <- psi_diff + 1.96*psi_se
    z_score <- psi_diff/psi_se
    p_value <- 2 * pnorm(-abs(z_score), 0, 1)
  }
  line3 <- c(psi_diff, psi_lower, psi_upper, p_value)
  out <- rbind(line1, line2, line3)
  rownames(out) <- c("Group1 (trt=1)", "Group2 (trt=0)", "Difference")
  colnames(out) <- c("Estimate", "Lower95", "Upper95", "p-value")
  return(out)
}
