/* ----------------------------------------------------------------------------

Copyright (C) 2002, 2003, 2004, 2005, 2006, 2007.

A. Ronald Gallant

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsmm.h"
using namespace std;  
using namespace scl;
using namespace libsmm;

#define RANDOM_WALK_CONDITIONAL_MOVE
//#undef  RANDOM_WALK_CONDITIONAL_MOVE

namespace {

  bool check_prop_def(prop_def def) 
  {
    if (def.empty()) error("Error, check_prop_def, empty proposal_def");

    INTEGER lrho = 0;
    INTEGER lpd = def.size();
 
   for (INTEGER i=0; i<lpd; i++) {
      INTEGER lgvec = def[i].gvec.size();
      if (lgvec <= 0) error("Error, check_prop_def, empty gvec");
      if (def[i].freq < 0) error("Error, check_prop_def, negative freq");
      if (lgvec != def[i].ginc.size()) error("Error, check_prop_def, ginc");
      if (lgvec != def[i].mean.size()) error("Error, check_prop_def, mean");
      lrho += lgvec;
    }

    intvec sorted(lrho);
    for (INTEGER i=0; i<lpd; i++) {
      intvec gvec = def[i].gvec;
      INTEGER lgvec = gvec.size();
      for (INTEGER j=1; j<= lgvec; j++) {
        if ( (gvec[j] < 1) || (lrho < gvec[j]) ) {
          error("Error, check_prop_def, some values of gvec out of range");
        }
        sorted[gvec[j]] = gvec[j];
      }
    }
  
    if (sorted != seq(1,lrho) ) {
      error("Error, check_prop_def, some values of gvec not set");
    }
    
    for (INTEGER i=0; i<lpd; i++) {
      INTEGER lgvec = def[i].gvec.size();
      realmat Vmat = def[i].Vmat;
      if ( Vmat.get_rows() != lgvec || Vmat.get_cols() != lgvec ) {
        error("Error, check_prop_def, Vmat wrong size");
      }
      for (INTEGER j=1; j<=lgvec; j++) {
        REAL var = Vmat(j,j);
        if (var <= 0) {
          error("Error, check_prop_def, variances must be postive.");
        }
      }
    }

    return true;
  }
}

bool libsmm::grid_group_move::get_grid(const realmat& parm, realmat& parm_real, 
      intvec& parm_int )
{
  const REAL lo = REAL(INT_MIN + 2);
  const REAL hi = REAL(INT_MAX - 2);

  INTEGER rows = parm.get_rows();
  INTEGER cols = parm.get_cols();

  parm_real = parm;
  parm_int.resize(rows,1);

  if ( increment.get_rows() != rows || cols != 1 ) {
    error("Error, grid_group_move, get_grid, bad dimension");
  }

  bool rv = true;
  
  for (INTEGER i=1; i<=rows; ++i) {
    REAL r_grid = parm[i]/increment[i];
    if (lo < r_grid && r_grid < hi) {
      parm_int[i] = r_grid < 0 ? -INTEGER(-r_grid)-1 : INTEGER(r_grid);
    }
    else if (r_grid <= lo) {
      parm_int[i] = INT_MIN + 2;
      rv = false;
    } 
    else {
      parm_int[i] = INT_MAX - 2;
      rv = false;
    }
    parm_real[i] = REAL(parm_int[i])*increment[i] + 0.5*increment[i];
  }

  return rv;
}


void libsmm::grid_group_move::make_pdf
  (INTEGER d, const realmat& ginc, const realmat& Vinv, 
   const intvec& lo, const intvec& hi, intvec& index, pdf_type& pdf)
{

  pdf_type::size_type max_size = 5000000;

  string msg;
  msg += "\nError, grid_group_move, make_pdf.  The number of points \n";
  msg += "in the support of the density for a group exceeds";
  msg += string(fmt('d',8,INTEGER(max_size)).get_ostr()) + ".\n";
  msg += "The number of support points can be reduced by reducing \n"; 
  msg += "scale, increasing increment size, and/or deleting \n";
  msg += "members from the group.\n";

  INTEGER lgvec = ginc.size();

  if (d == lgvec) {
    index = lo;
  }

  for (INTEGER j = lo[d]; j<=hi[d]; ++j) {
    index[d] = j;
    if (d == 1) {
      realmat x(lgvec,1);
      for (INTEGER i=1; i<=lgvec; ++i) x[i]=REAL(index[i])*ginc[i];
      realmat q = T(x)*(Vinv*x);
      pdf[index] = exp(-0.5*q[1]);
      if (pdf.size() > max_size) error(msg);
    }
    else {
      make_pdf(d-1, ginc, Vinv, lo, hi, index, pdf);
    }
  }
}

libsmm::grid_group_move::grid_group_move(const prop_def& def) 
  : pd(def)
{
  if(!check_prop_def(pd)) error("Error, grid_group_move, bad prop_def");

  lpd = pd.size();
  pa.resize(lpd);

  lparm = 0;
  REAL sum_freq = 0;
  
  for (INTEGER i=0; i<lpd; i++) {
    INTEGER lgvec = pd[i].gvec.size();
    lparm += lgvec;
    sum_freq += pd[i].freq;
  }

  increment.resize(lparm,1);

  for (INTEGER i=0; i<lpd; i++) {
    pa[i].prob = pd[i].freq/sum_freq;
    intvec gvec = pd[i].gvec;
    realmat ginc = pd[i].ginc;
    INTEGER lgvec = gvec.size();
    for (INTEGER j=1; j<=lgvec; ++j) increment[gvec[j]] = ginc[j];
    realmat Vmat = pd[i].Vmat;
    Vmat = (T(Vmat)+Vmat)/2.0;
    realmat U,S,V;
    INTEGER rank = svd(Vmat,U,S,V);
    if (rank != lgvec) {
      error("Error, grid_group_move, Vmat not a variance matrix");
    }
    realmat invS(lgvec,1);
    for (INTEGER j=1; j<=lgvec; j++) {
      invS[j] = 1.0/S[j];
    }
    realmat Vinv = V*diag(invS)*T(U);
    const REAL top = REAL(INT_MAX)/2.0;
    const REAL bot = REAL(INT_MIN)/2.0;
    intvec lo(lgvec,1);
    intvec hi(lgvec,1);
    for (INTEGER j=1; j<=lgvec; ++j) {
      REAL r_grid = 2.0*sqrt(Vmat(j,j))/ginc[j];
      if (bot < r_grid && r_grid < top) {
        hi[j]=INTEGER(r_grid);
        lo[j]=-hi[j];
        if (hi[j] == 0) {
          error("Error, grid_group_move, Vmat too small or ginc too big");
        }
      }
      else {
        error("Error, grid_group_move, Vmat too big or ginc too small");
      }
    }
    intvec index;
    make_pdf(lgvec, pd[i].ginc, Vinv, lo, hi, index, pa[i].pdf); 
    fill(index,0);
    pa[i].pdf[index] = 0.0;  // No sense to propose not moving
  }

  for (INTEGER i=0; i<lpd; i++) {
    pdf_type& p = pa[i].pdf;
    REAL sum = 0.0;
    for (pdf_itr itr = p.begin(); itr!=p.end(); ++itr) sum += itr->second;
    for (pdf_itr itr = p.begin(); itr!=p.end(); ++itr) itr->second /= sum;
  }

}

den_val libsmm::grid_group_move::operator()
        (const realmat& parm_old, const realmat& parm_new)
{
  if (parm_old.get_rows()!=lparm) error("Error, grid_group_move, bad parm_old");

  realmat parm_old_real;
  intvec parm_old_int;
  if ( ! get_grid(parm_old,parm_old_real,parm_old_int) ) {
    error("Error, grid_group_move, get_grid, increment too small");
  }

  realmat parm_new_real;
  intvec parm_new_int;
  if ( ! get_grid(parm_new,parm_new_real,parm_new_int) ) {
    error("Error, grid_group_move, get_grid, increment too small");
  }

  REAL log_den = -REAL_MAX;
  bool positive = false;

  REAL sum = 0.0;
  INTEGER count = 0;

  for (INTEGER i=0; i<lpd; i++) {
    intvec gvec = pd[i].gvec;
    INTEGER lgvec = gvec.size();
    intvec index(lgvec);
    intvec ivec = seq(1,lparm);
    for (INTEGER j=1; j<=lgvec; j++) {
      ivec[gvec[j]] = 0;
      index[j] = parm_new_int[gvec[j]]-parm_old_int[gvec[j]];
    }
    if ( parm_new_real(ivec,1) == parm_old_real(ivec,1) ) {
      REAL prob_i = pa[i].prob;
      REAL prob_index = pa[i].pdf[index];
      sum += prob_i*prob_index;
      count++;
      positive = (prob_i > 0.0) && (prob_index > 0.0);
      if (positive) log_den = log(prob_i) + log(prob_index);
    }
  }

  if (sum == 0.0) {
    return den_val(false, -REAL_MAX); 
  } 
  else if (count == 1) {
    return den_val(positive, log_den);
  }
  else {
    return den_val(true, log(sum));
  }
}

void libsmm::grid_group_move::draw
        (INT_32BIT& seed, const realmat& parm_old, realmat& parm_new)
{
  INT_32BIT jseed = seed;

  realmat parm_new_real;
  intvec parm_new_int;
  if ( ! get_grid(parm_old, parm_new_real, parm_new_int) ) {
    error("Error, grid_group_move, get_grid, increment too small");
  }
  
  REAL x = scl::ran(&jseed);
  REAL cum_prob = 0.0;
  INTEGER i = lpd-1;

  for (INTEGER j=0; j<lpd; j++) {
    cum_prob += pa[j].prob;
    if (x <= cum_prob) {
      i = j;
      break;
    }
  }
  
  intvec gvec = pd[i].gvec;
  INTEGER lgvec = gvec.size();
  realmat ginc = pd[i].ginc;
  pdf_type pdf = pa[i].pdf;

  x = scl::ran(&jseed);
  cum_prob = 0.0;

  pdf_itr draw = pdf.end();
  --draw;

  for (pdf_itr itr = pdf.begin(); itr!=pdf.end(); ++itr) {
    cum_prob += itr->second;
    if (x <= cum_prob) {
      draw = itr;
      break;
    }
  }

  intvec index = draw->first;

  for (INTEGER j=1; j<=lgvec; ++j) {
    parm_new_real[gvec[j]] += REAL(index[j])*ginc[j];
  }

  parm_new = parm_new_real;

  seed = jseed;
}


ostream& libsmm::grid_group_move::write_proposal(ostream& os)
{
  os << starbox("/grid_group_move proposal//") << '\n';
  for (vector<prop_aux>::size_type i=0; i<pa.size(); ++i) {
    os <<"\t Probability select group " << i << " is " << pa[i].prob << '\n';
    os <<"\t Group " << i << " density function is:" << '\n';
    os <<"\t prob      support in increments" << '\n';
    pdf_type pdf = pa[i].pdf;
    for (pdf_itr itr = pdf.begin(); itr!=pdf.end(); ++itr) {
      REAL prob = itr->second;
      intvec support = itr->first;
      os << "\t " << fmt('f',8,7,prob);
      for (INTEGER j=1; j<=support.size(); ++j) os << fmt('d',3,support[j]);
      os << '\n';
    }
    os << '\n';
  }
  return os;
}
  

libsmm::group_move::group_move(prop_def def) 
  : pd(def)
{
  if(!check_prop_def(pd)) error("Error, group_move, bad prop_def");

  const REAL root_two_pi = 2.5066282746310005024;

  lpd = pd.size();
  pa.resize(lpd);

  lrho = 0;
  REAL sum_freq = 0;
  
  for (INTEGER i=0; i<lpd; i++) {
    INTEGER lgvec = pd[i].gvec.size();
    lrho += lgvec;
    sum_freq += pd[i].freq;
  }

  for (INTEGER i=0; i<lpd; i++) {
    intvec  gvec = pd[i].gvec;
    INTEGER lgvec = gvec.size();

    realmat Vmat = pd[i].Vmat;
    Vmat = (T(Vmat)+Vmat)/2.0;

    realmat U,S,V;
    INTEGER rank = svd(Vmat,U,S,V);
    if (rank != lgvec) {
      error("Error, group_move, Vmat not a variance matrix");
    }

    realmat sqrtS(lgvec,1);
    realmat invS(lgvec,1);
    REAL rscale = pow(root_two_pi,lgvec);
    for (INTEGER j=1; j<=lgvec; j++) {
      sqrtS[j] = sqrt(S[j]);
      invS[j] = 1.0/S[j];
      rscale *= sqrtS[j];
    }

    pa[i].prob = pd[i].freq/sum_freq;

    pa[i].scale = 1.0/rscale;

    pa[i].Vinv = V*diag(invS)*T(U);

    pa[i].Rmat = U*diag(sqrtS); 
  }
}

den_val libsmm::group_move::operator()
        (const realmat& rho_old, const realmat& rho_new)
{
  if (rho_old.get_rows() != lrho) error("Error, group_move, bad rho_old");

  REAL log_den = -REAL_MAX;
  bool positive = false;
  REAL sum = 0.0;
  INTEGER count = 0;

  for (INTEGER i=0; i<lpd; i++) {
    intvec ivec = seq(1,lrho);
    intvec gvec = pd[i].gvec;
    for (INTEGER j=1; j<=gvec.size(); j++) {
      ivec[gvec[j]] = 0;
    }
    if ( rho_new(ivec,1) == rho_old(ivec,1) ) {
      realmat x = rho_new(gvec,1) - rho_old(gvec,1);
      realmat q = T(x)*(pa[i].Vinv*x);
      sum += (pa[i].prob)*(pa[i].scale)*exp(-0.5*q[1]);
      count++;
      positive = (pa[i].prob > 0.0) && (pa[i].scale > 0.0);
      if (positive) log_den = log(pa[i].prob) + log(pa[i].scale) - 0.5*q[1];
    }
  }
  if (sum == 0.0) {
    return den_val(false, -REAL_MAX); 
  } 
  else if (count == 1) {
    return den_val(positive, log_den);
  }
  else {
    return den_val(true, log(sum));
  }
}

void libsmm::group_move::draw
        (INT_32BIT& seed, const realmat& rho_old, realmat& rho_new)
{
  INT_32BIT jseed = seed;

  rho_new = rho_old;
  
  if (rho_new.get_rows() != lrho) error("Error, group_move, draw, bad rho");

  REAL x = scl::ran(&jseed);
  REAL cum_prob = 0.0;
  INTEGER i = lpd-1;

  for (INTEGER j=0; j<lpd; j++) {
    cum_prob += pa[j].prob;
    if (x <= cum_prob) {
      i = j;
      break;
    }
  }
  
  intvec gvec = pd[i].gvec;
  INTEGER lgvec = gvec.size();
  realmat z(lgvec,1);

  for (INTEGER j=1; j<=lgvec; j++) {
    z[j]=scl::unsk(&jseed);
  }

  realmat u = pa[i].Rmat*z;

  for (INTEGER j=1; j<=lgvec; j++) {
    rho_new[gvec[j]] += u[j];
  }

  seed = jseed;
}

libsmm::conditional_move::conditional_move(prop_def pd, ostream& detail, 
           bool print) 
{
  if(!check_prop_def(pd)) error("Error, conditional_move, bad prop_def");

  lrho = 0;
  lpm = 0;
  
  for (prop_def::const_iterator itr=pd.begin(); itr!=pd.end(); ++itr) {
    lrho += itr->gvec.size();
    if(itr->freq>0) lpm += itr->gvec.size();
  }

  intvec pmvec(lpm);
  realmat mean(lrho,1,0.0);
  realmat Vmat(lrho,lrho,0.0);
  
  INTEGER row = 0;
  for (prop_def::const_iterator itr=pd.begin(); itr!=pd.end(); ++itr) {
    if(itr->freq>0) {
      intvec gvec = itr->gvec; 
      for (INTEGER i=1; i<=gvec.size(); ++i) {
        pmvec[++row] = gvec[i];
        mean[gvec[i]] = itr->mean[i];
        for (INTEGER j=1; j<=gvec.size(); ++j) {
          Vmat(gvec[j],gvec[i]) = itr->Vmat(j,i);
        }
      }
    }
  }
  Vmat = (T(Vmat)+Vmat)/2.0;

  if (row != lpm) error("Error, conditional_move, inexplicable");
  if (lpm == 1) error("Error, conditional_move, need at least two parameters");

  pm.resize(lpm);

  for (INTEGER k=1; k<=lpm; ++k) {
    INTEGER y = pmvec[k];
    intvec x = pmvec;
    x[k] = 0;
    realmat mu_y(1,1,mean[y]);
    realmat mu_x = mean(x,1);
    realmat sig_yy(1,1,Vmat(y,y));
    realmat sig_xy = Vmat(x,y);
    realmat sig_xx = Vmat(x,x);
    realmat XX = sig_xx;
    realmat Xy = sig_xy;
    realmat invXX = invpsd(XX);
    realmat b = invXX*Xy;
    realmat v = sig_yy - (T(b)*sig_xx)*b;
    if (v[1] <= 0.0) error("Error, conditional_move, bad prop_def");
    INTEGER count = 0;
    for (INTEGER i=1; i<=b.size(); ++i) {
      if (fabs(b[i]) < sqrt(invXX(i,i)*v[1])) {
        ++count;
        b[i] = 0.0;
        Xy[i] = 0.0;
        for (INTEGER j=1; j<=b.size(); ++j) {
          XX(i,j) = 0.0;
          XX(j,i) = 0.0;
        }
      }
    }
    if (count != b.size()) {
      psdsol(XX,Xy);
      b = Xy;
    }
    v = sig_yy - (T(b)*sig_xx)*b;
    if (v[1] <= 0.0) error("Error, conditional_move, bad prop_def");
    #if defined RANDOM_WALK_CONDITIONAL_MOVE
      realmat b0 = - T(b)*mu_x;
    #else
      realmat b0 = mu_y - T(b)*mu_x;
    #endif
    pm[k-1].y = y;
    pm[k-1].x = x;
    pm[k-1].b0 = b0[1];
    pm[k-1].b = b;
    pm[k-1].s = sqrt(v[1]);
  }
  if (print) {
    detail << starbox("/Conditional Move Proposal Regression//");
    for (INTEGER k=0; k<lpm; ++k) {
      detail << '\n';
      detail << "\t dependent = " << pm[k].y << '\n';
      detail << "\t independent = " << '\n';
      for (INTEGER i=1; i<=pm[k].x.size(); ++i) {
        if (pm[k].x[i] != 0) detail << "\t\t\t" << pm[k].x[i] << '\n';
      }
      detail << "\t coefficients = " << pm[k].b << '\n';
      detail << "\t cond scale = " << pm[k].s << '\n';
    }
  }
}


void libsmm::conditional_move::draw
        (INT_32BIT& seed, const realmat& rho_old, realmat& rho_new)
{
  INT_32BIT jseed = seed;

  rho_new = rho_old;
  
  if (rho_new.get_rows() != lrho) error("Error, conditional_move, bad rho");

  REAL u = scl::ran(&jseed);

  const REAL freq = 1.0/REAL(lpm);
  REAL cum_prob = 0.0;
  INTEGER k = lpm-1;

  for (INTEGER j=0; j<lpm; ++j) {
    cum_prob += freq;
    if (u <= cum_prob) {
      k = j;
      break;
    }
  }

  REAL z = scl::unsk(&jseed);

  realmat bx = T(pm[k].b)*rho_new(pm[k].x,1);

  #if defined RANDOM_WALK_CONDITIONAL_MOVE
    rho_new[pm[k].y] += pm[k].b0 + bx[1] + pm[k].s*z;
  #else  
    rho_new[pm[k].y] = pm[k].b0 + bx[1] + pm[k].s*z;
  #endif

  seed = jseed;
}

den_val libsmm::conditional_move::operator()
        (const realmat& rho_old, const realmat& rho_new)
{
  if (rho_old.get_rows() != lrho) error("Error, group_move, bad rho_old");

  const REAL freq = 1.0/REAL(lpm);
  const REAL root_two_pi = 2.5066282746310005024;

  REAL log_den = -REAL_MAX;
  REAL sum = 0.0;
  INTEGER count = 0;

  for (INTEGER k=0; k<lpm; ++k) {
    intvec ivec = seq(1,lrho);
    ivec[pm[k].y] = 0;
    if ( rho_new(ivec,1) == rho_old(ivec,1) ) {
      realmat bx = T(pm[k].b)*rho_old(pm[k].x,1);
      #if defined RANDOM_WALK_CONDITIONAL_MOVE
        REAL mu = rho_old[pm[k].y] + pm[k].b0 + bx[1];
      #else
        REAL mu = pm[k].b0 + bx[1];
      #endif
      REAL sig = pm[k].s;
      REAL x = rho_new[pm[k].y] - mu;
      REAL q = pow(x/sig,2);
      REAL c = 1.0/(root_two_pi*sig);
      sum += freq*c*exp(-0.5*q);
      count++;
      log_den = log(freq*c) - 0.5*q;
    }
  }
  if (sum == 0.0) {
    return den_val(false, -REAL_MAX); 
  } 
  else if (count == 1) {
    return den_val(true, log_den);
  }
  else {
    return den_val(true, log(sum));
  }
}
