% this file is the runner of the active learning averaging over all 12
% annotators using consensus as the ground truth
clear;clc;

load('twophoton_motion.mat')

annotators = cell(1,4);
annotators{1} = cheetah_choices;
annotators{2} = dragon_choices;
annotators{3} = koala_choices;
annotators{4} = panda_choices;

evaluators = cell(1,4);
evaluators{1} = 2*((dragon_choices+koala_choices+panda_choices)>0)-1;
evaluators{2} = 2*((cheetah_choices+koala_choices+panda_choices)>0)-1;
evaluators{3} = 2*((cheetah_choices+dragon_choices+panda_choices)>0)-1;
evaluators{4} = 2*((cheetah_choices+dragon_choices+koala_choices)>0)-1;


%% define the hyperparameters
ratio = 0.5;

method_rand = struct("name", "random");
methods     = {method_rand};
method_cal = struct("name", "cal");
methods{end+1} = method_cal;
method_dal = struct("name", "dal");
methods{end+1} = method_dal;

weights_lst = linspace(1,9,9);
weights_lst = [0.3, 0.5, 0.7];
for i=1:3
    weight = weights_lst(i);
    method_dcal = struct("name", "dcal", "weight", weight);
    methods{end+1} = method_dcal;
end
num_methods = length(methods);
%% ActSort 
eval_lst    = cell(1,4,num_methods);
config.DO_ZSCORING = true;
config.n = 1;
config.repeat = 10;
config.balance = 1;
lam_all = [1];


config.lam = "auto";
for ll = 1:1
    config.lam = lam_all(ll);
    for ann = 1:4
        choices    = annotators{ann};
        choices_gt = evaluators{ann};
        for k=1:num_methods
            method = methods{k};
            dspname = get_legend_name(method);
            fprintf("%s: Working on lam %d. annotator %i using method %s", datetime("now"), ll, ann, dspname);
            [valid, eval_metrics, dataset] = play_active_learning_new(metrics,choices,choices_gt,ratio,method,config);
            eval_lst{ll,ann, k} = eval_metrics;
        end 
    end
end

save("sort50_oscar_balance1.mat", "-v7.3")

%% average over the 12 annotators
lam_pick = 1;
H = size(eval_lst{1,1,1}.ACC,2);
avg_acc = zeros(num_methods, H);
avg_tpr = zeros(num_methods, H);
avg_tnr = zeros(num_methods, H);
avg_precision = zeros(num_methods, H);
avg_recall    = zeros(num_methods, H);
avg_fscore    = zeros(num_methods, H);
avg_auc       = zeros(num_methods, H);

for k=1:num_methods
    for ann = 1:4
        avg_acc(k,:) = avg_acc(k,:) + eval_lst{lam_pick,ann, k}.ACC(1:H);
        avg_tpr(k,:) = avg_tpr(k,:) + eval_lst{lam_pick,ann, k}.TPR(1:H);
        avg_tnr(k,:) = avg_tnr(k,:) + eval_lst{lam_pick,ann, k}.TNR(1:H);
        avg_precision(k,:) = avg_precision(k,:) + eval_lst{lam_pick,ann, k}.Precision(1:H);
        avg_recall(k,:)    = avg_recall(k,:) + eval_lst{lam_pick,ann, k}.Recall(1:H);
        avg_fscore(k,:)    = avg_fscore(k,:) + eval_lst{lam_pick,ann, k}.Fscore(1:H);
        avg_auc(k,:)       = avg_auc(k,:) + eval_lst{lam_pick,ann, k}.AUC(1:H);
    end
    avg_acc(k,:) = avg_acc(k,:) ./ 4;
    avg_tpr(k,:) = avg_tpr(k,:) ./ 4;
    avg_tnr(k,:) = avg_tnr(k,:) ./ 4;
    avg_precision(k,:) = avg_precision(k,:) ./ 4;
    avg_recall(k,:)    = avg_recall(k,:) ./ 4;
    avg_fscore(k,:)    = avg_fscore(k,:) ./ 4;
    avg_auc(k,:)       = avg_auc(k,:) ./ 4;
end

figure;
for k=1:5
    method = methods{k};
    dspname = get_legend_name(method);
    plot(avg_auc(k,:), 'DisplayName',dspname)
    hold on
end
legend()


