/*-----------------------------------------------------------------------------

Copyright (C) 2018

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libmle.h"
#include "mle.h"
#include "mleusr.h"
#include "lprior.h"

using namespace scl;
using namespace libmle;
using namespace mle;
using namespace std;

mle::dcf_usrmod::dcf_usrmod
  (const realmat& dat, INTEGER len_mod_parm, INTEGER len_mod_func,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec, 
   std::ostream& detail)
{ 
  detail << starbox("/dcf_usrmod constructor now controls print//");
  detail << '\n';

  if (len_mod_parm != n_parms) error("Error, dcf_usrmod, len_mod_parm wrong");
  if (len_mod_func != n_funcs) error("Error, dcf_usrmod, len_mod_func wrong");

  if (VAR_M != 3) error("Error, dcf_usrmod, VAR_M != 3");
  if (VAR_L != 1) error("Error, dcf_usrmod, VAR_L != 1");

  detail << '\n';
  detail << "\t len_mod_parm = " << len_mod_parm << '\n'; 
  detail << "\t len_mod_func = " << len_mod_func << '\n'; 
  detail << '\n';
  detail << "\t n_years = " << n_years << '\n';
  detail << "\t n_parms = " << n_parms << '\n';
  detail << "\t n_funcs = " << n_funcs << '\n';
  detail << "\t n_target = " << n_target << '\n';
  detail << '\n';
  detail << "\t n_data_rows = " << n_data_rows << '\n';
  detail << "\t n_data_cols = " << n_data_cols << '\n';
  detail << '\n';
  detail << "\t VAR_M = " << VAR_M << '\n'; 
  detail << "\t VAR_L = " << VAR_L << '\n'; 
  detail << '\n';
  detail << "\t loc_dat_yr = "     << loc_dat_yr << '\n';
  detail << "\t loc_dat_mrs = "    << loc_dat_mrs << '\n';
  detail << "\t loc_dat_gdp_cp = " << loc_dat_gdp_cp << '\n';
  detail << "\t loc_dat_cp = "     << loc_dat_cp << '\n';
  detail << '\n';
  detail << "\t loc_var_mrs = "    << loc_var_mrs << '\n';
  detail << "\t loc_var_gdp_cp = " << loc_var_gdp_cp << '\n';
  detail << "\t loc_var_cp = "     << loc_var_cp << '\n';
  detail << '\n';
  detail.flush();
  
  parms.resize(n_parms,1);
  target.resize(n_target,1);
  idx.resize(n_target);

  data = dat;

  if (mod_alvec.size() != 27) error("Error, dcf_usrmod, bad mod_alvec");
  vector<string>::const_iterator al_ptr=mod_alvec.begin();
  prior_mean_01 = atof((++al_ptr)->substr(0,12).c_str());
  prior_sdev_01 = atof((++al_ptr)->substr(0,12).c_str());
  prior_mean_30 = atof((++al_ptr)->substr(0,12).c_str());
  prior_sdev_30 = atof((++al_ptr)->substr(0,12).c_str());
  for (INTEGER i=1; i<=n_target; ++i) {
    target[i] = atof((++al_ptr)->substr(0,12).c_str());
  }
  for (INTEGER i=1; i<=n_target; ++i) {
    idx[i] = atoi((++al_ptr)->substr(0,12).c_str());
  }
  lambda = atof((++al_ptr)->substr(0,12).c_str());


  detail << '\n';
  detail << "\t prior_mean_01 = "<< prior_mean_01 <<'\n';
  detail << "\t prior_sdev_01 = "<< prior_sdev_01 <<'\n';
  detail << "\t prior_mean_30 = "<< prior_mean_30 <<'\n';
  detail << "\t prior_sdev_30 = "<< prior_sdev_30 <<'\n';
  detail << "\t target = "<< target <<'\n';
  detail << "\t idx = "<< idx <<'\n';
  detail << "\t lambda = "<< lambda <<'\n';
  detail << '\n';
  detail.flush();

  INTEGER r = data.nrow();
  INTEGER c = data.ncol();

  if (r != n_data_rows) error("Error, dcf_usrmod, data has wrong nrow");
  if (c != n_data_cols) error("Error, dcf_usrmod, data has wrong ncol");

  detail << starbox("/First 6 observations/yr,mrs,gdp,gdp_cp,cp//");
  detail << data("",seq(1,6));
  detail << starbox("/Last 6 observations/yr,mrs,gdp,gdp_cp,cp//");
  detail << data("",seq(c-5,c));
  detail.flush();

  detail << starbox("/mle now controls print//");
  detail.flush();
}

bool mle::dcf_usrmod::get_stats(scl::realmat& stats) 
{
  stats.resize(1,1,0.0);
  return true;
}

bool mle::dcf_usrmod::support(const realmat& rho) 
{
  const INTEGER M = VAR_M;
  const INTEGER L = VAR_L;

  realmat b0, B, S, R, theta;

  frac_rho(M, L,rho, b0, B, S, R, theta);

  REAL beta = theta[1];
  REAL gamma = theta[2];

  INTEGER ier;
  REAL maxlam = eigen(B,S,ier);

  if ( ier != 0 || maxlam > 0.99 ) return false;
  if ( (beta  <= 0.8) || (0.99 <= beta  ) ) return false;
  if ( (gamma <= 0.0) || (100.0 <= gamma) ) return false;

  return true;
}

den_val mle::dcf_usrmod::likelihood() 
{
  const INTEGER n = n_data_cols;
  const INTEGER M = VAR_M;
  const INTEGER L = VAR_L;

  intvec idx(M);
  idx[loc_var_mrs]    = loc_dat_mrs;
  idx[loc_var_gdp_cp] = loc_dat_gdp_cp;
  idx[loc_var_cp]     = loc_dat_cp;

  realmat b0, B, S, R, theta;
  
  frac_rho(M, L, parms, b0, B, S, R, theta);

  //REAL beta = theta[1];
  //REAL lambda = theta[2];

  den_val rv(true,0.0);

  realmat y(M,1);
  realmat ylag(M,1,0.0);

  /*
  for (INTEGER t=1+L; t<=n; ++t) {
      y[loc_var_mrs]       = data(loc_dat_mrs,t);
      y[loc_var_gdp_cp]    = data(loc_dat_gdp_cp,t);
      y[loc_var_cp]        = data(loc_dat_cp,t);
      ylag[loc_var_mrs]    = data(loc_dat_mrs,t-1);
      ylag[loc_var_gdp_cp] = data(loc_dat_gdp_cp,t-1);
      ylag[loc_var_cp]     = data(loc_dat_cp,t-1);
      realmat mu = b0 + B*ylag;
      rv += mvn(y,mu,S);
  }
  */

  set_lib_error_handler(&alternate_lib_error_handler);
  for (INTEGER t=1+L; t<=n; ++t) {
      y[loc_var_mrs]       = data(loc_dat_mrs,t);
      y[loc_var_gdp_cp]    = data(loc_dat_gdp_cp,t);
      y[loc_var_cp]        = data(loc_dat_cp,t);
      ylag[loc_var_mrs]    = data(loc_dat_mrs,t-1);
      ylag[loc_var_gdp_cp] = data(loc_dat_gdp_cp,t-1);
      ylag[loc_var_cp]     = data(loc_dat_cp,t-1);
      realmat mu = b0 + B*ylag;
      try { 
        rv += mvn(y,mu,S);
      }
      catch (lib_error err) {
        warn("Warn, likelihood, S less than full rank");
        set_lib_error_handler(&default_lib_error_handler);
        return den_val(false,-REAL_MAX);
      }
  }
  set_lib_error_handler(&default_lib_error_handler);

  return rv;
}

den_val mle::dcf_usrmod::prior(const realmat& rho, const realmat& stats)
{
  const INTEGER n = n_data_cols;
  const INTEGER M = VAR_M;
  const INTEGER L = VAR_L;
  /*
  const INTEGER mrs_pos = loc_var_mrs;
  const INTEGER cf_pos = loc_var_cp;
  */
  const INTEGER years = n_years;

  den_val sum = denval(true,0.0);

  /*
  realmat discard;
  sum += lprior(M,L,mrs_pos,cf_pos,years,rho,lambda,target,idx,discard);
  */

  const REAL pi = 4.0*atan(1.0);
  const REAL minus_log_root_two_pi = -log(sqrt(2.0*pi));

  realmat b0, B, S, R;
  realmat theta;

  frac_rho(M,L,rho, b0, B, S, R, theta);

  realmat y0(M,1);
  realmat pvcf, pv1, dcf, yld;
  realmat ecf, cecf;

  for (INTEGER t=1; t<=n; ++t) {

    y0[loc_var_mrs]       = data(loc_dat_mrs,t);
    y0[loc_var_cp]        = data(loc_dat_cp,t);

    dcfyld(b0,B,S,y0,loc_var_mrs,loc_var_cp,years,ecf,pvcf,pv1,cecf,dcf,yld);

    REAL observed_yld_01 = yld[1];
    REAL z = (observed_yld_01 - prior_mean_01)/prior_sdev_01;
    REAL e = minus_log_root_two_pi - log(prior_sdev_01) - 0.5*pow(z,2);
    sum += den_val(true,e);

    REAL observed_yld_30 = yld[30];
    z = (observed_yld_30 - prior_mean_30)/prior_sdev_30;
    e = minus_log_root_two_pi - log(prior_sdev_30) - 0.5*pow(z,2);
    sum += den_val(true,e);
  }

  return sum;
}

void mle::dcf_usrmod::write_usrvar(const char* filename)
{
  string stem;
  const char* fn = filename;
  while(*fn) stem.push_back(*fn++);
  string::size_type nchar = stem.size();
  stem = stem.substr(0,nchar-3);

  ofstream fout;
  string usrfile;

  const INTEGER M = VAR_M;
  const INTEGER L = VAR_L;
  const INTEGER mrs_pos = loc_var_mrs;
  const INTEGER cf_pos = loc_var_cp;
  const INTEGER years = n_years;

  realmat b0, B, S, R, theta;

  frac_rho(M, L, parms, b0, B, S, R, theta);

  REAL beta = theta[1];
  REAL gamma = theta[2];

  realmat I(M,M,0.0);
  for (INTEGER i=1; i<=M; ++i) I(i,i) = 1.0;

  realmat IB = I - B;
  realmat IBinv = inv(IB);

  realmat lim_mean = IBinv*b0;
  realmat vecS(M*M,1);
  for (INTEGER i=1; i<=M*M; ++i) vecS[i] = S[i];
  realmat vecV = inv(kronprd(I,I) - kronprd(B,B))*vecS;
  realmat lim_V(M,M);
  for (INTEGER i=1; i<=M*M; ++i) lim_V[i] = vecV[i];
  realmat lim_R = lim_V;
  if (factor(lim_R)!=0) error("Error, factor failed");

  realmat stats;
  
  denval lp = lprior(M,L,mrs_pos,cf_pos,years,parms,lambda,target,idx,stats);

  usrfile = stem+"stats.dat";
  vecwrite(usrfile.c_str(),stats);

  usrfile = stem + "stats.txt";
  fout.open(usrfile.c_str());
  if (!fout) error("Error, dcf_usrmod, write_usrvar, cannot open "+usrfile);

  fout << '\n';
  fout << '\n';
  fout << "stat        idx    value    target" << '\n';
  fout << "lcg_mean     "  << idx[ 1]
    << fmt('f',10,5,stats[ 1]) << fmt('f',10,5,target[ 1]) << '\n';
  fout << "lcg_sdev     "  << idx[ 2]
    << fmt('f',10,5,stats[ 2]) << fmt('f',10,5,target[ 2]) << '\n';
  fout << "yld_mean_01  "  << idx[ 3]
    << fmt('f',10,5,stats[ 3]) << fmt('f',10,5,target[ 3]) << '\n';
  fout << "yld_sdev_01  "  << idx[ 4]
    << fmt('f',10,5,stats[ 4]) << fmt('f',10,5,target[ 4]) << '\n';
  fout << "yld_mean_yr  "  << idx[ 5]
    << fmt('f',10,5,stats[ 5]) << fmt('f',10,5,target[ 5]) << '\n';
  fout << "yld_sdev_yr  "  << idx[ 6]
    << fmt('f',10,5,stats[ 6]) << fmt('f',10,5,target[ 6]) << '\n';
  fout << "stk_mean_01  "  << idx[ 7]
    << fmt('f',10,5,stats[ 7]) << fmt('f',10,5,target[ 7]) << '\n';
  fout << "stk_sdev_01  "  << idx[ 8]
    << fmt('f',10,5,stats[ 8]) << fmt('f',10,5,target[ 8]) << '\n';
  fout << "stk_mean_yr  "  << idx[ 9]
    << fmt('f',10,5,stats[ 9]) << fmt('f',10,5,target[ 9]) << '\n';
  fout << "stk_sdev_yr  "  << idx[10]
    << fmt('f',10,5,stats[10]) << fmt('f',10,5,target[10]) << '\n';
  fout << '\n';

  REAL implied_gamma = sqrt(lim_V(1,1))/target[2];
  REAL implied_beta = exp(lim_mean[1] + implied_gamma*target[1]);

  fout << "mean lmrs     " << fmt('f',10,5,lim_mean[1]) << '\n';
  fout << "sdev lmrs     " << fmt('f',10,5,sqrt(lim_V(1,1))) << '\n';
  fout << "implied gamma " << fmt('f',10,5,implied_gamma) << '\n';
  fout << "implied beta  " << fmt('f',10,5,implied_beta) << '\n';
  fout << '\n';

  fout << '\n';
  fout << b0 << B << R << theta << '\n';
  fout << boolalpha;
  fout << "\t\t\tlp = (" << lp.positive << ", " << lp.log_den << ')' << '\n';
  fout << '\n';

  fout.clear(); fout.close(); 

  usrfile = stem + "quad.txt";
  fout.open(usrfile.c_str());
  if (!fout) error("Error, dcf_usrmod, write_usrvar, cannot open "+usrfile);

  const REAL pi = M_PI;
  const REAL sqrt2 = M_SQRT2;

  const INTEGER quad_n = 5;
  realmat x, w;

  if (hquad(quad_n,x,w)!=0) error("Error, hquad failed");

  realmat z(M,1);

  REAL lcg_mean = 0.0;
  REAL lcg_sdev = 0.0;

  REAL yld_mean_01 = 0.0;
  REAL yld_sdev_01 = 0.0;
  REAL yld_mean_yr = 0.0;
  REAL yld_sdev_yr = 0.0;

  REAL stk_mean_01 = 0.0;
  REAL stk_sdev_01 = 0.0;
  REAL stk_mean_yr = 0.0;
  REAL stk_sdev_yr = 0.0;

  realmat ecf,pvcf,pv1,cecf,dcf,yld;

  realmat y0;

  realmat zero(M,1,0.0);

  for (INTEGER i=1; i<=quad_n; ++i) {
    for (INTEGER j=1; j<=quad_n; ++j) {
      for (INTEGER k=1; k<=quad_n; ++k) {

        z[1] = sqrt2*x[i];
        z[2] = sqrt2*x[j];
	z[3] = sqrt2*x[k];

        y0 = IBinv*(b0 + R*z);

        REAL weight = w[i]*w[j]/pi;

        dcfyld(b0,B,S,y0,mrs_pos,cf_pos,years,ecf,pvcf,pv1,cecf,dcf,yld);

        lcg_mean = weight*(log(beta)-y0[1])/gamma;
        lcg_sdev = weight*pow((log(beta)-y0[1])/gamma,2);

        yld_mean_01 = weight*yld[1];
        yld_mean_yr = weight*yld[years];
        yld_sdev_01 = weight*pow(yld[1],2);
        yld_sdev_yr = weight*pow(yld[years],2);

        REAL ret01 = log(dcf[years]+ecf[1]) -log(dcf[years]);
        REAL retyr = ( log(dcf[years]+cecf[years]) -log(dcf[years]))/years;

        stk_mean_01 = weight*ret01;
        stk_mean_yr = weight*retyr;
        stk_sdev_01 = weight*pow(ret01,2);
        stk_sdev_yr = weight*pow(retyr,2);

        fout << '\n';

        if (z == zero) fout << "center" << '\n';

        fout << "y0[1]  =" << fmt('f',10,5,y0[1]) << '\n';
        fout << "y0[2]  =" << fmt('f',10,5,y0[2]) << '\n';
        fout << "y0[3]  =" << fmt('f',10,5,y0[3]) << '\n';

        fout << "weight =" << fmt('f',10,5,weight) << '\n';

        fout << "lcg    =" << fmt('f',10,5,lcg_mean/weight) << '\n';

        fout << "yld_01 =" << fmt('f',10,5,yld_mean_01/weight) << '\n';
        fout << "yld_yr =" << fmt('f',10,5,yld_mean_yr/weight) << '\n';

        fout << "stk_01 =" << fmt('f',10,5,stk_mean_01/weight) << '\n';
        fout << "stk_yr =" << fmt('f',10,5,stk_mean_yr/weight) << '\n';
      }
    }
  }
    
  fout.clear(); fout.close(); 
} 
