#include "libscl.h"
#include "spline_interpolator.h"
#include "tree_policy.h"

using namespace std;
using namespace scl;

struct tree_usrmod_variables {
  INTEGER N;   // simulation length
  scl::realmat log_consumption;
  scl::realmat log_stock_price;
  scl::realmat log_marginal_rate_of_substitution;
  scl::realmat log_consumption_growth;
  scl::realmat geometric_stock_return;
  scl::realmat geometric_risk_free_rate;
  scl::realmat pricing_errors;
  tree_usrmod_variables() { }
  tree_usrmod_variables(INTEGER num_sim)
  : N(num_sim),
    log_consumption(N,1,0.0),
    log_stock_price(N,1,0.0),
    log_marginal_rate_of_substitution(N,1,0.0),
    log_consumption_growth(N,1,0.0),
    geometric_stock_return(N,1,0.0),
    geometric_risk_free_rate(N,1,0.0),
    pricing_errors(N,1,0.0)
  { }
  std::vector<std::string> get_tree_usrmod_variables(scl::realmat& mv);
};

vector<string> 
tree_usrmod_variables::get_tree_usrmod_variables(realmat& vars)
{
  vector<string> names(7);
  names[0] = "log_consumption";
  vars = log_consumption;
  names[1] = "log_stock_price";
  vars = cbind(vars,log_stock_price);
  names[2] = "log_marginal_rate_of_substitution";
  vars = cbind(vars,log_marginal_rate_of_substitution);
  names[3] = "log_consumption_growth";
  vars = cbind(vars,log_consumption_growth);
  names[4] = "geometric_stock_return";
  vars=cbind(vars,geometric_stock_return);
  names[5] = "geometric_risk_free_rate";
  vars=cbind(vars,geometric_risk_free_rate);
  names[6] = "pricing_errors";
  vars=cbind(vars,pricing_errors);
  return names;
}

int main(int argc, char** argp, char** envp)
{
  INT_32BIT seed = 740726;

  const INTEGER n_sim = 200000; 

  /*
  REAL alpha = 0.95;
  REAL sigma = 0.03;
  REAL beta  = 0.95;

  REAL r = -0.02015;

  REAL alpha = 0.95;
  REAL sigma = 0.02;
  REAL beta  = 0.95;

  REAL r = 0.02;

  REAL gamma = sqrt(-2.0*(r + log(beta)))/sigma;  // Gives risk free rate of r
  */

  REAL alpha = 0.95;
  REAL sigma = 0.02;
  REAL beta  = 0.95;
  REAL gamma = 12.5;

  cout << '\n';
  cout << "Parameters" << '\n';
  cout << "alpha " << fmt('f',10,5,alpha) << '\n';
  cout << "sigma " << fmt('f',10,5,sigma) << '\n';
  cout << "beta  " << fmt('f',10,5,beta) << '\n';
  cout << "gamma " << fmt('f',10,5,gamma) << fmt('e',27,16,gamma) << '\n';

  spline_interpolator P;
  spline_interpolator Q;
  spline_interpolator p;
  spline_interpolator q;

  tree_usrmod_variables mv(n_sim);

  bool converge = tree_policy(alpha, sigma, beta, gamma, P, Q, p, q);

  if (!converge) warn("Warning, tree_policy did not converge");

  REAL sdev_y = sigma/sqrt(1.0 - alpha*alpha);
  REAL y_lag = sdev_y*unsk(seed);

  for (INTEGER i=1; i<=100; ++i) {
    REAL y = alpha*y_lag + sigma*unsk(seed);
    y_lag = y;
  }

  for (INTEGER i=1; i<=mv.N; ++i) {
    REAL y = alpha*y_lag + sigma*unsk(seed);
    mv.log_consumption[i] = y;
    mv.log_stock_price[i] = p(y);
    REAL stock_payoff = exp(y) + exp(p(y));
    mv.log_marginal_rate_of_substitution[i] = log(beta) - gamma*(y - y_lag);
    mv.log_consumption_growth[i] = y - y_lag;
    mv.geometric_stock_return[i] = log(stock_payoff) - p(y_lag);
    mv.geometric_risk_free_rate[i] 
      = -log(beta) - (1.0-alpha)*gamma*y - 0.5*pow(gamma*sigma,2);
    mv.pricing_errors[i] 
      = 1.0 - exp(mv.log_marginal_rate_of_substitution[i] +
        mv.geometric_stock_return[i]);
    y_lag = y;  
  }

  realmat s = simple(mv.log_consumption_growth);

  cout << '\n';
  cout << "log_consumption_growth" << '\n';
  cout << "mean = " << fmt('f',10,5,s[1]) << '\n';
  cout << "sdev = " << fmt('f',10,5,s[2]) << '\n';

  s = simple(mv.geometric_stock_return);

  cout << '\n';
  cout << "geometric_stock_return" << '\n';
  cout << "mean = " << fmt('f',10,5,s[1]) << '\n';
  cout << "sdev = " << fmt('f',10,5,s[2]) << '\n';

  s = simple(mv.geometric_risk_free_rate);

  cout << '\n';
  cout << "geometric_risk_free_rate" << '\n';
  cout << "mean = " << fmt('f',10,5,s[1]) << '\n';
  cout << "sdev = " << fmt('f',10,5,s[2]) << '\n';

  s = simple(mv.pricing_errors);

  cout << '\n';
  cout << "pricing_errors" << '\n';
  cout << "mean = " << fmt('f',10,5,s[1]) << '\n';
  cout << "sdev = " << fmt('f',10,5,s[2]) << '\n';

  cout << '\n';

  realmat pc_log_level = cbind(mv.log_stock_price,mv.log_consumption);

  writetable("pc_log_level.txt",pc_log_level("1:5000",""),20,16);

  realmat pc_geometric 
    = cbind(mv.geometric_stock_return,mv.log_consumption_growth);

  writetable("pc_geometric.txt",pc_geometric("1:5000",""),20,16);

  return 0;
}

