/*-----------------------------------------------------------------------------

Copyright (C) 2013.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/
#undef USR_PROPOSAL_TYPE_IMPLEMENTED

#include "pathname.h"

#include <typeinfo>
#include "mpi.h"

#include "libscl.h"
#include "libmle.h"
#include "mle.h"

using namespace scl;
using namespace libmle;
using namespace mle;
using namespace std;

namespace {

  int my_rank;
  int no_procs;

  void mpi_error (string msg) {
    cerr << msg << endl; MPI_Abort(MPI_COMM_WORLD, my_rank);
  }

  void mpi_warn (string msg) {
    if (my_rank == 0) cerr << msg << endl;
  }
  
  void output(const estblock& est_blk,ostream& detail,INTEGER ifile,
      string prefix, const realmat& rho_sim, const realmat& stats_sim, 
      const realmat& pi_sim, realmat reject, const realmat& rho_hat, 
      const realmat& V_hat, INTEGER n, const realmat& rho_mean, 
      const realmat& rho_mode, REAL post_high, const realmat& foc_hat, 
      const realmat& I, const realmat& invJ, INTEGER reps);
}

int main(int argc, char** argp, char** envp)
{
  MPI_Init(&argc, &argp);
  MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
  MPI_Comm_size(MPI_COMM_WORLD, &no_procs);

  LIB_ERROR_HANDLER_PTR previous_error = set_lib_error_handler(&mpi_error);
  LIB_WARN_HANDLER_PTR previous_warn = set_lib_warn_handler(&mpi_warn);

  string pathname = "/tmp/";

  if (my_rank == 0) {
    pathname = string(PATHNAME) + string("/");
  }

  if (my_rank == 0) { // Check some type assumptions

    REAL x1; double x2;
    if (typeid(x1) != typeid(x2) ) {
      error("Error, a REAL is not a double");
      return 1;
    }

    INTEGER i1; int i2;
    if (typeid(i1) != typeid(i2)) {
      error("Error, an INTEGER is not an int");
      return 1;
    }

    /*
    Cannot check these assumptions:
    MPI_DOUBLE is double
    MPI_INT    is int
    MPI__CHAR  is char
    */
  }

  ostream* detail_ptr = &cout;

  string parmfile, prefix;

  vector<string> pfvec; // Parmfile as a vector of strings

  vecstrbuf pfbuf;   // Parmfile as char array of null
                     // terminated strings stored end to end

  int  dim[4];       // dim[0] size, dim[1] rows, dim[2] cols, dim[3] extra

  dim[0] = dim[1] = dim[2] = dim[3] = 0;

  if (my_rank == 0) {

    string filename = pathname + "control.dat";

    ifstream ctrl_ifs(filename.c_str());
    if (!ctrl_ifs) error("Error, mle, control.dat open failed");
 
    ctrl_ifs >> parmfile >> prefix;

    filename = pathname + prefix + ".detail.dat";

    detail_ptr = new(nothrow) ofstream(filename.c_str());
    if ( (detail_ptr==0) || (!*detail_ptr) )
      error("Error, mle, detail.dat open failed");

    filename = pathname + parmfile;

    ifstream pf_ifs(filename.c_str());
    if (!pf_ifs) error("Error, gsm, cannot open " + parmfile);

    string line;

    while (getline(pf_ifs, line)) pfvec.push_back(line);

    vecstrbuf send_buf(pfvec);

    dim[0] = send_buf.size();
    dim[1] = send_buf.get_rows();
    dim[2] = send_buf.get_cols();

    pfbuf = send_buf;
  }

  MPI_Bcast(&dim,4,MPI_INT,0,MPI_COMM_WORLD);
  if (my_rank != 0) pfbuf.resize(dim[1],dim[2]);
  MPI_Bcast(pfbuf.get_ptr(),dim[0],MPI_CHAR,0,MPI_COMM_WORLD);
  if (my_rank != 0) pfvec = pfbuf.get_vec();

  if (my_rank != 0) { // Must suppress print if my_rank != 0
    keyword kw;
    vector<string>::iterator kw_ptr;
    string header=kw.set_keyword("ESTIMATION DESCRIPTION");
    kw_ptr = find_if(pfvec.begin(), pfvec.end(), kw);
    if (kw_ptr == pfvec.end() || kw_ptr + 5 > pfvec.end())
      error("Error, mle, " + header);
    kw_ptr += 5;
    *kw_ptr = "0 ";
  }

  ostream& detail = *detail_ptr;

  mleparms pf;

  pf.set_parms(pfvec, detail);

  estblock est_blk = pf.get_estblock();
  datblock dat_blk = pf.get_datblock();
  modblock mod_blk = pf.get_modblock();

  realmat data; 

  if (my_rank == 0) {

    if (!dat_blk.read_data(pathname,data)) {
      error("Error, mleparm, cannot read data, dsn = " + dat_blk.dsn);
    }

    if (est_blk.print) {
      detail << starbox("/First 12 observations//");
      detail << data("",seq(1,12));
      detail << starbox("/Last 12 observations//");
      detail << data("",seq(dat_blk.n-11, dat_blk.n));
      detail.flush();
    }
  
    dim[0] = data.size();
    dim[1] = data.get_rows();
    dim[2] = data.get_cols();
  }

  MPI_Bcast(&dim,4,MPI_INT,0,MPI_COMM_WORLD);
  if (my_rank != 0) data.resize(dim[1],dim[2]);
  MPI_Bcast(data.get_x(),dim[0],MPI_DOUBLE,0,MPI_COMM_WORLD);

  vector<string> mod_pfvec;
  vecstrbuf mod_pfbuf;

  if (mod_blk.is_mod_parmfile) {

    if (my_rank == 0) {
      string filename;

      if (mod_blk.mod_parmfile[0] == '/') {
        filename = mod_blk.mod_parmfile;
      } 
      else {
        filename = pathname + mod_blk.mod_parmfile;
      }

      ifstream mod_pf_ifs(filename.c_str());
      if (!mod_pf_ifs) error("Error, mle, cannot open " + filename);

      string line;
      while (getline(mod_pf_ifs, line)) mod_pfvec.push_back(line);

      vecstrbuf send_buf(mod_pfvec);

      dim[0] = send_buf.size();
      dim[1] = send_buf.get_rows();
      dim[2] = send_buf.get_cols();

      mod_pfbuf = send_buf;
    }

    MPI_Bcast(&dim,4,MPI_INT,0,MPI_COMM_WORLD);
    if (my_rank != 0) mod_pfbuf.resize(dim[1],dim[2]);
    MPI_Bcast(mod_pfbuf.get_ptr(),dim[0],MPI_CHAR,0,MPI_COMM_WORLD);
    if (my_rank != 0) mod_pfvec = mod_pfbuf.get_vec();
  }

  usrmod_type usrmod 
    (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
     mod_pfvec, pf.get_mod_alvec(), detail);

  proposal_base* proposal_ptr;

  if (est_blk.proptype == 0) {
    proposal_ptr = new(nothrow) group_move(pf.get_prop_groups());
    if (proposal_ptr == 0) error("Error, mle main, operator new failed.");
  }
  else if (est_blk.proptype == 1) {
    bool print = est_blk.print;
    proposal_ptr 
      = new(nothrow) conditional_move(pf.get_prop_groups(),detail,print);
    if (proposal_ptr == 0) error("Error, mle main, operator new failed.");
  } else {
    #if defined USR_PROPOSAL_TYPE_IMPLEMENTED
      proposal_ptr = new(nothrow) proposal_type(pf.get_prop_groups());
      if (proposal_ptr == 0) error("Error, mle main, operator new failed.");
    #else
      error("Error, mle, no user proposal type implemented");
      proposal_ptr = new(nothrow) group_move(pf.get_prop_groups());
      if (proposal_ptr == 0) error("Error, mle main, operator new failed.");
    #endif
  }

  proposal_base& proposal = *proposal_ptr;

  vector<string> obj_pfvec;
  vecstrbuf obj_pfbuf;

  mle_mcmc mcmc(proposal, usrmod);
  mcmc.set_simulation_size(est_blk.lchain);
  mcmc.set_stride(est_blk.stride);
  mcmc.set_draw_from_posterior(!est_blk.draw_from_prior);
  mcmc.set_temp(est_blk.temperature);

  asymptotics_base* asymptotics_ptr;

  if (est_blk.kilse) {
    asymptotics_ptr 
      = new(nothrow) minimal_asymptotics(data,usrmod,mcmc);
    if (asymptotics_ptr == 0) error("Error, mle main, operator new failed.");
  }
  else {
    asymptotics_ptr 
      = new(nothrow) sandwich_asymptotics(data,usrmod,mcmc,est_blk.lhac);
    if (asymptotics_ptr == 0) error("Error, mle main, operator new failed.");
  }

  asymptotics_base& asymptotics = *asymptotics_ptr;

  INT_32BIT seed = est_blk.seed;
  realmat rho = pf.get_rho();

  // To be interchanged are the following

  realmat new_mode;
  REAL    new_high;
  realmat rho_chain;
  realmat stats_chain;
  realmat pi_chain;
  realmat reject;

  if (my_rank == 0) {

    INTEGER count = 1;

    while (count < no_procs) {

      int  src;
      int  tag;
      int  ifile;
      char number[5];
      char sender[5];

      MPI_Status status;

      MPI_Recv
        (&ifile,1,MPI_INT,MPI_ANY_SOURCE,MPI_ANY_TAG,MPI_COMM_WORLD,&status);

      src = status.MPI_SOURCE;
      tag = status.MPI_TAG;

      sprintf(number,"%03d",ifile);
      sprintf(sender,"%03d",src);

      string id = string(sender) + string(".") + string(number);

      REAL* buf;
      string filename;

      if (tag == 50) {

        MPI_Recv(&dim,4,MPI_INT,src,tag,MPI_COMM_WORLD,&status);
        new_mode.resize(dim[1],dim[2]);
        buf = new_mode.get_x();
        MPI_Recv(buf,dim[0],MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        MPI_Recv(&new_high,1,MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        mcmc.set_mode(new_mode,new_high);

        MPI_Recv(&dim,4,MPI_INT,src,tag,MPI_COMM_WORLD,&status);
        rho_chain.resize(dim[1],dim[2]);
        buf = rho_chain.get_x();
        MPI_Recv(buf,dim[0],MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        filename = pathname + prefix + ".rho." + id;
        vecwrite(filename.c_str(),rho_chain);
        asymptotics.set_asymptotics(rho_chain);

        MPI_Recv(&dim,4,MPI_INT,src,tag,MPI_COMM_WORLD,&status);
        stats_chain.resize(dim[1],dim[2]);
        buf = stats_chain.get_x();
        MPI_Recv(buf,dim[0],MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        filename = pathname + prefix + ".stats." + id;
        vecwrite(filename.c_str(),stats_chain);

        MPI_Recv(&dim,4,MPI_INT,src,tag,MPI_COMM_WORLD,&status);
        pi_chain.resize(dim[1],dim[2]);
        buf = pi_chain.get_x();
        MPI_Recv(buf,dim[0],MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        filename = pathname + prefix + ".pi." + id;
        vecwrite(filename.c_str(),pi_chain);

        MPI_Recv(&dim,4,MPI_INT,src,tag,MPI_COMM_WORLD,&status);
        reject.resize(dim[1],dim[2]);
        buf = reject.get_x();
        MPI_Recv(buf,dim[0],MPI_DOUBLE,src,tag,MPI_COMM_WORLD,&status);
        filename = pathname + prefix + ".reject." + id;
        vecwrite(filename.c_str(),reject);

        realmat rho_hat;
        realmat V_hat;
        INTEGER n;
  
        realmat rho_mean;
        realmat rho_mode;
        REAL post_high;
        realmat foc_hat;
        realmat I;
        realmat invJ;
        INTEGER reps = 0;
  
        asymptotics.get_asymptotics(rho_hat,V_hat,n);
        asymptotics.get_asymptotics
          (rho_mean,rho_mode,post_high,I,invJ,foc_hat,reps);
    
        usrmod.set_rho(rho_mode);
        realmat usr_stats;
        usrmod.get_stats(usr_stats);
        filename = pathname + prefix + ".diagnostics.dat";
      
        pf.write_parms(parmfile, pathname + prefix, seed, 
	  rho, rho_mode, invJ/n);
        usrmod.set_rho(rho_mode);
        filename = pathname + prefix + ".usrvar.dat";
        usrmod.write_usrvar(filename.c_str());

        output(est_blk, detail, ifile, pathname + prefix,
          rho_chain, stats_chain, pi_chain, reject, rho_hat, V_hat, n, 
          rho_mean, rho_mode, post_high, foc_hat, I, invJ, reps);
      } 
      else {
        count++;
        cout << "Process " << src << " has finished\n";
      }
    }
  } 
  else {
      
    for (INTEGER ifile = 0; ifile <= est_blk.nfile; ++ifile) {

      seed += my_rank*1111;

      reject = mcmc.draw(seed, rho, rho_chain, stats_chain, pi_chain);

      int tag = 50;
      int dest = 0;
      REAL* buf = 0;

      MPI_Send(&ifile,1,MPI_INT,dest,tag,MPI_COMM_WORLD);

      new_mode = mcmc.get_mode();
      new_high = mcmc.get_high(); 
      
      dim[0] = new_mode.size();
      dim[1] = new_mode.get_rows();
      dim[2] = new_mode.get_cols();
      MPI_Send(&dim,4,MPI_INT,dest,tag,MPI_COMM_WORLD);
      buf = new_mode.get_x();
      MPI_Send(buf,dim[0],MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);
      MPI_Send(&new_high,1,MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);

      dim[0] = rho_chain.size();
      dim[1] = rho_chain.get_rows();
      dim[2] = rho_chain.get_cols();
      MPI_Send(&dim,4,MPI_INT,dest,tag,MPI_COMM_WORLD);
      buf = rho_chain.get_x();
      MPI_Send(buf,dim[0],MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);

      dim[0] = stats_chain.size();
      dim[1] = stats_chain.get_rows();
      dim[2] = stats_chain.get_cols();
      MPI_Send(&dim,4,MPI_INT,dest,tag,MPI_COMM_WORLD);
      buf = stats_chain.get_x();
      MPI_Send(buf,dim[0],MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);

      dim[0] = pi_chain.size();
      dim[1] = pi_chain.get_rows();
      dim[2] = pi_chain.get_cols();
      MPI_Send(&dim,4,MPI_INT,dest,tag,MPI_COMM_WORLD);
      buf = pi_chain.get_x();
      MPI_Send(buf,dim[0],MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);

      dim[0] = reject.size();
      dim[1] = reject.get_rows();
      dim[2] = reject.get_cols();
      MPI_Send(&dim,4,MPI_INT,dest,tag,MPI_COMM_WORLD);
      buf = reject.get_x();
      MPI_Send(buf,dim[0],MPI_DOUBLE,dest,tag,MPI_COMM_WORLD);
    }

    int tag = 90;
    int dest = 0;
    int i = 0;
    MPI_Send(&i,1,MPI_INT,dest,tag,MPI_COMM_WORLD);
  }

  delete proposal_ptr;
  delete asymptotics_ptr;

  if (my_rank == 0) delete detail_ptr;

  MPI_Finalize();

  previous_error = set_lib_error_handler(previous_error);
  previous_warn = set_lib_warn_handler(previous_warn);

  return 0;
}

namespace {
  void output(const estblock& est_blk,ostream& detail,INTEGER ifile,
      string prefix, const realmat& rho_sim, const realmat& stats_sim, 
      const realmat& pi_sim, realmat reject, const realmat& rho_hat, 
      const realmat& V_hat, INTEGER n, const realmat& rho_mean, 
      const realmat& rho_mode, REAL post_high, const realmat& foc_hat, 
      const realmat& I, const realmat& invJ, INTEGER reps)
  {
    if (ifile > 999) error("Error, mle, output, nfile too big");

    string filename;
  
    /*
    filename = prefix + ".rho_hat.dat";
    vecwrite(filename.c_str(), rho_hat);
    */

    filename = prefix + ".rho_mean.dat";
    vecwrite(filename.c_str(), rho_mean);
  
    filename = prefix + ".rho_mode.dat";
    vecwrite(filename.c_str(), rho_mode);
  
    realmat V_hat_hess = invJ/n;

    filename = prefix + ".V_hat_hess.dat";
    vecwrite(filename.c_str(), V_hat_hess);
  
    INTEGER lrho = rho_sim.nrow();
    realmat V_hat_info(lrho,lrho,0.0);

    if (!est_blk.kilse) {
      
      filename = prefix + ".V_hat_sand.dat";
      vecwrite(filename.c_str(), V_hat);

      if (I.size() > 0) V_hat_info = inv(I)/n;

      filename = prefix + ".V_hat_info.dat";
      vecwrite(filename.c_str(), V_hat_info);

      if (est_blk.print) {
        detail << starbox("/Get asymptotics/results are cumulative//") << '\n';
        detail << "\t ifile = " << ifile << '\n';
        detail << '\n';
        detail << "\t rho_hat = " << rho_hat << '\n';
        detail << "\t V_hat = " << V_hat << '\n';
        detail << "\t n = " << n << '\n';
        detail << "\t rho_mean = " << rho_mean << '\n';
        detail << "\t rho_mode = " << rho_mode << '\n';
        detail << "\t post_high = " << post_high << '\n';
        detail << "\t I = " << I << '\n';
        detail << "\t invJ = " << invJ << '\n';
        detail << "\t foc_hat = " << foc_hat << '\n';
        detail << "\t reps = " << reps << '\n';
        detail.flush();
      }
  
    }

    filename = prefix + ".summary.dat";
    ofstream summary_ofs(filename.c_str());
    if (summary_ofs) {
      summary_ofs 
        << "   parm"
        << "     rhomean"
        << "     rhomode"
        << "      sesand"
        << "      sehess"
        << "      seinfo"
        << '\n';
      for (INTEGER i=1; i<=rho_hat.size(); ++i) {
        if (est_blk.kilse) {
          summary_ofs
            << fmt('i',7,i)
            << fmt('g',12,5,rho_mean[i])
            << fmt('g',12,5,rho_mode[i])
            << "            "
            << fmt('g',12,5,sqrt(V_hat_hess(i,i)))
            << "            "
            << '\n';
        }
        else {
          summary_ofs
            << fmt('i',7,i)
            << fmt('g',12,5,rho_mean[i])
            << fmt('g',12,5,rho_mode[i])
            << fmt('g',12,5,sqrt(V_hat(i,i)))
            << fmt('g',12,5,sqrt(V_hat_hess(i,i)))
            << fmt('g',12,5,sqrt(V_hat_info(i,i)))
            << '\n';
        }
      }
      summary_ofs << '\n';
      summary_ofs << "The log posterior (log prior + log likelihood) at the";
      summary_ofs << " mode is" << fmt('g',12,5,post_high) << ".\n";
    }

  }
}

