# Longitudinal attribution function
# This function is currently designed only for Copula2 models with semiparametric additive hazard margins.


#  estimates the total disability prevalence and cause-specific disability prevalence at age (i.e., time) t0 among survivors.
Longitudinal_disability_attribution <- function(object, type.attrib="both", t0=0){
  data <- object$data
  time_list <- grep("^time_", names(data), value = TRUE)
  u <-  as.numeric(object$u)
  p <-   as.numeric(object$estimates)
  p1 <- as.numeric(object$p1)
  p2 <-  as.numeric(object$p2)
  m1 <-  as.numeric(object$m1)
  m2 <-  as.numeric(object$m2)
  alpha <- exp(p[1])/(1+exp(p[1]))
  kappa <- exp(p[2])
  phi <- p[3:(m1+3)]
  Coeff <- p[(m1+4):(m1+3+p1)]
  dead_phi <- p[(m1+4+p1):(m1+4+p1+m2)]
  dead_Coeff <- p[(m1+4+p1+m2+1):(m1+4+p1+m2+p2)]
  wgt <-  if(is.null(data$weight)){rep(1,nrow(data))}else{as.numeric(data$weight)}
  var_list <- object$var_list

  # Data Preparation
  # Filter individuals who are still alive at age (i.e., time) t0
  wgt <- wgt[which(data$timeD > t0)]
  data <- data[which(data$timeD > t0),]
  time <- rep(0,nrow(data))
  time <- cbind(time,data[,time_list])
  time <- t(apply(time, 1, sort))
  ep2 <- cumsum(exp(phi))
  dead_ep2 <- cumsum(exp(dead_phi))

  # Calculate midpoint intervals A
  A <-  matrix(ncol = ncol(time)+1, nrow = nrow(time))
  A[, 1] <- time[, 1]
  for(i in 2:(ncol(A)-1)) {
    A[, i] <- round((time[, i] + time[, i - 1])/2, 6)
  }
  A[, ncol(A)] <- Inf

  #  Recalculate changepoints-
  update_changepoints <- function(data, time, var_list) {
    changepoint_vars <- paste(var_list,"_change",sep="")
    for (var in changepoint_vars) {
      for (i in 1:nrow(data)) {
        changepoint_time <- data[i, var]
        if(changepoint_time!=0){
          time_index <- which(time[i,] > changepoint_time)

          if (length(time_index) == 0) {
            data[i, var] <- Inf
          } else {
            new_changepoint_time <- round((time[i, time_index[1]-1] + time[i, time_index[1]]) / 2, 6)
            data[i, var] <- new_changepoint_time
          }
        }
      }
    }

    return(data)
  }

  data <- update_changepoints(data, time, var_list)


  # Find the interval where t0 falls (tau)
  t=rep(t0,nrow(data))
  tau <- rep(0,nrow(data))
  for (i in 1:nrow(A)) {
    for (j in 1:(ncol(A)-1)) {
      if (A[i, j] <= t[i] && t[i] < A[i, (j+1)]) {
        tau[i] <- j
        break
      }
    }
    if (t[i] >= A[i, (ncol(A)-1)]) {
      tau[i] <- ncol(A) - 1
    }
  }

  #Calculate interval lengths
  CA <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
  for (i in 1:nrow(data)) {
    for (j in 1:(ncol(A) - 1)) {
      CA[i, j] <- A[i, j + 1] - A[i, j]
    }
  }

  #  Compute time-varying covariate values per interval
  compute_zz <- function(data, A, variable) {
    zz <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
    for (i in 1:nrow(data)) {
      if (is.infinite(data[[paste(variable,"_change",sep="")]][i])) {
        zz[i, ] <- rep(0, ncol(A) - 1)
      }else{
        k <- as.numeric(which(A[i, ] == round(data[[paste(variable,"_change",sep="")]][i], 6))[1]) - 1
        if(k == 0) {
          zz[i, ] <- rep(1, ncol(A) - 1)
        }
        else if(k == ncol(A)-1) {
          zz[i, ] <- rep(0, ncol(A) - 1)
        }
        else {
          zz[i, 1:k] <- 0
          zz[i, (k + 1):(ncol(A) - 1)] <- 1
        }
      }}
    return(zz)
  }


  zz_list <- list()
  for (var_name in var_list) {
    name <- paste("zz_",which(var_name == var_list),sep="")
    zz_list[[name]] <- compute_zz(data, A, var_name)
  }



  #  Compute cumulative hazards for time-varying covariate  per interval
  compute_ZZ <- function(tau, t, A, CA, zz_matrix) {
    nrow_data <- nrow(A)
    ncol_data <- ncol(A) - 1

    ZZ <- matrix(NA, nrow = nrow_data, ncol = ncol_data)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if(tau[i] == 1) {
          ZZ[i, j] = zz_matrix[i, j] * (t[i] - A[i, j])
        } else if(j < tau[i]) {
          ZZ[i, j] = sum(zz_matrix[i, 1:j] * CA[i, 1:j])
        } else if(j == tau[i]) {
          ZZ[i, j] = ZZ[i, j-1] + zz_matrix[i, j] * (t[i] - A[i, j])
        } else {
          ZZ[i, j] = 0
        }
      }
    }
    return(ZZ)
  }


  ZZ_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_list[[name]] <- compute_ZZ(tau, t, A, CA, zz_matrix)
  }



  ZZ_dead_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_dead_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_dead_list[[name]] <- compute_ZZ(tau, t, A, CA, zz_matrix)
  }


  #  Compute cumulative baseline hazards for disability
  compute_LAMBDA <- function(tau, t, A, ep2, u) {
    nrow_data <- nrow(A)
    ncol_data <- ncol(A)
    LAMBDA <- matrix(NA, nrow = nrow_data, ncol = ncol_data)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if (j == 1) {
          LAMBDA[i,j] <- sum(sapply(0:m1, function(x) bern(x, m1, 0, u, 0)) * ep2)}
        else if((j-1) < tau[i]){ LAMBDA[i,j] <- sum(sapply(0:m1, function(x) bern(x, m1, 0, u, A[i,j])) * ep2)}
        else if((j-1) == tau[i]){ LAMBDA[i,j] <- sum(sapply(0:m1, function(x) bern(x, m1, 0, u, t[i])) * ep2)}
        else{LAMBDA[i,j]=0}
      }
    }
    return(LAMBDA)
  }
  LAMBDA <- compute_LAMBDA(tau, t, A, ep2, u)


  #  Compute cumulative baseline hazards for death
  compute_dead_LAMBDA <- function(t, dead_ep2, u) {
    nrow_data <- length(t)
    dead_LAMBDA <- sapply(t, function(t_i) sum(sapply(0:m2, function(x) bern(x, m2, 0, u, t_i)) * dead_ep2))
    return(matrix(dead_LAMBDA, ncol = 1))
  }
  dead_LAMBDA <- compute_dead_LAMBDA(t, dead_ep2, u)



  #  Compute survival probability S2(t)=P(Ti2>t)
  compute_dead_Q <- function(dead_LAMBDA, ZZ_dead_list, tau, dead_Coeff) {
    dead_Q <- matrix(NA, nrow = length(dead_LAMBDA), ncol = 1)
    for (i in 1:nrow(dead_Q)) {
      ZZ_dead_values <- sapply(ZZ_dead_list, function(ZZ_dead) ZZ_dead[i, tau[i]])
      dead_Q[i] <- exp(-dead_LAMBDA[i] - sum(ZZ_dead_values * dead_Coeff))
    }
    return(dead_Q)
  }
  dead_Q <- compute_dead_Q(dead_LAMBDA, ZZ_dead_list, tau, dead_Coeff)


  #  Compute conditional disability probabilities P(Ti1>A_k,Ti2>t)/P(Ti2>t)-P(Ti1>A_k+1,Ti2>t)/P(Ti2>t)
  compute_Q <- function(LAMBDA, ZZ_list, dead_Q, tau, Coeff, C_Q) {
    nrow_data <- nrow(dead_Q)
    ncol_data <- ncol(LAMBDA) - 1
    Q <- matrix(NA, nrow = nrow_data, ncol = ncol_data + 1)

    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if (j == 1) {
          S1 <- exp(-LAMBDA[i, j + 1] - sum(sapply(ZZ_list, function(ZZ) ZZ[i, 1]) * Coeff))
          S2 <- dead_Q[i]
          Q[i, j] <- (S2 - C_Q(S1, S2,kappa, alpha)) / S2
        } else if (j <= tau[i]) {
          S0 <- exp(-LAMBDA[i, j] - sum(sapply(ZZ_list, function(ZZ) ZZ[i, j-1]) * Coeff))
          S11 <- exp(-LAMBDA[i, j + 1] - sum(sapply(ZZ_list, function(ZZ) ZZ[i, j]) * Coeff))
          S2 <- dead_Q[i]
          Q[i, j] <- (C_Q(S0, S2,kappa,alpha) - C_Q(S11, S2,kappa,alpha)) / S2
        } else {
          Q[i, j] <- 0
        }
      }
      Q[i, (ncol_data + 1)] <- sum(Q[i, 1:ncol_data])
    }

    return(Q)
  }

  Q <- compute_Q(LAMBDA, ZZ_list, dead_Q, tau, Coeff, C_Q)


  #  Compute total disability prevalence
  P=sum(Q[,ncol(Q)]* wgt)
  # Compute interval-specific hazard functions for disability
  llambda <- matrix(NA,nrow=nrow(data),ncol=ncol(A))
  for (i in 1:nrow(data)){
    for (j in 1:ncol(A)){
      if(j == 1){llambda[i,j]=sum(sapply(0:m1, function(x) bern_derivative(x, m1, 0, u, 0)) * ep2)}
      if((j-1) < tau[i]){llambda[i,j]=sum(sapply(0:m1, function(x) bern_derivative(x, m1, 0, u, A[i,j])) * ep2)}
      else if((j-1) == tau[i]){llambda[i,j]=sum(sapply(0:m1, function(x) bern_derivative(x, m1, 0, u, t[i])) * ep2)}
      else{llambda[i,j]=0}
    }
  }

  #  Compute age-cause-specific disability hazard proportion
  n_risk_factors <- length(Coeff)
  contri_list <- list()

  #  backgroud
  contri_0 <- matrix(NA, nrow = nrow(data), ncol = ncol(CA))
  for (i in 1:nrow(data)) {
    for (j in 1:ncol(CA)) {
      if(j <= tau[i]) {
        total_risk <- llambda[i, (j+1)]
        for (l in 1:n_risk_factors) {
          zz_l <- zz_list[[l]]
          total_risk <- total_risk + Coeff[l] * zz_l[i, j]
        }
        contri_0[i, j] <- llambda[i, (j+1)] / total_risk
      } else {
        contri_0[i, j] = 0
      }
    }
  }
  contri_list[["contri_0"]] <- contri_0

  # cause
  for (k in 1:n_risk_factors) {
    contri_k <- matrix(NA, nrow = nrow(data), ncol = ncol(CA))
    zz_k <- zz_list[[k]]

    for (i in 1:nrow(data)) {
      for (j in 1:ncol(CA)) {
        if(j <= tau[i]) {
          total_risk <- llambda[i, (j+1)]
          for (l in 1:n_risk_factors) {
            zz_l <- zz_list[[l]]
            total_risk <- total_risk + Coeff[l] * zz_l[i, j]
          }
          contri_k[i, j] <- (Coeff[k] * zz_k[i, j]) / total_risk
        } else {
          contri_k[i, j] = 0
        }
      }
    }
    contri_list[[paste("contri_",k,sep="")]] <- contri_k
  }

  #  Compute cause-specific conditional disability probabilities
  compute_Qi <- function(contri_list, Q) {
    Qi_list <- list()

    for (k in 1:length(contri_list)) {
      contri <- contri_list[[k]]
      Q_i <- matrix(NA, nrow = nrow(Q), ncol = ncol(Q))

      for (i in 1:nrow(Q)) {
        for (j in 1:ncol(CA)) {
          if (j <= tau[i]) {
            Q_i[i, j] <- contri[i, j] * Q[i, j]
          } else {
            Q_i[i, j] = 0
          }
        }
        Q_i[i, ncol(Q)] <- sum(Q_i[i, 1:ncol(CA)])
      }
      Qi_list[[paste("Q_",(k-1),sep="")]] <- Q_i
    }
    return(Qi_list)
  }

  Qi_list <- compute_Qi(contri_list, Q)


  # Compute cause-specific disability prevalence
  D_list <- sapply(names(Qi_list), function(name) {
    sum(Qi_list[[name]][, ncol(Qi_list[[name]])] * wgt)
  })
  D_total <- sum(D_list)


  #   Compute attribution metrics
  # Relative contributions
  att_final_rel <- D_list / P
  new_names <- c("Relative_Background",paste("Relative_",var_list,sep=""))
  att_final_rel <- setNames(att_final_rel, new_names)

  # Absolute contributions
  att_final_abs <- c(P/sum(wgt), D_list / sum(wgt))
  new_names <- c("Total","Absolute_Background",paste("Absolute_",var_list,sep=""))
  att_final_abs <- setNames(att_final_abs, new_names)

  # Return results
  att_final <- switch(type.attrib,
                      "rel" = att_final_rel,
                      "abs" = att_final_abs,
                      "both" = list(att.rel = att_final_rel, att.abs = att_final_abs))

  return(list(att_final = att_final, Total_Cause_People = P, Total_People = sum(wgt)))
}



# Estimates the total disability prevalence and cause-specific disability prevalence within the age interval [t0, t0+1) among survivors.
Longitudinal_disability_attribution_pai <- function(object, type.attrib="both", t0=0){
  pai0 <- matrix(0,length(object$var_list)+2,10)
  S20 <- rep(0,10)
  for (j in c(1:10)){
    m=j/10
    S20[j] <- Survival(object,type.attrib="both", t0=(t0+m))
    hh <- Longitudinal_disability_attribution(object, type.attrib="both", t0=(t0+m))
    pai0[,j] <-hh$att_final$att.abs * S20[j]
  }
  # Aggregate over all 10 subintervals by survival-weighted average
  pai = rowSums(pai0)/sum(S20)

  # Relative attribution
  var_list <- object$var_list
  att_final_rel <- pai[2:length(pai)] / pai[1]
  new_names <- c("Relative_Background",paste("Relative_",var_list,sep=""))
  att_final_rel <- setNames(att_final_rel, new_names)

  # Absolute attribution
  att_final_abs <- pai
  new_names <- c("Total","Absolute_Background",paste("Absolute_",var_list,sep=""))
  att_final_abs <- setNames(att_final_abs, new_names)

  # Return results
  att_final <- switch(type.attrib,
                      "rel" = att_final_rel,
                      "abs" = att_final_abs,
                      "both" = list(att.rel = att_final_rel, att.abs = att_final_abs))

  return(list(att_final = att_final))
}



# Estimates the total disability prevalence and cause-specific death probability within the age interval [t0, t0+1).
Longitudinal_death_attribution <-  function (object, type.attrib="both", t0=0){
  data <- object$data
  time_list <- grep("^time_", names(data), value = TRUE)
  u <- object$u
  p <-  object$estimates
  p1 <- object$p1
  p2 <- object$p2
  m1 <- object$m1
  m2 <- object$m2
  alpha <- exp(p[1])/(1+exp(p[1]))
  kappa <- exp(p[2])
  phi <- p[3:(m1+3)]
  Coeff <- p[(m1+4):(m1+3+p1)]
  dead_phi <- p[(m1+4+p1):(m1+4+p1+m2)]
  dead_Coeff <- p[(m1+4+p1+m2+1):(m1+4+p1+m2+p2)]
  wgt <-  as.numeric(data$weight)
  var_list <- object$var_list


  #  Data preparation
  # Keep only individuals who died after age (i.e., time) t0
  wgt <- wgt[which(data$timeD > t0)]
  data <- data[which(data$timeD > t0),]
  time <- rep(0,nrow(data))
  time <- cbind(time,data[,time_list])
  time <- t(apply(time, 1, sort))
  dead_ep2 <- cumsum(exp(dead_phi))

  # Calculate midpoint intervals A
  A <-  matrix(ncol = ncol(time)+1, nrow = nrow(time))
  A[, 1] <- time[, 1]
  for(i in 2:(ncol(A)-1)) {
    A[, i] <- round((time[, i] + time[, i - 1])/2, 6)
  }
  A[, ncol(A)] <- Inf

  #  Recalculate changepoints
  update_changepoints <- function(data, time, var_list) {
    changepoint_vars <- paste(var_list,"_change",sep="")
    for (var in changepoint_vars) {
      for (i in 1:nrow(data)) {
        changepoint_time <- data[i, var]
        if(changepoint_time!=0){
          time_index <- which(time[i,] > changepoint_time)
          if (length(time_index) == 0) {
            data[i, var] <- Inf
          } else {
            new_changepoint_time <- round((time[i, time_index[1]-1] + time[i, time_index[1]]) / 2, 6)
            data[i, var] <- new_changepoint_time
          }
        }
      }
    }

    return(data)
  }

  data <- update_changepoints(data, time, var_list)

  # Find the interval where t0 and t0+1 falls (tau_0, tau_1)
  # tau_0
  t=rep(t0,nrow(data))
  tau_0 <- rep(0,nrow(data))
  for (i in 1:nrow(A)) {
    for (j in 1:(ncol(A)-1)) {
      if (A[i, j] <= t[i] && t[i] < A[i, (j+1)]) {
        tau_0[i] <- j
        break
      }
    }
    if (t[i] >= A[i, (ncol(A)-1)]) {
      tau_0[i] <- ncol(A) - 1
    }
  }


  # tau_1
  tau_1 <- rep(0,nrow(data))
  for (i in 1:nrow(A)) {
    for (j in 1:(ncol(A)-1)) {
      if (A[i, j] <= (t[i]+1) && (t[i]+1) < A[i, j+1]) {
        tau_1[i] <- j
        break
      }
    }
    if ((t[i]+1) >= A[i, ncol(A)-1]) {
      tau_1[i] <- ncol(A) - 1
    }
  }

  # Calculate interval lengths
  CA <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
  for (i in 1:nrow(data)) {
    for (j in 1:(ncol(A) - 1)) {
      CA[i, j] <- A[i, j + 1] - A[i, j]
    }
  }


  #  Compute time-varying covariate values per interval
  compute_zz <- function(data, A, variable) {
    zz <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
    for (i in 1:nrow(data)) {
      if (is.infinite(data[[paste(variable,"_change",sep="")]][i])) {
        zz[i, ] <- rep(0, ncol(A) - 1)
      }else{
        k <- as.numeric(which(A[i, ] == round(data[[paste(variable,"_change",sep="")]][i], 6))[1]) - 1
        if(k == 0) {
          zz[i, ] <- rep(1, ncol(A) - 1)
        }
        else if(k == ncol(A)-1) {
          zz[i, ] <- rep(0, ncol(A) - 1)
        }
        else {
          zz[i, 1:k] <- 0
          zz[i, (k + 1):(ncol(A) - 1)] <- 1
        }
      }}
    return(zz)
  }

  zz_list <- list()
  for (var_name in var_list) {
    name <- paste("zz_",which(var_name == var_list),sep="")
    zz_list[[name]] <- compute_zz(data, A, var_name)
  }


  #  Compute cumulative hazards for time-varying covariate  per interval
  compute_A_ZZ <- function(A, CA, zz_matrix) {
    nrow_data <- nrow(A)
    ncol_data <- ncol(A) - 1

    ZZ <- matrix(NA, nrow = nrow_data, ncol = ncol_data)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if(j==1){ ZZ[i, j] = zz_matrix[i, j] * CA[i, j]}
        else{
          ZZ[i, j] = ZZ[i, j-1] + zz_matrix[i, j] * CA[i, j]
        }
      }}
    return(ZZ)
  }


  ZZ_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_list[[name]] <- compute_A_ZZ(A, CA, zz_matrix)
  }

  # for age t
  compute_ZZ <- function(tau, t, A, CA, zz_matrix) {
    nrow_data <- nrow(A)
    ncol_data <- ncol(A) - 1

    ZZ <- matrix(NA, nrow = nrow_data, ncol = ncol_data)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if(tau[i] == 1) {
          ZZ[i, j] = zz_matrix[i, j] * (t[i] - A[i, j])
        } else if(j < tau[i]) {
          ZZ[i, j] = sum(zz_matrix[i, 1:j] * CA[i, 1:j])
        } else if(j == tau[i]) {
          ZZ[i, j] = ZZ[i, j-1] + zz_matrix[i, j] * (t[i] - A[i, j])
        } else {
          ZZ[i, j] = 0
        }
      }
    }
    return(ZZ)
  }


  ZZ_dead_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_dead_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_dead_list[[name]] <- compute_ZZ(tau_0, t, A, CA, zz_matrix)
  }


  #  for age t+1
  ZZ_dead2_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_dead2_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_dead2_list[[name]] <- compute_ZZ(tau_1, t+1, A, CA, zz_matrix)
  }

  #  Compute cumulative baseline hazards for death
  compute_dead_LAMBDA <- function(t, dead_ep2, u) {
    nrow_data <- length(t)
    dead_LAMBDA <- sapply(t, function(t_i) sum(sapply(0:m2, function(x) bern(x, m2, 0, u, t_i)) * dead_ep2))
    return(matrix(dead_LAMBDA, ncol = 1))
  }

  #  Compute death probability S2(t)=P(Ti2>A_k)/P(Ti2>t)-P(Ti2>A_k+1)/P(Ti2>t)

  compute_dead_Q <- function(ZZ_list, ZZ_dead_list, ZZ_dead2_list, tau_0,tau_1, dead_Coeff) {
    nrow_data <- nrow(data)
    ncol_data <- ncol(A) - 1
    Q <- matrix(NA, nrow = nrow_data, ncol = ncol_data + 1)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        S_t <- exp(-compute_dead_LAMBDA(t[i],dead_ep2, u)-sum(sapply(ZZ_dead_list, function(ZZ) ZZ[i,tau_0[i]]) * dead_Coeff))
        if(j < tau_0[i]) {
          Q[i, j] = 0
        } else if(j==tau_0[i] & j==tau_1[i]){
          S_1 <- exp(-compute_dead_LAMBDA(t[i],dead_ep2, u)-sum(sapply(ZZ_dead_list, function(ZZ) ZZ[i,tau_0[i]]) * dead_Coeff))
          S_2 <- exp(-compute_dead_LAMBDA((t+1)[i],dead_ep2, u)-sum(sapply(ZZ_dead2_list, function(ZZ) ZZ[i,tau_1[i]]) * dead_Coeff))
          Q[i, j] = (S_1 - S_2)/S_t
        }else if(j == tau_0[i]) {
          S_1 <- exp(-compute_dead_LAMBDA(t[i],dead_ep2, u)-sum(sapply(ZZ_dead_list, function(ZZ) ZZ[i,tau_0[i]]) * dead_Coeff))
          S_2 <- exp(-compute_dead_LAMBDA(A[i,j+1],dead_ep2, u)-sum(sapply(ZZ_list, function(ZZ) ZZ[i,j]) * dead_Coeff))
          Q[i, j] = (S_1 - S_2)/S_t
        } else if(tau_0[i]<j & j<tau_1[i]) {
          S_1 <- exp(-compute_dead_LAMBDA(A[i,j],dead_ep2, u)-sum(sapply(ZZ_list, function(ZZ) ZZ[i,j-1]) * dead_Coeff))
          S_2 <- exp(-compute_dead_LAMBDA(A[i,j+1],dead_ep2, u)-sum(sapply(ZZ_list, function(ZZ) ZZ[i,j]) * dead_Coeff))
          Q[i, j] = (S_1 - S_2)/S_t
        } else if(j==tau_1[i]){
          S_1 <- exp(-compute_dead_LAMBDA(A[i,j],dead_ep2, u)-sum(sapply(ZZ_list, function(ZZ) ZZ[i,j-1]) * dead_Coeff))
          S_2 <- exp(-compute_dead_LAMBDA((t+1)[i],dead_ep2, u)-sum(sapply(ZZ_dead2_list, function(ZZ) ZZ[i,tau_1[i]]) * dead_Coeff))
          Q[i, j] = (S_1 - S_2)/S_t
        } else if(j>tau_1[i]){
          Q[i, j] = 0
        }
      }
      Q[i, (ncol_data + 1)] <- sum(Q[i, 1:ncol_data])
    }
    return(Q)
  }
  dead_Q <- compute_dead_Q(ZZ_list, ZZ_dead_list, ZZ_dead2_list, tau_0,tau_1, dead_Coeff)

  Q_t <-sum(dead_Q[,ncol(dead_Q)]* wgt)

  # Compute interval-specific hazard functions for death
  llambda <- matrix(NA,nrow=nrow(data),ncol=ncol(A))
  for (i in 1:nrow(data)){
    for (j in 1:ncol(A)){
      if(j == 1){llambda[i,j]=sum(sapply(0:m2, function(x) bern_derivative(x, m2, 0, u, 0)) * dead_ep2)}
      if((j-1) < tau_1[i]){llambda[i,j]=sum(sapply(0:m2, function(x) bern_derivative(x, m2, 0, u, A[i,j])) * dead_ep2)}
      else if((j-1) == tau_1[i]){llambda[i,j]=sum(sapply(0:m2, function(x) bern_derivative(x, m2, 0, u, (t[i]+1))) * dead_ep2)}
      else{llambda[i,j]=0}
    }
  }


  #  Compute age-cause-specific death hazard proportion
  n_risk_factors <- length(dead_Coeff)
  contri_list <- list()
  #  backgroud
  contri_0 <- matrix(NA, nrow = nrow(data), ncol = ncol(CA))
  for (i in 1:nrow(data)) {
    for (j in 1:ncol(CA)) {
      if(j <= tau_1[i]) {
        total_risk <- llambda[i, (j+1)]
        for (l in 1:n_risk_factors) {
          zz_l <- zz_list[[l]]
          total_risk <- total_risk + dead_Coeff[l] * zz_l[i, j]
        }
        contri_0[i, j] <- llambda[i, (j+1)] / total_risk
      } else {
        contri_0[i, j] = 0
      }
    }
  }
  contri_list[["contri_0"]] <- contri_0

  # cause
  for (k in 1:n_risk_factors) {
    contri_k <- matrix(NA, nrow = nrow(data), ncol = ncol(CA))
    zz_k <- zz_list[[k]]

    for (i in 1:nrow(data)) {
      for (j in 1:ncol(CA)) {
        if(j <= tau_1[i]) {
          total_risk <- llambda[i, (j+1)]
          for (l in 1:n_risk_factors) {
            zz_l <- zz_list[[l]]
            total_risk <- total_risk + dead_Coeff[l] * zz_l[i, j]
          }
          contri_k[i, j] <- (dead_Coeff[k] * zz_k[i, j]) / total_risk
        } else {
          contri_k[i, j] = 0
        }
      }
    }
    contri_list[[paste("contri_",k,sep="")]] <- contri_k
  }


  #  Compute cause-specific death probabilities
  compute_Qi <- function(contri_list, Q) {
    Qi_list <- list()

    for (k in 1:length(contri_list)) {
      contri <- contri_list[[k]]
      Q_i <- matrix(NA, nrow = nrow(Q), ncol = ncol(Q))

      for (i in 1:nrow(Q)) {
        for (j in 1:ncol(CA)) {
          if(j<tau_0[i]){
            Q_i[i, j] = 0
          } else if (j <= tau_1[i]) {
            Q_i[i, j] <- contri[i, j] * Q[i, j]
          } else {
            Q_i[i, j] = 0
          }
        }
        Q_i[i, ncol(Q)] <- sum(Q_i[i, 1:ncol(CA)])
      }
      Qi_list[[paste("Q_",(k-1),sep="")]] <- Q_i
    }
    return(Qi_list)
  }

  Qi_list <- compute_Qi(contri_list, dead_Q)


  # Compute cause-specific death probability
  D_list <- sapply(names(Qi_list), function(name) {
    sum(Qi_list[[name]][, ncol(Qi_list[[name]])] * wgt)
  })
  D_total <- sum(D_list)

  #   Compute attribution metrics
  # Relative contributions
  att_final_rel <- D_list / Q_t
  new_names <- c("Relative_Background",paste("Relative_",var_list,sep=""))
  att_final_rel <- setNames(att_final_rel, new_names)

  # Absolute contributions
  att_final_abs <- c(Q_t/sum(wgt), D_list / sum(wgt))
  new_names <- c("Total","Absolute_Background",paste("Absolute_",var_list,sep=""))
  att_final_abs <- setNames(att_final_abs, new_names)

  # Return results
  att_final <- switch(type.attrib,
                      "rel" = att_final_rel,
                      "abs" = att_final_abs,
                      "both" = list(att.rel = att_final_rel, att.abs = att_final_abs))

  return(list(att_final = att_final, Total_Cause_People =Q_t, Total_People = sum(wgt)))
}



### Estimates the survival probability S2(t) for death
Survival <-  function (object, type.attrib="both", t0=0){
  data <- object$data
  time_list <- grep("^time_", names(data), value = TRUE)
  u <- object$u
  p <-  object$estimates
  p1 <- object$p1
  p2 <- object$p2
  m1 <- object$m1
  m2 <- object$m2
  alpha <- exp(p[1])/(1+exp(p[1]))
  kappa <- exp(p[2])
  phi <- p[3:(m1+3)]
  Coeff <- p[(m1+4):(m1+3+p1)]
  dead_phi <- p[(m1+4+p1):(m1+4+p1+m2)]
  dead_Coeff <- p[(m1+4+p1+m2+1):(m1+4+p1+m2+p2)]

  wgt <-  if(is.null(data$weight)){rep(1,nrow(data))}else{as.numeric(data$weight)}
  var_list <- object$var_list

  #   Data preparation
  time <- rep(0,nrow(data))
  time <- cbind(time,data[,time_list])
  time <- t(apply(time, 1, sort))
  dead_ep2 <- cumsum(exp(dead_phi))

  # Calculate midpoint intervals A
  A <-  matrix(ncol = ncol(time)+1, nrow = nrow(time))
  A[, 1] <- time[, 1]
  for(i in 2:(ncol(A)-1)) {
    A[, i] <- round((time[, i] + time[, i - 1])/2, 6)
  }
  A[, ncol(A)] <- Inf

  #  Recalculate changepoints
  update_changepoints <- function(data, time, var_list) {
    changepoint_vars <- paste(var_list,"_change",sep="")
    for (var in changepoint_vars) {
      for (i in 1:nrow(data)) {
        changepoint_time <- data[i, var]
        if(changepoint_time!=0){
          time_index <- which(time[i,] > changepoint_time)
          if (length(time_index) == 0) {
            data[i, var] <- Inf
          } else {
            new_changepoint_time <- round((time[i, time_index[1]-1] + time[i, time_index[1]]) / 2, 6)
            data[i, var] <- new_changepoint_time
          }
        }
      }
    }
    return(data)
  }

  data <- update_changepoints(data, time, var_list)

  # Find the interval where t0 falls (tau)
  t=rep(t0,nrow(data))
  tau_0 <- rep(0,nrow(data))

  for (i in 1:nrow(A)) {
     for (j in 1:(ncol(A)-1)) {
      if (A[i, j] <= t[i] && t[i] < A[i, (j+1)]) {
        tau_0[i] <- j
        break
      }
    }
    if (t[i] >= A[i, (ncol(A)-1)]) {
      tau_0[i] <- ncol(A) - 1
    }
  }


  #Calculate interval lengths
  CA <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
  for (i in 1:nrow(data)) {
    for (j in 1:(ncol(A) - 1)) {
      CA[i, j] <- A[i, j + 1] - A[i, j]
    }
  }


  #  Compute time-varying covariate values per interval
  compute_zz <- function(data, A, variable) {
    zz <- matrix(NA, nrow = nrow(data), ncol = ncol(A) - 1)
    for (i in 1:nrow(data)) {
      if (is.infinite(data[[paste(variable,"_change",sep="")]][i])) {
        zz[i, ] <- rep(0, ncol(A) - 1)
      }else{
        k <- as.numeric(which(A[i, ] == round(data[[paste(variable,"_change",sep="")]][i], 6))[1]) - 1
        if(k == 0) {
          zz[i, ] <- rep(1, ncol(A) - 1)
        }
        else if(k == ncol(A)-1) {
          zz[i, ] <- rep(0, ncol(A) - 1)
        }
        else {
          zz[i, 1:k] <- 0
          zz[i, (k + 1):(ncol(A) - 1)] <- 1
        }
      }}
    return(zz)
  }


  zz_list <- list()
  for (var_name in var_list) {
    name <- paste("zz_",which(var_name == var_list),sep="")
    zz_list[[name]] <- compute_zz(data, A, var_name)
  }

  #  Compute cumulative hazards for time-varying covariate  per interval
  compute_ZZ <- function(tau, tt, A, CA, zz_matrix) {
    nrow_data <- nrow(A)
    ncol_data <- ncol(A) - 1

    ZZ <- matrix(NA, nrow = nrow_data, ncol = ncol_data)
    for (i in 1:nrow_data) {
      for (j in 1:ncol_data) {
        if(tau[i] == 1) {
          ZZ[i, j] = zz_matrix[i, j] * (tt[i] - A[i, j])
        } else if(j < tau[i]) {
          ZZ[i, j] = sum(zz_matrix[i, 1:j] * CA[i, 1:j])
        } else if(j == tau[i]) {
          ZZ[i, j] = ZZ[i, j-1] + zz_matrix[i, j] * (tt[i] - A[i, j])
        } else {
          ZZ[i, j] = 0
        }
      }
    }
    return(ZZ)
  }


  ZZ_dead_list <- list()
  for (var_name in var_list) {
    name <- paste("ZZ_dead_",which(var_name == var_list),sep="")
    zz_matrix <- zz_list[[paste("zz_",which(var_name == var_list),sep="")]]
    ZZ_dead_list[[name]] <- compute_ZZ(tau_0, t, A, CA, zz_matrix)
  }

  #  Compute cumulative baseline hazards for death
  compute_dead_LAMBDA <- function(t, dead_ep2, u) {
    nrow_data <- length(t)
    dead_LAMBDA <- sapply(t, function(t_i) sum(sapply(0:m2, function(x) bern(x, m2, 0, u, t_i)) * dead_ep2))
    return(matrix(dead_LAMBDA, ncol = 1))
  }

  #  Compute survival probability S2(t)=P(Ti2>t)
  S_2 <- matrix(NA,nrow=nrow(data),ncol=1)
  for (i in 1:nrow(data)){
    S_2[i] <- exp(-compute_dead_LAMBDA(t[i],dead_ep2, u)-sum(sapply(ZZ_dead_list, function(ZZ) ZZ[i,tau_0[i]]) * dead_Coeff))
  }
  Survival <-sum(S_2* wgt)/sum(wgt)
  return(Survival)
}





