% pls_scatter_all
% this script plots the pls scatter distribution of all 12 annotators
% according to their labels and groundtruth.
clc; clear; close all

%% configurations
exportFigureFlag = 0;
methodArray = {'random','cal','dal','dcal-0.3','dcal-0.5','dcal-0.7'};
iterArray = [1,3,5,10]; % percentage array, 1%, 5%, 10%
nDataset = 3;
nAnnotator = 4;
nMethod = length(methodArray);


%% load data
load("E:\research\extract\umapFiles\exlabels_12datasets.mat", "methods", "labels_saved_all")
load("D:\JHU\OneDrive - Johns Hopkins\EXTRACT\umap\plsPlot\data\hemisphere_dataset_summary.mat", "choices_all", "choices_gt_all")
load("D:\JHU\OneDrive - Johns Hopkins\EXTRACT\umap\hemisphere_dataset_summary.mat","metrics_all");
num_methods = length(methods);

annotator_lst = [1, 2, 3, 4; % dataset 1
                 1, 3, 4, 5; % dataset 2
                 1, 2, 3, 6];% dataset 3

%% plot figures
for datasetCur = 1:nDataset
    features = metrics_all{datasetCur};
    % pls_plot_gt(coordinates,groundtruth,exportFigureFlag,fileTag);
    [features] = standardizeFeatures(features);
    
    for annCur = 1:nAnnotator
        annotator_cur = annotator_lst(datasetCur,annCur);
        groundTruth = choices_all{annotator_cur,datasetCur};
        [~, ~, XS, ~, ~, ~] = plsregress(features', groundTruth', 3);
        coordinates = XS;
        fileTag = sprintf('data%d_ann%d',datasetCur,annCur);
        pls_plot_gt(coordinates,groundTruth,exportFigureFlag,fileTag);
        % XS contains the coordinates.
        for methodCur = 1:nMethod
            
            labels_all = labels_saved_all{datasetCur,annCur,methodCur};
            fileTag = sprintf('data%d_ann%d_meth%d',datasetCur,annCur,methodCur);
            pls_plot(coordinates,labels_all.labels_ex,exportFigureFlag,fileTag,iterArray);
            fprintf('Dataset %d, annotator %d, method %d. \n',datasetCur,annCur,methodCur);
            close all
        end
    end
end

%%

figure(Position=[30 30 1400 1000]); hold on;

% colors
tp =[64,176,166]/255;
tn =[219,16,72]/255;
fn =[255 190 106]/255;
tn = '#bbceac';
fn = '#e64b35';
choise = '#ff7a7a';
choise = 'k';
% for annCur = 1:
[XL, YL, XS, YS, Beta, PCTVAR] = plsregress(features, groundTruth, 3);

for perCur = 1:length(iterArray)
    subplot(1,length(iterArray),perCur)
    iterCur = round(iterArray(perCur)/100*nIter);
    % have to use plot() instead of scatter() because I found scatter()
    % has issure dealing with overlapping dots, it cannot correctly show
    % the actual distribution.
    % plot  not cells
    plot(coordinates(labels_ex(iterCur,:)==0,1),coordinates(labels_ex(iterCur,:)==0,2),'o',MarkerSize = 2,MarkerEdgeColor='#d4d6d9',MarkerFaceColor='#d4d6d9');hold on
    % plot  cells
    plot(coordinates(labels_ex(iterCur,:)==1,1),coordinates(labels_ex(iterCur,:)==1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=tp);
    % plot  unlabled
    plot(coordinates(labels_ex(iterCur,:)==-1,1),coordinates(labels_ex(iterCur,:)==-1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=fn);
    axis off
    box on
    % grid minor
    title(sprintf('%d%% sorted',iterArray(perCur)))
    % pause
    hold off
    axis equal
    view(2)
    set(gcf,color='w')
    grid off
    set(gca,"FontSize",14)
    ylim([-15 30])
    xlim([-10 30])
    % legend('Unlabeled cell','Cell', 'Not cell',Location='northwest',edgeColor='none',color='none')
end
% hold off
if exportFlag == 1
    exportgraphics(gcf,['plsPercentages_',datasetCur,'.png'],Resolution=600)
end


figure
scatter3(XS(groundTruth==1,1),XS(groundTruth==1,2),XS(groundTruth==1,3),8,'filled',MarkerEdgeColor='none',MarkerFaceColor='#0072BD',...
    MarkerFaceAlpha=0.7);hold on
scatter3(XS(groundTruth==-1,1),XS(groundTruth==-1,2),XS(groundTruth==-1,3),8,'filled',MarkerEdgeColor='none',MarkerFaceColor='#cc0000',...
    MarkerFaceAlpha=0.7);

axis off
% grid minor
title('Ground Truth')
hold off
axis equal
view(2)
set(gcf,color='w')
grid off
% pause(0.1)
set(gca,"FontSize",14)
% ylim([-15 30])
% xlim([-10 30])
legend('Cell','Not cell')
if exportFlag == 1
    exportgraphics(gcf,'pls_oscar.pdf', 'ContentType','vector')
end

%% help functions
% plot animation
function pls_plot(coordinates,labels_ex,exportFigureFlag,fileTag,iterArray)
figure(Position=[30 30 1400 800]); hold on;
nIter = size(labels_ex,2);
% colors
cellColor =[64,176,166]/255;
% tn =[219,16,72]/255;
% notCellColor =[255 190 106]/255;
% tn = '#bbceac';
notCellColor = '#e64b35';
% choise = '#ff7a7a';
% choise = 'k';

for perCur = 1:length(iterArray)
    subplot(1,length(iterArray),perCur)
    iterCur = floor(iterArray(perCur)/100*nIter);
    % have to use plot() instead of scatter() because I found scatter()
    % has issure dealing with overlapping dots, it cannot correctly show
    % the actual distribution.
    % plot  not cells
    plot(coordinates(labels_ex(iterCur,:)==0,1),coordinates(labels_ex(iterCur,:)==0,2),'o',MarkerSize = 2,MarkerEdgeColor='#d4d6d9',MarkerFaceColor='#d4d6d9');hold on
    % plot  cells
    plot(coordinates(labels_ex(iterCur,:)==1,1),coordinates(labels_ex(iterCur,:)==1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=cellColor);
    % plot  unlabled
    plot(coordinates(labels_ex(iterCur,:)==-1,1),coordinates(labels_ex(iterCur,:)==-1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=notCellColor);
    axis off
    box on
    % grid minor
    % title(sprintf('%d%% sorted',iterArray(perCur)))
    % pause
    hold off
    axis equal
    view(2)
    set(gcf,color='w')
    grid off
    set(gca,"FontSize",14)
    % ylim([-15 30])
    % xlim([-10 30])
    % legend('Unlabeled cell','Cell', 'Not cell',Location='northwest',edgeColor='none',color='none')
end
% hold off
if exportFigureFlag == 1
    exportgraphics(gcf,['plsPlots/plsPercentages_',fileTag,'.pdf'],'ContentType','vector',Resolution=600)
end
end

function pls_plot_gt(coordinates,groundTruth,exportFigureFlag,fileTag)
figure(Position=[30 30 800 700]); hold on;
% nIter = length(groundtruth);
% colors
cellColor ='#0072BD';
% tn =[219,16,72]/255;
% notCellColor =[255 190 106]/255;
% tn = '#bbceac';
notCellColor = '#cc0000';
% choise = '#ff7a7a';
% choise = 'k';

    % have to use plot() instead of scatter() because I found scatter()
    % has issure dealing with overlapping dots, it cannot correctly show
    % the actual distribution.
    % plot  unlabled
    % plot(coordinates(groundtruth(iterCur,:)==0,1),coordinates(groundtruth(iterCur,:)==0,2),'o',MarkerSize = 2,MarkerEdgeColor='#d4d6d9',MarkerFaceColor='#d4d6d9');hold on
    % plot  cells
    plot(coordinates(groundTruth==1,1),coordinates(groundTruth==1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=cellColor);
    % plot  not cells
    plot(coordinates(groundTruth==-1,1),coordinates(groundTruth==-1,2),'o',MarkerSize = 3,MarkerEdgeColor='none',MarkerFaceColor=notCellColor);
    axis off
    box on
    % grid minor
    % title(sprintf('%d%% sorted',iterArray(perCur)))
    % pause
    hold off
    axis equal
    view(2)
    set(gcf,color='w')
    grid off
    set(gca,"FontSize",14)
    % ylim([-15 30])
    % xlim([-10 30])
    % legend('Unlabeled cell','Cell', 'Not cell',Location='northwest',edgeColor='none',color='none')
% end
% hold off
if exportFigureFlag == 1
    exportgraphics(gcf,['plsPlots/groundtruth_',fileTag,'.pdf'],'ContentType','vector',Resolution=600)
end
end


function [features] = standardizeFeatures(features)
features = features';
infs = isinf(features);
features(infs) = 0;
features = zscore(features, 0, 1);
features = features';
end