clear;
seed = RandStream('mt19937ar','seed',sum(100*clock));
RandStream.setGlobalStream(seed);



% set the N and T
N = 100;
T = 40;

p = 2;

% construct time and unit index
time_ind = zeros(N*T, 1);
unit_ind = zeros(N*T, 1);
for ii = 1:N
    ind1 = (ii-1)*T + 1;
    ind2 = (ii-1)*T + T;
    time_ind(ind1:ind2) = (1:T)';
    unit_ind(ind1:ind2) = ii;
end

% grouping properties
K = 8;

b_true_group_1 = [-4; 4];
b_true_group_2 = [-3; 3];
b_true_group_3 = [-2; 2];
b_true_group_4 = [-1; 1];
b_true_group_5 = [1; -1];
b_true_group_6 = [2; -2];
b_true_group_7 = [3; -3];
b_true_group_8 = [4; -4];

b_true_group = [b_true_group_1, b_true_group_2, b_true_group_3,...
    b_true_group_4, b_true_group_5, b_true_group_6, b_true_group_7,...
    b_true_group_8];

b_true = zeros(p, N);
group_true = ones(N, 1);
n = 10;
for ii = 1:K
    ind1 = (ii-1)*n+1;
    ind2 = (ii-1)*n+n;
    b_true(:,ind1:ind2) = repmat(b_true_group(:,ii), 1, n);
    group_true(ind1:ind2) = ii;
end
k = N - n*K;
b_true(:,N-k+1:end) = repmat(b_true_group(:, 1),1,k);
group_true(N-k+1:end) = 1;
group_true_exact = kron(group_true, ones(T,1));

% generate data
repeat_num = 100;
for num = repeat_num:-1:1
   temp1 = normrnd(0,1,[N*T, p]);
   temp2 = kron(normrnd(0,1,[N, 1]), ones(T,1)); % fixed effect
   temp3 = [temp1, temp2];
   A = [1,0; 0,1; 0.2,0.2];  
   x = temp3 * A;
   y = zeros(N*T,1);
   
   for ii = 1:K
       s = size(y(group_true_exact==ii),1);
       y(group_true_exact==ii) = x(group_true_exact==ii,:) * b_true_group(:,ii) + temp2(group_true_exact==ii, end) + normrnd(0, 1, [s, 1]);
   end
   
   design(num).x = x;
   design(num).y = y;
end

% useful matrix
j_bar = ones(T)./T;
j_all = kron(eye(N),j_bar); 
q_all = eye(N*T) - j_all;


for num = repeat_num:-1:1
    display(num);
    x_orig = design(num).x;
    y_orig = design(num).y;
    
    x = q_all * x_orig;
    y = q_all * y_orig;
    
    % oracle estimator
    [b_oracle, K] = oracle_panel(y, x,group_true, T);
    result(num).b_oracle = b_oracle';
    result(num).oracle_group = K;
    
    % usual fixed effect estimator
    b_fe = panel_fe(y, x, T);
    result(num).b_fe = b_fe;
    
    % unit-wise estimator
    b_unitwise = oracle_panel(y, x,1:N, T);
    result(num).b_unitwise = b_unitwise';
    
    % panel aCARDS estimator
    eta = 0.02;
    delta_choice = 20:10:50;
    lam_within_choice = 0.04:0.01:0.08;
    lam_neighbor_choice = 0.09:0.01:0.13;

    LA = length(delta_choice)*length(lam_within_choice)*length(lam_neighbor_choice);
    result(num).RSS_aCARDS_panel = zeros(LA, 7);
    L = 1;
    for s = 1:length(delta_choice)
        display(s);
        for j = 1:length(lam_within_choice)
            for i = 1:length(lam_neighbor_choice)
                    lam_neighbor = lam_neighbor_choice(i);
                    lam_within = lam_within_choice(j);
                    delta = delta_choice(s);
                    [b, df] = cardsest_panel(x, y, T, b_unitwise, b_unitwise, delta, eta, lam_within, lam_neighbor,1e-4, 30, 'scad', 'ampl');
                    b_all = kron(b', ones(T,1));
                    sigma2 = sum((y - sum(x.*b_all,2)).^2) / (N*T - df*p);
                    result(num).RSS_aCARDS_panel(L,:) = [delta, lam_within, lam_neighbor, sigma2, df, log(sigma2)+ df*p*log(N*T)/(N*T), log(sigma2)+ (1/2)*df*p/sqrt(N*T)];
                    L = L+1;
            end
        end
        save('dgp3_R100N100T40eta2per.mat');
    end
    rss = result(num).RSS_aCARDS_panel;
    LL = minselect(rss, 6, 5);
    [b, df] = cardsest_panel(x, y, T, b_unitwise, b_unitwise, rss(LL,1), eta, rss(LL,2), rss(LL,3), 1e-4, 30, 'scad', 'ampl');
    result(num).b_aCARDS_bic = b';
    result(num).K_aCARDS_bic = df;
    result(num).param_aCARDS_bic = rss(LL, 1:3); 
    LL = minselect(rss, 7, 5);
    [b, df] = cardsest_panel(x, y, T, b_unitwise, b_unitwise, rss(LL,1),eta, rss(LL,2), rss(LL,3), 1e-4, 30, 'scad', 'ampl');
    result(num).b_aCARDS_ic2 = b';
    result(num).K_aCARDS_ic2 = df;
    result(num).param_aCARDS_ic2 = rss(LL, 1:3);
    

end

for num = repeat_num:-1:1
    K_num(num,2) = result(num).K_aCARDS_ic2;
    K_num(num,1) = result(num).K_aCARDS_bic;

end


save('dgp3_R100N100T40eta2per.mat');