
#' @title Compute optimal transport barycenters
#' @description Compute the optimal transport (OT) barycenter of multiple probability vectors via linear programming.
#' @param mu matrix (row-wise) or list containing \eqn{K} probability vectors of length \eqn{N}.
#' @param w weight vector \eqn{w \in \mathbb{R}_+^K}. The default is \eqn{w = (1 / K, \ldots, 1 / K)}.
#' @param costm cost matrix \eqn{c \in \mathbb{R}^{N \times N}}.
#' @param solver the LP solver to use, see [`ot_test_lp_solver`].
#' @param constr_mat the constraint matrix for the underlying LP.
#' @returns A list containing the entries `cost` and `barycenter`.
#' @details The OT barycenter is defined as the minimizer of the cost functional,
#' \deqn{
#' B_c^w(\mu^1, \ldots, \mu^K) := \min_{\nu} \sum_{k=1}^k w_k \, \mathrm{OT}_c(\mu^k, \nu)\,,
#' }
#' where the minimum is taken over all probability vectors \eqn{\nu}.
#' The OT barycenter is solved via linear programming (LP) and the underlying solver can be controlled
#' via the parameter `solver`.
#' @example examples/bary_ot.R
#' @export
ot_barycenter <- \(mu, costm, w = NULL, solver = ot_test_lp_solver(), constr_mat = NULL) {
    if (is.list(mu)) {
        mu <- do.call(rbind, mu)
    }
    check_mu(mu)
    K <- nrow(mu)
    N <- ncol(mu)
    if (is.null(w)) {
        w <- rep(1 / K, K)
    }
    check_w(w, K)
    check_cost_mat(costm, N)
    check_lp_solver(solver)

    tcostm <- do.call(c, lapply(w, `*`, c_byrow(costm)))
    if (is.null(constr_mat)) {
        constr_mat <- ot_barycenter_constrmat(K, N)
    }

    b <- c(c_byrow(mu), rep(0, (K - 1) * (N - 1)))
    dir <- ROI::eq(length(b))
    res <- lp_solve(
        objective   = ROI::L_objective(tcostm),
        constraints = ROI::L_constraint(constr_mat, dir, b),
        types       = NULL,
        bounds      = NULL,
        maximum     = FALSE,
        solver      = solver,
        add.info    = "OT barycenter"
    )

    list(
        cost           = ROI::solution(res, "objval"),
        barycenter     = ROI::solution(res, "primal")[seq_len(N^2)] |> matrix(N, N, byrow = TRUE) |> colSums()
    )
}

# mu and nu are normalized, if they dont have the same total mass (roughly),
# then this might have unintended consequences
ot_cost <- \(mu, nu, costm, tol.0 = 1e-10) {
    smu <- sum(mu)
    snu <- sum(nu)
    if (smu <= tol.0 || snu <= tol.0) {
        0.0
    } else {
        0.5 * (smu + snu) * transport(mu / smu, nu / snu, costm, fullreturn = TRUE)$cost
    }
}

pos <- \(mu) pmax(mu, 0)
neg <- \(mu) -pmin(mu, 0)

ot_cost_sgn0 <- \(mu, nu, costm) {
    a <- pos(mu) + neg(nu)
    b <- pos(nu) + neg(mu)
    stopifnot(all.equal(sum(a), sum(b), check.attributes = FALSE))
    ot_cost(a, b, costm)
}

ot_cost_sgn_rowwise <- \(mu, nu, costm) {
    if (nrow(mu) == 0) {
        return(numeric())
    }
    # seed = NULL, as we use transport::transport with method = "networkflow" (doesnt require RNG)
    future_sapply(seq_len(nrow(mu)), \(k) ot_cost_sgn0(mu[k, ], nu[k, ], costm), future.seed = NULL)
}

ot_cost_sgn_rowwise_posneg <- \(mu, costm) {
    future_sapply(seq_len(nrow(mu)), \(k) ot_cost(pos(mu[k, ]), neg(mu[k, ]), costm), future.seed = NULL)
}

ot_cost_sgn_mat <- \(mu, costm) {
    K <- nrow(mu)
    if (K <= 1) {
        return(matrix(0, K, K))
    }
    idx <- indicesAboveDiag(K)
    vals <- future_sapply(idx, \(ij) ot_cost_sgn0(mu[ij[1], ], mu[ij[2], ], costm), future.seed = NULL)
    dmat <- matrix(0, K, K)
    dmat[lower.tri(dmat)] <- vals
    dmat + t(dmat)
}

ot_cost_sgn_mat1 <- \(mu, nu, costm) {
    K1 <- nrow(mu)
    K2 <- nrow(nu)
    if (K1 == 0 || K2 == 0) {
        return(matrix(0, K1, K2))
    }
    idx <- indicesAll(K1, K2)
    vals <- future_sapply(idx, \(ij) ot_cost_sgn0(mu[ij[1], ], nu[ij[2], ], costm), future.seed = NULL)
    dmat <- matrix(vals, K1, K2, byrow = TRUE)
    dmat
}

#' @title Compute optimal transport costs for signed measures
#' @description Compute the optimal transport (OT) cost between signed measures that have the same total mass.
#' @param mu matrix (row-wise) or list containing \eqn{K_1} vectors of length \eqn{N}.
#' @param nu matrix (row-wise) or list containing \eqn{K_2} vectors of length \eqn{N} or `NULL`.
#' @param costm cost matrix \eqn{c \in \mathbb{R}^{N \times N}}.
#' @param mode controls which of the pairwise OT costs are computed.
#' @details The extended OT functional for vectors \eqn{\mu,\,\nu \in \mathbb{R}^N} with \eqn{\sum_{i=1}^N \mu_i = \sum_{i=1}^N \nu_i} is defined as
#' \deqn{
#'  \mathrm{OT}^{\pm}_c(\mu, \nu) := \mathrm{OT}_c(\mu^+ + \nu^-, \, \nu^+ + \mu^-)\,,
#' }
#' where \eqn{\mu^+ = \max(0, \mu)} and \eqn{\mu^- = -\min(0, \mu)} denote the positive and negative part of \eqn{\mu}, and \eqn{\mathrm{OT}_c}
#' is the standard OT functional. To compute the standard OT, the function [`transport::transport`] is used.
#' The values may be computed in parallel via [`future::plan`].
#' @returns The OT cost between the vectors in `mu` and `nu`.
#'
#' For `mode = "all"` the whole matrix of size \eqn{K_1 \times K_2} is returned. If `mu` or `nu` is a vector, then this matrix is also returned as a vector.
#' `nu = NULL` means that `nu = mu` and only the lower triangular part is actually computed and then reflected.
#'
#' If `mode = "diag"`, then only the diagonal is returned (requiring \eqn{K_1 = K_2}).
#' @example examples/sgn_ot.R
#' @seealso [`transport::transport`]
#' @export
ot_cost_sgn <- \(mu, nu, costm, mode = c("all", "diag")) {
    mode <- match.arg(mode)
    stopifnot(is_num_mat(costm), nrow(costm) == ncol(costm))
    stopifnot(!is.null(mu))
    N <- nrow(costm)
    iv <- is_num_vec(mu)
    mu <- as_rowmat0(mu, N)
    if (!is.null(nu)) {
        iv <- iv || is_num_vec(nu)
        nu <- as_rowmat0(nu, N)
    }

    if (mode == "diag") {
        stopifnot(
            !is.null(nu),
            nrow(mu) == nrow(nu)
        )
        ot_cost_sgn_rowwise(mu, nu, costm)
    }
    else { # mode == "all"
        if (is.null(nu)) {
            res <- ot_cost_sgn_mat(mu, costm)
        } else {
            res <- ot_cost_sgn_mat1(mu, nu, costm)
        }

        if (iv) c(res) else res
    }
}

uot_cost_sgn <- \(mu, nu, costm, C.p, p = 1) {

    N <- nrow(costm)
    costm2 <- matrix(C.p^p / 2, N + 1, N + 1)
    costm2[1:N, 1:N] <- pmin(costm^p, C.p^p)
    costm2[N + 1, N + 1] <- 0

    a <- pos(mu) + neg(nu)
    b <- pos(nu) + neg(mu)

    sa <- sum(a)
    sb <- sum(b)

    r <- max(sa, sb)

    a <- c(a, r - sa)
    b <- c(b, r - sb)

    ot_cost(a, b, costm2)
}
