#include "RcppArmadillo.h"
#include "aux_functions.h"

Rcpp::List model_sparse(arma::mat X, arma::mat Y, arma::mat R, int k, 
                        std::string method, int mcmc_samples, int burnin, 
                        bool verbose) 
{
  int n = X.n_rows;
  int p = X.n_cols;
  int m = Y.n_rows;
  int q = Y.n_cols;
  
  arma::mat Z = make_Z(X, Y, method);
  
  arma::uvec observed_vec = find_observed(R);
  arma::uvec unobserved_vec = find_unobserved(R);
  int observed_vec_size = observed_vec.size();
  int unobserved_vec_size = unobserved_vec.size();
  
  int var_dim;
  if (method == "bilinear"){
    var_dim = p * q;
  } else {
    var_dim = p + q;
  }
  
  arma::mat kappa = (R - 0.5 * (k + 1));
  arma::vec kappa_reduced = vectorise(kappa);
  kappa_reduced.shed_rows(unobserved_vec);
  
  arma::mat Z_reduced = Z;
  Z_reduced.shed_rows(unobserved_vec);
  arma::mat Z_reduced_t = Z_reduced.t();
  
  // set priors
  arma::vec mu_0(var_dim, arma::fill::zeros);
  arma::mat Sigma_0(var_dim, var_dim, arma::fill::eye);
  arma::mat Sigma_0_inv = inv(Sigma_0);

  // Bayesian horseshoe priors (induce sparsity on beta)
  double zeta = 1.0;
  double tau = 1.0;
  arma::vec nu(var_dim, arma::fill::ones);
  arma::vec lambda(var_dim, arma::fill::randu);
  
  // saved posterior draws for Beta and R
  arma::cube B_hat(var_dim, 1, mcmc_samples - burnin, arma::fill::zeros);
  arma::cube R_hat(n, m, mcmc_samples - burnin, arma::fill::zeros);
  
  // single posterior draw for Beta and R
  arma::vec B_est(var_dim, arma::fill::zeros);
  arma::mat R_pred_mat(n, m, arma::fill::zeros);
  
  for(int iter = 0; iter < mcmc_samples; ++iter){
    if(verbose){
      Rcpp::Rcout << "This is iteration: " << iter + 1 << " out of " << mcmc_samples << "\n";
    }
    
    arma::vec omega(n*m, arma::fill::zeros);
    // Rcpp::Rcout << "Hij doet het hier" << B_est(1) << "\n";  
    for(int i = 0; i < observed_vec_size; ++i){
      int row_i = observed_vec(i);
      omega(row_i) = rcpp_pgdraw(k - 1, dot(Z.row(row_i), B_est));
    }
    
    arma::mat Omega(observed_vec_size, observed_vec_size, arma::fill::zeros);
    Omega.diag() = omega.elem(observed_vec);
    
    arma::mat Sigma_B(var_dim, var_dim, arma::fill::zeros);
    arma::vec mu_B(var_dim, arma::fill::zeros);
    
    Sigma_B = inv(Z_reduced_t * (Omega * Z_reduced) + Sigma_0_inv);
    mu_B = Sigma_B * ((Z_reduced_t * kappa_reduced) + Sigma_0_inv * mu_0);
    
    arma::mat mv_sample = mvrnorm_arma(1, mu_B, Sigma_B);
    B_est = vectorise(mv_sample);
    
    for(int i = 0; i < var_dim; ++i){
      lambda(i) = 1.0/R::rgamma(1.0, 1.0/(1.0/nu(i) + pow(B_est(i), 2)/(2.0*tau)));
      if(lambda(i) < 1e-9){
        lambda(i) = 1e-9;
      }
    }
    // lambda.elem(find(lambda < 1e-9)) = 1e-9;
    
    tau = 1.0/R::rgamma(0.5*(var_dim + 1.0), 1.0/(1.0/zeta + 0.5*sum(pow(B_est, 2)/lambda)));
    if(tau < 1e-9){
      tau = 1e-9;
    } 
    
    for(int i = 0; i < var_dim; ++i){
      nu(i) = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/lambda(i)));
      if(nu(i) < 1e-9){
        nu(i) = 1e-9;
      }
    }
    //nu.elem(find(nu < 1e-9)) = 1e-9;
    
    zeta = 1.0/R::rgamma(1.0, 1.0/(1.0 + 1.0/tau));
    if(zeta < 1e-9){
      zeta = 1e-9;
    } 
    
    Sigma_0_inv.diag() = 1.0/(tau*lambda);
    
    if(iter >= burnin){
      B_hat.slice(iter - burnin) = B_est;
      
      // # predict NA in R
      arma::vec binom_samp(unobserved_vec_size, arma::fill::zeros);
      for(int i = 0; i < unobserved_vec_size; ++i){
        int row_i = unobserved_vec(i);
        double lin_pred_i = logit(dot(Z.row(row_i), B_est));
        binom_samp(i) = R::rbinom(k - 1.0, lin_pred_i); // this seems to work okay
      }
      
      arma::vec R_pred(n*m, arma::fill::zeros);
      R_pred.elem(unobserved_vec) = binom_samp + 1; // seems to work too
      R_pred_mat = vec_2_mat(R_pred, n, m);
      
      R_hat.slice(iter - burnin) = R_pred_mat; // here it goes wrong
    }
  }
  return Rcpp::List::create(Rcpp::Named("B_hat") = B_hat,
                            Rcpp::Named("R_hat") = R_hat);
}
