/* ----------------------------------------------------------------------------

Copyright (C) 2018.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libmle.h"
#include "mle.h"

using namespace std;
using namespace scl;
using namespace mle;

namespace mle {
  
  vector<string> 
  tree_usrmod_variables::get_tree_usrmod_variables(realmat& vars)
  {
    vector<string> names(6);
    names[0] = "log_consumption";
    vars = log_consumption;
    names[1] = "log_stock_price";
    vars = cbind(vars,log_stock_price);
    names[2] = "log_marginal_rate_of_substitution";
    vars = cbind(vars,log_marginal_rate_of_substitution);
    names[3] = "log_consumption_growth";
    vars = cbind(vars,log_consumption_growth);
    names[4] = "geometric_stock_return";
    vars=cbind(vars,geometric_stock_return);
    names[5] = "geometric_risk_free_rate";
    vars=cbind(vars,geometric_risk_free_rate);
    return names;
  }
  
  tree_usrmod::tree_usrmod(const scl::realmat& dat, INTEGER len_mod_parm,
      INTEGER len_mod_func, const std::vector<std::string>& pfvec,
      const std::vector<std::string>& alvec, std::ostream& detail)
  : data(dat), simulation_seed(740726), mv(n_sim), 
    prior_mean(len_mod_parm,1), prior_sdev(len_mod_parm,1) 
  {

    if (n_parms != len_mod_parm) {
      error("Error, usrmod, constructor, len_mod_parm is wrong in parmfile");
    }

    if (n_stats != len_mod_func) {
      error("Error, usrmod, constructor, len_mod_parm is wrong in parmfile");
    }

    if (n_vars != dat.nrow()) {
      error("Error, usrmod, constructor, M is wrong in parmfile");
    }

    /*
    if (n_obs != dat.ncol()) {
      error("Error, usrmod, constructor, n is wrong in parmfile");
    }
    */

    if (alvec.size() != 10) {
      error("Error, tree_usrmod, prior mean and sdev not in parmfile");
    }

    prior_mean[1] = atof(alvec[1].substr(0,12).c_str());
    prior_mean[2] = atof(alvec[2].substr(0,12).c_str());
    prior_mean[3] = atof(alvec[3].substr(0,12).c_str());
    prior_mean[4] = atof(alvec[4].substr(0,12).c_str());
    prior_sdev[1] = atof(alvec[5].substr(0,12).c_str());
    prior_sdev[2] = atof(alvec[6].substr(0,12).c_str());
    prior_sdev[3] = atof(alvec[7].substr(0,12).c_str());
    prior_sdev[4] = atof(alvec[8].substr(0,12).c_str());

    vecwrite("tree.prior_mean.dat",prior_mean);
    vecwrite("tree.prior_sdev.dat",prior_sdev);

    realmat sim, stats;
    set_rho(prior_mean);
    gen_sim(sim,stats);

    vecwrite("tree.prior_sim.dat",sim);
  }
  
  bool tree_usrmod::gen_sim(realmat& sim, realmat& stats)
  {
    INT_32BIT seed = simulation_seed;

    REAL alpha = rho[1];
    REAL sigma = rho[2];
    REAL beta  = rho[3];
    REAL gamma = rho[4];

    bool converge = tree_policy(alpha, sigma, beta, gamma, P, Q, p, q);

    REAL sdev_y = sigma/sqrt(1.0 - alpha*alpha);
    REAL y_lag = sdev_y*unsk(seed);

    for (INTEGER i=1; i<=mv.N; ++i) {
      REAL y = alpha*y_lag + sigma*unsk(seed);
      mv.log_consumption[i] = y;
      mv.log_stock_price[i] = p(y);
      REAL stock_payoff = exp(y) + exp(p(y));
      mv.log_marginal_rate_of_substitution[i] = log(beta) - gamma*(y - y_lag);
      mv.log_consumption_growth[i] = y - y_lag;
      mv.geometric_stock_return[i] = log(stock_payoff) - p(y_lag);
      mv.geometric_risk_free_rate[i] 
        = -log(beta) - (1.0-alpha)*gamma*y - 0.5*pow(gamma*sigma,2);
      y_lag = y;	
    }

    stats.resize(n_stats,1);

    if (n_stats != 6) error("Error, mleusr, gen_sim, n_stats wrong");

    realmat s = simple(mv.log_consumption_growth);
    stats[1] = s[1];
    stats[2] = s[2];

    s = simple(mv.geometric_stock_return);
    stats[3] = s[1];
    stats[4] = s[2];

    s = simple(mv.geometric_risk_free_rate);
    stats[5] = s[1];
    stats[6] = s[2];

    sim = mv.log_stock_price;

    return converge;
  }
  
  bool tree_usrmod::support(const realmat& rho) 
  {
    realmat rho_lo(n_parms,1);
    
    rho_lo[1] = -0.99;   // alpha
    rho_lo[2] =  0.01;   // sigma
    rho_lo[3] =  0.8;    // beta
    rho_lo[4] =  0.0;    // gamma
  
    realmat rho_hi(n_parms,1);
  
    rho_hi[1] =  0.99;   // alpha
    rho_hi[2] =  100.0;  // sigma
    rho_hi[3] =  0.99;   // beta
    rho_hi[4] =  100.0;  // gamma
  
    return ( (rho_lo < rho) && (rho < rho_hi) );
  }
  
  den_val tree_usrmod::prior(const realmat& parms, const realmat& stats)
  {
    const REAL minus_log_root_two_pi = -9.1893853320467278e-01;
  
    den_val sum(true,0.0);
    
    for (INTEGER i=1; i<=n_parms; ++i) {
      REAL z = (parms[i] - prior_mean[i])/prior_sdev[i];
      REAL e = minus_log_root_two_pi - log(prior_sdev[i]) - 0.5*pow(z,2);
      sum += den_val(true,e);
    }

    return sum;
  }

  bool tree_usrmod::get_stats(realmat& stats)
  {
    realmat sim;
    return gen_sim(sim,stats); 
  }

  den_val tree_usrmod::likelihood()
  {
    REAL alpha = rho[1];
    REAL sigma = rho[2];
    REAL beta  = rho[3];
    REAL gamma = rho[4];

    bool converge = tree_policy(alpha, sigma, beta, gamma, P, Q, p, q);

    if (!converge) return den_val(false,-REAL_MAX);

    INTEGER n = data.ncol();

    const REAL minus_log_root_two_pi = -9.1893853320467278e-01;

    REAL sum = 0.0;

    REAL log_P = data[1];
    REAL mean = 0.0;
    REAL sdev = sigma/sqrt( 1.0 - pow(alpha,2) );
    REAL jacb = fabs(q.derivative(log_P));

    REAL z = (q(log_P) - mean)/sdev;
    REAL e = log(jacb) + minus_log_root_two_pi - log(sdev) - 0.5*pow(z,2);

    REAL log_P_lag = log_P;

    sum += e; 

    for (INTEGER i=2; i<=n; ++i) {
      log_P = data[i];
      mean = alpha*q(log_P_lag);
      sdev = sigma;
      jacb = fabs(q.derivative(log_P));

      z = (q(log_P) - mean)/sdev;
      e = log(jacb) + minus_log_root_two_pi - log(sdev) - 0.5*pow(z,2);

      log_P_lag = log_P;

      sum += e;
    }
  
    if (IsFinite(sum)) {
      return den_val(true,sum);
    }
    else {
      return den_val(false,-REAL_MAX);
    }
  }

  void  tree_usrmod::write_usrvar(const char* filename) 
  { 
    ofstream fout(filename);

    realmat vars;

    realmat sim, stats;

    gen_sim(sim,stats);

    vector<string> names = mv.get_tree_usrmod_variables(vars);

    stats = simple(vars);

    INTEGER top = names.size();
    
    for (INTEGER j=0; j<top; ++j) {
      fout << starbox(string("/")+names[j]+string("//"));
      fout << "\t\t\t mean = " << fmt('f',10,5,stats(1,j+1)) << '\n';
      fout << "\t\t\t sdev = " << fmt('f',10,5,stats(2,j+1)) << '\n';
      fout << "\t\t\t var  = " << fmt('f',10,5,stats(3,j+1)) << '\n';
      fout << "\t\t\t skew = " << fmt('f',10,5,stats(4,j+1)) << '\n';
      fout << "\t\t\t kurt = " << fmt('f',10,5,stats(5,j+1)) << '\n';
    }

    return; 
  }

}
