/* ----------------------------------------------------------------------------

Copyright (C) 2002, 2003, 2004, 2005, 2006, 2007.

A. Ronald Gallant

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsmm.h"
using namespace std;  
using namespace scl;
using namespace libsmm;

#define ADJUST_FOR_PRIOR

// Conceptually the prior is treated as the density of the first
// observation.
// The likelihood is exp(-objfun)*exp(prior.log_den).
// The Log likelihood is -objfun + prior.log_den.
// The object differentiated is (objfun - prior.log_den)/n_dat.

bool libsmm::bstrap_asymptotics::set_asymptotics(const realmat& chain)
{

  INTEGER n_chain = chain.ncol();

  if (chain.nrow()!=p) error("Error, bstrap_asymptotics, bad rho chain");

  if (n_tot == 0) {
    for (INTEGER t=1; t<=n_chain; ++t) {
      for (INTEGER i=1; i<=p; i++) {
        rho_ctr[i] += chain(i,t);
      }
    }
    rho_ctr = rho_ctr/n_chain;
  }

  for (INTEGER t=1; t<=n_chain; ++t) {
    for (INTEGER i=1; i<=p; i++) {
      rho_sum[i] += (chain(i,t) - rho_ctr[i]);
    }
  }

  for (INTEGER t=1; t<=n_chain; ++t) {
    for (INTEGER j=1; j<=p; j++) {
      for (INTEGER i=1; i<=p; i++) {
        rho_sse(i,j) += (chain(i,t) - rho_ctr[i])*(chain(j,t) - rho_ctr[j]);
      }
    }
  }

  n_tot += n_chain;

  mean = rho_sum/n_tot; 
  sscp = rho_sse/n_tot;

  for (INTEGER j=1; j<=p; j++) {
    for (INTEGER i=1; i<=p; i++) {
      sscp(i,j) -= mean[i]*mean[j];
    }
  }

  mean += rho_ctr;
  sscp = REAL(n_dat)*mcmc.get_temp()*sscp;

  realmat mode_new = mcmc.get_mode();
  REAL high_new = mcmc.get_high();
  if (high_new > high) {
    mode = mode_new;
    high = high_new;;
  }

  model.set_rho(mode);
  
  REAL root_n_dat = sqrt(REAL(n_dat));
  REAL eps = pow(REAL_EPSILON,0.33333333);

  vector<realmat> bs;
  if (!model.gen_bootstrap(bs)) {
    warn("Warning, bstrap_asymptotics, bootstrap from usrmod failed");
    return false;
  }

  typedef vector<realmat>::const_iterator bs_itr;
  bs_itr itr = bs.begin();

  if (itr->get_cols()!=data.get_cols() || itr->get_rows()!=data.get_rows()) {
    error("Error, bstrap_asymptotics, bad bootstrap vector");
  }

  INTEGER len = bs.size();
  INTEGER len_vec = deriv_vec.size();

  if (len == 0) {
    error("Error, bstrap_asymptotics, bad bootstrap vector");
  }
  else if (len_vec < len) {
    deriv_vec.resize(len);
  }

  for (INTEGER i=0; i<len; i++) {
    deriv_vec[i].obj = objfun.new_objfun();  // new operator called here
    (*deriv_vec[i].obj).set_data(bs[i]);
    deriv_vec[i].der_obj.resize(p,1);
  }

  (*deriv_vec[0].obj).set_data(data);     // 0th set to data
  
  realmat sim, stats;

  for (INTEGER j=1; j<=p; j++) {          // Centered differention

    REAL delta = eps*fabs(mode[j]);       // See Press et. al. p. 186
    delta = (delta == 0) ? eps : delta;

    delta = sqrt(delta);                  // This seems to help

    REAL hi = mode[j] + delta;
    REAL lo = mode[j] - delta;

    realmat rho_delta = mode;

    rho_delta[j] = hi;
    model.set_rho(rho_delta);
    if (!model.gen_sim(sim,stats)) {
      warn("Warning, bstrap_asymptotics, sim from usrmod failed");
      return false;
    }

    for (INTEGER i=0; i<len; i++) {
      REAL obj_rho_delta = (*deriv_vec[i].obj)(rho_delta,sim,stats);
      #if defined ADJUST_FOR_PRIOR
      obj_rho_delta -= model.prior(rho_delta,stats).log_den;
      #endif
      obj_rho_delta /= root_n_dat;
      deriv_vec[i].der_obj[j] = obj_rho_delta;
    }

    rho_delta[j] = lo;
    model.set_rho(rho_delta);
    if (!model.gen_sim(sim,stats)) {
      warn("Warning, bstrap_asymptotics, sim from usrmod failed");
      return false;
    }

    REAL divisor = hi - lo;

    for (INTEGER i=0; i<len; i++) {
      REAL obj_rho_delta = (*deriv_vec[i].obj)(rho_delta,sim,stats);
      #if defined ADJUST_FOR_PRIOR
      obj_rho_delta -= model.prior(rho_delta,stats).log_den;
      #endif
      obj_rho_delta /= root_n_dat;
      deriv_vec[i].der_obj[j] -= obj_rho_delta;
      deriv_vec[i].der_obj[j] /= divisor;
    }

  }

  model.set_rho(mode);
  if (!model.gen_sim(sim,stats)) {
    warn("Warning, bstrap_asymptotics, sim from usrmod failed");
    return false;
  }

  REAL f = objfun(mode,sim,stats);
  #if defined ADJUST_FOR_PRIOR
  f -= model.prior(mode,stats).log_den;
  #endif

  foc = -exp(-f)*deriv_vec[0].der_obj;       // first order conditions

  for (INTEGER i=1; i<len; ++i) {            // deriv_vec[0] excluded
    realmat score = deriv_vec[i].der_obj;
    bool score_is_finite = true;
    for (INTEGER j=1; j<=p; ++j) {
      if (!IsFinite(score[j])) score_is_finite = false;
    }
    info_item ii;
    ii.score = score;
    if (score_is_finite) info_vec.push_back(ii);
  }

  vector<info_item>::iterator iv_itr;

  realmat lo_mat, hi_mat;

  for (iv_itr = info_vec.begin(); iv_itr != info_vec.end(); ++iv_itr) {
    realmat test = (sscp*iv_itr->score)*T(sscp*iv_itr->score);
    REAL lo = test(1,1);
    REAL hi = test(1,1);
    for (INTEGER i=1; i<=p; ++i) {
      lo = test(i,i) < lo ? test(i,i) : lo;
      hi = test(i,i) > hi ? test(i,i) : hi;
    }
    iv_itr->lo = lo;
    iv_itr->hi = hi;
    lo_mat.push_back(lo);
    hi_mat.push_back(hi);
  }
  
  lo_mat.sort(1);
  hi_mat.sort(1);

  REAL lo_min = lo_mat[1];
  REAL hi_min = hi_mat[1];
  REAL lo_max = lo_mat[lo_mat.size()];
  REAL hi_max = hi_mat[hi_mat.size()];
  if (lo_mat.size() > 20) {
    INTEGER q5 = lo_mat.size()/20;
    lo_min = lo_mat[q5];
    lo_max = lo_mat[lo_mat.size()-q5];
    hi_min = hi_mat[q5];
    hi_max = hi_mat[hi_mat.size()-q5];
  }

  INTEGER count = 0;
  realmat I_mean(p,1,0.0);

  for (iv_itr = info_vec.begin(); iv_itr != info_vec.end(); ++iv_itr) {
    if (lo_min < iv_itr->lo && iv_itr->lo < lo_max 
        && hi_min < iv_itr->hi && iv_itr->hi < hi_max) {
      I_mean += iv_itr-> score;
      ++count;
    }
  }

  I_reps += count;

  I_mean = I_mean/count;
    
  for (iv_itr = info_vec.begin(); iv_itr != info_vec.end(); ++iv_itr) {
    if (lo_min < iv_itr->lo && iv_itr->lo < lo_max 
        && hi_min < iv_itr->hi && iv_itr->hi < hi_max) {
      I_mat += (iv_itr->score-I_mean)*T(iv_itr->score-I_mean);
    }
  }

  return true;
}

void libsmm::bstrap_asymptotics::get_asymptotics (realmat& rho_hat, 
         realmat& V_hat, INTEGER& n)
{ 
  rho_hat = mode;
  realmat I = I_mat/I_reps;
  realmat invJ = sscp;
  V_hat = (invJ*I*invJ)/n_dat;
  n = n_dat;
}

void libsmm::bstrap_asymptotics::get_asymptotics (realmat& rho_mean, 
         realmat& rho_mode, REAL& post_high, realmat& I, realmat& invJ, 
         realmat& foc_hat, INTEGER& reps)
{ 
  rho_mean = mean; 
  rho_mode = mode; 
  post_high = high;
  I = I_mat/I_reps; 
  invJ = sscp;
  foc_hat = foc;
  reps = I_reps;
}

bool libsmm::minimal_asymptotics::set_asymptotics(const realmat& chain)
{
  INTEGER n_chain = chain.ncol();

  if (chain.nrow()!=p) error("Error, minimal_asymptotics, bad chain matrix");

  if (n_tot == 0) {
    for (INTEGER t=1; t<=n_chain; ++t) {
      for (INTEGER i=1; i<=p; i++) {
        rho_ctr[i] += chain(i,t);
      }
    }
    rho_ctr = rho_ctr/n_chain;
  }

  for (INTEGER t=1; t<=n_chain; ++t) {
    for (INTEGER i=1; i<=p; i++) {
      rho_sum[i] += (chain(i,t) - rho_ctr[i]);
    }
  }

  for (INTEGER t=1; t<=n_chain; ++t) {
    for (INTEGER j=1; j<=p; j++) {
      for (INTEGER i=1; i<=p; i++) {
        rho_sse(i,j) += (chain(i,t) - rho_ctr[i])*(chain(j,t) - rho_ctr[j]);
      }
    }
  }

  n_tot += n_chain;

  mean = rho_sum/n_tot; 
  sscp = rho_sse/n_tot;

  for (INTEGER j=1; j<=p; j++) {
    for (INTEGER i=1; i<=p; i++) {
      sscp(i,j) -= mean[i]*mean[j];
    }
  }

  mean += rho_ctr;
  sscp = REAL(n_dat)*mcmc.get_temp()*sscp;

  realmat mode_new = mcmc.get_mode();
  REAL high_new = mcmc.get_high();
  if (high_new > high) {
    mode = mode_new;
    high = high_new;;
  }

  return true;
}

void minimal_asymptotics::get_asymptotics (realmat& rho_hat, realmat& V_hat,
         INTEGER& n)
{ 
  rho_hat = mean;
  V_hat = sscp/n_dat;
  n = n_dat;
}

void minimal_asymptotics::get_asymptotics(realmat& rho_mean, realmat& rho_mode,
         REAL& post_high, realmat& I, realmat& invJ, realmat& foc_hat, 
         INTEGER& reps)
{ 
  realmat null;
  rho_mean = mean; 
  rho_mode = mode; 
  post_high = high;
  I = null; 
  invJ = sscp;
  foc_hat = null;
  reps = 0;
}

