#include "RcppArmadillo.h"
#include <cmath>
#define _USE_MATH_DEFINES
#include <math.h>

// Rcpp::depends(RcppArmadillo)

using namespace Rcpp;

// [[Rcpp::export]]
double dnormC(double x){
	return (1/sqrt(2*M_PI))*exp(-pow(x,2.0)*0.5);
}

// [[Rcpp::export]]
arma::colvec mz(double n){
  arma::colvec aux = arma::zeros(n);
  return(aux);
}

// [[Rcpp::export]]
double pnormC(double x){
    // constants
    double a1 =  0.254829592;
    double a2 = -0.284496736;
    double a3 =  1.421413741;
    double a4 = -1.453152027;
    double a5 =  1.061405429;
    double p  =  0.3275911;

    // Save the sign of x
    int sign = 1;
    if (x < 0)
        sign = -1;
    x = fabs(x)/sqrt(2.0);

    // A&S formula 7.1.26
    double t = 1.0/(1.0 + p*x);
    double y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);

    return 0.5*(1.0 + sign*y);
}                 

// [[Rcpp::export]]
arma::colvec mvndrawC(arma::colvec mu, arma::mat sig) {

double k = mu.size();
arma::colvec aux = as<arma::colvec>(rnorm(k));
arma::mat csig = arma::chol(sig).t();
arma::colvec out = mu + csig*aux;
return(out);

}
 
 // [[Rcpp::export]]
List carterkohn(arma::mat y, arma::mat Z, arma::mat Ht, arma::mat Qt, double m, double p, double t, arma::colvec B0, arma::mat V0) {
 
arma::colvec bp = B0; // reuses memory and avoids extra copy
arma::mat Vp = V0;
arma::colvec btt = B0;	// initialize now s/t that their scope extends beyond loop
arma::mat Vtt = V0;
arma::mat f = arma::zeros(p, Ht.n_cols);
arma::mat invf = f;
arma::colvec cfe = arma::zeros(y.n_rows);
arma::mat bt = arma::zeros(t,m);
arma::mat Vt = arma::zeros(pow(m,2),t);
arma::colvec loglik = arma::zeros(1);
arma::mat R = arma::zeros(p,p);
arma::mat H = arma::zeros(p,m);
 
for (int i = 1; i < (t+1); i++) {
  R = Ht.rows((i-1)*p,i*p-1);
  H = Z.rows((i-1)*p,i*p-1);

  cfe = y.col(i-1) - H*bp;   // conditional forecast error
  f = H*Vp*H.t() + R;    	  // variance of the conditional forecast error
  invf = f.i();		  // invert only once
  loglik = loglik + log(det(f)) + cfe.t()*invf*cfe;

  btt = bp + Vp*H.t()*invf*cfe;
  Vtt = Vp - Vp*H.t()*invf*H*Vp;

  if (i < t){
    bp = btt;
    Vp = Vtt + Qt;
  }

  bt.row(i-1) = btt.t();
  Vt.col(i-1) = arma::vectorise(Vtt);
}

arma::mat bdraw = arma::zeros(t,m);
bdraw.row(t-1) = mvndrawC(btt,Vtt).t();   

for (int i = 1; i < t; i++) {	// Backward recursions
    arma::colvec bf = bdraw.row(t-i).t();
    btt = bt.row(t-i-1).t();
    Vtt = arma::reshape(Vt.col(t-i-1),m,m);
    f = Vtt + Qt;
    invf = f.i();
    cfe = bf - btt;
    arma::colvec bmean = btt + Vtt*invf*cfe;
    arma::mat bvar = Vtt - Vtt*invf*Vtt;
    bdraw.row(t-i-1) = mvndrawC(bmean, bvar).t();
}

return List::create(Named("loglik") = loglik, Named("bdraws") = bdraw.t());
}



// [[Rcpp::export]]
arma::mat meye(double n){
  arma::mat aux = arma::zeros(n,n);
  aux.eye();
  return(aux);  
}

// [[Rcpp::export]]
arma::colvec vechC(arma::mat x) {

double n = x.n_rows;
arma::colvec out = arma::zeros(0.5*n*(n+1));
double ct = 0;

for (int ii = 1; ii < (n+1); ii++) { 

  out.rows(ct, ct+n-ii) = x(arma::span(ii-1,n-1), arma::span(ii-1, ii-1));
  ct = ct + n + 1 - ii;

}

return(out);

}

// [[Rcpp::export]]
arma::mat wishdrawC(arma::mat h, double n) {

double k = h.n_rows;
arma::mat out = arma::zeros(k,n);

for (int i = 1; i < (n + 1); i++) {
  out.col(i-1) = mvndrawC(mz(k),h);
}

out = out * out.t();
return(out);

}


// [[Rcpp::export]]
arma::mat matmult(arma::mat x, double nt){
arma::mat out = meye(x.n_rows);
if (nt == 1){
  out = x;
} else if (nt > 1) {
  arma::mat tempmat = x;
    for (int ii = 1; ii < (nt); ii++) {
      tempmat = tempmat * x;
    }
  out = tempmat;
} 
return(out);
}

// [[Rcpp::export]]
List varfcst(arma::mat b, arma::mat sig, arma::mat y, double nf){
double k = b.n_rows;
double p = (b.n_cols - 1)/k;
arma::colvec bigy = arma::zeros(p*k);
// Define y vector in companion form
for (int ii = 1; ii < (p+1); ii++) {
  bigy.rows((ii-1)*k,ii*k-1) = y.row(p-ii).t();
}
arma::colvec nu = b.col(0);   // intercept vector 
arma::mat a = b.cols(1,b.n_cols-1); // VAR coefficient matrices
if (p > 1){
  arma::colvec tmp = arma::zeros(k*(p-1));
  nu = arma::join_cols(nu,tmp);
  arma::mat tmp2 = meye(k*(p-1));
  arma::mat tmp3 = arma::zeros(k*(p-1),k);
  a = arma::join_cols(a, arma::join_rows(tmp2, tmp3));  
}
arma::mat om0 = sig;
if (p > 1){
  arma::mat tmp4 = arma::zeros(k, k*(p-1));
  arma::mat tmp5 = arma::zeros(k*(p-1), k*p);
  om0 = arma::join_cols(arma::join_rows(sig, tmp4), tmp5);  
}
arma::mat fcv = om0;
arma::mat aux = meye(k*p);
for (int hh = 1; hh < (nf); hh++) {
  aux = aux + matmult(a, hh);
  fcv = fcv + matmult(a, hh)*om0*(matmult(a, hh).t());
}
arma::mat aux2 = aux*nu + matmult(a, nf)*bigy;
return List::create(Named("mean") = aux2.rows(0, k-1), Named("variance") = fcv(arma::span(0,k-1), arma::span(0, k-1))); 
}   

// [[Rcpp::export]]
arma::mat Sigmahelper(arma::mat A_draw, arma::mat lambda_draw){
  double M = A_draw.n_rows;
  double t = lambda_draw.n_rows;
  arma::mat A_draw_t = A_draw.t();
  arma::mat Sigma_i_all = arma::zeros(M*t,M);
  for (int i = 1; i < (t+1); i++) {
    arma::colvec tmp1 = arma::ones(M, 1);
    arma::colvec tmp2 = lambda_draw.row(i-1).t();
    arma::colvec tmp3 = tmp1 / tmp2; 
    arma::mat tmp4 = arma::diagmat(tmp3);
    Sigma_i_all.rows((i-1)*M,i*M-1) = A_draw_t * tmp4 * A_draw;
  }
  return Sigma_i_all;
}

// [[Rcpp::export]]
List Pi_post(arma::colvec Pi_prior_m, arma::mat Pi_prior_v_i, arma::mat Sigma_i_all, arma::mat y, arma::mat x){
  // Get dimensions
  double M = y.n_cols;
  double t = y.n_rows;
  double k = x.n_cols / M;
  
  // Initialize posterior mean and variance for slope parameters
  arma::mat Pi_post_v_i = Pi_prior_v_i;
  arma::mat auxm = arma::zeros(M*k, M);
  for (int ii = 1; ii < (t+1); ii++) { 
    arma::mat Sigma_i_tmp = Sigma_i_all.rows((ii-1)*M, ii*M-1);
    arma::colvec x_tmp = x.row(ii-1).t();
    arma::colvec y_tmp = y.row(ii-1).t();
    Pi_post_v_i = Pi_post_v_i + arma::kron(Sigma_i_tmp, x_tmp * x_tmp.t());
    auxm = auxm + x_tmp * y_tmp.t() * Sigma_i_tmp;
  }
  arma::colvec auxm_vec = arma::vectorise(auxm);
  arma::mat Pi_post_v = Pi_post_v_i.i();
  arma::colvec Pi_post_m = Pi_post_v * (auxm_vec + Pi_prior_v_i * Pi_prior_m);
  arma::colvec Pi_post_draw = mvndrawC(Pi_post_m, Pi_post_v);
  return List::create(Named("Pi_post_m") = Pi_post_m, Named("Pi_post_v") = Pi_post_v, Named("Pi_post_draw") = Pi_post_draw);
}


// [[Rcpp::export]]
List make_y_x(arma::mat dat, double k){
  
  // (Raw) dimensions
  double t_raw = dat.n_rows;
  double M = dat.n_cols;
  
  // Make regressor matrix
  arma::mat x = arma::zeros(t_raw-k,M*k);
  for (int ii = 1; ii < (k+1); ii++) { 
    x.cols((ii-1)*M,ii*M-1) = dat.rows(k-ii,t_raw-ii-1);
  }
  
  // Modify matrix of dependent variables
  arma::mat y = dat.rows(k,t_raw-1);
  double t = y.n_rows;

  return List::create(Named("y") = y, Named("x") = x, Named("t") = t);
}

// [[Rcpp::export]]
List Psi_post(arma::colvec Psi_prior_m, arma::mat Psi_prior_v_i, arma::mat Sigma_i_all, arma::mat q_all, arma::mat U, double k){
  // Get dimensions
  double t = q_all.n_rows;
  double p = Psi_prior_m.n_rows;
  
  // Matrix dbar
  arma::colvec d_bar = arma::ones(k+1) * (-1);
  d_bar(0) = 1;
  arma::rowvec d_bar_t = d_bar.t();
  arma::mat dd = kron(d_bar, d_bar_t);
  
  // Initialize posterior mean and variance for slope parameters
  arma::mat aux1 = arma::zeros((k+1)*p, (k+1)*p);
  arma::mat aux2 = arma::zeros(p, k+1);
  for (int ii = 1; ii < (t+1); ii++) { 
    arma::mat Sigma_i_tmp = Sigma_i_all.rows((ii-1)*p, ii*p-1);
    arma::colvec q_tmp = q_all.row(ii-1).t();
    aux1 = aux1 + kron(dd, Sigma_i_tmp);
    aux2 = aux2 + Sigma_i_tmp * q_tmp * d_bar_t;
  }
  arma::mat Psi_post_v_i = Psi_prior_v_i + U.t() * aux1 * U;
  arma::mat Psi_post_v = Psi_post_v_i.i();
  arma::colvec Psi_post_m = Psi_post_v * (U.t() * arma::vectorise(aux2) + Psi_prior_v_i * Psi_prior_m);
  arma::colvec Psi_post_draw = mvndrawC(Psi_post_m, Psi_post_v);
  return List::create(Named("Psi_post_m") = Psi_post_m, Named("Psi_post_v") = Psi_post_v, Named("Psi_post_draw") = Psi_post_draw);
}

// [[Rcpp::export]]
arma::mat rep_vec(arma::rowvec v, double rep_row, double rep_col){
  arma::mat aux = arma::mat(1, v.n_cols);
  aux.row(0) = v;
  arma::mat out = arma::repmat(aux, rep_row, rep_col);
  return out;
}

// [[Rcpp::export]]
arma::mat make_y_hat(arma::mat y, arma::mat x, arma::colvec Psi, arma::mat Pi){
  // Subtract long-run mean from y
  double t = y.n_rows;
  double p = y.n_cols;
  double k = x.n_cols/p;
  arma::mat y_dm = y - rep_vec(Psi.t(), t, 1);
  arma::mat x_dm = x - rep_vec(Psi.t(), t, k);
  // Apply lag polynomial to y_dm (see Eq 10 in Clark paper)
  arma::mat y_hat = y_dm - x_dm * Pi.t();
  return y_hat;
}

// [[Rcpp::export]]
arma::colvec ols_post(arma::colvec prior_m, arma::mat prior_v_i, arma::colvec y, arma::mat x){
  arma::mat post_v_i = (prior_v_i + x.t()*x);
  arma::colvec post_m = post_v_i.i() * (prior_v_i * prior_m + x.t()*y);
  arma::colvec out = mvndrawC(post_m, post_v_i.i());
  return out;
}


// [[Rcpp::export]]
List sigmahelper2_new(arma::mat A_draw, arma::mat y_hat, arma::colvec qs, arma::colvec ms, arma::colvec u2s, arma::mat Sigtdraw, arma::mat Wdraw, arma::colvec sigma_prmean, arma::mat sigma_prvar){
  
  double M = y_hat.n_cols;
  double t = y_hat.n_rows; 
  
  arma::mat y2 = pow(A_draw * y_hat.t(), 2);
  arma::mat aux = 0.001 * arma::ones(t,M);
  arma::mat yss = log( aux + y2.t() );
  
  arma::colvec cprw = arma::zeros(10,1);
  arma::mat statedraw = arma::zeros(t,M);
  for (int jj = 1; jj < (M+1); jj++) {
    for (int i = 1; i < (t+1); i++) {
      arma::colvec prw = arma::zeros(10,1);
      for (int k = 1; k < 11; k++) {
        prw(k-1) = qs(k-1) * (1/sqrt(2*M_PI*u2s(k-1)))*exp(-0.5*((pow(yss(i-1,jj-1) - Sigtdraw(jj-1,i-1) - ms(k-1),2))/u2s(k-1)));
      }
      cprw = arma::cumsum(prw/arma::sum(prw));
      double trand = as<double>(runif(1));
      double imix = 0;
      if (trand < cprw[0]){
        imix = 1;
      } else if (trand < cprw[1]) {
        imix = 2;
      } else if (trand < cprw[2]) {
        imix = 3;
      } else if (trand < cprw[3]) {
        imix = 4;
      } else if (trand < cprw[4]) {
        imix = 5;
      } else if (trand < cprw[5]) {
        imix = 6;
      } else if (trand < cprw[6]) {
        imix = 7;
      } else if (trand < cprw[7]) {
        imix = 8;
      } else if (trand < cprw[8]) {
        imix = 9;
      } else if (trand < cprw[9]) {
        imix = 10;
      }
      statedraw(i-1,jj-1) = imix;  
    }
  }
  
  arma::mat vart = arma::zeros(t*M,M);
  arma::mat yss1 = arma::zeros(t,M);
  for (int i = 1; i < (t+1); i++) {
    for (int j = 1; j < (M+1); j++) {
      double imix = statedraw(i-1,j-1);
      vart(((i-1)*M+j-1),j-1) = u2s(imix-1);
      yss1(i-1,j-1) = yss(i-1,j-1) - ms(imix-1);
    }
  }
  
  arma::mat auxm1 = arma::ones(t, 1);
  arma::mat Zs = arma::kron(auxm1, meye(M));
  arma::mat Sigtdraw_new = carterkohn(yss1.t(),Zs,vart,Wdraw,M,M,t,sigma_prmean,sigma_prvar)["bdraws"];
  arma::mat exp_Sigtdraw_new = arma::exp(Sigtdraw_new);
  
  return List::create(Named("log_resid_var") = Sigtdraw_new.t(), Named("resid_var") = exp_Sigtdraw_new.t());
}

// [[Rcpp::export]]
arma::mat diff_helper(arma::mat x){
  double t = x.n_rows;
  arma::mat d = x.rows(1,t-1) - x.rows(0, t-2);
  arma::mat dd = d.t() * d;
  return dd;
}

// [[Rcpp::export]]
List sigmahelper3(arma::mat capAt, arma::mat sigt){
  
  double M = sigt.n_cols;
  double t = sigt.n_rows/M;
  arma::mat Ht = arma::zeros(M*t,M);
  arma::mat Htsd = arma::zeros(M*t,M);
  
  for (int i = 1; i < (t+1); i++) {
    arma::mat inva = capAt.rows((i-1)*M, (i*M)-1).i();
    arma::mat stem = sigt.rows((i-1)*M, (i*M)-1);
    arma::mat Hsd = inva*stem;
    Ht.rows((i-1)*M, (i*M)-1) = Hsd * Hsd.t();
    Htsd.rows((i-1)*M, (i*M)-1) = Hsd;
  }
  
  return List::create(Named("Ht") = Ht, Named("Htsd") = Htsd);
}

// [[Rcpp::export]]
List demean(arma::mat y, arma::mat x, arma::colvec Psi){
  double k = x.n_cols / y.n_cols;
  double t = y.n_rows;
  arma::mat aux1 = arma::repmat(Psi.t(), t, 1);
  arma::mat y_dm = y - aux1;
  arma::mat aux2 = arma::repmat(Psi.t(), t, k);
  arma::mat x_dm = x - aux2;
  return List::create(Named("y_dm") = y_dm, Named("x_dm") = x_dm);
}
