clear
clc
N = 1000;
accuracy_all = cell(1,2);

for is_D1 = [1,0]
    if is_D1
        root_folder = '/Users/dinc/Desktop/paper codes/ActSort/data/striatum_data/D1';
    else
        root_folder = '/Users/dinc/Desktop/paper codes/ActSort/data/striatum_data/D2';
    end
    files = dir(root_folder);
    num_files = size(files,1)-2;
    accuracy_pearson = zeros(num_files,6,N);
    accuracy_spearman = zeros(num_files,6,N);
    rmse = zeros(num_files,6,N);
    for k=1:num_files
        
        current_folder = files(2+k);
        load([root_folder '/' current_folder.name]);
        num_cells = size(metrics,2);
        valid_all = boolean(zeros(6,num_cells));
        valid_all(1,:) = valid == 0;
        valid_all(2,:) = valid == 1;
        choices = 2*valid-1;

        ratio_all = [0.01,0.05,0.1,1];
        for o = 1:4
            valid = play_active_learning(metrics,choices,ratio_all(o));
            valid = valid>0.5;
            valid_all(o+2,:)=valid == 1;
        end
        
        num_valids = sum(valid_all,2);
        num_pick = min(num_valids);

        s = centroids(2:end,:)-centroids(1:end-1,:);
        s = sqrt(sum(s.^2,2));
        s = s*5/210*31;
        filt = ones(1,30)/30;
        c = conv(s,filt);
        speeds    = c(1:17999,:);

        T_ica = double(T);
        S_ica = double(S);
        T_ica     = T_ica(:,1:17999);
        [nx, ny, k_ica] = size(S_ica);
        E_ica = zeros(size(T_ica));
        time_decay = 2;
        parfor idx = 1:k_ica
            temp = T_ica(idx,:);
            temp_kernel = exp(-(1:5*time_decay)/time_decay);
            tr_dec = deconv([temp,zeros(1,5*time_decay-1)],temp_kernel);
            noise = estimate_noise_std(tr_dec);
            tr_dec(tr_dec < 5*noise ) = 0;
            E_ica(idx,:) =tr_dec;
        end
        E_ica = single(E_ica)';
            
        smooth_len  = 2.5;
        temp_x = linspace(-5*smooth_len,5*smooth_len,10*smooth_len+1);
        filt = exp(-temp_x.^2 / (2 * smooth_len));
        E_ica = conv2(filt, 1, E_ica,'same');
        parfor num_samp=1:N
            for i=1:6
                traces = [];
                k_ica = sum(valid_all(i,:));
                ind_random = linspace(1,k_ica,k_ica);
                ind_random = ind_random(randperm(length(ind_random)));
                ind_random = ind_random(1:num_pick);
                traces = E_ica(:,valid_all(i,:));
                traces = traces(:,ind_random);


                y = speeds;
                [rmse(k,i,num_samp),accuracy_spearman(k,i,num_samp),...
                    accuracy_pearson(k,i,num_samp),~] = ...
                    run_speed_decoders(traces,y);
            end
        end
        temp = squeeze(accuracy_pearson(k,:,:));
        a = mean(temp,2);
        
        fprintf("%s: Mouse %d. Cells %d. " + ...
            "Bad: %.2f, Human: %.2f "...
            + "One: %.2f, Five: %.2f,  Ten: %.2f, All: %.2f. \n",...
             datestr(now),k,num_pick,a(1),a(2),a(3),a(4),a(5),a(6));
    end
    accuracy_temp = {};
    accuracy_temp{1} = accuracy_pearson;
    accuracy_temp{2} = accuracy_spearman;
    accuracy_temp{3} = rmse;
    if is_D1
        accuracy_all{1} = accuracy_temp;
    else
        accuracy_all{2} = accuracy_temp;
    end
end

save('speed_experiment.mat','accuracy_all','-v7.3');