#include "mex.h"
#include "matrix.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <omp.h>

#define MAX(X,Y) ((X)>(Y)?(X):(Y))
#define MIN(X,Y) ((X)<(Y)?(X):(Y))

#define MAXTHREADS MIN(16, omp_get_max_threads()-1)

// Structures
typedef struct{
    // settings
    int Nobs,Nsim;
    int ret_min, ret_max, age_max, age_min;
    double *RetTimeGrid;
    int NumRetGrid,NumAge;
    // parameters
    double disc,*disc_vec,weight_h;
    double joint_h,joint_w;
    double *age_h, *age_w;
    double SPA_w;
} par_struct;

typedef struct{
    double *Xbeta, *e_draws, *SPA, *Age, *Retired, *ExpRetAge;
} data_struct;

double utility(int i,int age_h, int age_w, int RetAge_h, int RetAge_w,
        double ExplVar, double epsilon, double AgeDum, double SPA, double joint_val)
{
    int both_retired =( (age_h>=RetAge_h) & (age_w>=RetAge_w));             // indicator 1 of both are retired at current age

    return joint_val*((double)both_retired) + ExplVar + AgeDum + SPA + epsilon;
}

double value_of_ret(int i, int s, int RetAge_h, int RetAge_w,data_struct *data_h,data_struct *data_w, par_struct *par)
{
    double SPA      = 0.0;
    double val_h=0.0;
    for(int t=0;t <= (par->age_max - RetAge_h); t++){
        int age_h       = RetAge_h + t;
        int age_w       = age_h - (int) (data_h->Age[i] - data_w->Age[i]);
        int idx         = (age_h - (int)data_h->Age[i]);
        double AgeDum   = par->age_h[MIN(age_h-par->age_min , par->NumAge-1 )];
        val_h += par->disc_vec[idx]*utility(i,age_h,age_w,RetAge_h,RetAge_w,data_h->Xbeta[i],data_h->e_draws[i+s*par->Nobs],AgeDum,SPA,par->joint_h);
    }
    
    double val_w = 0.0;
    for(int t=0;t <= (par->age_max - RetAge_w); t++){
        int age_w       = RetAge_w + t;
        int age_h       = age_w + (int) (data_h->Age[i] - data_w->Age[i]);
        int idx         = (age_w - (int)data_w->Age[i]);
        double AgeDum   = par->age_w[MIN(age_w-par->age_min , par->NumAge-1 )];
//         SPA      = par->SPA_w*((double)((age_w>=RetAge_w) & (age_w>data_w->SPA[i])));
        SPA      = par->SPA_w*((double)((age_w>=RetAge_w) & (age_w>=data_w->SPA[i])));
        
        val_w += par->disc_vec[idx]*utility(i,age_h,age_w,RetAge_h,RetAge_w,data_w->Xbeta[i],data_w->e_draws[i+s*par->Nobs],AgeDum,SPA,par->joint_w);

    }
    
    return par->weight_h*val_h + val_w;
}

// Mex wrapper
void mexFunction(int nlhs, mxArray *plhs[],int nrhs, const mxArray *prhs[])
{
    
    int in = 0, out = 0;
    double *Opt_h, *Opt_w;
    
    // 1. load input
    par_struct*     par = new par_struct;
    data_struct* data_h = new data_struct;
    data_struct* data_w = new data_struct;
    
    par->Nobs           = (int)     mxGetScalar(mxGetField(prhs[in],0,"Nobs"));
    par->Nsim           = (int)     mxGetScalar(mxGetField(prhs[in],0,"Nsim"));
    par->ret_min        = (int)     mxGetScalar(mxGetField(prhs[in],0,"ret_min"));
    par->ret_max        = (int)     mxGetScalar(mxGetField(prhs[in],0,"ret_max"));
    par->age_max        = (int)     mxGetScalar(mxGetField(prhs[in],0,"age_max"));
    par->age_min        = (int)     mxGetScalar(mxGetField(prhs[in],0,"age_min"));
    par->RetTimeGrid    = (double*) mxGetPr(mxGetField(prhs[in],0,"RetTimeGrid"));
    par->NumRetGrid     = (int)     mxGetScalar(mxGetField(prhs[in],0,"NumRetGrid"));
    par->NumAge         = (int)     mxGetScalar(mxGetField(prhs[in],0,"NumAge"));
    
    par->disc_vec       = (double*)  mxGetPr(mxGetField(prhs[in],0,"disc_vec"));
    par->disc           = (double)  mxGetScalar(mxGetField(prhs[in],0,"disc"));
    par->weight_h       = (double)  mxGetScalar(mxGetField(prhs[in],0,"weight_h"));
    par->joint_h        = (double)  mxGetScalar(mxGetField(prhs[in],0,"joint_h"));
    par->joint_w        = (double)  mxGetScalar(mxGetField(prhs[in],0,"joint_w"));
    par->age_h          = (double*) mxGetPr(mxGetField(prhs[in],0,"age_h"));
    par->age_w          = (double*) mxGetPr(mxGetField(prhs[in],0,"age_w"));
    par->SPA_w          = (double)  mxGetScalar(mxGetField(prhs[in],0,"SPA_w"));
    in++;
    
    data_h->Xbeta       = (double*) mxGetPr(mxGetField(prhs[in],0,"Xbeta"));
    data_h->e_draws     = (double*) mxGetPr(mxGetField(prhs[in],0,"e_draws"));
    data_h->SPA         = (double*) mxGetPr(mxGetField(prhs[in],0,"SPA"));
    data_h->Age         = (double*) mxGetPr(mxGetField(prhs[in],0,"Age"));
    data_h->Retired     = (double*) mxGetPr(mxGetField(prhs[in],0,"Retired"));
    data_h->ExpRetAge   = (double*) mxGetPr(mxGetField(prhs[in],0,"ExpRetAge"));
    in++;
    
    data_w->Xbeta       = (double*) mxGetPr(mxGetField(prhs[in],0,"Xbeta"));
    data_w->e_draws     = (double*) mxGetPr(mxGetField(prhs[in],0,"e_draws"));
    data_w->SPA         = (double*) mxGetPr(mxGetField(prhs[in],0,"SPA"));
    data_w->Age         = (double*) mxGetPr(mxGetField(prhs[in],0,"Age"));
    data_w->Retired     = (double*) mxGetPr(mxGetField(prhs[in],0,"Retired"));
    data_w->ExpRetAge   = (double*) mxGetPr(mxGetField(prhs[in],0,"ExpRetAge"));
    in++;
    
    // 2. output
    plhs[out]   = mxCreateDoubleMatrix(par->Nobs, par->Nsim,mxREAL);
    Opt_h       = mxGetPr(plhs[out]); out++;
    plhs[out]   = mxCreateDoubleMatrix(par->Nobs, par->Nsim,mxREAL);
    Opt_w       = mxGetPr(plhs[out]); out++;
    
    // 3. loop through individuals, simulations and combinations
#pragma omp parallel num_threads(MAXTHREADS)
    {
        double *max = new double[par->Nsim]; // allocate memory to store the simulation-specific optima
        
#pragma omp for
        for(int i=0;i<par->Nobs;i++){
            
            if(mxIsNaN(data_h->Xbeta[i])==1 || mxIsNaN(data_w->Xbeta[i])==1){
                for(int s=0;s<par->Nsim;s++){
                    Opt_h[i+s*par->Nobs] = mxGetNaN();
                    Opt_w[i+s*par->Nobs] = mxGetNaN();
                }
                continue;
            }
            
            int alt = 0;
            for (int th=0;th<par->NumRetGrid;th++){
                int RetAge_h = (int) par->RetTimeGrid[th];
                
                // only calculate the admissable combinations
                if((int)data_h->Retired[i]==1 && (int)data_h->ExpRetAge[i]!=RetAge_h){  continue; }
                if((int)data_h->Retired[i]==0 && (int)data_h->Age[i]>RetAge_h){         continue; }
                
                for (int tw=0;tw<par->NumRetGrid;tw++){
                    int RetAge_w = (int) par->RetTimeGrid[tw];
                    
                    // only calculate the admissable combinations
                    if((int)data_w->Retired[i]==1 && (int)data_w->ExpRetAge[i]!=RetAge_w){  continue; }
                    if((int)data_w->Retired[i]==0 && (int)data_w->Age[i]>RetAge_w){         continue; }
                    
                    for(int s=0;s<par->Nsim;s++){
                         
                        int idx      = i + s*par->Nobs + th*par->Nobs*par->Nsim + tw*par->Nobs*par->Nsim*par->NumRetGrid;
                        double value = value_of_ret(i,s,RetAge_h,RetAge_w,data_h,data_w,par);
                        
                        // update the optimal choice
                        if(alt==0){
                            max[s]               = value;
                            Opt_h[i+s*par->Nobs] = (double) RetAge_h;
                            Opt_w[i+s*par->Nobs] = (double) RetAge_w;
                        } else if(value>max[s]){
                            max[s]               = value;
                            Opt_h[i+s*par->Nobs] = (double) RetAge_h;
                            Opt_w[i+s*par->Nobs] = (double) RetAge_w;
                        }
                           
                    } // s
                    
                    alt++; // update the alternative number (used to determ)
                    
                } // tw
            }   // th
        } // i

        delete[] max; // free memory allocated 
        
    } // pragma
    
    // free memory allocated 
    delete par;
    delete data_h;
    delete data_w;
}

