#include "libscl.h"
#include "tree_sim.h"
#include "tree_mf.h"
#include "preimage_norm.h"
#include "moment_norm.h"

using namespace std;
using namespace scl;

extern void rom(INTEGER rows, INTEGER cols, scl::realmat& Q, INT_32BIT& seed);

int main(int argc, char** argp, char** envp)
{
  const INTEGER n_obs = 500;
  const INTEGER dat_rows = 2;
  const INTEGER mf_lags = 1;
  const INTEGER HAC_lags = 1;
  const INTEGER n_prior_reps = 1;

  INT_32BIT seed = 740726;

  realmat rho(4,1);
  realmat prior_mean(4,1);
  realmat prior_sdev(4,1);

  REAL alpha = prior_mean[1] = 0.95;
  REAL sigma = prior_mean[2] = 0.03;
  REAL beta  = prior_mean[3] = 0.95;
  REAL gamma = prior_mean[4] = 1.2600114143005971e+01;

  prior_sdev[1] = 0.02;
  prior_sdev[2] = 0.02;
  prior_sdev[3] = 0.02;
  prior_sdev[4] = 2.00;

  cout << '\n';
  cout << "Parameters" << '\n';
  cout << "alpha_mean " << fmt('e',27,17,alpha) << '\n';
  cout << "sigma_mean " << fmt('e',27,17,sigma) << '\n';
  cout << "beta_mean  " << fmt('e',27,17,beta)  << '\n';
  cout << "gamma_mean " << fmt('e',27,17,gamma) << '\n';
  cout << "alpha_sdev " << fmt('e',27,17,prior_sdev[1]) << '\n';
  cout << "sigma_sdev " << fmt('e',27,17,prior_sdev[2]) << '\n';
  cout << "beta_sdev  " << fmt('e',27,17,prior_sdev[3]) << '\n';
  cout << "gamma_sdev " << fmt('e',27,17,prior_sdev[4]) << '\n';
  cout << '\n';
  cout << "Sizes" << '\n';
  cout << "n_obs " << n_obs << '\n';
  cout << "mf_lags " << mf_lags << '\n';
  cout << "HAC_lags " << HAC_lags << '\n';
  cout << "n_prior_reps " << n_prior_reps << '\n';
  cout << '\n';

  bool rv = true;

  realmat data(dat_rows,n_obs);
  realmat data_orig(dat_rows,n_obs);
  realmat data_tran(n_obs,dat_rows);

  tree_variables mv;

  bool success = tree_sim(prior_mean, n_obs, seed, mv);

  if (!success) rv = false;

  for (INTEGER t=1; t<=n_obs; ++t) {
    data(1,t) = data_orig(1,t) = data_tran(t,1) = mv.geometric_stock_return[t];
    data(2,t) = data_orig(2,t) = data_tran(t,2) = mv.log_consumption_growth[t];
  }

  intvec idx = seq(1,2);
  realmat b0,B,V,C;

  varcoef(data_tran,idx,1,b0,B,V,C);
  cout << B;

  //cout << simple(T(data));

  preimage_norm Znorm(data,n_obs,mf_lags,HAC_lags);
  moment_norm mnorm(data,n_obs,mf_lags,HAC_lags);

  INTEGER d = Znorm.get_d();

  realmat z(d,1);
  realmat f(1,1);

  INTEGER scope = d;
  realmat x(scope,1);
  realmat xtst(scope,1);
  realmat xmin(scope,1);

  REAL increment = 0.1;
  REAL ztol = 1.0e-3;
  REAL mtol = 1.0e-8;
  REAL iter_limit = 5000;

  tree_moment_function tree_mf;

  gmm tree_gmm(&tree_mf, mf_lags, &data, n_obs, HAC_lags);

  tree_gmm.set_correct_W_for_mean(true);
  tree_gmm.set_warning_messages(true);

  realmat m;
  realmat W;
  REAL logdetW;
  INTEGER rankW;
  realmat S;

  nlopt bfgs(mnorm);
  bfgs.set_lower_bound(0.0);
  bfgs.set_solution_tolerance(mtol);
  bfgs.set_iter_limit(iter_limit);

  realmat m_sum(d,1,0.0);
  realmat m_ssq(d,d,0.0);

  for (INTEGER rep=1; rep<=n_prior_reps; ++rep) {

    for (INTEGER i=1; i<=4; ++i) {
      rho[i] = prior_mean[i] + prior_sdev[i]*unsk(seed);
    }

    for (INTEGER zcase = 0; zcase<=1; ++zcase) {

      if (zcase == 0) {
        data = data_orig;
        for (INTEGER i=1; i<=d; ++i) z[i] = 0.0;
        Znorm.set_z_rho(z,rho);
        Znorm.set_data(data);
        cout << "zcase = " << zcase << '\n';
      } 
      else {
        if(!tree_gmm.set_data(&data_orig)) rv = false;
        tree_gmm(rho,m,W,logdetW,rankW,S);
        realmat R;
        if (cholesky(W,R) != d) rv = false;
        z = sqrt(n_obs)*R*m;
        Znorm.set_z_rho(z,rho);
        Znorm.set_data(data);
	mnorm.set_m_rho(m,rho);
        mnorm.set_data(data);
        cout << "zcase = " << zcase << '\n';
      }

      REAL ymin;
      INTEGER t = 1;
      INTEGER iter;

      string filename = "peplt.txt";
      ofstream fout;
      fout.open(filename.c_str());
      if (!fout) error("Error, cannot open " + filename);
  
      for (iter=1; iter<=iter_limit; ++iter) {
  
        if (zcase == 0) {
          ++t;
        }
        else {
          t = n_obs/2;
        }
  
        INTEGER base = dat_rows*(t - 1) + t - 1;

        if (base + scope  > n_obs*dat_rows) t = 1;
  
        Znorm.set_t(t);
	mnorm.set_t(t);

        for (INTEGER i=1; i<=scope; ++i) x[i] = data[base + i];

        if (zcase == 0) {
          realmat Q;
          rom(scope,scope,Q,seed);
  
          REAL incr = ran(seed)*increment;
  
          if (!Znorm.get_f(x,f)) rv = false;
          ymin = f[1]; xmin = x;
  
          for (INTEGER j=1; j<=scope; ++j) {
            for (INTEGER i=1; i<=scope; ++i) {
              xtst[i] = x[i] + incr*Q(i,j);
            }
            if (Znorm.get_f(xtst,f)) { 
              if (f[1] < ymin) {ymin = f[1]; xmin = xtst;}
            }
          }
        }
        else {
          realmat Q;
          rom(scope,scope,Q,seed);
  
          REAL incr = ran(seed)*increment;
  
          if (!Znorm.get_f(x,f)) rv = false;
          ymin = f[1]; xmin = x;

	  for (INTEGER i=1; i<=scope; ++i) {
	    fout << fmt('f',15,10,xtst[i]);
	  }
          fout << fmt('f',15,10,f[1]) << '\n';
  
          for (INTEGER j=1; j<=scope; ++j) {
            for (INTEGER i=1; i<=scope; ++i) {
              xtst[i] = x[i] + incr*Q(i,j);
            }
            if (Znorm.get_f(xtst,f)) { 
              if (f[1] < ymin) {ymin = f[1]; xmin = xtst;}
            }
          }
        }

        Znorm.set_x(xmin); 
        mnorm.set_x(xmin); 

        data = Znorm.get_data();
  
        if (ymin < ztol) break; 
      }

      fout.clear(); fout.close(); 
  
      if (zcase == 0) {
        cout << fmt('d',5,iter) <<' '<< fmt('f',20,8,ymin) << '\n';
      }
      else {
        cout << fmt('d',5,iter) <<' '<< fmt('f',20,8,ymin) << '\n';
	cout.flush();
      }
  
      //cout << simple(T(data));

      if (iter >= iter_limit) rv = false;

      if(!tree_gmm.set_data(&data)) rv = false;

      tree_gmm(prior_mean,m,W,logdetW,rankW,S);

      m_sum += m;
      m_ssq += m*T(m);
    }
  }

  realmat W_pre = m_ssq - m_sum*T(m_sum)/n_prior_reps;

  W_pre = W_pre/(n_prior_reps - 1);

  vecwrite("W_pre.dat", W_pre);

  cout << (rv ? "success" : "failure") << '\n';

  return 0;
}

