#include "libscl.h"
#include "spline_interpolator.h"

using namespace scl;
using namespace std;

extern bool tree_policy(REAL alpha, REAL sigma, REAL beta, REAL gamma,
  spline_interpolator& P, spline_interpolator& Q,
  spline_interpolator& p, spline_interpolator& q);

bool tree_policy(REAL alpha, REAL sigma, REAL beta, REAL gamma,
  spline_interpolator& P, spline_interpolator& Q,
  spline_interpolator& p, spline_interpolator& q)
{
  const bool print_warn_msg = true;
  const REAL lo = -3.5*sigma/sqrt(1.0 - alpha*alpha);
  const REAL hi = -lo;

  const INTEGER n_quad_points = 5;
  INTEGER n_grid_points = 20;

  realmat grid(n_grid_points+n_quad_points,1);

  for (INTEGER i=1; i<=n_grid_points; ++i) {
    grid[i] = lo + (REAL(i-1)/REAL(n_grid_points-1))*(hi - lo);
  }

  REAL grid_tol = (0.5/REAL(n_grid_points-1))*(hi - lo);

  realmat x, w;
  if( hquad(n_quad_points,x,w) ) error("Error, tree_policy, hquad failed");

  const REAL root2 = sqrt(2.0);              // 1.4142135623730951
  const REAL rootpi = sqrt(4.0*atan(1.0));   // 1.7724538509055161

  realmat abscissae(n_quad_points,1);
  realmat weights(n_quad_points,1);
  for (INTEGER i=1; i<=n_quad_points; ++i) {
    abscissae[i] = root2*sigma*x[i];
    weights[i] = w[i]/rootpi;
  }

  // delete grid points that are too close to abscissae

  intvec idx = seq(1,n_grid_points);
  for (INTEGER i=1; i<=n_grid_points; ++i) {
    for (INTEGER j=1; j<=n_quad_points; ++j) {
      if (fabs(grid[i] - abscissae[j]) < grid_tol) idx[i] = -1;
    }
  }

  grid = grid(idx,"");

  for (INTEGER i=1; i<=n_quad_points; ++i) {
    grid.push_back(abscissae[i]);
  }

  n_grid_points = grid.size();

  realmat values(n_grid_points,1);

  for (INTEGER i=1; i<=n_grid_points; ++i) {
    REAL y = grid[i];
    REAL p = 0.87654 + 7.65819*(y + 0.5);
    values[i] = exp(p);
  }

  P.update(grid,values);

  const INTEGER maxiter = 5000;
  const REAL tol = 1.0e-4;
  bool converge = false;

  for (INTEGER iter=1; iter<=maxiter; ++iter) {

    realmat values_lag = values;

    for (INTEGER i=1; i<=n_grid_points; ++i) {
      REAL yt = grid[i];
      REAL rhs0 = beta*exp(gamma*yt);
      REAL rhs1 = exp(alpha*(1.0-gamma)*yt+0.5*pow((1.0-gamma)*sigma,2));
      REAL rhs2 = 0.0;
      for (INTEGER j=1; j<=n_quad_points; ++j) {
        REAL vt = alpha*yt + abscissae[j]; 
        rhs2 += weights[j]*exp(-gamma*vt)*P(vt);
      }
      values[i] = rhs0*(rhs1 + rhs2);
    }

    P.update(grid,values);

    converge = true;
    for (INTEGER i=2; i<=n_grid_points-1; ++i) {
      if (fabs(values[i] - values_lag[i]) > fabs(values_lag[i] + tol)*tol) {
        converge = false;
	break;
      }
    }

    if (converge) break;
  }

  realmat log_values(n_grid_points,1);
  for (INTEGER i=1; i<=n_grid_points; ++i) {
    log_values[i] = log(values[i]);
  }

  Q.update(values,grid);

  p.update(grid,log_values);
  q.update(log_values,grid);

  if (!converge) {
    if (print_warn_msg) {
      warn("Warning, convergence failed at these parameter values");
      cerr << "\t alpha " << alpha << '\n';
      cerr << "\t sigma " << sigma << '\n';
      cerr << "\t beta  " << beta << '\n';
      cerr << "\t gamma " << gamma << '\n';
    }
    return false;
  }

  return true;
}
