% this file is the runner of the active learning fon cross-mice experiment
% it includes 6 dataset pairs, 16 annotators for each dataset pairs
clear;clc;
% parpool('local',6)
dp = 1;
d1_ratio = 0.05;
d2_ratio = 0.08;
threshold = -1;
balance = 0;
balance_pretrained = 1;
align = 0;
repeat = 1;
%% load data and create folders
home_path = '../../../ActSort/';
addpath(genpath(home_path));

folderName = './results_test';
% Check if the folder exists
if ~exist(folderName, 'dir')
    % If it doesn't exist, create the folder
    mkdir(folderName);
    fprintf('[INFO] : Directory "%s" has been created.\n', folderName);
else
    fprintf('[INFO] : Directory "%s" already exists.\n', folderName);
end

% rng(19991112)
load('../../data/hemisphere_dataset_summary.mat')
%% define the annotator indices
days_lst = [1,2];%; 1,3; 2,1; 2,3; 3,1; 3,2];
days_lst = days_lst(dp,:);

annotator_lst = [1, 2, 3, 4; % dataset1
                 1, 3, 4, 5; % dataset2
                 1, 2, 3, 6];% dataset3
%% define hyperparameters
config.DO_ZSCORING = true;
config.n = 1;
config.balance = balance; % for balancing dataset when training from scratch
config.balance_pretrained = balance_pretrained; % for balancing dataset when pretraining
config.align = align;

config_pretrain = config;
config_pretrain.repeat = 1;
config_finetune = config;
config_finetune.repeat = repeat;

method_rand = struct("name", "random", "continue", false);
method_cal  = struct("name", "cal", "continue", false);
method_dal  = struct("name", "dal", "continue", false);
methods     = {method_rand, method_cal, method_dal};

weights_lst = [0.3, 0.7];
for i=1:length(weights_lst)
    weight = weights_lst(i);
    method_dcal = struct("name", "dcal", "weight", weight, "continue", false);
    methods{end+1} = method_dcal;
end

% % initializing multi-arm bandit exp3 method
%  [reward_func_lst, reward_name_lst] = get_reward_funcs([3, 7, 8, 10]); % return all predefined reward function0
%  gamma_lst = [0.3, 0.5];
%  for i=1:length(reward_func_lst)
%      for g=1:length(gamma_lst)
%          gamma       = gamma_lst(g);
%          reward_func = reward_func_lst{i};
%          reward_name = reward_name_lst(i);
%          method_mab  = struct("name", "mab-exp3", "gamma", gamma, ...
%                        "reward_func", reward_func, "reward_name", reward_name);
%          methods{end+1} = method_mab;
%      end
%  end
% 
% % initializing multi-arm ubc method
% alpha_lst = [sqrt(2)];
% for i=1:length(alpha_lst)
%     alpha      = alpha_lst(i);
%     method_ucb = struct("name", "mab-ucb", "alpha", alpha);
%     methods{end+1} = method_ucb;
% end

clear g i gamma method_cal method_cal method_dcal method_mab reward_func reward_name weight;

num_methods = length(methods);
num_dparis  = size(days_lst,1);

%% ActSort
eval_lst_ft = cell(num_dparis, 4, 4, num_methods);
eval_lst_sc = cell(num_dparis, 4, 4, num_methods);
for dpair = 1:num_dparis
    d1 = days_lst(dpair, 1);
    d2 = days_lst(dpair, 2);
    metrics_d1    = metrics_all{d1};
    metrics_d2    = metrics_all{d2};
    for ann1 = 1:1
        ann1_name = annotator_lst(d1, ann1);
        choices1    = choices_all{ann1_name, d1};
        choices1_gt = choices_gt_all{ann1_name, d1};
        for ann2 = 1:1
            ann2_name = annotator_lst(d2, ann2);
            choices2    = choices_all{ann2_name, d2};
            choices2_gt = choices_gt_all{ann2_name, d2}; 
            
            k = 1;
            method = methods{k};
            dspname = get_legend_name(method);

%             [fullPath, ckpt_exist_flag] = check_ckpt_exist(folderName, dspname, dpair, ann1, ann2);
            ckpt_exist_flag = false;
            if ckpt_exist_flag
                fprintf("%s: Data Pair %i Annotators %i-%i Exist! Load eval_metrics ===>", datetime("now"), dpair, ann1, ann2);
                load(fullPath, 'eval_metrics_ft')
                eval_lst_ft{dpair, ann1, ann2, k} = eval_metrics_ft;
                fprintf("update eval_lst. \n")
            else
                pretrain_method = struct("name", "cal", "continue", false);
                pm_dspname = get_legend_name(pretrain_method);
                fprintf("%s: Start training on dataset %i (%.2f) with Annotators %i-%i using method %s ...", datetime("now"), d1, d1_ratio, ann1, ann2, pm_dspname);
                [~, ~, dataset] = play_active_learning_new(metrics_d1, choices1, choices1_gt,...
                                                            d1_ratio, pretrain_method, config_pretrain);

                fprintf("Fine-tuning on dataset %i (%.2f) using method %s ...", d2, d2_ratio, dspname);
                pretrained           = dataset;
                pretrained.threshold = threshold;
%                 labels_exml = combine_exml_labels(pretrained);
%                 pretrained.labels_ml = labels_exml;
%                   pretrained = balance_pretrained_dataset(pretrained);
                method.continue = true;
                method.mdl      = dataset.mdl;
                method.pretrained  = pretrained;
                [~, eval_metrics_ft, dataset_ft] = play_active_learning_new(metrics_d2, choices2,choices2_gt,...
                                                                            d2_ratio, method, config_finetune);

                fprintf("Start training on dataset %i (%.2f)", d2, d2_ratio);
                method = methods{k};
                [~, eval_metrics_sc, dataset_sc] = play_active_learning_new(metrics_d2, choices2, choices2_gt,...
                                                                            d2_ratio, method, config);
                eval_lst_ft{dpair, ann1, ann2, k} = eval_metrics_ft;
                eval_lst_sc{dpair, ann1, ann2, k} = eval_metrics_sc;
                fprintf('\n');
%                 ckpt_name = folderName+"/ckpt-"+dspname+"-dp"+num2str(dpair)+"-ann"+num2str(ann1)+"-ann"+num2str(ann2)+".mat";
%                 save(ckpt_name, 'eval_metrics_ft', '-v7.3');
                fprintf("[INFO] : checkpoint %s saved\n", ckpt_name);
            end
        end
    end 
end
%% save
save_name = folderName+'/runner_crossmice_aug_dp'+num2str(dp)+'.mat';
save(save_name, '-v7.3')

%% average over the 16 annotators and 6 datasets
%%%% DEBUG %%%%
num_anns = 16;
%%%%%%%%%%%%%%%
H_lst = zeros(1,num_dparis); % horizon list
for dpair=1:num_dparis
    H_lst(dpair) = size(eval_lst_ft{dpair,1,1,1}.ACC, 2);
end
H = min(H_lst);

avg_acc_ft = zeros(num_methods, H);
avg_tpr_ft = zeros(num_methods, H);
avg_tnr_ft = zeros(num_methods, H);
avg_precision_ft = zeros(num_methods, H);
avg_recall_ft    = zeros(num_methods, H);
avg_fscore_ft    = zeros(num_methods, H);
avg_auc_ft       = zeros(num_methods, H);

avg_acc_sc = zeros(num_methods, H);
avg_tpr_sc = zeros(num_methods, H);
avg_tnr_sc = zeros(num_methods, H);
avg_precision_sc = zeros(num_methods, H);
avg_recall_sc    = zeros(num_methods, H);
avg_fscore_sc    = zeros(num_methods, H);
avg_auc_sc       = zeros(num_methods, H);

avg_acc_human = 0;
avg_tpr_human = 0;
avg_tnr_human = 0;
avg_precision_human = 0;
avg_recall_human    = 0;
avg_fscore_human    = 0;
avg_auc_human       = 0;

for k=1:num_methods
    for dpair=1:num_dparis
        for ann1=1:4 %%%%%%%% DEBUG %%%%%%%%%%
            for ann2=1:4 %%%%%%%% DEBUG %%%%%%%%%%
                avg_acc_ft(k,:) = avg_acc_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.ACC(1:H);
                avg_tpr_ft(k,:) = avg_tpr_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.TPR(1:H);
                avg_tnr_ft(k,:) = avg_tnr_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.TNR(1:H);
                avg_precision_ft(k,:) = avg_precision_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.Precision(1:H);
                avg_recall_ft(k,:)    = avg_recall_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.Recall(1:H);
                avg_fscore_ft(k,:)    = avg_fscore_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.Fscore(1:H);
                avg_auc_ft(k,:)       = avg_auc_ft(k,:) + eval_lst_ft{dpair, ann1, ann2, k}.AUC(1:H);

                avg_acc_sc(k,:) = avg_acc_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.ACC(1:H);
                avg_tpr_sc(k,:) = avg_tpr_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.TPR(1:H);
                avg_tnr_sc(k,:) = avg_tnr_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.TNR(1:H);
                avg_precision_sc(k,:) = avg_precision_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.Precision(1:H);
                avg_recall_sc(k,:)    = avg_recall_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.Recall(1:H);
                avg_fscore_sc(k,:)    = avg_fscore_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.Fscore(1:H);
                avg_auc_sc(k,:)       = avg_auc_sc(k,:) + eval_lst_sc{dpair, ann1, ann2, k}.AUC(1:H);
            end
        end
    end
    avg_acc_ft(k,:) = avg_acc_ft(k,:) ./ (num_dparis * num_anns);
    avg_tpr_ft(k,:) = avg_tpr_ft(k,:) ./ (num_dparis * num_anns);
    avg_tnr_ft(k,:) = avg_tnr_ft(k,:) ./ (num_dparis * num_anns);
    avg_precision_ft(k,:) = avg_precision_ft(k,:) ./ (num_dparis * num_anns);
    avg_recall_ft(k,:)    = avg_recall_ft(k,:) ./ (num_dparis * num_anns);
    avg_fscore_ft(k,:)    = avg_fscore_ft(k,:) ./ (num_dparis * num_anns);
    avg_auc_ft(k,:)       = avg_auc_ft(k,:) ./ (num_dparis * num_anns);

    avg_acc_sc(k,:) = avg_acc_sc(k,:) ./ (num_dparis * num_anns);
    avg_tpr_sc(k,:) = avg_tpr_sc(k,:) ./ (num_dparis * num_anns);
    avg_tnr_sc(k,:) = avg_tnr_sc(k,:) ./ (num_dparis * num_anns);
    avg_precision_sc(k,:) = avg_precision_sc(k,:) ./ (num_dparis * num_anns);
    avg_recall_sc(k,:)    = avg_recall_sc(k,:) ./ (num_dparis * num_anns);
    avg_fscore_sc(k,:)    = avg_fscore_sc(k,:) ./ (num_dparis * num_anns);
    avg_auc_sc(k,:)       = avg_auc_sc(k,:) ./ (num_dparis * num_anns);
end

for dpair=1:num_dparis
    d1 = days_lst(dpair, 1);
    d2 = days_lst(dpair, 2);
    for ann2=1:4
        ann2_name = annotator_lst(d2, ann2);
        choices2    = choices_all{ann2_name, d2};
        choices2_gt = choices_gt_all{ann2_name, d2}; 
        
        eval_metrics_human = get_ex_accuracy(choices2', choices2_gt');

        avg_acc_human = avg_acc_human + eval_metrics_human.ACC / (num_dparis * 4);
        avg_tpr_human = avg_tpr_human + eval_metrics_human.TPR / (num_dparis * 4);
        avg_tnr_human = avg_tnr_human + eval_metrics_human.TNR / (num_dparis * 4);
        avg_precision_human = avg_precision_human + eval_metrics_human.Precision / (num_dparis * 4);
        avg_recall_human    = avg_recall_human + eval_metrics_human.Recall / (num_dparis * 4);
        avg_fscore_human    = avg_fscore_human + eval_metrics_human.Fscore / (num_dparis * 4);
        avg_auc_huamn       = avg_auc_human + eval_metrics_human.AUC / (num_dparis * 4);
    end

end
%% plotting
color_map = [0,      0.4470, 0.7410; % blue
             0.8500, 0.3250, 0.0980; % orange
             0.9290, 0.6940, 0.1250; % yellow
             0.4940, 0.1840, 0.5560; % purple
             0.4660, 0.6740, 0.1880; % green
             0.6350, 0.0780, 0.1840];% red
x = 1:1:H;
x = x./10;
figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_auc_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_auc_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('AUC')
line([1, H]./10, [avg_auc_huamn, avg_auc_huamn], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_acc_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_acc_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('ACC')
line([1, H]./10, [avg_acc_human, avg_acc_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_recall_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_recall_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('Recall')
line([1, H]./10, [avg_recall_human, avg_recall_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_precision_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_precision_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('Precision')
line([1, H]./10, [avg_precision_human, avg_precision_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_tpr_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_tpr_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('True Positive Rate')
line([1, H]./10, [avg_tpr_human, avg_tpr_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_tnr_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_tnr_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('True Negative Rate')
line([1, H]./10, [avg_tnr_human, avg_tnr_human], 'Color', 'r', 'LineStyle', '--');
hold off

figure;
for k=1:num_methods
    method = methods{k};
    dspname = get_legend_name(method);
    plot(x, avg_fscore_ft(k,:), 'DisplayName',strcat(dspname, '-finetune'), 'Color',color_map(k,:), 'LineStyle','-')
    hold on
    plot(x, avg_fscore_sc(k,:), 'DisplayName',strcat(dspname, '-scratchtrain'), 'Color',color_map(k,:), 'LineStyle','--')
end
legend()
title('F-score')
line([1, H]./10, [avg_fscore_human, avg_fscore_human], 'Color', 'r', 'LineStyle', '--');
hold off

function [fullPath, ckpt_exist] = check_ckpt_exist(folderName, dp, ann1, ann2)
    directory = folderName;
    filename = sprintf('ckpt-dp%i-ann%i-ann%i.mat', dp, ann1, ann2);
    
    % Combine the directory and filename to create the full file path
    fullPath = fullfile(directory, filename);
    
    % Check if the file exists
    % 'exist' returns 2 if the file exists
    if exist(fullPath, 'file') == 2
        ckpt_exist = true;
    else
        ckpt_exist = false;
    end

end