#include "libscl.h"
#include "tree_sim.h"
#include "tree_mf.h"
#include "rom.h"
#include "preimage_norm.h"
#include "moment_norm.h"
#include "xstar_norm.h"

using namespace std;
using namespace scl;

int main(int argc, char** argp, char** envp)
{
  const INTEGER n_obs = 500;
  const INTEGER dat_rows = 2;
  const INTEGER mf_lags = 1;
  const INTEGER HAC_lags = 1;
  const INTEGER n_prior_reps = 500;

  INT_32BIT seed = 740726;

  realmat prior_mean(4,1);
  realmat prior_sdev(4,1);
  realmat rho(4,1);
  realmat theta(2,1);

  REAL alpha = rho[1] = prior_mean[1] = 0.95;
  REAL sigma = rho[2] = prior_mean[2] = 0.03;
  REAL beta  = rho[3] = theta[1] = prior_mean[3] = 0.95;
  REAL gamma = rho[4] = theta[2] = prior_mean[4] = 1.2600114143005971e+01;

  prior_sdev[1] = 0.02;
  prior_sdev[2] = 0.02;
  prior_sdev[3] = 0.02;
  prior_sdev[4] = 2.00;

  cout << '\n';
  cout << "Parameters" << '\n';
  cout << "alpha_mean " << fmt('e',26,16,alpha) << "   " << alpha << '\n';
  cout << "sigma_mean " << fmt('e',26,16,sigma) << "   " << sigma << '\n';
  cout << "beta_mean  " << fmt('e',26,16,beta)  << "   " << beta  << '\n';
  cout << "gamma_mean " << fmt('e',26,16,gamma) << "   " << gamma << '\n';
  cout << "alpha_sdev " << fmt('e',26,16,prior_sdev[1]) 
       << "   " << prior_sdev[1] << '\n';
  cout << "sigma_sdev " << fmt('e',26,16,prior_sdev[2]) 
       << "   " << prior_sdev[2] << '\n';
  cout << "beta_sdev  " << fmt('e',26,16,prior_sdev[3])
       << "   " << prior_sdev[3] << '\n';
  cout << "gamma_sdev " << fmt('e',26,16,prior_sdev[4])
       << "   " << prior_sdev[4] << '\n';
  cout << '\n';
  cout << "Sizes" << '\n';
  cout << "n_obs " << n_obs << '\n';
  cout << "mf_lags " << mf_lags << '\n';
  cout << "HAC_lags " << HAC_lags << '\n';
  cout << "n_prior_reps " << n_prior_reps << '\n';
  cout << '\n';

  ofstream fout;
  string filename;

  bool rv = true;

  realmat data(dat_rows,n_obs);
  realmat data_orig(dat_rows,n_obs);
  realmat data_tran(n_obs,dat_rows);

  tree_variables mv;

  if(!tree_sim(prior_mean, n_obs, seed, mv)) rv = false;

  for (INTEGER t=1; t<=n_obs; ++t) {
    data(1,t) = data_orig(1,t) = data_tran(t,1) = mv.geometric_stock_return[t];
    data(2,t) = data_orig(2,t) = data_tran(t,2) = mv.log_consumption_growth[t];
  }

  intvec idx = seq(1,2);
  realmat b0,B,V,C;

  cout << '\n';
  cout << "data(1,t) = mv.geometric_stock_return[t]" << '\n';
  cout << "data(2,t) = mv.log_consumption_growth[t]" << '\n';

  cout << '\n';
  cout << "mean, sdev, var, skew, kurt" << simple(T(data));
  varcoef(data_tran,idx,1,b0,B,V,C);
  cout << '\n';
  cout << "VAR B" << B;
  cout << "VAR V" << V;
  cout << "VAR C" << C;
  cout << '\n';

  cout << "rho and theta" << rho << theta;

  cout.flush();

  xstar_norm xnorm(data,n_obs,mf_lags,HAC_lags);

  INTEGER d = xnorm.get_d();

  realmat scl(d,1,1.0/pow(n_obs,2));
  xnorm.set_scale(scl);

  tree_moment_function tree_mf;

  gmm tree_gmm(&tree_mf, mf_lags, &data, n_obs, HAC_lags);

  tree_gmm.set_correct_W_for_mean(true);
  tree_gmm.set_warning_messages(true);

  realmat m;
  realmat W;
  REAL logdetW;
  INTEGER rankW;
  realmat S;
  realmat R;
  REAL obj;

  realmat z(d,1);
  realmat s(d,1);
  realmat f(1,1);

  cout << "first call to tree_gmm and xnorm" << '\n';

  cout << '\n';
  intvec jdx = seq(n_obs-5,n_obs);
  cout << "data(,n_obs-5:n_obs) = " << data("",jdx);

  if (!tree_gmm.set_data(&data_orig)) {
    rv = false; 
    cout << "tree_gmm.set_data failed" << '\n';
  }
  obj = tree_gmm(theta,m,W,logdetW,rankW,S);
  if (cholesky(W,R) != d) {
    rv = false; 
    cout << "cholesky failed" << '\n';
  }
  z = sqrt(n_obs)*R*m;

  cout << '\n';
  cout << "m, W, z " << m << W << z;

  if (!xnorm.set_theta(theta)) {
    rv = false; 
    cout << "xnorm.set_z_theta failed" << '\n';
  }
  if (!xnorm.set_data_z(data,z,s)) {
    rv = false; 
    cout << "xnorm.set_data_z failed" << '\n';
  }

  cout << '\n';
  cout << "s, s_inverse = " << s << xnorm.get_s_inverse();
  cout << '\n';
  cout << "data(,n_obs-5:n_obs) = " << xnorm.get_data()("",jdx);

  data = xnorm.get_data();

  if (!tree_gmm.set_data(&data)) rv = false;
  obj = tree_gmm(theta,m,W,logdetW,rankW,S);
  if (cholesky(W,R) != d) rv = false;
  z = sqrt(n_obs)*R*m;

  cout << '\n';
  cout << "m, W, z " << m << W << z;

  cout.flush();

  INTEGER scope = 2;
  realmat x(scope,1);
  realmat xtst(scope,1);
  realmat xmin(scope,1);

  REAL increment = 0.02;
  REAL tol = 1.0e-3;
  REAL iter_limit = 10000;

  bool stop_here = false;
  if (stop_here) return 0;

  for (INTEGER rep=1; rep<=n_prior_reps; ++rep) {

    for (INTEGER i=1; i<=4; ++i) {
      rho[i] = prior_mean[i] + prior_sdev[i]*unsk(seed);
    }

    alpha = rho[1];
    sigma = rho[2];
    beta  = theta[1] = rho[3];
    gamma = theta[2] = rho[4];

    data = data_orig;

    if (!tree_gmm.set_data(&data)) rv = false;
    obj = tree_gmm(theta,m,W,logdetW,rankW,S);
    if (cholesky(W,R) != d) rv = false;
    z = sqrt(n_obs)*R*m;

    if (!xnorm.set_data_z(data,z,s)) rv = false;
    if (!xnorm.set_theta(theta)) rv = false;

    REAL ymin;

    INTEGER t = n_obs/2;
    INTEGER base = dat_rows*(t - 1) ;

    xnorm.set_t(t);

    for (INTEGER i=1; i<=scope; ++i) x[i] = data[base + i];

    if (!xnorm.get_f(x,f)) rv = false;

    xmin = x;
    ymin = f[1]; 

    data = xnorm.get_data();

    cout << '\n';
cout << "start" << "  " << fmt('f',20,8,f[1]) << '\n';
  
    t = 0;
    INTEGER iter;
    for (iter=1; iter<=iter_limit; ++iter) {
  
      ++t;
  
      INTEGER base = dat_rows*(t - 1);

      if (base + scope  > n_obs*dat_rows - 3*dat_rows) t = 1;

      base = dat_rows*(t - 1);
  
      xnorm.set_t(t);
  
      for (INTEGER i=1; i<=scope; ++i) x[i] = data[base + i];

       realmat Q;
       rom(scope,scope,Q,seed);
  
       REAL incr = ran(seed)*increment;
  
       for (INTEGER j=1; j<=scope; ++j) {
         for (INTEGER i=1; i<=scope; ++i) {
           xtst[i] = x[i] + incr*Q(i,j);
         }
       }

       if (xnorm.get_f(xtst,f)) { 
         if (f[1] < ymin) {
	   ymin = f[1]; 
	   xmin = xtst;
//cerr << "Got to here 1 " << t << ' ' << f[1] << ' ' << ymin << '\n';
	 }
	 else {
	   xnorm.set_x(x);
//cerr << "Got to here 2 " << t << ' ' << f[1] << ' ' << ymin << '\n';
         }
       }
       else {
         rv = false;
       }

      if (ymin < tol) break; 
    }
  
cout << fmt('d',5,iter) << "  " << fmt('f',20,8,ymin) << '\n';

  }

  cout << (rv ? "success" : "failure") << '\n';

  return 0;
}

