/* ----------------------------------------------------------------------------

Copyright (C) 2005, 2006, 2009.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/
#undef SMOOTH_IMPLIED_MAP

#include "libscl.h"
#include "libgsm.h"
using namespace std;  
using namespace scl;
using namespace libgsm;

libgsm::assess_prior::assess_prior(const INTEGER len_stat_parm)
: l(len_stat_parm), n(0), sum_squares(l,l,0.0), 
  imap_set(false), kappa_set(false), sigma_set(false)
{ }
  
void libgsm::assess_prior::set_kappa(const REAL kap)
{
  kappa = kap;
  kappa_set = true;
}

const realmat& libgsm::assess_prior::get_sigma() const
{
  if (!sigma_set) error("Error, assess_prior, get_sigma, sigma not set");
  return sigma;
}

INTEGER libgsm::assess_prior::get_n() const
{
  if (!sigma_set) error("Error, assess_prior, get_n, sigma not set");
  return n;
}

namespace {
  struct smo_val {
    realmat sci_mod_parm;
    realmat stat_mod_parm;
    den_val sci_mod_prior;
  };
}

void libgsm::assess_prior::set_implied_map(const implied_map& im)
{
  if (im.empty()) {
    error("Error, assess_prior, set_implied_map, empty map");
  }

  sci_val mode = im.get_mode();

  if (l != mode.stat_mod_parm.size()) {
    error("Error, assess_prior, set_implied_map, len_stat_parm wrong");
  }

  #if defined SMOOTH_IMPLIED_MAP

    realmat center = mode.sci_mod_parm;
    INTEGER d = center.size();
    realmat sum(d,d,0.0);
    REAL count = 0.0;
    
    vector<smo_val> smovec;
    smovec.reserve(im.size());
  
    for (implied_map::im_c_itr itr=im.get_begin(); itr!=im.get_end(); ++itr) {
      smo_val smov;
      map_val mapv = itr->second;
      bool good = mapv.sci_mod_support 
                    && mapv.sci_mod_simulate 
                    && mapv.sci_mod_prior.positive
                    && mapv.stat_mod_logl.positive
                    && mapv.stat_mod_support
                    && mapv.stat_mod_prior.positive;
      if (good) {
        smov.sci_mod_parm = itr->first;
        smov.stat_mod_parm = mapv.stat_mod_parm;
        smov.sci_mod_prior = mapv.sci_mod_prior;
        smovec.push_back(smov);
        sum += (smov.sci_mod_parm - center)*T(smov.sci_mod_parm - center);
        ++count;
      }
    }
  
    if (smovec.empty()) {
      error("Error, assess_prior, set_implied_map, no good elements in map");
    }

    realmat scale(d,1);
    for (INTEGER i=1; i<=d; ++i) scale[i] = sqrt(count/sum(i,i));
  
    typedef vector<smo_val>::iterator smovec_itr;
    for (smovec_itr itr=smovec.begin(); itr!=smovec.end(); ++itr) {
      for (INTEGER i=1; i<=d; ++i) {
        itr->sci_mod_parm[i] *= scale[i];
      }
    }
  
    REAL h = 1.0/pow(count,1.0/REAL(d+4));
  
    manifold.clear();
    manifold.reserve(smovec.size());
  
    typedef vector<smo_val>::const_iterator smovec_c_itr;
    for (smovec_c_itr t=smovec.begin(); t!=smovec.end(); ++t) {
      realmat top(l,1,0.0);
      REAL bot = 0.0;
      for (smovec_c_itr s=smovec.begin(); s!=smovec.end(); ++s) {
        realmat z = (s->sci_mod_parm - t->sci_mod_parm);
        REAL weight = exp(-0.5*((T(z)*z)[1])/h);
        top += weight*s->stat_mod_parm;
        bot += weight;
      }
      man_val manv;
      manv.stat_parm = (1.0/bot)*top;
      manv.sci_prior = t->sci_mod_prior;
      manifold.push_back(manv);
    }

  #else

    manifold.clear();
    manifold.reserve(im.size());
  
    for (implied_map::im_c_itr itr=im.get_begin(); itr!=im.get_end(); ++itr) {
      map_val mapv = itr->second;
      bool good = mapv.sci_mod_support 
                    && mapv.sci_mod_simulate 
                    && mapv.sci_mod_prior.positive
                    && mapv.stat_mod_logl.positive
                    && mapv.stat_mod_support
                    && mapv.stat_mod_prior.positive;
      if (good) {
        man_val manv;
        manv.stat_parm = mapv.stat_mod_parm;
        manv.sci_prior = mapv.sci_mod_prior;
        manifold.push_back(manv);
      }
    }

    if (manifold.empty()) {
      error("Error, assess_prior, set_implied_map, no good elements in map");
    }

  #endif

  imap_set = true;
}     

void libgsm::assess_prior::update_sigma(const realmat& sig, INTEGER nobs)
{
  if (sig.nrow() != sum_squares.nrow() || sig.ncol() != sum_squares.ncol()) {
     error("Error, assess_prior, update_sigma, input sigma wrong size");
  }
  if (nobs <= 0) {
    warn("Warning, assess_prior, update_sigma, nobs <= 0");
    return;
  }
  sum_squares += REAL(nobs)*sig;
  n += nobs;
  sigma = sum_squares/n;
  sigma_set = true;
}

bool libgsm::assess_prior::read_sigma(const char* filename)
{
  ifstream sigma_ifs(filename);
  if (!sigma_ifs.good()) return false;
  INTEGER sz = vecread(sigma_ifs,sigma);
  sigma_ifs >> n;
  sum_squares = REAL(n)*sigma;
  sigma_set = true; 
  return bool(sz);
}

bool libgsm::assess_prior::write_sigma(const char* filename) const
{
  ofstream sigma_ofs(filename);
  if (!sigma_ofs.good()) return false;
  INTEGER sz = vecwrite(sigma_ofs,sigma);
  sigma_ofs << n << '\n';
  return bool(sz);
}

den_val libgsm::assess_prior::operator()(const realmat& eta) const
{
  if (!imap_set) {
    error("Error, assess_prior, operator(), implied_map not set");
  }
  if (!kappa_set) {
    error("Error, assess_prior, operator(), kappa not set");
  }
  if (!sigma_set) {
    error("Error, assess_prior, operator(), sigma not set");
  }
  
  realmat siginv = invpsd(sigma);

  typedef std::vector<man_val>::const_iterator c_itr;
  c_itr top = manifold.end();

  REAL dist = REAL_MAX;
  den_val prior;

  for (c_itr itr=manifold.begin(); itr!=top; ++itr){
    realmat q = T(eta - itr->stat_parm)*(siginv*(eta - itr->stat_parm));
    if (q[1] < dist) {
      dist = q[1];
      prior = itr->sci_prior;
    }
  }

  prior.log_den += -dist/(2.0*kappa);

  return prior;
}

void libgsm::assess_prior::increment_sigma(const realmat& eta_residual)
{
  sum_squares += eta_residual*T(eta_residual);
  ++n;
  sigma = sum_squares/n;
  sigma_set = true;
}
