/*-----------------------------------------------------------------------------

Copyright (C) 2018

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-------------------------------------------------------------------------------

Function      wishart - Wishart density function

Syntax        #include "libscl.h"
              denval wishart(INTEGER n, const realmat& X, const realmat& V) 

Prototype in  libscl.h

Description   X is a p by p positive definite, symmetric matrix at which 
              the density is to be evaluated. V is a p by p positive 
	      definite, symmetric variance-covariance matrix. p is
	      inferred from V.  n is the degrees freedom parameter. 

Remarks       If either X or V is not positive definite, symmetric, then 
              scl::error is called.  If scl::error is not modified by 
	      the calling program, scl::error will terminate execution.  
	      If scl::error is modified so that execution does not 
	      terminate and X or V are semi-definite, then a 
	      denval(false,-REAL_MAX) is returned.  If scl::error is 
	      modified so that execution does not terminate and either 
	      X or V is symmetric with positive diagonal elements but not 
	      semi-definite, then results are unpredictable.

Return value  A denval(true,log_den) where log_den is the natural 
              logarithm of the density.  

Functions     Library: atan, sqrt, log, pow, lgamma 
called        libscl: cholesky, logdetR, rinv

-----------------------------------------------------------------------------*/

#include "libscl.h"

namespace scl {

  denval wishart(INTEGER n, const realmat& X, const realmat& V)
  {
    const REAL log_pi = log(4.0*atan(1.0));

    const REAL tolerance = 16.0*EPS;
  
    denval rv(false,-REAL_MAX);
  
    INTEGER p = V.ncol();

    if (n <= p - 1) {
      error("Error, wishart, n <= p - 1");
      return rv;
    }

    if (V.nrow() != p) {
      error("Error, wishart, V is not square");
      return rv;
    }

    if ((X.nrow() != p) || (X.ncol() != p)) {
      error("Error, wishart, X dimensions do not agree with V's dimensions");
      return rv;
    }

    for (INTEGER i=1; i<=p; ++i) {
      if ((V(i,i) <= 0.0) || (X(i,i) <= 0.0)) {
        error("Error, wishart, V or X is not positive definite");
	return rv;
      }
    }

    for (INTEGER j=1; j<=p; ++j) {
      for (INTEGER i=1; i<j; ++i) {
        if (fabs((V(i,j)-V(j,i))/V(j,j)) > tolerance) {
          error("Error, wishart, V is not symmetric");
          return rv;
        }
        if (fabs((X(i,j)-X(j,i))/X(j,j)) > tolerance) {
          error("Error, wishart, X is not symmetric");
          return rv;
        }
      }
    }
  
    realmat Rv;
    
    INTEGER rankV = cholesky(V,Rv,tolerance);
  
    if (rankV != p) {
      error("Error, wishart, V is less than full rank");
      return rv;
    }
  
    realmat Rvinv;
  
    rinv(Rv,Rvinv);
  
    REAL log_det_Rv = logdetR(Rv,tolerance);

    realmat Rx;
    
    INTEGER rankX = cholesky(X,Rx,tolerance);
  
    if (rankX != p) {
      error("Error, wishart, X is less than full rank");
      return rv;
    }
  
    REAL log_det_Rx = logdetR(Rx,tolerance);

    realmat A = Rx*Rvinv;  // tr(VinvX) = tr[(RxRvinv)'(RxRvinv)]

    REAL trVinvX = 0.0;
    for (INTEGER i=1; i<=p*p; ++i) trVinvX += pow(A[i],2);

    // numerator terms

    REAL term1 = REAL(n - p - 1) * log_det_Rx;

    REAL term2 = -0.5*trVinvX;

    // denominator terms

    REAL term3 = REAL(n*p)/2.0 * log(2.0);

    REAL term4 = n * log_det_Rv;

    REAL term5 = REAL(p*(p-1))/4.0 * log_pi;
    for (INTEGER j=1; j<=p; ++j) term5 += lgamma(REAL(n-j+1)/2.0);

    /*
    std::cerr << "term1 = " << term1 << '\n';
    std::cerr << "term2 = " << term2 << '\n';
    std::cerr << "term3 = " << term3 << '\n';
    std::cerr << "term4 = " << term4 << '\n';
    std::cerr << "term5 = " << term5 << '\n';
    */

    rv.positive = true;
    rv.log_den = term1 + term2 - term3 - term4 -term5;
  
    return rv;
  }

}


