#' @importFrom MASS mvrnorm
#' @importFrom magrittr "%>%"
#' @importFrom graphics matplot
#' @importFrom stats ar median rWishart arima.sim na.omit rnorm var
#' @importFrom utils tail head
#' @useDynLib bvarsv3, .registration = TRUE

make_Pi_U <- function(vec_Pi, p, k){
  Pi <- matrix(vec_Pi, nrow = p, ncol = p*k, byrow = TRUE)
  U <- matrix(0, (k+1)*p, p)
  U[(1:p),(1:p)] <- diag(p)
  for (jj in 1:k){
    U[(jj*p+1):((jj+1)*p), ] <- Pi[, ((jj-1)*p+1):(jj*p)]
  }
  list(Pi = Pi, U = U)
}

##############################################################################
# Get parameters of mixture normal approximation to log-chi2 distribution 
##############################################################################
getmix <- function(){
  
  q <- c(0.00730, 0.10556, 0.00002, 0.04395, 0.34001, 0.24566, 0.25750)      # probabilities
  m <- c(-10.12999, -3.97281, -8.56686, 2.77786, 0.61942, 1.79518, -1.08819) # means
  u2 <- c(5.79596, 2.61369, 5.17950, 0.16735, 0.64009, 0.34023, 1.26261)    #variances
  return(list(q=q,m=m,u2=u2))
  
}

getmix_omori <- function(){
  
  q <- c(0.00609, 0.04775, 0.13057, 0.20674, 0.22715, 0.18842, 0.12047, 
         0.05591, 0.01575, 0.00115)
  m <- c(1.92677, 1.34744, 0.73504, 0.02266, -0.85173, -1.97278, -3.46788, 
         -5.55246, -8.68384, -14.65000)
  u2 <- c(0.11265, 0.17788, 0.26768, 0.40611, 0.62699, 0.98583, 1.57469,
          2.54498, 4.16591, 7.33342)
  return(list(q=q,m=m,u2=u2))
  
}

A_post <- function(y_hat, Lambda_all, const1, const2){
  p <- ncol(y_hat)
  A_draw <- diag(p)
  s_Lambda_all <- sqrt(Lambda_all)
  for (ii in 2:p) {
    y_tmp <- y_hat[,ii] / s_Lambda_all[,ii]
    x_tmp <- -y_hat[,1:(ii-1), drop = FALSE] / s_Lambda_all[,ii]
    m_tmp <- rep(const1, ii-1)
    v_i_tmp <- diag(const2, ii-1, ii-1) # changed matrix to diag
    A_draw[ii, 1:(ii-1)] <- ols_post(m_tmp, v_i_tmp, y_tmp, x_tmp)
  }
  A_draw
}

# construct minnesota prior for var slope parameters
minnesota_prior <- function(dat, k, rho = rep(0.9, ncol(dat)), a1 = 0.5, a2 = 0.5){
  # nr of variables
  p <- ncol(dat)
  # Initialize prior mean and variance
  # nr of elements
  n_elem <- (k*p^2)
  prior_m <- rep(0, n_elem)
  prior_v <- diag(n_elem) 
  
  # Compute residual variance in univariate regressions
  sigma <- rep(0, p)
  for (jj in 1:p){
    tmp <- ar(dat[,jj], aic = FALSE, order.max = p, method = "ols")
    sigma[jj] <- sqrt(mean(tmp$resid^2, na.rm = TRUE))
  }
  
  # helper matrix (needed to get info on each parameter to choose prior accordingly)
  # possible types: own lag and cross-lags, plus lag length
  mat <- matrix(1:n_elem, nrow = p, byrow = TRUE)
  ind_info <- data.frame(ind = 1:n_elem, l = NA, eq_l = NA, eq_r = NA)
  for (jj in 1:n_elem){
    # Position of jj in mat
    aux <- which(mat == jj, arr.ind = TRUE)  
    # Lag to which coef refers
    l <- ind_info$l[jj] <- ceiling(aux[1,2]/p)
    # Target equation (lhs variable)
    eq_l <- ind_info$eq_l[jj] <- aux[1,1]
    # Origin equation (rhs variable)
    eq_r <- ind_info$eq_r[jj] <-aux[1,2] - (l-1)*p
    # Is coef own-lag or cross-lag?
    own <- eq_l == eq_r
    # Use non-zero prior mean for first own lag
    if (l == 1 & own){
      prior_m[jj] <- rho[eq_l]
    }
    if (own){
      # Prior variance for own lag
      prior_v[jj, jj] <- a1/(l^2)
    } else {
      # Prior variance for cross lag
      prior_v[jj, jj] <- (a2/(l^2)) * (sigma[eq_l]/sigma[eq_r])
    }
  }
  list(prior_m = prior_m, prior_v = prior_v)
}

make_forecast <- function(dat, Pi, Psi){
  # Dimensions
  p <- ncol(dat)
  k <- ncol(Pi)/p
  T <- nrow(dat)
  # Get X vector
  x <- as.numeric(t(dat[T:(T-k+1), ]))
  # Demean it
  x_dm <- x - rep(Psi, k)
  # Forecast
  return(Psi + Pi %*% x_dm)
}

sim_var <- function(T = 300, init = 500){
  Sigma <- matrix(c(1, .5, .5, 2), 2, 2)
  Psi <- matrix(c(1, 3), ncol = 1)
  Psi_p <- t(Psi)
  Pi1 <- matrix(c(.5, .05, .05, .4), 2, 2)
  Pi2 <- matrix(c(-.1, 0, 0, -.1), 2, 2)
  T1 <- T + init
  e <- mvrnorm(T1, rep(0, 2), Sigma)
  y <- matrix(0, T1, 2)
  for (jj in 3:T1){
    y[jj, ] <- Psi + Pi1 %*% t(y[jj-1,] - Psi_p) + Pi2 %*% t(y[jj-2, ] - Psi_p) + e[jj, ]
  }
  return(list(y = y[(init+1):T1, ], Pi = cbind(Pi1, Pi2), Psi = Psi))
}

bvar_sv <- function(dat, k, n_mcmc = 15000, n_bi = 5000, n_fc = 5,
                    simulate_vola_drift = TRUE,
                    rho = rep(0.9, ncol(dat)), 
                    a1 = .5, a2 = .2, Psi_prm = rep(0, ncol(dat)), 
                    Psi_prvi = 1e-4*diag(ncol(dat)), 
                    log_var_prm = rep(log(.5^2), ncol(dat)),
                    log_var_prv = 10*diag(ncol(dat)),
                    Phi_prdf = ncol(dat) + 9, k_Phi = .2, 
                    A_const1 = 0, A_const2 = 1e-4){
  # dimensions
  p <- ncol(dat)
  T_orig <- nrow(dat)
  T <- T_orig - k
  
  # get y and x
  aux <- make_y_x(dat, k)
  y <- aux$y
  x <- aux$x  
  
  # Initialize stuff
  Sigma_i_all <- matrix(diag(p), p*T, p, byrow = TRUE)
  Lambda_all <- matrix(1, T, p)
  log_Lambda_all <- log(Lambda_all)
  A <- Phi <- W <- diag(p)
  Psi <- rep(0, p)
  
  # Priors
  # mean and inverse variance of pi (VAR slope coefficients)
  minn <- minnesota_prior(dat, k, rho, a1, a2)
  Pi_prm <- minn$prior_m
  Pi_prvi <- solve(minn$prior_v)  
  # prior mean for Phi (VCV of log vola residuals)
  # This is the prior mean matrix of an IW dist for X
  # Ensures that X has mean matrix = (k_Phi^2)*identity
  Phi_prm <- (k_Phi^2)*(Phi_prdf-p-1)*diag(p)

  # Matrices for draws
  Pi_draws <- matrix(0, n_mcmc-n_bi, k*(p^2))
  Psi_draws <- matrix(0, n_mcmc-n_bi, p)
  log_Lambda_draws <- array(0, c(n_mcmc-n_bi, T, p))
  if (p > 1){
    A_draws <- matrix(0, n_mcmc-n_bi, .5*p*(p-1))  
  } else {
    A_draws <- NULL
  }
  fc_m_draws <- fc_v_draws <- fc_y_draws <- array(0, c(n_mcmc-n_bi, n_fc, p))
  
  # MCMC loop
  for (jj in 1:n_mcmc){
    
    # Step 1: Draw Pi (VAR slope coefficients)
    # De-mean y and x (using current mean estimate Psi)
    dm <- demean(y, x, Psi)
    # draw vec of Pi
    # note: formula in Clark (JBES, 2011) is wrong. 
    # Correct formula (used here) is in online appendix of Krueger et al (JBES, 2017)
    Pi_vec <- Pi_post(Pi_prm, Pi_prvi, Sigma_i_all, dm$y_dm, dm$x_dm)
    # convert to matrix form
    PiU <- make_Pi_U(Pi_vec$Pi_post_draw, p, k)
    Pi <- PiU$Pi
    
    # Step 2: Draw Psi (long-run mean)
    U <- PiU$U
    # subtract impact of dynamics from y
    q_all <- y - x %*% t(Pi)
    Psi_list <- Psi_post(Psi_prm, Psi_prvi, Sigma_i_all, q_all, U, k)
    Psi <- Psi_list$Psi_post_draw
    
    # Step 3: Draw A (residual correlations)
    y_hat <- make_y_hat(y, x, Psi, Pi)
    if (p > 1){
      A <- A_post(y_hat, Lambda_all, A_const1, A_const2)  
    } else {
      A <- matrix(1)
    }
    
    # Step 4: Draw volatilities 
    mix <- getmix_omori()
    Lambda_all_list <- sigmahelper2_new(A, y_hat, mix$q, mix$m, mix$u2, t(log_Lambda_all), Phi,
                                        log_var_prm, log_var_prv)
    # Update draws of log residual variance
    log_Lambda_all <- Lambda_all_list$log_resid_var
    # Update inverse of residual covariance matrix (at each date)
    Sigma_i_all <- Sigmahelper(A, Lambda_all_list$resid_var)
    
    # Step 5: Draw VCV of vola residuals
    aux_sse <- diff_helper(log_Lambda_all)
    Phi <- solve(rWishart(1, T + Phi_prdf, solve(aux_sse + Phi_prm))[,,1])
    
    # If post burn in
    if (jj > n_bi){
      # save parameter draws
      jj2 <- jj - n_bi
      Pi_draws[jj2, ] <- Pi_vec$Pi_post_draw
      Psi_draws[jj2, ] <- Psi
      if (p > 1){
        A_draws[jj2, ] <- A[lower.tri(A)]
      }
      log_Lambda_draws[jj2, , ] <- log_Lambda_all
      
      # make and save forecasts
      dat_fc <- dat
      log_Lambda_tmp <- tail(log_Lambda_all, 1) %>% as.numeric
      for (hh in 1:n_fc){
        # Compute mean forecast
        mean_fc <- fc_m_draws[jj2,hh, ] <- make_forecast(dat_fc, Pi, Psi)
        
        if (simulate_vola_drift){
          # Simulate log variances
          log_Lambda_tmp <- mvndrawC(log_Lambda_tmp, Phi)  
        }
        
        # Get residual vcv
        Sigma_tmp <- Sigmahelper(A, matrix(exp(log_Lambda_tmp), nrow = 1)) %>% 
          solve 
        fc_v_draws[jj2,hh, ] <- diag(Sigma_tmp)
        
        # Draw from forecast dist
        fc_draw_tmp <- fc_y_draws[jj2,hh, ] <- mvndrawC(mean_fc, Sigma_tmp)
        
        # Append draws to data
        dat_fc <- rbind(dat_fc, t(fc_draw_tmp))
      }
    }
  }
  
  # Plot vola estimate over time (posterior median)
  var_pm <- apply(exp(log_Lambda_draws), c(2, 3), mean)
  matplot(var_pm, type = "l", bty = "n", 
          col = paste0(c("blue", "green", "red", "orange", "brown"), 4))
  
  return(list(A_draws = A_draws, 
              Psi_draws = Psi_draws,
              Pi_draws = Pi_draws,
              var_pm = var_pm,
              m_draws = fc_m_draws, v_draws = fc_v_draws,
              y_draws = fc_y_draws))
  
}


# Function for baseline model of Clark, Mertens and McCracken (2019)
cmm <- function(y, n_mcmc = 15000, n_bi = 5000, 
                simulate_vola_drift = TRUE,
                log_var_prm = rep(log(.5^2), ncol(y)),
                log_var_prv = 10*diag(ncol(y)),
                Phi_prdf = 14, k_Phi = .2, 
                A_const1 = 0, A_const2 = 1e-4){
  # dimensions
  T <- nrow(y)
  if (ncol(y) != 5) stop("y not in correct form")
  
  # prior mean for Phi (VCV of log vola residuals)
  # This is the prior mean matrix of an IW dist for a matrix X
  # Ensures that X has mean matrix = (k_Phi^2)*identity
  Phi_prm <- (k_Phi^2)*(Phi_prdf-5-1)*diag(5)
  
  # check missing values in y
  y_na <- is.na(y)
  any_y_na <- any(y_na)
  incomplete_rows <- which(rowSums(y_na) > 0)
  
  # Initialize stuff
  Sigma_i_all <- matrix(solve(var(na.omit(y))), 5*T, 5, byrow = TRUE)
  Lambda_all <- matrix(1, T, 5)
  log_Lambda_all <- log(Lambda_all)
  A <- Phi <- W <- diag(5)
  # Selection matrix 
  sel <- matrix(0, 5, 25)
  sel[1,1] <- sel[2, c(2, 6)] <- sel[3, c(3, 7, 11)] <- 
    sel[4, c(4, 8, 12, 16)] <- sel[5, c(5, 9, 13, 17, 21)] <- 1
  
  # Matrices for draws
  log_Lambda_draws <- array(0, c(n_mcmc-n_bi, T, 5))
  A_draws <- matrix(0, n_mcmc-n_bi, 10)
  e_m_draws <- e_v_draws <- e_draws <- matrix(0, n_mcmc-n_bi, 5)
  
  # MCMC loop
  for (jj in 1:n_mcmc){
    
    # Handle missing values in y
    if (any_y_na){
      for (kk in incomplete_rows){
        # Identify missing values in current row
        sel_tmp <- which(y_na[kk, ])
        # Construct relevant covariance and variance matrices
        # Variance matrix of all elems in current row
        v_tmp <- Sigma_i_all[((kk-1)*5 + 1):(kk*5), ] %>% solve 
        # Non-missing elems in current row
        y_tmp <- y[kk, -sel_tmp, drop = FALSE]
        # Covariance btw missing and non-missing elems
        cov12_tmp <- v_tmp[sel_tmp, -sel_tmp, drop = FALSE]
        # Variance of missing elems
        v11_tmp <- v_tmp[sel_tmp, sel_tmp, drop = FALSE]
        # Inverse variance of non-missing elems
        v22_tmp_i <- v_tmp[-sel_tmp, -sel_tmp, drop = FALSE] %>% solve
        # Conditional mean of missing elems
        cm_tmp <- cov12_tmp %*% v22_tmp_i %*% t(y_tmp)
        # Conditional variance of missing elems
        cv_tmp <- v11_tmp - cov12_tmp %*% v22_tmp_i %*% t(cov12_tmp)
        # Simulate missing data
        y[kk, sel_tmp] <- mvndrawC(cm_tmp, cv_tmp)
      }
    }
    
    # Draw A (residual correlations)
    A <- A_post(y, Lambda_all, A_const1, A_const2)
    
    # Draw volatilities 
    mix <- getmix_omori()
    Lambda_all_list <- sigmahelper2_new(A, y, mix$q, mix$m, mix$u2, 
                                        t(log_Lambda_all), Phi,
                                        log_var_prm, log_var_prv)
    # Update draws of log residual variance
    log_Lambda_all <- Lambda_all_list$log_resid_var
    # Update inverse of residual covariance matrix (at each date)
    Sigma_i_all <- Sigmahelper(A, Lambda_all_list$resid_var)
    
    # Draw VCV of vola residuals
    aux_sse <- diff_helper(log_Lambda_all)
    Phi <- solve(rWishart(1, T + Phi_prdf, solve(aux_sse + Phi_prm))[,,1])
    
    # If post burn in
    if (jj > n_bi){
      # save parameter draws
      jj2 <- jj - n_bi
      A_draws[jj2, ] <- A[lower.tri(A)]
      log_Lambda_draws[jj2, , ] <- log_Lambda_all
      
      # make and save forecasts
      y_fc <- y
      mean_fc <- rep(0, 5)
      log_Lambda_tmp <- tail(log_Lambda_all, 1) %>% as.numeric
      Big_Sigma <- matrix(0, 25, 25)
      for (hh in 1:5){
        # (Mean forecast is always equal to zero)
        
        if (simulate_vola_drift){
          # Simulate log variances
          log_Lambda_tmp <- mvndrawC(log_Lambda_tmp, Phi)  
        }
        
        # Get residual vcv
        Sigma_tmp <- Sigmahelper(A, matrix(exp(log_Lambda_tmp), nrow = 1)) %>% solve 
        
        # Enter into joint VCV matrix (across horizons)
        Big_Sigma[((hh-1)*5+1):(hh*5), ((hh-1)*5+1):(hh*5)] <- Sigma_tmp
        
        # Append draws to data
        y_fc <- rbind(y_fc, t(mvndrawC(mean_fc, Sigma_tmp)))
      }
      
      # Construct VCV of forecast errors
      Sigma_fe <- sel %*% Big_Sigma %*% t(sel)
      e_draws[jj2, ] <- mvndrawC(mean_fc, Sigma_fe)
      e_v_draws[jj2, ] <- diag(Sigma_fe)
      
    }
  }
  
  # Plot vola estimate over time (posterior median)
  vola_pm <- apply(exp(log_Lambda_draws), c(2, 3), median)
  matplot(vola_pm, type = "l", bty = "n", 
          col = paste0(c("blue", "green", "red", "orange", "brown"), 4))
  
  return(list(e_m_draws = e_m_draws, e_v_draws = e_v_draws,
              e_draws = e_draws))
  
}

sim_cmmdata <- function(n = 1e3, a = .8, s = 1, s_f = .1){
  
  y <- arima.sim(list(ar = a), n, innov = rnorm(n, sd = s))
  fc <- e <- out <- matrix(NA, n, 5)
  for (jj in 1:5){
    # Forecast at horizon jj
    fc[,jj] <- y*(a^jj) + rnorm(n, sd = s_f)
    # Forecast errors
    e[-(1:jj),jj] <- tail(y, -jj) - head(fc[,jj], -jj)
  }
  # 'Nowcast' (aka 1-step ahead) errors
  out[-1, 1] <- tail(y, -1) - head(fc[,1], -1)
  # Forecast revisions
  out[-1, -1] <- fc[-1, -5] - head(fc[,-1], -1)
  # Theoretical error variance nr 1 (true)
  v_th1 <- cumsum(a^(2*(0:4))) + s_f^2
  # Theoretical error variance nr 2 (forecast bias counted multiple times)
  v_th2 <- cumsum(a^(2*(0:4))) + (1 + 2*0:4)*s_f^2
  list(sim = na.omit(out), e = na.omit(e), v_th1 = v_th1, v_th2 = v_th2)
}

sim_var <- function(n = 5e2, rho = .5, phi = .6, n_vars = 3, 
                    sv = TRUE){
  A_i <- diag(n_vars)
  A_i[lower.tri(A_i)] <- rho
  M <- diag(n_vars)
  diag(M) <- 1:n_vars
  V <- A_i %*% M %*% t(A_i)
  e <- matrix(0, n, n_vars)
  c <- rep(1, n)
  if (sv){
    c[100:200] <- 2
  }
  for (ii in 1:n){
    e[ii, ] <- mvndrawC(mu = rep(0, n_vars), sig = c[ii]*V)
  }
  y <- matrix(0, n, n_vars)
  for (jj in 1:n_vars){
    y[,jj] <- arima.sim(n, model = list(ar = phi), 
                        innov = e[,jj])
  }
  list(A_true = solve(A_i), dat = y)
}
