clear all; clc
% 0. Setup
BETA    = [.97 .98];
RHO     = ones(1,2)*1.5;
SIGMA   = ones(1,2)*0.1;

for EDUC=0:1
    name = ['_educ' num2str(EDUC) '_beta' num2str(BETA(EDUC+1)) '_rho' num2str(RHO(EDUC+1)) '_sigma' num2str(SIGMA(EDUC+1))];
    load(['results\est' name '.mat']);  
    est_pars{EDUC+1} = par;
end

EstPar      = est_pars{1}.EstPar;
UseMoments  = est_pars{1}.UseMoments;
%%

% 1. Estimation results
diary off
delete(['results\tab_estimates.tex'])
diary(['results\tab_estimates.tex'])

fprintf('\n\\begin{tabular}{llcc} \\toprule \n')
%fprintf('\\multicolumn{2}{l}{} & \\multicolumn{2}{c}{Estimates} \\\\ \\cmidrule(lr){2-3} \n');
fprintf('\\multicolumn{2}{l}{Parameter} & Low skilled & High skilled \\\\ \\hline \n');
for p=1:numel(EstPar)
    if strcmp(EstPar(p),'rho')
        fprintf('$\\rho$ & crra coefficient');
    elseif strcmp(EstPar(p),'beta')
        fprintf('$\\beta$ & discount factor ');
    elseif strcmp(EstPar(p),'add_1')
        fprintf('$\\kappa_1$ & marg. value of 1st child');
    elseif strcmp(EstPar(p),'add_2')
        fprintf('[2mm] $\\kappa_2$ & marg. value of 2nd child');
    elseif strcmp(EstPar(p),'add_3')
        fprintf('[2mm] $\\kappa_3$ & marg. value of 3rd child');
    elseif strcmp(EstPar(p),'marg')
        fprintf('[2mm] $\\nu$ & children effect on marg. util.');
    elseif strcmp(EstPar(p),'cost_abort')
        fprintf('[2mm] $\\psi$ & abortion cost ');
    elseif strcmp(EstPar(p),'p_unplanned')
        fprintf('[2mm] $\\underline{\\wp}$ & prob. of unintended pregnancy ');  
    elseif strcmp(EstPar(p),'cost_perm_1')
        fprintf('[2mm] $\\omega_1$ & income effect of 1st child');  
    elseif strcmp(EstPar(p),'cost_perm_2')
        fprintf('[2mm] $\\omega_2$ & income effect of 2nd child');  
    end
    
    for e=0:1
        fprintf(' & $%2.3f$ ',est_pars{e+1}.est(p));
    end
    fprintf('\\\\ \n & ');
    for e=0:1
        fprintf('& $(%2.3f)$ ',est_pars{e+1}.SE(p));
    end
    fprintf('\\\\ \n');
end
fprintf('\\bottomrule \\end{tabular}\n');
diary off

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 2. Model moments and fit 
diary off
delete(['results\tab_fit.tex'])
diary(['results\tab_fit.tex'])

fprintf('\n\n\\begin{tabular}{lccccccc} \\toprule \n')
fprintf('& \\multicolumn{3}{c}{Low skilled} & & \\multicolumn{3}{c}{High skilled}   \\\\ \\cmidrule(lr){2-4}\\cmidrule(lr){6-8} \n')
fprintf('& \\multicolumn{1}{c}{Data} & [95\\%% CI] & Model && \\multicolumn{1}{c}{Data} & [95\\%% CI] & Model  \\\\ \\hline \n')
num = 0;
for m=1:numel(UseMoments)
    if strcmp(UseMoments(m),'RelShare')
        s = sprintf('Saving rate growth ($\\Delta_2$)');
        ss = {'Wanted','Unwanted'};
    elseif strcmp(UseMoments(m),'Unplanned')
        s = sprintf('Unexpected birth ratio');
        ss = {'1st child','2nd child','3rd child'};
    elseif strcmp(UseMoments(m),'abortion_rate')
        s = sprintf('Abortion rate (external)');
        ss = {'25-39'};
    elseif strcmp(UseMoments(m),'Kids1')
        s = sprintf('Share with 1 child');
        ss = {'25-29','30-34','35-39'};
    elseif strcmp(UseMoments(m),'Kids2')
        s = sprintf('Share with 2 children');
        ss = {'25-29','30-34','35-39'};
    elseif strcmp(UseMoments(m),'Kids3')
        s = sprintf('Share with 3 children');
        ss = {'25-29','30-34','35-39'};
    elseif strcmp(UseMoments(m),'IncGrowth')
        s = sprintf('Log income growth');
        ss = {'1st child','2nd child'};   
    end
    
    fprintf('\\multicolumn{4}{l}{\\textbf{\\emph{%s}}} \\\\ \n',s)
    
    for l=1:numel(ss)
        fprintf('%s',ss{l});
        
        for e=0:1
            par = est_pars{e+1};
            
            plus  = par.mom_data + 1.96*par.mom_data_std;
            minus = par.mom_data - 1.96*par.mom_data_std;
            
            if strcmp(UseMoments(m),'abortion_rate')
                fprintf('& $%2.3f$ &  & $%2.3f$ ',par.mom_data(num+l),par.mom_sim(num+l));
            else
                fprintf('& $%2.3f$ & $[%2.3f,%2.3f]$ & $%2.3f$ ',par.mom_data(num+l),minus(num+l),plus(num+l),par.mom_sim(num+l));
            end
            if e==0
                fprintf('&');
            end
        end
        fprintf('\\\\ \n');
    end
    
    num = num + numel(ss);
end
fprintf('\\bottomrule \\end{tabular}\n');

diary off

%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 3. Counter factual simulations
% load(['results\est' num2str(EDUC) strW '.mat']);
%load(['results\est' name '.mat']);  

for EDUC=0:1
    e = EDUC+1;
    par = est_pars{e};
    par.do_welfare = 1;
    
    % TEMP:
    %par.simN = 500000;
    %%%%%%%%%
    
    % i. Load BHPS data used to simulate savings series.
    data        = est.load_data(EDUC,par);
    rng(par.seed) % set seed
    data.draws  = est.draws(par,data);
    draws{e} = data.draws;
    if EDUC==0
        data_low = data;
    end
    
    % ii. Solve and simulate from baseline economy and alternatives
    n = 1;
    models{n}    = 'Baseline';
    par_baseline = par;
    pars{n,e}    = par;
    sol_baseline = mex_solve(par_baseline);
    sims{n,e}      = est.simulate(par_baseline,sol_baseline,data.draws);
    
    % No income uncertainty
    n = n + 1;
    models{n} = ['alt. ' round(num2str(n-1)) ': $\sigma_{\eta}=\sigma_{\varepsilon}=0$'];%'No income uncertainty';
    pars{n,e}   = par_baseline;
    
    vars    = {'sigma_perm','sigma_trans'};
    for v=1:numel(vars)
        pars{n,e}.(vars{v}) = 0;
    end
    pars{n,e}     = prep.construct_grids(pars{n,e});
    
    sol     = mex_solve(pars{n,e});
    sims{n,e} = est.simulate(pars{n,e},sol,data.draws);
    
    % No career costs of children
    n = n + 1;
    models{n} = ['alt. ' round(num2str(n-1)) ': $\omega_{1}=\omega_{2}=0$'];%'No career costs';
    pars{n,e}   = par_baseline;
    
    vars    = {'cost_perm_1','cost_perm_2'};
    for v=1:numel(vars)
        pars{n,e}.(vars{v}) = 0;
    end
    pars{n,e}     = prep.construct_grids(pars{n,e});
    
    sol     = mex_solve(pars{n,e});
    sims{n,e} = est.simulate(pars{n,e},sol,data.draws);
    
    % No abortion option
    n = n + 1;
    models{n} = ['alt. ' round(num2str(n-1)) ': $\psi=\infty$'];%'No abortion option';
    pars{n,e}   = par_baseline;
    
    pars{n,e}.cost_abort  = 1000;
    pars{n,e}             = prep.construct_grids(pars{n,e});
    
    sol     = mex_solve(pars{n,e});
    sims{n,e} = est.simulate(pars{n,e},sol,data.draws);
    
    % Perfect Contraceptive control
    n = n + 1;
    models{n} = ['alt. ' round(num2str(n-1)) ': $\underline{\wp}=0$'];%'Perfect control';
    pars{n,e}   = par_baseline;
    
    pars{n,e}.p_unplanned = 0;
    pars{n,e}             = prep.construct_grids(pars{n,e});
    
    sol     = mex_solve(pars{n,e});
    sims{n,e} = est.simulate(pars{n,e},sol,data.draws);
    
    
    % iii. calculate relevant statistics for all models
    max_age = 50;
    for i=1:numel(models)
        sim = sims{i,e};
        
        I = sim.age==max_age;
        stat(i,e).children = mean(sim.k(I));
        
        I = sim.b==1 & sim.k==1;
        stat(i,e).age1     = mean(sim.age(I));
        
        mxm = max(sim.UnplannedBirth(:,1:(max_age-par.agemin+1)) ,[],2 );
        stat(i,e).unintended = mean(mxm);
        
        ever_abort = max(sim.q(:,1:(max_age-par.agemin+1)) ,[],2 );
        stat(i,e).abort = mean(ever_abort);
        
        
        mxm = nansum(sim.q(:,1:(max_age-par.agemin+1)),2);
        stat(i,e).abort_num = mean(mxm);
        
        %     mxm = max(sim.EffortNotPregnant(:,1:(max_age-par.agemin+1)) ,[],2 );
        %     stat(i).effort_no_child = mean(mxm);
        last_effort        = par.max_age_pregnant-par.agemin+1 - 1; % TJEK THIS
        missed_conception   = sim.e(:,last_effort).*(1-sim.b(:,last_effort+2));
        stat(i,e).missed_conception = mean(missed_conception);
        
        % age profiles
        age_grid = 25:60;
        stat(i,e).age_wealth = est.age_profile(age_grid,sim.age,sim.A./sim.Y);
        stat(i,e).age_kids   = est.age_profile(age_grid,sim.age,sim.k);
        stat(i,e).age_income = est.age_profile(age_grid,sim.age,sim.Y);
        stat(i,e).age_effort = est.age_profile(age_grid,sim.age,sim.e);
        stat(i,e).age_effort(stat(i).age_effort==0) = NaN;
        stat(i,e).age_abort  = est.age_profile(age_grid,sim.age,sim.q);
        
        stat(i,e).age_pregnancies = est.age_profile(age_grid,sim.age,sim.g>0+0*sim.g);
        
        stat(i,e).children = stat(i,e).age_kids(end); % test
        
        lag_pregnant        = lagmatrix(sim.g',1)'<1;
        
        UnplannedBirth          = sim.UnplannedBirth;
        UnplannedBirth(lag_pregnant) = NaN;
        stat(i,e).age_unplanned   = est.age_profile(age_grid,sim.age,UnplannedBirth);
        
        unplanned_preg          = 1.0*(sim.g==2 + 0*sim.g);
        unplanned_preg(sim.g<1) = NaN;
        stat(i,e).age_unplanned_preg   = est.age_profile(age_grid,sim.age,unplanned_preg);
        
        unintended_birth = sim.b.*lagmatrix(unplanned_preg',1)';%sim.b.*lagmatrix(sim.g'==2,1)';
        mxm = nansum(unintended_birth(:,1:(max_age-par.agemin+1)),2);
        stat(i,e).unintented_births = mean(mxm);
        
        unintended_birth(sim.b<1) = NaN;
        stat(i,e).age_unintented_births  = est.age_profile(age_grid,sim.age,unintended_birth);
        
        abortion                = sim.q;
        abortion(sim.g<1)       = NaN;
        stat(i,e).age_abortions   = est.age_profile(age_grid,sim.age,abortion);
        
        % Welfare measure
        stat(i,e).welfare   = nanmean(sim.welfare);
        
        % clean c and y for time and children effects
        age_min = 25;
        age_max = 40;
        I = sim.age>=age_min & sim.age<=age_max;
        T = age_max - age_min+1;
        [N] = size(sim.C,1);
        kids = dummyvar(sim.k(I)+1); %ones(product(size(dc))), 
        age = sim.age(I)/10;
        X = [kids , age , age.^2];
        
        logC = log(sim.C);
        b = regress(logC(I),X);
        c = reshape(logC(I) - X*b,N,T);
        
        logY = log(sim.Y);
        b = regress(logY(I),X);
        y = reshape(logY(I) - X*b,N,T);
        
        % BPP: use their moments
        dc = diff(c')';
        dy = diff(y')';
        
        % calculate the transmission parameters using the residuals
        
        dy_lag = lagmatrix(dy',1)';
        dy_lead = lagmatrix(dy',-1)';
        
        dy_sum = dy_lag + dy + dy_lead;
        
        mat = nancov(dc(:),dy_lead(:));
        cov_dcdy = mat(1,2);
        mat = nancov(dy(:),dy_lead(:));
        cov_dydy = mat(1,2);
        
        mat = nancov(dc(:),dy_sum(:));
        cov_dcdy_sum = mat(1,2);
        mat = nancov(dy(:),dy_sum(:));
        cov_dydy_sum = mat(1,2);
        
        stat(i,e).BPP_trans = cov_dcdy./cov_dydy;
        stat(i,e).BPP_perm = cov_dcdy_sum./cov_dydy_sum;
        
        stat(i,e).BPP_trans = cov_dcdy./cov_dydy;
        stat(i,e).BPP_perm = cov_dcdy_sum./cov_dydy_sum;
        
        fprintf("i,e:%d,%d -> %2.5f,%2.5f\n",i,e,stat(i,e).BPP_perm  ,stat(i,e).BPP_trans)
    end
end % EDUC-loop

vars = {'age_wealth','age_kids','age_income','age_abortions','age_unplanned','age_pregnancies','age_effort','age_unplanned_preg','age_abort'};
var_names = {'wealth-to-income','number of children','income','abortion ratio','unexpected birth ratio','pregnancies (share)','effort','unintended pregnancy share','abortions'};

FontSize = 17;
%%
for EDUC=0:1
    e = EDUC+1;
for v=1:numel(vars)
figure(v)
var = vars{v};

plot(age_grid,stat(1,e).(var),'-black',...
    age_grid,stat(2,e).(var),'--black',...
    age_grid,stat(3,e).(var),'-.black',...
    age_grid,stat(4,e).(var),':black',...
    age_grid,stat(5,e).(var),'-red',...
    'LineWidth',2)
if strcmp(var,'age_income') || strcmp(var,'age_pregnancies')

    legend(models,'location','best','Interpreter','Latex')
end
grid on
xlabel('age')
ylabel(var_names{v})
if strcmp(var,'age_pregnancies') ||strcmp(var,'age_abortions') ||strcmp(var,'age_unplanned') ||strcmp(var,'age_unplanned_preg')
    set(gca,'FontSize',FontSize,'XTick',25:5:46)
    xlim([25 47])
else
    set(gca,'FontSize',FontSize,'XTick',25:5:60)
    xlim([25 60])
end
if strcmp(var,'age_wealth')
    ylim([0 7])
end
print('-depsc',['Results\policy_' var num2str(EDUC) '.eps']);
end
end % educ loop

%%
% iv. output table
stats = {'children','abort_num','unintented_births','missed_conception'};
%stats = {'children','abort_num','unintented_births','missed_conception','BPP_perm','BPP_trans','welfare'};

names.children   = 'children';
names.age1       = 'Age at first birth';
names.unintended = 'Had unintended birth';
names.unintented_births = 'unintended births';
names.abort      = 'Had abortion';
names.abort_num  = 'abortions';
names.missed_conception = 'missed conceptions';
names.BPP_perm = '\midrule Transmission par. (perm)';
names.BPP_trans = 'Transmission par. (trans)';
names.welfare = 'Welfare';

diary off
name = ['Results\tab_counter_factual.tex'];
delete(name)
diary(name)

fprintf('\n\n\\begin{tabular}{l*{%d}{c}} \\toprule \n',2*numel(models))
fprintf(' & \\multicolumn{%d}{c}{Low skilled} & \\multicolumn{%d}{c}{High skilled} \\\\ \n \\cmidrule(lr){2-%d} \\cmidrule(lr){%d-%d}',numel(models),numel(models),2+numel(models)-1,2+numel(models),2+2*numel(models)-1);
fprintf(' & Base & \\multicolumn{%d}{c}{Alternative simulations} & Base & \\multicolumn{%d}{c}{Alternative simulations} \\\\ \n \\cmidrule(lr){3-%d} \\cmidrule(lr){%d-%d}',numel(models)-1,numel(models)-1,1+numel(models),2+numel(models)+1,2+2*numel(models)-1);

fprintf('Number of &');
for e=1:2
    if e==2, fprintf('&'); end
for i=2:numel(models),    fprintf('& (%d) ',i-1); end
end
fprintf('\\\\ \\midrule \n')

for j=1:numel(stats)
fprintf('%s',names.(stats{j}));
for e=1:2
for i=1:numel(models),    fprintf('& $%2.2f$ ',stat(i,e).(stats{j})); end
end
fprintf('\\\\ \n')
end

fprintf('\\bottomrule \\end{tabular}\n');
diary off


%% Sensitivity to the saving rate moments
specs = {'-','--','-.',':'};

for e=1:2
par      = est_pars{e};
M = -par.sens.M1;
[num_par,num_mom] = size(M);
% [NOTE: we use the negative M1 because we do not have the negative version
% because we have (sim-data) rather than (data-sim)! and thus the gradient has the "wrong" sign when used an uneven number of times, as in M1!]
mom_data = ones(num_par,1)*par.mom_data';
params   = par.est'*ones(1,num_mom);
Me       = M.*mom_data/100; % 1pct change in moment/misspecification

figure(1)
for j=1:4
    plot(abs(Me(j,:))',specs{j},'LineWidth',2)
    hold on
end
hold off
xlabel('moment')
ylabel('sensitivity')
set(gca,'Xtick',1:numel(par.mom_data))
xlim([1 numel(par.mom_data)])
set(gca,'FontSize',15)
lgd = legend('$\kappa_1$','$\kappa_2$','$\kappa_3$','$\nu$');
set(lgd,'Interpreter','latex','FontSize',20,'Location','NorthWest')
print('-depsc',['results\sens1_' num2str(e-1) '.eps'])

figure(2)
for j=1:4
    plot(abs(Me(4+j,:))',specs{j},'LineWidth',2)
    hold on
end
hold off
xlabel('moment')
ylabel('sensitivity')
set(gca,'Xtick',1:numel(par.mom_data))
xlim([1 numel(par.mom_data)])
set(gca,'FontSize',15)
lgd = legend('$\psi$','$\underline{\wp}$','$\omega_1$','$\omega_2$');
set(lgd,'Interpreter','latex','FontSize',20,'Location','NorthWest')
print('-depsc',['results\sens2_' num2str(e-1) '.eps'])


% Print the entire original matrix in a table
diary off
name = ['results\tab_sens' num2str(e-1) '.tex'];
delete(name)
diary(name)

fprintf('\\begin{tabular}{l*{8}{r}} \\toprule\n & \\multicolumn{8}{c}{Parameter}\\\\ \\cmidrule(lr){2-9} \n');
fprintf('Moment');
for j=1:num_par
    fprintf('& \\multicolumn{1}{c}{$%d$}',j);
end
fprintf('\\\\ \\hline \n');
for k=1:num_mom
    
    fprintf('$%d$',k)
    
    for j=1:num_par
        
        fprintf('& $%2.2f $',M(j,k));
        
    end
    
    fprintf('\\\\ \n ');
end
fprintf('\\bottomrule \\end{tabular}\n');
diary off

end
