/*-----------------------------------------------------------------------------

Copyright (C) 2018

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libmle.h"
#include "mle.h"
#include "mleusr.h"
#include <cerrno>

using namespace scl;
using namespace libmle;
using namespace mle;
using namespace std;

mle::snp_usrmod::snp_usrmod
  (const realmat& dat, INTEGER len_mod_parm, INTEGER len_mod_func,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec, 
   std::ostream& detail)
: snpmod(&dat, mod_pfvec, mod_alvec, detail)
{ 
  detail << starbox("/snp_usrmod constructor now controls print//");
  detail << '\n';

  string msg; 
  msg= "\nWarning: When computing seinfo, and hence sesand, the degrees of\n";
  msg+="freedom are n from the mle parmfile and not n - drop from the snp\n";
  msg+="parmfile. You may want to inflate sseinfo and sesand reported in\n";
  msg+="summary.dat by the factor sqrt(n-drop)/sqrt(n)";
  warn(msg); 
  detail << msg << '\n';
  msg= "\nWarning: The sample size n in the snp parmfile and in the mle\n";
  msg+="parmfile must agree.  The code does not check this assumption.\n";
  warn(msg);
  detail << msg << '\n';
  detail.flush();

  crra_parms.resize(2,1);

  vector<string>::const_iterator al_ptr=mod_alvec.begin();

  REAL lsim = atoi((++al_ptr)->substr(0,12).c_str());
  REAL spin = atoi((++al_ptr)->substr(0,12).c_str());

  // lsim and spin are read and used by snp_stat_mod but not used by snp_usrmod
  // remainder of alvec is read and used only by snp_usrmod

  string vstr;
  REAL val;
  vstr = (++al_ptr)->substr(0,25);
  if (!isREAL(vstr.c_str(),val)) error("Error, snp_usrmod, vstr = " + vstr); 
  crra_parms[1] = val;
  vstr = (++al_ptr)->substr(0,25);
  if (!isREAL(vstr.c_str(),val)) error("Error, snp_usrmod, vstr = " + vstr); 
  crra_parms[2] = val;
  REAL log10lam = atof((++al_ptr)->substr(0,25).c_str());
  lambda = pow(10,log10lam);
  bool rv;

  detail << '\n';
  detail << "\t lsim = "<< lsim <<'\n';
  detail << "\t spin = "<< spin <<'\n';
  detail << "\t log10lam = "<< log10lam <<'\n';
  detail << "\t lambda = "<< lambda <<'\n';
  detail << "\t crra_parms[1] =" << fmt('f',25,16,crra_parms[1]) << '\n';
  detail << "\t crra_parms[2] =" << fmt('f',25,16,crra_parms[2]) << '\n';
  detail.flush();

  rv = crramf.set_data(&dat);
  if (!rv) error("Error, snp_usrmod, crramf.set_data failed ");

  INTEGER r = dat.nrow();
  INTEGER c = dat.ncol();

  rv = crramf.set_sample_size(c);
  if (!rv) error("Error, snp_usrmod, crramf.set_sample_size failed ");

  INTEGER mflags = 0;
  rv = crramf.set_L(mflags);
  if (!rv) error("Error, snp_usrmod, crramf.set_L failed ");

  INTEGER HAC_lags = 0;

  gmm tmp_gmm(&crramf,mflags,&dat,c,HAC_lags);

  detail << "\t moment_function_lags = " << mflags << '\n';
  detail << "\t data first element data(1,1) = " << dat(1,1) << '\n';
  detail << "\t data last element data(r,c) = " << dat(r,c) << '\n';
  detail << "\t r = " << r << '\n';
  detail << "\t c = " << c << '\n';
  detail << "\t sample size = " << c << '\n';
  detail << "\t Lhac = " << HAC_lags << '\n';

  crragmm = tmp_gmm;

  const bool correct_W_for_mean = true;
  crragmm.set_correct_W_for_mean(correct_W_for_mean);

  const bool regularize_W = true;
  const REAL ridge = 0.0;
  crragmm.set_regularize_W(regularize_W,ridge);

  const bool warning_messages = true;
  crragmm.set_warning_messages(warning_messages);


  detail << "\t correct_W_for_mean set "
           << "to " << boolalpha << correct_W_for_mean << '\n';
  detail << "\t regularize_W set "
         << "to (" << regularize_W << ", " << ridge <<')'<<'\n';

  detail << "\t warning_messages set "
         << "to " << warning_messages << '\n';

  detail << starbox("/mle now controls print//");
  detail.flush();
}

bool mle::snp_usrmod::get_stats(scl::realmat& stats) 
{ 
  INTEGER n_funcs = len_stats();

  INTEGER d = crramf.get_d();
  
  stats.resize(n_funcs,1);
  realmat sim;

  bool rv;
  errno = 0;

  realmat snpsim;
  realmat snpstats;

  rv = snpmod.gen_sim(snpsim, snpstats);
  if (!rv) error("Error, snp_usrmod, gen_sim failed");

  for (INTEGER i=5; i<=n_funcs; ++i) stats[i] = snpstats[i-4]; 

  rv = crragmm.set_data(&snpsim);
  if (!rv) error("Error, snp_usrmod, crragmm.set_data failed ");

  INTEGER n = snpsim.ncol();

  rv = crragmm.set_sample_size(n);
  if (!rv) error("Error, snp_usrmod, crragmm.set_sample_size failed");

  /*
  realmat z;

  denval gmmll = crragmm.likelihood(crra_parms,z);

  stats[1] = gmmll.log_den;
  stats[2] = z[1];
  stats[3] = z[2];
  stats[4] = z[3];

  rv = gmmll.positive;
  if (!rv) error("Error, snp_usrmod, crragmm.likelihood failed");
  */

  realmat m,W,S;
  REAL logdetW;
  INTEGER rankW;
  REAL obj = crragmm(crra_parms, m, W, logdetW, rankW, S);

  REAL sumsq = 0.0;
  for (INTEGER i=1; i<=d; ++i) sumsq += pow(m[i],2);

  if (!IsFinite(sumsq)) rv = false;

  const REAL halflogtwopi = 0.5*log(2.0*M_PI);

  stats[1] = -(0.5)*REAL(n)*sumsq - halflogtwopi;
  stats[2] = m[1];
  stats[3] = m[2];
  stats[4] = m[3];

  for (INTEGER i=1; i<=n_funcs; ++i) if (!IsFinite(stats[i])) rv = false;

  if ( (errno == ERANGE) || (errno == EDOM) ) {
    rv = false;
    if (!rv) warn("Warning, get_stats, fail, errno = " + fmt('d',3,errno)());
    errno = 0;
  }
  return rv;
}     

den_val mle::snp_usrmod::prior(const realmat& rho_in, const realmat& stats) 
{
  REAL penalty = stats[1];
  penalty *= lambda;
  return den_val(true, penalty);
}

bool mle::snp_usrmod::get_scores(realmat& scores)
{
  realmat dlogl;
  realmat infmat;
  den_val dv = snpmod.loglikelihood(dlogl, infmat, scores);
  return dv.positive;
}

void mle::snp_usrmod::write_usrvar(const char* filename)
{
  /*
  realmat sim;
  realmat stats;
  gen_sim(sim,stats);
  vecwrite(filename,sim);
  */

  string stem = filename;
  string::size_type nchar = stem.size();
  stem = stem.substr(0,nchar-3);
  string parmfile = stem + "pf";

  if (!snpmod.write_parmfile(parmfile.c_str())) {
    error("Error, mleusr::write_usrvar, cannot write + parmfile");
  }
}
