#include "libscl.h"
#include "kronprd.h"
#include "tree_sim.h"
#include "tree_mf.h"

using namespace std;
using namespace scl;

int main(int argc, char** argp, char** envp)
{
  const INTEGER n_obs = 50;

  const INTEGER mf_lags = 1;
  const INTEGER HAC_lags = 0;
  const INTEGER n_prior_reps = 1000;
  const INTEGER ngrid = 11;

  const INTEGER p_ge = 4;
  const INTEGER p_pe = 2;
  const INTEGER K_pe = 2;

  INT_32BIT seed = 740726;

  realmat prior_mean(p_ge,1);
  realmat prior_sdev(p_ge,1);

  REAL alpha = prior_mean[1] = 0.95;
  REAL sigma = prior_mean[2] = 0.02;
  REAL beta  = prior_mean[3] = 0.95;
  REAL gamma = prior_mean[4] = 12.5;

  prior_sdev[1] = 0.01;
  prior_sdev[2] = 0.01;
  prior_sdev[3] = 0.01;
  prior_sdev[4] = 2.00;

  const REAL beta_lo = 0.8;
  const REAL beta_hi = 0.99;

  const REAL gamma_lo = 0.5;
  const REAL gamma_hi = 100.0;

  cout << '\n';
  cout << "Prior mean" << '\n';
  cout << "alpha_mean " << fmt('e',26,16,alpha) << "   " << alpha << '\n';
  cout << "sigma_mean " << fmt('e',26,16,sigma) << "   " << sigma << '\n';
  cout << "beta_mean  " << fmt('e',26,16,beta)  << "   " << beta  << '\n';
  cout << "gamma_mean " << fmt('e',26,16,gamma) << "   " << gamma << '\n';
  cout << '\n';
  cout << "Prior sdev" << '\n';
  cout << "alpha_sdev " << fmt('e',26,16,prior_sdev[1])
       << "   " << prior_sdev[1] << '\n';
  cout << "sigma_sdev " << fmt('e',26,16,prior_sdev[2])
       << "   " << prior_sdev[2] << '\n';
  cout << "beta_sdev  " << fmt('e',26,16,prior_sdev[3])
       << "   " << prior_sdev[3] << '\n';
  cout << "gamma_sdev " << fmt('e',26,16,prior_sdev[4])
       << "   " << prior_sdev[4] << '\n';
  cout << '\n';
  cout << "Sizes" << '\n';
  cout << "n_obs " << n_obs << '\n';
  cout << "mf_lags " << mf_lags << '\n';
  cout << "HAC_lags " << HAC_lags << '\n';
  cout << "n_prior_reps " << n_prior_reps << '\n';
  cout << "ngrid " << ngrid << '\n';
  cout << '\n';

  string filename;
  ofstream fout;

  realmat rho(p_ge,1);
  realmat theta(p_pe,1);

  alpha = rho[1] = prior_mean[1];
  sigma = rho[2] = prior_mean[2];
  beta  = theta[1] = rho[3] = prior_mean[3];
  gamma = theta[2] = rho[4] = prior_mean[4];

  bool rv = true;

  const REAL minus_log_root_two_pi = -0.5*(M_LN2 + log(M_PI));

  realmat data(K_pe,n_obs);
  realmat data_orig(K_pe,n_obs);

  tree_variables mv;

  bool success = tree_sim(prior_mean, n_obs, seed, mv);
  if (!success) rv = false;

  for (INTEGER t=1; t<=n_obs; ++t) {
    data(1,t) = data_orig(1,t) = mv.geometric_stock_return[t];
    data(2,t) = data_orig(2,t) = mv.log_consumption_growth[t];
  }

  tree_moment_function tree_mf;

  success = tree_mf.set_data(&data);
  if (!success) rv = false;

  success = tree_mf.set_sample_size(n_obs);
  if (!success) rv = false;

  success = tree_mf.set_L(mf_lags);
  if (!success) rv = false;

  success = tree_mf.set_theta(theta);
  if (!success) rv = false;

  INTEGER T0 = tree_mf.get_T0();

  INTEGER d = tree_mf.get_d();

  gmm tree_gmm(&tree_mf, mf_lags, &data, n_obs, HAC_lags);

  tree_gmm.set_correct_W_for_mean(true);
  tree_gmm.set_warning_messages(true);
  tree_gmm.set_regularize_W(true);

  success = tree_gmm.set_data(&data);
  if (!success) rv = false;

  alpha = rho[1] = prior_mean[1];
  sigma = rho[2] = prior_mean[2];
  beta  = theta[1] = rho[3] = prior_mean[3];
  gamma = theta[2] = rho[4] = prior_mean[4];

  d = tree_gmm.get_d();

  tree_mf.set_theta(theta);
  
  realmat z;
  REAL log_likelihood = tree_gmm.likelihood(theta, z).log_den;

  REAL z_beta = (rho[3]-prior_mean[3])/prior_sdev[3];
  REAL z_gamma = (rho[4]-prior_mean[4])/prior_sdev[4];
  REAL e_beta = minus_log_root_two_pi-log(prior_sdev[3])-0.5*pow(z_beta,2);
  REAL e_gamma = minus_log_root_two_pi-log(prior_sdev[4])-0.5*pow(z_gamma,2);
  REAL log_pe_prior = e_beta + e_gamma;

  const REAL sf = 1.0 - exp(1.0);

  REAL tz1 = tanh(z[1]/4.0);
  REAL dZ1;
  if (-1.0 < tz1 && tz1 < 1.0) {
    REAL top = -4.0*(1.0 - tz1);
    REAL bot = 1 - pow(tz1,2);
    dZ1 = top/bot;
  }
  else if (tz1 >= 1.0){
    dZ1 = -2.0;
  }
  else {
    dZ1 = 2.0;
  }
  REAL det = dZ1*sf*sf;
  REAL adj = det < 0.0 ? -det : det;
  REAL log_adj = log(adj);
  
  cout << '\n';
  cout << "First tree_gmm call" << '\n';
  cout << "z " << z << '\n';
  cout << "T0                 " << T0 <<'\n';  
  cout << "log_likelihood " << log_likelihood << '\n';
  cout << "log_pe_prior " << log_pe_prior << '\n';
  cout << "log_adj " << log_adj << '\n';
  cout << '\n';

  realmat prior_range(p_ge,2);
  for (INTEGER i=1; i<=p_ge; ++i) {
    prior_range(i,1) =  REAL_MAX;
    prior_range(i,2) = -REAL_MAX;
  }

// Contour starts here  

  realmat xgrid(ngrid,1);       // beta
  realmat ygrid(ngrid,1);       // gamma
  realmat lgrid(ngrid,ngrid);   // likelihood
  realmat pgrid(ngrid,ngrid);   // prior
  realmat agrid(ngrid,ngrid);   // adj
  realmat dgrid(ngrid,ngrid);   // fabs(lgrid) - fabs(agrid)
  realmat lgrid_save(ngrid,ngrid);
  realmat pgrid_save(ngrid,ngrid);
  realmat agrid_save(ngrid,ngrid);

  for (INTEGER i=1; i<=ngrid; ++i) {
    REAL delta = REAL(i-1)/REAL(ngrid-1);
    xgrid[i] = beta_lo + (beta_hi-beta_lo)*delta;
    ygrid[i] = gamma_lo + (gamma_hi-gamma_lo)*delta;
  }

  REAL min_adj = REAL_MAX;
  REAL max_adj = -REAL_MAX;
  REAL max_dif = -REAL_MAX;

  INTEGER W_error_count = 0;
  INTEGER W_success_count = 0;

  INTEGER Z_error_count = 0;

  for (INTEGER rep=1; rep<=n_prior_reps; ++rep) {

    for (INTEGER i=1; i<=p_ge; ++i) {
      rho[i] = prior_mean[i] + prior_sdev[i]*unsk(seed);
    }
    
    while (fabs(rho[1]) >= 0.99) 
      rho[1]=prior_mean[1]+prior_sdev[1]*unsk(seed);
    while ((rho[2] <= 0.01) || (rho[2] >= 100.0)) 
      rho[2]=prior_mean[2]+prior_sdev[2]*unsk(seed);
    while ((rho[3] <= beta_lo) || (rho[3] >= beta_hi)) 
      rho[3]=prior_mean[3]+prior_sdev[3]*unsk(seed);
    while ((rho[4] <= gamma_lo) || (rho[4] >= gamma_hi))
      rho[4]=prior_mean[4]+prior_sdev[4]*unsk(seed);
    
    alpha = rho[1];
    sigma = rho[2];
    beta  = theta[1] = rho[3];
    gamma = theta[2] = rho[4];

    for (INTEGER i=1; i<=p_ge; ++i) {
      prior_range(i,1) = rho[i] < prior_range(i,1) ? rho[i] : prior_range(i,1);
      prior_range(i,2) = rho[i] > prior_range(i,2) ? rho[i] : prior_range(i,2);
    }

    success = tree_sim(rho, n_obs, seed, mv);
    if (!success) rv = false;

    for (INTEGER t=1; t<=n_obs; ++t) {
      data(1,t) = mv.geometric_stock_return[t];
      data(2,t) = mv.log_consumption_growth[t];
    }

    if (data != data) rv = false;

    INTEGER ier = 0;

    success = tree_gmm.set_data(&data);
    if (!success) rv = false;

    for (INTEGER i=1; i<=ngrid; ++i) {
      for (INTEGER j=1; j<=ngrid; ++j) {

        beta  = theta[1] = xgrid[i];
        gamma = theta[2] = ygrid[j];

        log_likelihood = tree_gmm.likelihood(theta,z).log_den;

	ier = tree_gmm.get_W_numerr();

        z_beta = (beta-prior_mean[3])/prior_sdev[3];
        z_gamma = (gamma-prior_mean[4])/prior_sdev[4];
        e_beta = minus_log_root_two_pi-log(prior_sdev[3])-0.5*pow(z_beta,2);
        e_gamma = minus_log_root_two_pi-log(prior_sdev[4])-0.5*pow(z_gamma,2);
        log_pe_prior = e_beta + e_gamma;

        const REAL c = 4.0;

        tz1 = tanh(z[1]/c);
        if (-1.0 < tz1 && tz1 < 1.0) {
          REAL top = -c*(1.0 - tz1);
          REAL bot = 1 - pow(tz1,2);
          dZ1 = top/bot;
        }
        else if (tz1 >= 1.0){
          dZ1 = -2.0;
        }
        else {
          dZ1 = -REAL_MAX;
	  ++Z_error_count;
        }
        det = dZ1*sf*sf;
        adj = det < 0.0 ? -det : det;
        log_adj = log(adj);

        if (adj != adj) 
          //cerr << beta <<' '<< gamma <<' '<< z[1] 
          //     <<' '<< s1 <<' '<< esb <<' '<< adj << '\n';
          cerr << z[1] <<' '<< tanh(z[1]/c) << '\n';

        lgrid(i,j) = log_likelihood;   // likelihood
        pgrid(i,j) = log_pe_prior;     // prior
        agrid(i,j) = log_adj;          // adj

        min_adj = log_adj < min_adj ? log_adj : min_adj;
        max_adj = log_adj > max_adj ? log_adj : max_adj;
        if (ier == 0) {
          ++W_success_count;
        }
        else {
          ++W_error_count;
          /*
          if (debug) {
            cerr << "W_numerr = " << fmt('d',3,ier)
                 << fmt('f',8,4,beta) << fmt('f',9,4,gamma)
                 << fmt('f',7,2,lgrid(i,j)) << fmt('f',7,2,agrid(i,j)) <<'\n';
          }
          */
        }

      }
    }
    REAL dif = max_adj-min_adj;
    if (dif > max_dif && ier == 0) {
      max_dif = dif;
      lgrid_save = lgrid;
      pgrid_save = pgrid;
      agrid_save = agrid;
      filename = "pc_gmetric_data_" + fmt('d',4,n_obs)('0') + ".txt";
      realmat Tdata;
      Tdata = T(data);
      INTEGER cnt = writetable(filename.c_str(),Tdata,20,16);
      if (cnt == 0) warn("Warning, writetable failed, " + filename);
      filename = "pc_gmetric_rho_" + fmt('d',4,n_obs)('0') + ".txt";
      cnt = writetable(filename.c_str(),rho,20,16);
      if (cnt == 0) warn("Warning, writetable failed, " + filename);
    }
    ier = 0;
  }

  lgrid = lgrid_save;
  pgrid = pgrid_save;
  agrid = agrid_save;

  for (INTEGER i=1; i<=ngrid*ngrid; ++i) {
    dgrid[i] = fabs(lgrid[i]) - fabs(agrid[i]);
    dgrid[i] *= M_LOG10E;
  }

  cout << "prior_range " << prior_range << '\n';
  cout << "xgrid " << xgrid << '\n';
  cout << "ygrid " << ygrid << '\n';
  cout << "lgrid " << lgrid << '\n';
  cout << "pgrid " << pgrid << '\n';
  cout << "agrid " << agrid << '\n';
  cout << "dgrid " << dgrid << '\n';

  REAL dmin = REAL_MAX;
  for (INTEGER i=1; i<=ngrid*ngrid; ++i) {
    if (IsFinite(dgrid[i])) {
      dmin = (dgrid[i] < dmin ? dgrid[i] : dmin);
    }
  }

  cout << '\n';
  cout << "dmin = " << dmin << '\n';

  filename = "contour_" + fmt('d',4,n_obs)('0') + ".tex";
  fout.open(filename.c_str());
  if (!fout) error("Error, cannot open " + filename);

  fout << "\\multicolumn{" << ngrid+1 << "}{c}{Log Likelihood}\\\\" << '\n';
  fout << "$\\gamma/\\beta$&";
  for (INTEGER j=1; j<=ngrid; ++j) fout << "& " <<fmt('f',7,2,xgrid[j]) <<' ';
  fout << " \\\\" << '\n';
  for (INTEGER i=1; i<=ngrid; ++i) {
    fout << fmt('f',6,2,ygrid[i]) << " &";
    for (INTEGER j=1; j<=ngrid; ++j) fout <<"& "<<fmt('f',7,2,lgrid(i,j)) <<' ';
    fout << " \\\\" << '\n';
  }

  fout << " \\\\" << '\n';

  fout << "\\multicolumn{" << ngrid+1 << "}{c}{Log Adjustment}\\\\" << '\n';
  fout << "$\\gamma/\\beta$&";
  for (INTEGER j=1; j<=ngrid; ++j) fout << "& " <<fmt('f',7,2,xgrid[j]) <<' ';
  fout << " \\\\" << '\n';
  for (INTEGER i=1; i<=ngrid; ++i) {
    fout << fmt('f',6,2,ygrid[i]) << " &";
    for (INTEGER j=1; j<=ngrid; ++j) fout <<"& "<<fmt('f',7,2,agrid(i,j)) <<' ';
    fout << " \\\\" << '\n';
  }
  
  fout.clear(); fout.close();

  cout << "W_success_count = " << W_success_count << '\n';
  cout << "W_error_count = " << W_error_count << '\n';
  cout << "Z_error_count = " << Z_error_count << '\n';

  cout << '\n';

  cout << (rv ? "success" : "failure") << '\n';

  return 0;
}

