clearvars
close all
% evalc('feature(''numcores'')')
% parpool('local',6)
t0 = tic();
date_string = '20210618';
country = 'Austria';
% country = 'France';
% country = 'Germany';
% country = 'Italy';
% country = 'Spain';
% country = 'UK';
% country = 'US';

N = 50000;
nsimul = 500;
roll_window = 14;
MF_window = roll_window;
ini_days = 7;
MA_window = 7;
gamma = 1/14;
cstar = 0.01;
MF_ini = 5;

path_read = ['.\data_raw\',date_string,'\'];
path_write = ['.\results_empirical\',date_string,'_MF',num2str(MF_window/7),...
    'W_guess',num2str(MF_ini),'_N',num2str(N),'_par\'];

if ~exist(path_write,'dir')
    mkdir(path_write)
end

%% load realized data
fname_read = [path_read, country,'.csv'];
data = readtable(fname_read);
CT = data.C;
pop = data.Pop(1);
data.Date = datetime(data.Date,'InputFormat','ddMMMyyyy');
IT = fn_est_IT(CT,gamma);

out_MA = fn_MA([CT,IT],MA_window);
CT_MA = out_MA(:,1);
IT_MA = out_MA(:,2);
cT_MA = CT_MA./pop;
iT_MA = IT_MA./pop;

dCT_MA = diff(CT_MA);
if min(dCT_MA)<0
    error('Negative new cases!')
end
dcT_MAp = dCT_MA./pop.*100000;  % daily new cases per 100k
idx_cross1 = find(dcT_MAp>=1,1,'first'); 
idx_beg = idx_cross1 + (MA_window-1) + 1; 
date_beg_betahat = data.Date(idx_beg+1);

idx_beg_MA = idx_beg - MA_window + 1;                                     
cT_MA_use = cT_MA(max(idx_beg_MA-roll_window,0)+1:end); 
iT_MA_use = iT_MA(max(idx_beg_MA-roll_window,0)+1:end);
T = length(cT_MA_use);
date_beg_MA_use = data.Date(end - length(cT_MA_use) + 1);

idx_beg_joint = find(cT_MA_use>cstar,1,'first');
date_beg_joint = data.Date(end - (length(cT_MA_use) - idx_beg_joint + 1) + 1);

%% start joint estimation
t_est_set = idx_beg_joint-1: MF_window: T;
date_est_set =  date_beg_joint-1: MF_window: data.Date(end);

for j=1:length(t_est_set)
    fprintf('  part %d\n', j);
    t_est_end = t_est_set(j);
    cT_est = cT_MA_use(1:t_est_end); 
    iT_est = iT_MA_use(1:t_est_end);
    if j==1
        MF_latest = MF_ini;
    else
        fname_MF_latest = [path_write, country, '_MF_', num2str(roll_window/7),'W_nsimul',num2str(nsimul),'_part',num2str(j-1)];
        data_MF_latest = load(fname_MF_latest);
        MF_latest = data_MF_latest.MF;
    end
    betahat_rollT = fn_est_beta(cT_est,iT_est,roll_window,MF_latest);
    if j==1
        beta_beg_idx = find(betahat_rollT./gamma<3,1,'first'); 
        betaT_less3 = betahat_rollT(beta_beg_idx:end);
        betaT = [betaT_less3(1).*ones(ini_days,1); betaT_less3];
        x_ini = []; y_ini = [];
    end
    if j>1  % replace with earlier estimates
        t_est_beg = t_est_set(j-1); % time of last end point
        betaT = betahat_rollT(t_est_beg-roll_window : t_est_end-roll_window);  % Wm+1 by 1
        fname_sim_latest =  [path_write, country, '_sim_',num2str(roll_window/7),'W_nsimul',num2str(nsimul),...
            '_ER_inidays',num2str(ini_days),'_part',num2str(j-1)];
        data_sim_latest = load(fname_sim_latest);
        x_ini = data_sim_latest.results_sim.x_end;
        y_ini = data_sim_latest.results_sim.y_end;
    end
    fname_beta = [path_write, country, '_beta_', num2str(roll_window/7),'W_nsimul',num2str(nsimul),'_part',num2str(j)];
    save(fname_beta,'betahat_rollT','betaT');
    
    %====== run simulation ======   
    myStream = RandStream.create('Threefry','NumStreams',length(t_est_set),'StreamIndices',j);
    [parm, results_sim] = fn_sim_empirical_cont_par(myStream, betaT, N, nsimul, x_ini, y_ini);
    fname_sim = [path_write, country, '_sim_',num2str(roll_window/7),'W_nsimul',num2str(nsimul),...
        '_ER_inidays',num2str(ini_days),'_part',num2str(j)];
    save(fname_sim,'parm','results_sim');
     
    %====== calculate MF ======
    cT_cal = results_sim.cT;
    iT_cal = results_sim.iT;
    if j==1
        tbeg_cal = beta_beg_idx - ini_days + roll_window;
        cT_real_ref = cT_est(tbeg_cal:end);  
        MF = mean(cT_cal(:,end))/cT_real_ref(end);
    else
        cT_real_ref = cT_est(end-length(betaT)+1: end);   % Wm+1 by 1
        MF = fn_compute_MF(iT_cal, cT_real_ref, betaT);
    end
    fname_MF = [path_write, country, '_MF_', num2str(roll_window/7),'W_nsimul',num2str(nsimul),'_part',num2str(j)];
    save(fname_MF,'MF');     
end
elapsedMin = toc(t0)/60
delete(gcp('nocreate'))