#include <Rcpp.h>
#include <spatial_lagging.hpp>

/**
 * @title Rcpp wrapper for single-variable lattice lag computation
 *
 * @description
 * This function provides an Rcpp interface to the C++ core function `GenLatticeLagUni()`,
 * which computes lagged mean values for a lattice-based spatial structure defined by
 * explicit neighbor relationships. The computation aggregates values from neighbors
 * up to a specified lag order and calculates the mean while ignoring missing (NaN) values.
 *
 * The function converts an input R numeric vector and neighborhood list into C++ standard
 * containers, applies zero-based indexing adjustment, and invokes the underlying C++ routine.
 *
 * @param vec Numeric vector representing spatial observations for each lattice unit.
 * @param nb List of integer vectors specifying the neighborhood structure.
 *           Each element contains indices of immediate neighbors (1-based in R, converted to 0-based in C++).
 * @param lagNum Integer specifying the lag order (non-negative) to compute.
 *
 * @return Numeric vector where each element represents the mean of lagged values
 *         for the corresponding spatial unit at the specified lag order.
 *         Returns NaN for spatial units with no valid neighbors.
 *
 * @examples
 * \dontrun{
 * # Example usage in R:
 * vec <- c(1.0, 2.0, 3.0, 4.0, 5.0)
 * nb <- list(c(2, 3), c(1, 4), c(1, 5), c(2), c(3))
 * lagNum <- 1
 * result <- RcppGenLatticeLagUni(vec, nb, lagNum)
 * print(result)
 * }
 */
// [[Rcpp::export(rng = false)]]
Rcpp::NumericVector RcppGenLatticeLagUni(const Rcpp::NumericVector& vec,
                                         const Rcpp::List& nb,
                                         int lagNum = 1) {
  // Convert R numeric vector to std::vector<double>
  std::vector<double> cpp_vec = Rcpp::as<std::vector<double>>(vec);

  // Get the number of elements in the nb object
  int n = nb.size();

  // Create a std::vector<std::vector<int>> to store the result
  std::vector<std::vector<int>> cpp_nb(n);

  // Iterate over each element in the nb object
  for (int i = 0; i < n; ++i) {
    // Get the current element (should be an integer vector)
    Rcpp::IntegerVector current_nb = nb[i];

    // Create a vector<int> to store the current subset of elements
    std::vector<int> current_subset;

    // Iterate over each element in the current subset
    for (int j = 0; j < current_nb.size(); ++j) {
      // Subtract one from each element to convert from R's 1-based indexing to C++'s 0-based indexing
      current_subset.push_back(current_nb[j] - 1);
    }

    // Add the current subset to the result
    cpp_nb[i] = current_subset;
  }

  // Call the C++ function
  std::vector<double> result = SpatialLagging::GenLatticeLagUni(cpp_vec, cpp_nb, lagNum);

  // Convert result back to R numeric vector
  return Rcpp::wrap(result);
}

/**
 * @title Rcpp wrapper for flexible multi-variable lattice lag computation
 *
 * @description
 * This function provides an Rcpp interface to the C++ function `GenLatticeLagMulti()`,
 * which computes lagged mean values for multiple spatial variables on a common lattice structure.
 * Each column in the input matrix represents one spatial variable, and each row represents a spatial unit.
 *
 * Unlike the simpler version, this function supports *distinct lag orders* per variable.
 * The parameter `lagNums` allows specifying a custom lag order for each column in `vecs`.
 * Internally, the function converts R objects into efficient C++ containers,
 * calls `GenLatticeLagMulti()` for computation, and converts the results back to an R matrix.
 *
 * @param vecs Numeric matrix where each column represents one spatial variable
 *             and each row corresponds to a spatial unit.
 * @param nb List of integer vectors representing the neighborhood structure.
 *           Each element contains 1-based indices of immediate neighbors for each spatial unit.
 * @param lagNums Integer vector specifying the lag order for each variable (column of `vecs`).
 *                Must be the same length as the number of columns in `vecs`.
 *
 * @return Numeric matrix of the same dimension as `vecs`, where each column contains
 *         the lagged mean values for the corresponding variable. If a spatial unit has
 *         no valid neighbors at the specified lag, its value is `NaN`.
 *
 * @examples
 * \dontrun{
 * # Example usage in R:
 * vecs <- matrix(c(1,2,3,4,5,
 *                  2,3,4,5,6), ncol = 2)
 * nb <- list(c(2,3), c(1,4), c(1,5), c(2), c(3))
 * lagNums <- c(1, 2)  # Variable 1 uses lag=1, variable 2 uses lag=2
 * result <- RcppGenLatticeLagMulti(vecs, nb, lagNums)
 * print(result)
 * }
 */
// [[Rcpp::export(rng = false)]]
Rcpp::NumericMatrix RcppGenLatticeLagMulti(const Rcpp::NumericMatrix& vecs,
                                           const Rcpp::List& nb,
                                           const Rcpp::IntegerVector& lagNums) {
  // --- Validate inputs ---
  if (vecs.nrow() == 0 || vecs.ncol() == 0) {
    Rcpp::stop("Input matrix 'vecs' must not be empty.");
  }
  if (lagNums.size() != vecs.ncol()) {
    Rcpp::stop("Length of 'lagNums' must match the number of columns in 'vecs'.");
  }

  // --- Convert neighborhood list (R) → C++ vector of vectors ---
  int n_units = nb.size();
  std::vector<std::vector<int>> cpp_nb(n_units);

  for (int i = 0; i < n_units; ++i) {
    Rcpp::IntegerVector current_nb = nb[i];
    std::vector<int> neighbors;
    neighbors.reserve(current_nb.size());

    for (int j = 0; j < current_nb.size(); ++j) {
      // Convert from 1-based (R) to 0-based (C++)
      neighbors.push_back(current_nb[j] - 1);
    }
    cpp_nb[i] = std::move(neighbors);
  }

  // --- Convert matrix columns (Rcpp) → vector of vectors (C++) ---
  int n_rows = vecs.nrow();
  int n_cols = vecs.ncol();

  std::vector<std::vector<double>> cpp_vecs;
  cpp_vecs.reserve(n_cols);

  for (int j = 0; j < n_cols; ++j) {
    std::vector<double> col(n_rows);
    for (int i = 0; i < n_rows; ++i) {
      col[i] = vecs(i, j);
    }
    cpp_vecs.emplace_back(std::move(col));
  }

  // --- Convert lagNums (Rcpp) → std::vector<int> ---
  std::vector<int> cpp_lagNums(lagNums.begin(), lagNums.end());

  // --- Call the core computation function ---
  std::vector<std::vector<double>> result_cpp = SpatialLagging::GenLatticeLagMulti(cpp_vecs, cpp_nb, cpp_lagNums);

  // --- Validate output ---
  if (result_cpp.empty()) {
    Rcpp::stop("Computation returned empty result.");
  }

  if (result_cpp.size() != static_cast<size_t>(n_cols)) {
    Rcpp::stop("Result dimension mismatch between input and output.");
  }

  // --- Convert C++ result back to Rcpp matrix ---
  Rcpp::NumericMatrix result(n_rows, n_cols);

  for (int j = 0; j < n_cols; ++j) {
    const auto& col = result_cpp[j];
    if (static_cast<int>(col.size()) != n_rows) {
      Rcpp::stop("Output column length mismatch for variable %d.", j + 1);
    }
    for (int i = 0; i < n_rows; ++i) {
      result(i, j) = col[i];
    }
  }

  return result;
}

/**
 * Rcpp wrapper for computing the mean of lagged values for a grid structure at a specified lag number.
 *
 * @param mat Numeric matrix representing the grid data.
 * @param lagNum Integer specifying the number of steps to lag when considering
 *               neighbors in the Moore neighborhood.
 *
 * @return Numeric vector where each element represents the mean of lagged values
 *         for the corresponding grid cell at the specified lag number.
 *         The results are arranged in row-major order (same as the input grid).
 *         Returns NaN for grid cells with no valid neighbors.
 *
 * @examples
 * # Example usage in R:
 * # mat <- matrix(1:9, nrow = 3, ncol = 3)
 * # result <- RcppGenGridLagUni(mat, 1)
 */
// [[Rcpp::export(rng = false)]]
Rcpp::NumericVector RcppGenGridLagUni(const Rcpp::NumericMatrix& mat,
                                      int lagNum = 1) {
  // Convert R matrix to std::vector<std::vector<double>>
  int nrow = mat.nrow();
  int ncol = mat.ncol();
  std::vector<std::vector<double>> cpp_mat(nrow, std::vector<double>(ncol));

  for (int i = 0; i < nrow; ++i) {
    for (int j = 0; j < ncol; ++j) {
      cpp_mat[i][j] = mat(i, j);
    }
  }

  // Call the C++ function
  std::vector<double> result = SpatialLagging::GenGridLagUni(cpp_mat, lagNum);

  // Convert result back to R numeric vector
  return Rcpp::wrap(result);
}

/**
 * @title Rcpp wrapper for multi-variable grid lag computation with variable lag distances
 *
 * @description
 * This function provides an Rcpp interface to the C++ function `GenGridLagMulti()`,
 * which computes lagged mean values for multiple 2D grid variables using the Moore neighborhood
 * (also known as queen’s case). Each column in the input matrix represents a separate spatial variable,
 * stored in row-major (flattened) order.
 *
 * Unlike the single-lag version, this function allows specifying a distinct lag distance for
 * each variable via the `lagNums` parameter.
 *
 * @param mat Numeric matrix where each column represents one spatial variable,
 *            stored in row-major flattened order (each row corresponds to one grid cell).
 * @param lagNums Integer vector specifying the lag distance for each variable (column of `mat`).
 *                Must be the same length as the number of columns in `mat`.
 * @param nrow Integer specifying the number of rows in each 2D grid.
 *             The number of columns in the grid is inferred from `mat.nrow() / nrow`.
 *
 * @return Numeric matrix of the same dimensions as `mat`, where each column contains
 *         the lagged mean values for the corresponding input variable.
 *         NaN values are returned for cells without valid neighbors.
 *
 * @examples
 * \dontrun{
 * # Example usage in R:
 * nrow <- 3
 * mat <- matrix(c(1:9, 9:1), ncol = 2)
 * lagNums <- c(1, 2)  # variable 1 uses lag=1, variable 2 uses lag=2
 * result <- RcppGenGridLagMulti(mat, lagNums, nrow)
 * print(result)
 * }
 */
// [[Rcpp::export(rng = false)]]
Rcpp::NumericMatrix RcppGenGridLagMulti(const Rcpp::NumericMatrix& mat,
                                        const Rcpp::IntegerVector& lagNums,
                                        int nrow) {
  // --- Validate inputs ---
  if (mat.nrow() == 0 || mat.ncol() == 0) {
    Rcpp::stop("Input matrix 'mat' must not be empty.");
  }
  if (nrow <= 0) {
    Rcpp::stop("Parameter 'nrow' must be positive.");
  }

  const int n_cells = mat.nrow();
  const int n_vars = mat.ncol();

  if (lagNums.size() != n_vars) {
    Rcpp::stop("Length of 'lagNums' must match the number of columns in 'mat'.");
  }

  // Derive number of columns per grid
  if (n_cells % nrow != 0) {
    Rcpp::stop("Parameter 'nrow' does not evenly divide the number of rows in 'mat'.");
  }
  const int ncol = n_cells / nrow;

  // --- Convert input matrix to 3D std::vector: arr[var][row][col] ---
  std::vector<std::vector<std::vector<double>>> cpp_arr;
  cpp_arr.reserve(n_vars);

  for (int v = 0; v < n_vars; ++v) {
    std::vector<std::vector<double>> grid(nrow, std::vector<double>(ncol));
    const Rcpp::NumericVector col_data = mat.column(v);

    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        // Reconstruct row-major 2D grid
        grid[i][j] = col_data[i * ncol + j];
      }
    }
    cpp_arr.emplace_back(std::move(grid));
  }

  // --- Convert lagNums (Rcpp) → std::vector<int> ---
  std::vector<int> cpp_lagNums(lagNums.begin(), lagNums.end());

  // --- Call core computation function ---
  std::vector<std::vector<double>> result_cpp = SpatialLagging::GenGridLagMulti(cpp_arr, cpp_lagNums);

  // --- Validate output size ---
  if (result_cpp.size() != static_cast<size_t>(n_vars)) {
    Rcpp::stop("Mismatch between input and output variable counts.");
  }

  // --- Convert C++ result back to Rcpp::NumericMatrix ---
  Rcpp::NumericMatrix result(n_cells, n_vars);

  for (int v = 0; v < n_vars; ++v) {
    const auto& flat_vec = result_cpp[v];
    if (static_cast<int>(flat_vec.size()) != n_cells) {
      Rcpp::stop("Output vector size mismatch for variable %d.", v + 1);
    }

    for (int i = 0; i < n_cells; ++i) {
      result(i, v) = flat_vec[i];
    }
  }

  return result;
}
