% This script is used to plot distances among cells in the boundary region.
clc; clear;
%% load data
load("exlabels_12datasets.mat", "methods", "labels_saved_all")
load("../../data/hemisphere_dataset_summary.mat", "metrics_all", "choices_all", "choices_gt_all")

rng(42)

num_methods = length(methods);

annotator_lst = [1, 2, 3, 4; % dataset 1
                 1, 3, 4, 5; % dataset 2
                 1, 2, 3, 6];% dataset 3

methods_dspnames = cell(1, num_methods);
for k=1:num_methods
    method_dspnames{k} = get_legend_name(methods{k});
end
%% configuration
nIters = 10; % number of iterations.
boundary = [0.25, 0.75];

cell_probs_all = cell(4, 3);

%% get cell's predicted probability
for d = 1:3
    for ann = 1:4
        ann_idx = annotator_lst(d, ann);
        choices    = choices_all{ann_idx,d};
        choices_gt = choices_gt_all{ann_idx,d};
        metrics    = metrics_all{d};

        features = metrics';
        infs = isinf(features);
        features(infs) = 0;
        features = zscore(features, 0, 1);

        n_cells = length(choices);

        fprintf("Working on Dataset %i and Annotation %i ", d, ann);
        [~,cell_probs_each] = get_classifier_preds(features,choices, nIters);

        cell_probs = mean(cell_probs_each, 1);
        cell_probs_all{d, ann} = cell_probs;
        fprintf("\n")
    end
end
%% get boundary cells and pairwise distance for each method
ratio_lst  = [0.01, 0.03, 0.05, 0.1];
num_ratios = length(ratio_lst);

picked_boundary_cells = cell(3, 4, num_ratios, num_methods);
num_picked_boundary_cells    = zeros(num_methods, num_ratios, 12);
num_boundary_cells_percent = zeros(num_methods, num_ratios, 12);

distance_picked_cells    = zeros(num_methods, num_ratios, 12);
 
percent4agreements_inpicked = zeros(num_methods, num_ratios, 3, 12); % 3 agreement types: 0-4, 1-3, 2-2
percent4agreements_invote   = zeros(num_methods, num_ratios, 3, 12); % 3 agreement types: 0-4, 1-3, 2-2

for d = 1:3
    metrics   = metrics_all{d};
    features = metrics';
    infs = isinf(features);
    features(infs) = 0;
    features = zscore(features, 0, 1);
    num_cells = size(metrics,2);

    choices_ofd = get_choices_of_dataset(choices_all, annotator_lst, d); % the four annotators choices for dataset d
    for ann = 1:4
        i = (d-1) * 4 + ann;
        % get classifier's predicted cell probabilties
        cell_probs = cell_probs_all{d, ann};
        % number of boundary cells
        num_total_boundary_cells = sum(cell_probs > boundary(1) & cell_probs < boundary(2));
        num_total_cells            = length(cell_probs);
        for k=1:num_methods                  
            for r=1:num_ratios
                ratio = ratio_lst(r);
                fprintf("Working on Dataset %i Annotator %i Method %s Ratio %s%%.", d, ann, get_legend_name(methods{k}), num2str(ratio*100))
                % get picked cells
                q_idxs_all   = labels_saved_all{d, ann, k}.q_idxs;

                num_ratio_cells = floor(num_total_cells * ratio);
                assert(num_ratio_cells<=length(q_idxs_all), "you sorted less than your defined ratio");
                
                q_idxs_ratio = q_idxs_all(1:num_ratio_cells);
                % get the associate cell probabilities for each picked cell
                picked_cells_probs = cell_probs(q_idxs_ratio); 
                % get the picked cell indices that are within the predefined cell boundary      
                picked_cells_idxs_inbound = q_idxs_ratio(picked_cells_probs > boundary(1) & picked_cells_probs < boundary(2));
                
                picked_boundary_cells{d, ann, r, k} = picked_cells_idxs_inbound;
                num_picked_boundary_cells(k, r, i) = length(picked_cells_idxs_inbound);
                num_boundary_cells_percent(k, r, i) = length(picked_cells_idxs_inbound) / num_total_boundary_cells;
    
                picked_cells_features = features(picked_cells_idxs_inbound,:); % q_idxs_all_sort
                distance_matrix = pdist2(picked_cells_features, picked_cells_features);
                distances = nonzeros(triu(distance_matrix, 1));
                if isnan(mean(distances))
                    distance_picked_cells(k, r, i) = 0;
                else
                    distance_picked_cells(k, r, i) = mean(distances);
                end
                
                [agreements_inpicked, agreements_invote] = get_agreements(q_idxs_ratio, choices_ofd);
                percent4agreements_inpicked(k, r, :, i) = agreements_inpicked;
                percent4agreements_invote(k, r, :, i) = agreements_invote;
            end
            fprintf("\n")
        end
    end
end
%% plot percentage of boundary cells that are picked by each method
plot_percentage(num_boundary_cells_percent, ratio_lst, method_dspnames, "percent_boundary_cells_picked")

%%
pvalues = [];
r = 2;
for i =1:7
    for j=1:7
        if i>=j
            continue
        end
        sr = signrank(squeeze(num_boundary_cells_percent(i, r,:)), squeeze(num_boundary_cells_percent(j, r,:)));
        pvalues(end+1) = sr;        
    end
end
[isSignificant,adjusted_pvals,alpha]= bonferroni_holm(pvalues);
index = 0;
for i=1:7
    for j=1:7
        if i>=j
%             continue
            fprintf(" & N/A")
        else
            index = index + 1;
            p = adjusted_pvals(index);
            methodi = get_legend_name(methods{i});
            methodj = get_legend_name(methods{j});
%             fprintf("%s-%s has p-value: %.4f\n", methodi, methodj, p);
            fprintf(" & %.4f",p);
        end
    end
    fprintf("\\ \n")
end
%% plot the feature distance between picked cells
plot_percentage(distance_picked_cells, ratio_lst, method_dspnames, "distance_picked_cells")

%%
pvalues = [];
r = 2;
for i =1:7
    for j=1:7
        if i>=j
            continue
        end
        sr = signrank(squeeze(distance_picked_cells(i, r,:)), squeeze(distance_picked_cells(j, r,:)));
        pvalues(end+1) = sr;        
    end
end
[isSignificant,adjusted_pvals,alpha]= bonferroni_holm(pvalues);
index = 0;
for i=1:7
    for j=1:7
        if i>=j
            continue
        end
        index = index + 1;
        p = adjusted_pvals(index);
        methodi = get_legend_name(methods{i});
        methodj = get_legend_name(methods{j});
        fprintf("%s-%s has p-value: %.4f\n", methodi, methodj, p);
    end
end
%% Plot agreement in picked bar plot
plot_percentage(squeeze(percent4agreements_inpicked(:, :, 3, :)), ratio_lst, method_dspnames, "percent_picked_is2-2")
%% Plot agreement in vote bar plot
plot_percentage(squeeze(percent4agreements_invote(:, :, 3, :)), ratio_lst, method_dspnames, "percent_2-2_ispicked")
%%
for i=1:num_methods
    methodi = get_legend_name(methods{i});
    sr1 = signrank(squeeze(percent4agreements_invote(i, 2, 1, :)), squeeze(percent4agreements_invote(i, 2, 2, :)));
    sr2 = signrank(squeeze(percent4agreements_invote(i, 2, 1, :)), squeeze(percent4agreements_invote(i, 2, 3, :)));
    sr3 = signrank(squeeze(percent4agreements_invote(i, 2, 2, :)), squeeze(percent4agreements_invote(i, 2, 3, :)));
    pvalues = [sr1, sr2, sr3];
    [isSignificant,adjusted_pvals,alpha]= bonferroni_holm(pvalues);
    fprintf("%s has p value %.3f, %.3f, %.3f\n", methodi, adjusted_pvals(1), adjusted_pvals(2), adjusted_pvals(3))
end
%%
plot_agreements_all(percent4agreements_inpicked, ratio_lst, method_dspnames, "percent_picked_is2-2_all")
plot_agreements_all(percent4agreements_invote, ratio_lst, method_dspnames, "percent_2-2_ispicked_all")