#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;


//[[Rcpp::export]]
Rcpp::List rhierLinearMixtureParallel_rcpp_loop(List const& regdata, 
                                                arma::mat const& Z, 
                                          arma::vec const& deltabar, 
                                          arma::mat const& Ad, 
                                          arma::mat const& mubar, 
                                          arma::mat const& Amu,
                                          int const& nu, 
                                          arma::mat const& V, 
                                          int nu_e, 
                                          arma::vec const& ssq,
                                          int R, 
                                          int keep, 
                                          int nprint, 
                                          arma::mat olddelta, 
                                          arma::vec const& a, 
                                          arma::vec oldprob, 
                                          arma::vec ind, 
                                          arma::vec tau, 
                                          bool drawdelta, bool verbose
                                          )
                                          {
  
  // Wayne Taylor 10/02/2014
  // Boyang Yu 06/2023
  
  int nreg = regdata.size();
  int nvar = V.n_cols;
  int nz = Z.n_cols;
  //int nvarx = nvar - nvarw;
  
  mat rootpi, betabar, Abeta, Abetabar;

  int mkeep;
  unireg runiregout_struct; 
  Rcpp::List regdatai, nmix;
  
  // convert List to std::vector of type "moments"
  std::vector<moments> regdata_vector;
  moments regdatai_struct;
  
  // store vector with struct
  for (int reg = 0; reg<nreg; reg++){
    regdatai = regdata[reg];
    
    regdatai_struct.y = as<vec>(regdatai["y"]);
    regdatai_struct.X = as<mat>(regdatai["X"]);
    regdatai_struct.XpX = as<mat>(regdatai["XpX"]);//check for XpX
    regdatai_struct.Xpy = as<vec>(regdatai["Xpy"]);//check for Xpy
    regdata_vector.push_back(regdatai_struct); //lgtdata_vector.push_back(lgtdatai_struct) appends the lgtdatai_struct object to the end of the lgtdata_vector vector, increasing its size by one.
  }
  
  // allocate space for draws
    mat oldbetas = zeros<mat>(nreg,nvar);//check
    mat taudraw(floor(R/keep), nreg); //check
  arma::mat probdraw(floor(R/keep), oldprob.size());//same as MNL
  arma::mat Deltadraw(1,1); if(drawdelta) Deltadraw.zeros(floor(R/keep), nz*nvar);//Same as MNL.enlarge Deltadraw only if the space is required
  Rcpp::List compdraw(floor(R/keep));//same as MNL

  if ((nprint>0) && verbose) startMcmcTimer();
  
  for (int rep = 0; rep<R; rep++){
    //first draw comps,ind,p | {beta_i}, delta
    // ind,p need initialization comps is drawn first in sub-Gibbs
     Rcpp::List mgout;
    if(drawdelta) {
      olddelta.reshape(nvar,nz);
      mgout = rmixGibbs1(oldbetas-Z*trans(olddelta),mubar,Amu,nu,V,a,oldprob,ind);
    } else {
      mgout = rmixGibbs1(oldbetas,mubar,Amu,nu,V,a,oldprob,ind);
    }
    
    Rcpp::List oldcomp = mgout["comps"];
    oldprob = as<vec>(mgout["p"]); //conversion from Rcpp to Armadillo requires explict declaration of variable type using as<>
    ind = as<vec>(mgout["z"]); //conversion from Rcpp to Armadillo requires explict declaration of variable type using as<>
    
    if(drawdelta) olddelta = drawDelta1(Z,oldbetas,ind,oldcomp,deltabar,Ad);
    //Revised:
    for(int reg = 0; reg<nreg; reg++){
      Rcpp::List oldcompreg = oldcomp[ind[reg]-1];
      rootpi = as<arma::mat>(oldcompreg[1]);
      
      //note: beta_i = Delta*z_i + u_i  Delta is nvar x nz
      if(drawdelta){
        olddelta.reshape(nvar,nz);
        betabar = as<vec>(oldcompreg[0])+olddelta*vectorise(Z(reg,span::all));
      } else {
        betabar = as<vec>(oldcompreg[0]);
      }
      
      Abeta = trans(rootpi)*rootpi;
      Abetabar = Abeta*betabar;
      //check:
      runiregout_struct = runiregG1(regdata_vector[reg].y, regdata_vector[reg].X,
                                   regdata_vector[reg].XpX, regdata_vector[reg].Xpy, 
                                   tau[reg], Abeta, Abetabar, nu_e, ssq[reg]);
      // Purpose: 
      //  perform one Gibbs iteration for Univ Regression Model
      //  only does one iteration so can be used in rhierLinearModel
      // return beta and sigmasq
      oldbetas(reg,span::all) = trans(runiregout_struct.beta); //Beta not necessary return here
      tau[reg] = runiregout_struct.sigmasq; //taui is the variance of ei
    }

if ((nprint>0) && verbose) if ((rep+1)%nprint==0) infoMcmcTimer(rep, R);

if(((rep+1)>0) & ((rep+1)%keep==0)){
  mkeep = (rep+1)/keep;
  taudraw(mkeep-1, span::all) = trans(tau);
  //betadraw.slice(mkeep-1) = oldbetas;
  probdraw(mkeep-1, arma::span::all) = trans(oldprob);
  if(drawdelta) Deltadraw(mkeep-1, span::all) = trans(vectorise(olddelta));
  compdraw[mkeep-1] = oldcomp;
}
}

if ((nprint>0) && verbose) endMcmcTimer();
if (drawdelta){
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw,
                            Rcpp::Named("Deltadraw")= Deltadraw));
                            
}else{
  return(Rcpp::List::create(Rcpp::Named("compdraw") = compdraw,
                            Rcpp::Named("probdraw")= probdraw));
}
}
