//////////////////
// 1. variables //
//////////////////

typedef struct{

    // a. grids
    size_t Na,Na_ret, Nm, Nk, Np, Ng, Ne,Nq,Nc_guess;  
    double *grid_a,*grid_a_ret, *grid_m, *grid_P, *grid_P_alpha, min_C;
    size_t do_pd,do_nlopt,do_multistart,do_egm,do_euler_error,do_welfare;

    // b. demographics
    size_t T, TR, agemin, max_age_pregnant;
    double p_unplanned, *fecundity;

    // c. preferences
    double beta, rho, gamma;
    double marg, add_1, add_2, add_3;
    double cost_abort;

    // d. income process      
    double *G,*perm,*trans,*weight, kappa,alpha;     
    size_t Nshocks;
    double sigma_perm,sigma_trans;
    double cost_perm_1, cost_perm_2, cost_perm_3,cost_perm_1_u, cost_perm_2_u, cost_perm_3_u;
    double cost_trans_1, cost_trans_2, cost_trans_3;

    // e. assets
    double R,credit;

    // Simulation
    size_t simT,simN;
    double *A0,*P0,*g0,*k0,*draws_perm,*draws_trans,*draws_pregnant;

    // cost-parameters
    double cost_welfare;

    // output
        
        // solve
        double **C, **m, **V, **e,**q,**Vd;
        //double *EV_pd;

        // simulation
        double *sim_P,*sim_Y,*sim_M,*sim_C,*sim_A,*sim_e,*sim_q,*sim_g,*sim_k,*sim_b,*sim_age,*sim_euler_error,*sim_effort,*sim_dlogY,*sim_welfare,*sim_d2SavingRate;

} par_struct;


namespace par {

//////////////
// 2. setup //
//////////////

void setup(par_struct *par, mxArray *plhs[], const mxArray *prhs[],int type){

    ///////////////
    // 1. inputs //
    ///////////////

        // a. grids   
        par->Na     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Na"));
        par->Na_ret = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Na_ret"));
        par->Nm     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Nm"));
        par->Nk     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Nk"));  
        par->Np     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Np"));     
        par->Ng     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Ng"));
        par->Ne     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Ne"));  
        par->Nq     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Nq"));  

        par->do_pd    = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_pd"));
        par->do_nlopt = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_nlopt"));
        par->do_multistart = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_multistart"));
        par->do_egm   = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_egm"));

        par->Nc_guess     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Nc_guess"));  
        par->min_C        = (double) mxGetScalar(mxGetField(prhs[0],0,"min_C"));             

        par->grid_a     = (double*) mxGetPr(mxGetField(prhs[0],0,"grid_a"));
        par->grid_a_ret = (double*) mxGetPr(mxGetField(prhs[0],0,"grid_a_ret"));
        par->grid_m     = (double*) mxGetPr(mxGetField(prhs[0],0,"grid_m"));    
        par->grid_P     = (double*) mxGetPr(mxGetField(prhs[0],0,"grid_P"));      
        par->grid_P_alpha = (double*) mxGetPr(mxGetField(prhs[0],0,"grid_P_alpha"));    

        // b. demographics
        par->T          = (size_t) mxGetScalar(mxGetField(prhs[0],0,"T"));
        par->TR         = (size_t) mxGetScalar(mxGetField(prhs[0],0,"TR")); 
        par->agemin     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"agemin"));
    
        par->max_age_pregnant = (size_t) mxGetScalar(mxGetField(prhs[0],0,"max_age_pregnant"));
        par->p_unplanned = (double) mxGetScalar(mxGetField(prhs[0],0,"p_unplanned"));
        par->fecundity   = (double*) mxGetPr(mxGetField(prhs[0],0,"fecundity")); 

        // c. preferences
        par->beta  = (double) mxGetScalar(mxGetField(prhs[0],0,"beta"));
        par->rho   = (double) mxGetScalar(mxGetField(prhs[0],0,"rho"));
        par->gamma = (double) mxGetScalar(mxGetField(prhs[0],0,"gamma"));
        par->marg  = (double) mxGetScalar(mxGetField(prhs[0],0,"marg"));
        par->add_1 = (double) mxGetScalar(mxGetField(prhs[0],0,"add_1"));
        par->add_2 = (double) mxGetScalar(mxGetField(prhs[0],0,"add_2"));
        par->add_3 = (double) mxGetScalar(mxGetField(prhs[0],0,"add_3"));
        par->cost_abort = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_abort"));

        par->cost_welfare  = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_welfare"));
        par->cost_welfare  = 0.01*par->cost_welfare;


        // d. income process
        par->Nshocks  = (size_t) mxGetScalar(mxGetField(prhs[0],0,"Nshocks"));
        par->sigma_perm     = (double) mxGetScalar(mxGetField(prhs[0],0,"sigma_perm"));
        par->sigma_trans    = (double) mxGetScalar(mxGetField(prhs[0],0,"sigma_trans"));
        par->perm     = (double*) mxGetPr(mxGetField(prhs[0],0,"perm"));
        par->trans    = (double*) mxGetPr(mxGetField(prhs[0],0,"trans"));  
        par->weight   = (double*) mxGetPr(mxGetField(prhs[0],0,"weight"));   
        par->G        = (double*) mxGetPr(mxGetField(prhs[0],0,"G"));
        par->kappa    = (double) mxGetScalar(mxGetField(prhs[0],0,"kappa"));
        par->alpha    = (double) mxGetScalar(mxGetField(prhs[0],0,"alpha"));
        par->cost_perm_1    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_1"));
        par->cost_perm_2    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_2"));
        par->cost_perm_3    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_3"));
        par->cost_trans_1    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_trans_1"));
        par->cost_trans_2    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_trans_2"));
        par->cost_trans_3    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_trans_3"));

        par->cost_perm_1_u    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_1_u"));
        par->cost_perm_2_u    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_2_u"));
        par->cost_perm_3_u    = (double) mxGetScalar(mxGetField(prhs[0],0,"cost_perm_3_u"));

        // scale some parameters
        par->add_1       = par->add_1/100.0;
        par->add_2       = par->add_2/100.0;
        par->add_3       = par->add_3/100.0;
        par->cost_abort  = par->cost_abort/100.0;
        par->marg        = par->marg/100.0;
        par->p_unplanned = par->p_unplanned/100.0;
        par->cost_perm_1 = par->cost_perm_1/100.0;
        par->cost_perm_2 = par->cost_perm_2/100.0;
        par->cost_perm_3 = par->cost_perm_3/100.0;

        // e. assets
        par->R      = (double) mxGetScalar(mxGetField(prhs[0],0,"R"));  
        par->credit = (double) mxGetScalar(mxGetField(prhs[0],0,"credit"));  

        // h. simulate
        if(type == 2){

            par->simT        = (size_t) mxGetScalar(mxGetField(prhs[0],0,"simT"));
            par->simN        = (size_t) mxGetScalar(mxGetField(prhs[0],0,"simN"));

            par->do_euler_error = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_euler_error"));
            par->do_welfare     = (size_t) mxGetScalar(mxGetField(prhs[0],0,"do_welfare"));

            par->A0 = (double*) mxGetPr(mxGetField(prhs[2],0,"A0"));
            par->P0 = (double*) mxGetPr(mxGetField(prhs[2],0,"P0")); 
            par->g0 = (double*) mxGetPr(mxGetField(prhs[2],0,"g0")); 
            par->k0 = (double*) mxGetPr(mxGetField(prhs[2],0,"k0")); 

            par->draws_perm   = (double*) mxGetPr(mxGetField(prhs[2],0,"perm"));
            par->draws_trans  = (double*) mxGetPr(mxGetField(prhs[2],0,"trans"));
            par->draws_pregnant = (double*) mxGetPr(mxGetField(prhs[2],0,"pregnant"));

            // solution
            par->C  = new double*[par->TR+1];  
            par->Vd = new double*[par->TR+1];      

            for(size_t t = 0; t < par->TR+1; t++){
                par->C[t]  = (double*) mxGetPr(mxGetCell(mxGetField(prhs[1],0,"C"),t));
                par->Vd[t] = (double*) mxGetPr(mxGetCell(mxGetField(prhs[1],0,"Vd"),t));
            }

        }
   


    ////////////////////////
    // 3. outputs - solve //
    ////////////////////////
        if(type == 1){

        // a. struct
        const char *field_names[] = {"C", "m", "V","Vd"};
        
        int num_fields = sizeof(field_names)/sizeof(*field_names);
        plhs[0] = mxCreateStructMatrix(1, 1, num_fields, field_names);
        auto sol_struct = plhs[0];

        // b. cell dimensions
        size_t ndim_cell  = 1;
        auto dims_cell = new size_t[1];

        // c. array dimensions
        auto ndim = new size_t[1];
        auto dims = new size_t*[1];
                        
        // d. solution

            // cell
            dims_cell[0] = par->TR+1; // the last stored solution is the retirement solution

            // array      
            ndim[0] = 4;
            dims[0] = new size_t[4];           
            dims[0][0] = par->Nm;
            dims[0][1] = par->Np;
            dims[0][2] = par->Nk;
            dims[0][3] = par->Ng;

        par->C = misc::set_field_cell(sol_struct,"C",ndim_cell,dims_cell,ndim,dims); 
        par->m = misc::set_field_cell(sol_struct,"m",ndim_cell,dims_cell,ndim,dims);
        par->V = misc::set_field_cell(sol_struct,"V",ndim_cell,dims_cell,ndim,dims);
        //par->e = misc::set_field_cell(sol_struct,"e",ndim_cell,dims_cell,ndim,dims);
        //par->q = misc::set_field_cell(sol_struct,"q",ndim_cell,dims_cell,ndim,dims);
        
        ndim[0]      = 5;
        auto dims_d  = new size_t*[1];
        dims_d[0]    = new size_t[5];           
        dims_d[0][0] = par->Nm;
        dims_d[0][1] = par->Np;
        dims_d[0][2] = par->Nk;
        dims_d[0][3] = par->Ng;
        dims_d[0][4] = par->Ne*par->Nq; // descrete choice

        par->Vd = misc::set_field_cell(sol_struct,"Vd",ndim_cell,dims_cell,ndim,dims_d);

        // initialize with NaNs for security. 
        #pragma omp parallel num_threads(THREADS)
        {
        
        for(size_t t=par->TR+1 ; t--> 0 ;){
            #pragma omp for collapse(4)
            for(size_t g=0 ; g<par->Ng ; g++){
                for(size_t k=0 ; k<par->Nk ; k++){
                    for(size_t i_P=0 ; i_P<par->Np ; i_P++){
                        for(size_t i_m=0; i_m<par->Nm;i_m++){
                            size_t i = index::d4(g,k,i_P,i_m,par->Nk,par->Np,par->Nm);
                            par->C[t][i] = mxGetNaN();
                            par->m[t][i] = mxGetNaN();
                            par->V[t][i] = mxGetNaN();
                            //par->e[t][i] = mxGetNaN();
                            //par->q[t][i] = mxGetNaN();

                            // loop through discrete choices available
                            for(size_t e=0; e<par->Ne; e++) {
                            for(size_t q=0; q<par->Nq; q++) {
                                size_t d  = index::discrete(e,q,par->Nq);
                                size_t id = index::d5(d,g,k,i_P,i_m,par->Ng,par->Nk,par->Np,par->Nm);
                                par->Vd[t][id] = mxGetNaN();
                            }
                            }

                        }
                    }
                }
            }
        }
        } // parallel
    
        delete[] dims_cell;
        delete[] ndim;
        delete[] dims[0];
        delete[] dims;
        delete[] dims_d[0];
        delete[] dims_d;

        //logs::solve(1," output allocated\n");            
        
        } else if (type==2) { // simulate

            // b. allocate output
            // a. struct
        const char *field_names[] = {"P", "Y", "M", "C", "A", 
                                     "e", "q","g", "k","b","age","euler_error","dlogY","welfare","d2SavingRate"};
        
        int num_fields = sizeof(field_names)/sizeof(*field_names);
        plhs[0] = mxCreateStructMatrix(1, 1, num_fields, field_names);
        auto sim_struct = plhs[0]; 

        // b. dimensions
        size_t ndim = 2;
        auto dims = new size_t[2];

            dims[0] = par->simN;
            dims[1] = par->simT;

        // c. elements
        par->sim_P   = misc::set_field_double(sim_struct,"P",ndim,dims);
        par->sim_Y   = misc::set_field_double(sim_struct,"Y",ndim,dims);
        par->sim_M   = misc::set_field_double(sim_struct,"M",ndim,dims);
        par->sim_C   = misc::set_field_double(sim_struct,"C",ndim,dims);        
        par->sim_A   = misc::set_field_double(sim_struct,"A",ndim,dims);
        par->sim_e   = misc::set_field_double(sim_struct,"e",ndim,dims);
        par->sim_q   = misc::set_field_double(sim_struct,"q",ndim,dims);
        par->sim_g   = misc::set_field_double(sim_struct,"g",ndim,dims);
        par->sim_k   = misc::set_field_double(sim_struct,"k",ndim,dims);
        par->sim_b   = misc::set_field_double(sim_struct,"b",ndim,dims);
        par->sim_age = misc::set_field_double(sim_struct,"age",ndim,dims);
        //par->sim_effort = misc::set_field_double(sim_struct,"effort",ndim,dims);
        par->sim_dlogY = misc::set_field_double(sim_struct,"dlogY",ndim,dims);
        par->sim_d2SavingRate = misc::set_field_double(sim_struct,"d2SavingRate",ndim,dims);
        if(par->do_euler_error==1){
            par->sim_euler_error = misc::set_field_double(sim_struct,"euler_error",ndim,dims);
        }

        dims[1] = 1;
        par->sim_welfare = misc::set_field_double(sim_struct,"welfare",ndim,dims);
        

            delete[] dims;


        } // type-end

}


////////////////
// 3. destroy //
////////////////

void destroy(par_struct *par, size_t type){

    if(type == 1){
        
    }

}

} // namespace