/*-----------------------------------------------------------------------------

Copyright (C) 2018

A. Ronald Gallant
Post Office Box 659
Raleigh NC 27514-0659
USA

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libscl.h"
#include "lprior.h"

using namespace scl;
using namespace std;

REAL ssd
  (INTEGER M, INTEGER L, INTEGER mrs_pos, INTEGER cf_pos, INTEGER years,
  const realmat& parms, REAL lambda, const realmat& target, 
  const intvec& idx, realmat& stats)

{
  const INTEGER nstats = 10;

  if (M != 2 && M != 3) error("Error, lprior, M must be either 2 or 3");
  if (parms.size() != M+M*M*L+(M*M+M)/2+2) error("Error, lprior, bad parms");
  if (mrs_pos > M || cf_pos > M)  error("Error, lprior, bad M or L");
  if (mrs_pos < 1 || cf_pos < 1)  error("Error, lprior, bad M or L");
  if (years < 1)  error("Error, lprior, bad years");
  if (target.size() != nstats) error("Error, lprior, bad target");
  if (idx.size() != nstats) error("Error, lprior, bad idx");

  stats.resize(nstats,1);

  realmat b0, B, R, S, theta;
  REAL beta, gamma;

  frac_rho(M, L, parms, b0, B, S, R, theta);

  beta = theta[1];
  gamma = theta[2];

  realmat IB = -B;
  for (INTEGER i=1; i<=M; ++i) IB(i,i) += 1.0;
  realmat IBinv;
  IBinv = inv(IB);

  //const REAL pi = 4.0*atan(1.0);
  const REAL pi = M_PI;
  const REAL sqrtpi = sqrt(pi);
  const REAL sqrt2 = sqrt(2.0);

  const INTEGER quad_n = 5;

  realmat x, w;
  if (hquad(quad_n,x,w)!=0) return -REAL_MAX;

  REAL lcg_mean = 0.0;
  REAL lcg_sdev = 0.0;

  REAL yld_mean_01 = 0.0;
  REAL yld_sdev_01 = 0.0;
  REAL yld_mean_yr = 0.0;
  REAL yld_sdev_yr = 0.0;

  REAL stk_mean_01 = 0.0;
  REAL stk_sdev_01 = 0.0;
  REAL stk_mean_yr = 0.0;
  REAL stk_sdev_yr = 0.0;

  realmat z(M,1);
  realmat y0(M,1);
  realmat ecf,pvcf,pv1,cecf,dcf,yld;

  if (M == 2) {
    for (INTEGER i=1; i<=quad_n; ++i) {
      for (INTEGER j=1; j<=quad_n; ++j) {
  
        z[1] = sqrt2*x[i];
        z[2] = sqrt2*x[j];
  
        y0 = IBinv*(b0 + R*z);
  
        REAL weight = w[i]*w[j]/pi;
  
        dcfyld(b0,B,S,y0,mrs_pos,cf_pos,years,ecf,pvcf,pv1,cecf,dcf,yld);

        lcg_mean += weight*(log(beta)-y0[1])/gamma;
        lcg_sdev += weight*pow((log(beta)-y0[1])/gamma,2);
  
        yld_mean_01 += weight*yld[1];
        yld_mean_yr += weight*yld[years];
        yld_sdev_01 += weight*pow(yld[1],2);
        yld_sdev_yr += weight*pow(yld[years],2);
  
        REAL ret01 = log(dcf[years]+ecf[1]) -log(dcf[years]);
        REAL retyr = ( log(dcf[years]+cecf[years]) -log(dcf[years]) )/years;
  
        stk_mean_01 += weight*ret01;
        stk_mean_yr += weight*retyr;
        stk_sdev_01 += weight*pow(ret01,2);
        stk_sdev_yr += weight*pow(retyr,2);
      }
    }
 }
 else if (M == 3) {
    for (INTEGER i=1; i<=quad_n; ++i) {
      for (INTEGER j=1; j<=quad_n; ++j) {
        for (INTEGER k=1; k<=quad_n; ++k) {

          z[1] = sqrt2*x[i];
          z[2] = sqrt2*x[j];
          z[3] = sqrt2*x[k];
    
          y0 = IBinv*(b0 + R*z);
    
          REAL weight = w[i]*w[j]*w[k]/(sqrtpi*pi);
    
          dcfyld(b0,B,S,y0,mrs_pos,cf_pos,years,ecf,pvcf,pv1,cecf,dcf,yld);
    
          lcg_mean += weight*(log(beta)-y0[1])/gamma;
          lcg_sdev += weight*pow((log(beta)-y0[1])/gamma,2);
    
          yld_mean_01 += weight*yld[1];
          yld_mean_yr += weight*yld[years];
          yld_sdev_01 += weight*pow(yld[1],2);
          yld_sdev_yr += weight*pow(yld[years],2);
    
          REAL ret01 = log(dcf[years]+ecf[1]) -log(dcf[years]);
          REAL retyr = ( log(dcf[years]+cecf[years]) -log(dcf[years]) )/years;
    
          stk_mean_01 += weight*ret01;
          stk_mean_yr += weight*retyr;
          stk_sdev_01 += weight*pow(ret01,2);
          stk_sdev_yr += weight*pow(retyr,2);
        }
      }
    }
  }
  else error("Error, lprior, M must be either 2 or 3");

  lcg_sdev = sqrt(lcg_sdev - pow(lcg_mean,2));

  yld_sdev_01 = sqrt(yld_sdev_01 - pow(yld_mean_01,2));
  yld_sdev_yr = sqrt(yld_sdev_yr - pow(yld_mean_yr,2));

  stk_sdev_01 = sqrt(stk_sdev_01 - pow(stk_mean_01,2));
  stk_sdev_yr = sqrt(stk_sdev_yr - pow(stk_mean_yr,2));

  stats[ 1] = lcg_mean;
  stats[ 2] = lcg_sdev;
  stats[ 3] = yld_mean_01;
  stats[ 4] = yld_sdev_01;
  stats[ 5] = yld_mean_yr;
  stats[ 6] = yld_sdev_yr;
  stats[ 7] = stk_mean_01;
  stats[ 8] = stk_sdev_01;
  stats[ 9] = stk_mean_yr;
  stats[10] = stk_sdev_yr;

  REAL ss = 0.0;
  for (INTEGER i=1; i<=nstats; ++i) {
    if (idx[i] == 1) ss += pow(stats[i]-target[i],2);
  }
  return ss;
}

scl::denval lprior
  (INTEGER M, INTEGER L, INTEGER mrs_pos, INTEGER cf_pos, INTEGER years,
  const realmat& parms, REAL lambda, const realmat& target, 
  const intvec& idx, realmat& stats)
{
  const REAL pi = 4.0*atan(1.0);
  const REAL minus_log_root_two_pi = -log(sqrt(2.0*pi));

  REAL ss = ssd(M,L,mrs_pos,cf_pos,years,parms,lambda,target,idx,stats);

  if (!IsFinite(ss)) return denval(false,-REAL_MAX);

  REAL logprior = minus_log_root_two_pi + 0.5*log(lambda) - 0.5*lambda*ss;

  return denval(true,logprior);
}

INTEGER frac_rho (INTEGER M, INTEGER L, const realmat& rho,
  realmat& b0, realmat& B, realmat& S, realmat& R, realmat& theta)
{
  INTEGER lB = M*M*L;
  INTEGER lR = (M*(M+1))/2;
  INTEGER p = M + lB + lR + 2;

  b0.resize(M,1);
  B.resize(M,M*L);
  R.resize(M,M,0.0);
  theta.resize(2,1);

  INTEGER count = 0;
  for (INTEGER i=1; i<=M; ++i) b0[i] = rho[++count];
  for (INTEGER i=1; i<=M*M*L; ++i) B[i] = rho[++count];
  for (INTEGER j=1; j<=M; ++j) {
    for (INTEGER i=1; i<=j; ++i) {
       R(i,j) = rho[++count];
    }
  }
  for (INTEGER i=1; i<=2; ++i) theta[i] = rho[++count];

  S = R*T(R);

  if (count != p) error("Error, frac_rho , should never happen");
  return p;
}

INTEGER frac_tau (INTEGER M, INTEGER L, const realmat& tau,
  realmat& b0, realmat& B, realmat& S, realmat& R)
{
  INTEGER lB = M*M*L;
  INTEGER lR = (M*(M+1))/2;
  INTEGER p = M + lB + lR;

  b0.resize(M,1);
  B.resize(M,M*L);
  R.resize(M,M,0.0);

  INTEGER count = 0;
  for (INTEGER i=1; i<=M; ++i) b0[i] = tau[++count];
  for (INTEGER i=1; i<=M*M*L; ++i) B[i] = tau[++count];
  for (INTEGER j=1; j<=M; ++j) {
    for (INTEGER i=1; i<=j; ++i) {
       R(i,j) = tau[++count];
    }
  }

  S = R*T(R);

  if (count != p) error("Error, frac_tau , should never happen");
  return p;
}

