/*-----------------------------------------------------------------------------

Copyright (C) 2014, 2016, 2018

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libmle.h"
#include "mle.h"
#include "mleusr.h"

using namespace scl;
using namespace libmle;
using namespace mle;
using namespace std;

mle::crra_usrmod::crra_usrmod
  (const realmat& dat, INTEGER len_mod_parm, INTEGER len_mod_func,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec, 
   std::ostream& detail)
{ 
  detail << starbox("/crra_usrmod constructor now controls print//");
  detail << '\n';

  if (len_mod_parm != n_parms) error("Error, crra_usrmod, len_mod_parm wrong");
  if (len_mod_func != n_funcs) error("Error, crra_usrmod, len_mod_func wrong");

  detail << '\n';
  detail << "\t len_mod_parm = " << len_mod_parm << '\n'; 
  detail << "\t len_mod_func = " << len_mod_func << '\n'; 
  detail << '\n';
  detail.flush();
  
  parms.resize(n_parms,1);

  prior_mean.resize(n_parms,1);
  prior_sdev.resize(n_parms,1);

  data = dat;

  if (mod_alvec.size() != 7) error("Error, crra_usrmod, bad mod_alvec");

  prior_mean[1] = atof(mod_alvec[1].substr(0,12).c_str());
  prior_mean[2] = atof(mod_alvec[2].substr(0,12).c_str());
  prior_sdev[1] = atof(mod_alvec[3].substr(0,12).c_str());
  prior_sdev[2] = atof(mod_alvec[4].substr(0,12).c_str());
  ridge         = atof(mod_alvec[5].substr(0,12).c_str());

  detail << "\t prior_mean = " << prior_mean << '\n';
  detail << "\t prior_sdev = " << prior_sdev << '\n';
  detail << "\t ridge = " << ridge << '\n';
  detail << '\n';

  bool rv;
  rv = crramf.set_data(&data);
  if (!rv) error("Error, crra_usrmod, crramf.set_data failed ");

  INTEGER r = data.nrow();
  INTEGER c = data.ncol();

  if (r != n_data_rows) error("Error, crra_usrmod, data has wrong nrow");
  if (c > n_data_cols) error("Error, crra_usrmod, data too large ncol");

  rv = crramf.set_sample_size(c);
  if (!rv) error("Error, crra_usrmod, crramf.set_sample_size failed ");

  if (crramf.get_d() != n_moments) error("Error, crra_usrmod, n_moments wrong");

  detail << "\t crramf.sample_size set by crra_usrmod constructor to "<<c<<'\n';

  INTEGER mflags = 0;

  rv = crramf.set_L(mflags);
  if (!rv) error("Error, habit_usrmod, hmf.set_L failed ");

  gmm tmp_gmm(&crramf,mflags,&data,c,HAC_lags);

  detail << "\t crragmm constructed with these values:" << '\n';
  detail << "\t\t moment_function_lags = " << mflags << '\n';
  detail << "\t\t data first element data(1,1) = " << data(1,1) << '\n';
  detail << "\t\t data last element data(r,c) = " << data(r,c) << '\n';
  detail << "\t\t r = " << r << '\n';
  detail << "\t\t c = " << c << '\n';
  detail << "\t\t sample size = " << c << '\n';
  detail << "\t\t Lhac = " << HAC_lags << '\n';

  crragmm = tmp_gmm;

  const bool correct_W_for_mean = true;
  crragmm.set_correct_W_for_mean(correct_W_for_mean);
  detail << "\t crragmm.set_correct_W_for_mean set by crra_usrmod "
         << "constructor to " << boolalpha << correct_W_for_mean << '\n';

  const bool regularize_W = true;
  crragmm.set_regularize_W(regularize_W,ridge);
  detail << "\t crragmm.set_regularize_W set by crra_usrmod "
         << "constructor to (" << regularize_W << ", " << ridge <<')'<<'\n';

  const bool warning_messages = true;
  crragmm.set_warning_messages(warning_messages);
  detail << "\t crragmm.set_warning_messages set by crra_usrmod "
         << "constructor to " << warning_messages << '\n';

  detail << starbox("/mle now controls print//");
  detail.flush();
}


bool mle::crra_usrmod::support(const realmat& rho) 
{
  REAL beta  = rho[1];
  REAL gamma = rho[2];

  if ( (beta  <= 0.8) || (0.99 <= beta  ) ) return false;

  if ( (gamma <= 0.0) || (100.0 <= gamma) ) return false;

  return true;
}


den_val mle::crra_usrmod::prior(const realmat& rho, const realmat& stats) 
{
    const REAL minus_log_root_two_pi = -9.1893853320467278e-01;

    den_val sum(true,0.0);

    for (INTEGER i=1; i<=n_parms; ++i) {
      REAL z = (parms[i] - prior_mean[i])/prior_sdev[i];
      REAL e = minus_log_root_two_pi - log(prior_sdev[i]) - 0.5*pow(z,2);
      sum += den_val(true,e);
    }

    return sum;
}

den_val mle::crra_usrmod::likelihood() 
{
  realmat z;
  denval likelihood = crragmm.likelihood(parms,z);
  if (crragmm.get_W_numerr() != 0) return denval(false,-REAL_MAX);
  return likelihood;
}

void mle::crra_usrmod::write_usrvar(const char* filename)
{
  /*
  string stem;
  const char* fn = filename;
  while(*fn) stem.push_back(*fn++);
  string::size_type nchar = stem.size();
  stem = stem.substr(0,nchar-3);

  ofstream fout;
  string usrfile;

  usrfile = stem+"whatever.dat";
  fout.open(usrfile.c_str());
  if (!fout) error("Error, crra_usrmod, write_usrvar, cannot open "+usrfile);
  */
}

