% this file is the runner of the active learning averaging over all 12
% annotators using consensus as the ground truth
clear;clc;
load('hemisphere_dataset_summary.mat')
%% define the annotator indices
annotator_lst = [1, 2, 3, 4; % dataset 1
                 1, 3, 4, 5; % dataset 2
                 1, 2, 3, 6];% dataset 3
%% define the hyperparameters
ratio = 0.5;

config.DO_ZSCORING = true;
config.n = 1;
config.repeat = 10;

method_rand = struct("name", "random");
method_cal  = struct("name", "cal");
method_dal  = struct("name", "dal");
methods     = {method_rand, method_cal, method_dal};
% initializing dcal method
weights_lst = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9];
for i=1:length(weights_lst)
    weight = weights_lst(i);
    method_dcal = struct("name", "dcal", "weight", weight);
    methods{end+1} = method_dcal;
end

% initializing multi-arm bandit exp3 method
 [reward_func_lst, reward_name_lst] = get_reward_funcs([3, 7, 8, 10]); % return all predefined reward function0
 gamma_lst = [0.3, 0.5];
 for i=1:length(reward_func_lst)
     for g=1:length(gamma_lst)
         gamma       = gamma_lst(g);
         reward_func = reward_func_lst{i};
         reward_name = reward_name_lst(i);
         method_mab  = struct("name", "mab-exp3", "gamma", gamma, ...
                       "reward_func", reward_func, "reward_name", reward_name);
         methods{end+1} = method_mab;
     end
 end

% initializing multi-arm ubc method
alpha_lst = [sqrt(2)];
for i=1:length(alpha_lst)
    alpha      = alpha_lst(i);
    method_ucb = struct("name", "mab-ucb", "alpha", alpha);
    methods{end+1} = method_ucb;
end

%clear g i gamma method_cal method_cal method_dcal method_mab method_ucb reward_func reward_name weight;

num_methods = length(methods);
%% ActSort 
eval_lst    = {};
for d = 1:3
    for ann = 1:4
        ann_idx = annotator_lst(d, ann);
        choices    = choices_all{ann_idx,d};
        choices_gt = choices_gt_all{ann_idx,d};
        metrics    = metrics_all{d};
        parfor k=1:num_methods
            method = methods{k};
            dspname = get_legend_name(method);
            fprintf("%s: Working on dataset %i, annotator %i using method %s\n",datetime("now"), d, ann, dspname);
            [valid, eval_metrics, dataset] = play_active_learning_new(metrics,choices,choices_gt,ratio,method,config);
            eval_lst{d, ann, k} = eval_metrics;
        end 
    end
end
save("sort50_allmethods.mat", "-v7.3")


% %%
% best_mab = [];
% name_lst = {};
% % figure;
% for k=[10:44]
%     method = methods{k};
%     dspname = get_legend_name(method);
% %     plot(avg_tnr(k,:), 'DisplayName',dspname)
%     best_mab(end+1) = avg_auc(k,end);
%     name_lst{end+1} = dspname;
% %     hold on
% end
% % legend()
% [sortedValues, sortedIndices] = sort(best_mab, 'descend');
% best_indices = sortedIndices(1:5);
% best_name = name_lst(best_indices);