/* ----------------------------------------------------------------------------

Copyright (C) 2002, 2003, 2004, 2005, 2006, 2007.

A. Ronald Gallant

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsmm.h"
using namespace std;  
using namespace scl;
using namespace libsmm;

void libsmm::smm_mcmc::set_simulation_size(INTEGER n)
{
  if (n > 0) {
    simulation_size = n;
  }
  else {
    error("Error, smm_mcmc, simulation_size must be positive");
  }
}

void libsmm::smm_mcmc::set_stride(INTEGER k)
{
  if (k > 0) {
    stride = k;
  }
  else {
    error("Error, smm_mcmc, stride must be positive");
  }
}

void libsmm::smm_mcmc::set_draw_from_posterior(bool from_posterior)
{
  draw_from_posterior = from_posterior;
}

void libsmm::smm_mcmc::set_temp(REAL temperature)
{
  if (temperature > 0.0) {
    temp = temperature;
  }
  else {
    error("Error, smm_mcmc, temperature must be positive");
  }
}

void libsmm::smm_mcmc::set_mode(const realmat& new_mode, REAL new_high)
{
  if (U.len_rho() != new_mode.size()) error("Error, smm_mcmc, bad new_mode");
  if (new_high > high) {
    mode = new_mode;
    high = new_high;
  }
}

realmat libsmm::smm_mcmc::draw(INT_32BIT& seed, realmat& rho_start, 
                 realmat& rho_sim, realmat& stats_sim, realmat& pi_sim) 
{
  INTEGER len = rho_start.get_rows();
  INTEGER num = L.len_stats();
  
  if (U.len_rho() != len) error("Error, smm_mcmc, bad rho_start or usrmod");
  if (L.len_rho() != len) error("Error, smm_mcmc, bad rho_start or cachemgr");
  if (T.len_rho() != len) error("Error, smm_mcmc, bad rho_start or proposal");

  INTEGER out_size = simulation_size/stride;
  if (simulation_size % stride != 0) ++out_size;
  INTEGER out_count = 0;
  
  if ( rho_sim.get_rows() != len || rho_sim.get_cols() != out_size ) {
    rho_sim.resize(len,out_size);
  }
  if ( stats_sim.get_rows() != num || stats_sim.get_cols() != out_size ) {
    stats_sim.resize(num,out_size);
  }
  if ( pi_sim.get_rows() != 3 || pi_sim.get_cols() != out_size ) {
    pi_sim.resize(3,out_size);
  }

  INT_32BIT  jseed = seed;

  realmat reject(len+1,4,0.0);
  
  realmat rho_old = rho_start;
  realmat stats_old(num,1);

  den_val likehood_old = L(rho_old,stats_old);
  den_val prior_old = U.prior(rho_old,stats_old);

  den_val pi_old = prior_old;
  if (draw_from_posterior) {
    pi_old += likehood_old;
  }

  mode = rho_start;
  high = pi_old.log_den;

  if (pi_old.positive) pi_old.log_den *= temp; 

  if (!pi_old.positive) error("Error, smm_mcmc, bad rho_start or prior");

  realmat stats_new = stats_old;
  realmat rho_new = rho_old;
  den_val likehood_new = likehood_old;
  den_val prior_new = prior_old;
  den_val pi_new = pi_old;

  for (INTEGER t=1; t <= simulation_size; t++) {

    T.draw(jseed, rho_old, rho_new);

    REAL u = scl::ran(&jseed);

    bool move = false;
    
     /*
    Setting the parameter before checking support is intentional.
    Do not change it.  SNP will not work if you do.
    */

    U.set_rho(rho_new);

    if (U.support(rho_new)) {

      likehood_new = L(rho_new,stats_new);
      prior_new = U.prior(rho_new,stats_new);

      pi_new = prior_new;
      if (draw_from_posterior) {
        pi_new += likehood_new;
      }

      if (pi_new.positive) {
        if (pi_new.log_den > high) {
          mode = rho_new; 
          high = pi_new.log_den;
        }
        pi_new.log_den *= temp;
        den_val top = pi_new;
        den_val prob = T(rho_new,rho_old);
        top += prob;
        den_val bot = pi_old;
        if (T.transition_is_symmetric()) {
          bot += prob;
        }
        else {
          bot += T(rho_old,rho_new);
        }
        REAL diff = top.log_den - bot.log_den;
        REAL r = ( (0.0 < diff) ? 1.0 : exp(diff) );
        if (u <= r) move = true;
      }
    }
    
    if (move) {
      reject(len+1,4) += 1;
      for (INTEGER i=1; i<=len; i++) {
        if (rho_old[i] != rho_new[i]) {
          reject(i,4) += 1;
        }
      }
      if ( (stride == 1) || (t % stride == 1) ) {
        ++out_count;
        for (INTEGER i=1; i<=len; i++) rho_sim(i,out_count)=rho_new[i];
        for (INTEGER i=1; i<=num; i++) stats_sim(i,out_count)=stats_new[i];
        pi_sim(1,out_count)=pi_new.log_den;
        pi_sim(2,out_count)=likehood_new.log_den;
        pi_sim(3,out_count)=prior_new.log_den;
      }
      rho_old = rho_new;
      stats_old = stats_new;
      pi_old = pi_new;
      likehood_old = likehood_new;
      prior_old = prior_new;
    }
    else {
      reject(len+1,4) += 1;
      reject(len+1,3) += 1;
      for (INTEGER i=1; i<=len; i++) {
        if (rho_old[i] != rho_new[i]) {
          reject(i,4) += 1;
          reject(i,3) += 1;
        }
      }
      if ( (stride == 1) || (t % stride == 1) ) {
        ++out_count;
        for (INTEGER i=1; i<=len; i++) rho_sim(i,out_count)=rho_old[i];
        for (INTEGER i=1; i<=num; i++) stats_sim(i,out_count)=stats_old[i];
        pi_sim(1,out_count)=pi_old.log_den;
        pi_sim(2,out_count)=likehood_old.log_den;
        pi_sim(3,out_count)=prior_old.log_den;
      }
    }
  }
  
  for (INTEGER i=1; i<=len; i++) {
    rho_start[i]=rho_sim(i,out_count);
  }

  if (out_count != out_size) error("Error, smm_mcmc, this should not happen");

  seed = jseed;
  for (INTEGER i=1; i<=len+1; ++i) {
    REAL bot = reject(i,4);
    if (bot>0.0) reject(i,1) = reject(i,3)/bot;
    bot = reject(len+1,4);
    if (bot>0.0) reject(i,2) = reject(i,4)/bot;
  }
  return reject;
}


