import numpy as np
from src.evaluation.evaluation_pipeline.evaluate_realizations import *
from src.evaluation.aux.load_results import *

import matplotlib.pyplot as plt
import os
import argparse
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import config as config

my_pal= config.COLOR
my_marker=config.MARKER

time="_Date-2022-05-18_Time-16-45"


'''this file plots the figure 2 in main paper'''

parser = argparse.ArgumentParser(description='task1 for plotting figure 2 in the main paper')
parser.add_argument('-t', default="time", type=str, help='please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45')

args = parser.parse_args()
time = args.t

if time == "time":
    print("please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45")
    exit()


#rename macros
n_RS=config.n_RS
n_Oracle=config.n_CAMS_best_policy
n_QBC = config.n_qbc
n_IWAL = config.n_iwal
n_MP = config.n_mp
n_CQBC = config.n_contextual_qbc
n_CIWAL = config.n_contextual_iwal
n_CAMS = config.n_CAMS_identity
n_test = config.n_CAMS_test

def rename_method_list(methods):
    arr=[]
    for item in methods:
        if item == "rs":
            arr.append(n_RS)
        elif item == "qbc":
            arr.append(n_QBC)
        elif item == "iwal":
            arr.append(n_IWAL)
        elif item == "mp":
            arr.append(n_MP)
        elif item == "contextual_qbc":
            arr.append(n_CQBC)
        elif item == "contextual_iwal":
            arr.append(n_CIWAL)
        elif item == "CAMS_best_policy":
            arr.append(n_Oracle)
        elif item == "CAMS_identity":
            arr.append(n_CAMS)
        elif item == "CAMS_test":
            arr.append(n_test)
        else:
            print("error")
            print(item)
            exit()
    
    return arr


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]



def organize_plot(dataset_name, budget, folder_name ,my_pal=my_pal):

    # Preprocess and load data from experiments
    path_ = os.getcwd() + "/resources/contextual_data/" + dataset_name

    predictions_arr = np.loadtxt(str(path_) + "/predictions.out")
    oracle_arr = np.loadtxt(str(path_) + "/oracle.out")
    oracle_arr = np.asarray(oracle_arr)

    path = os.getcwd() + "/resources/results/" + folder_name + "/"

    file_list = os.listdir(path)
    print(dataset_name,":", file_list)

    # data output
    data = np.load(path + "data.npz")
    num_reals = data["num_reals"]
    print(dataset_name,":", num_reals)
    num_instances = data["num_instances"]
    num_models = data["num_models"]
    methods = data["methods"]
    budget_raw = data["budgets"]
    experiment_result = np.load(path + "experiment_results_budget" + str(budget) + ".npz")

    idx_log = experiment_result['idx_log']  # labelled_instances: if algo decide to query
    idx_budget_log = experiment_result['idx_budget_log']  # U_t_budget: query under budget
    ct_log = experiment_result['ct_log']  # ct_log: how many instance: all 1
    streaming_instances_log = experiment_result['streaming_instances_log']
    hidden_loss_log = experiment_result['hidden_loss_log']  # loss each query
    posterior_log = experiment_result['posterior_log']
    posterior_log_ap = experiment_result["posterior_log_ap"]
    posterior_log_ap_identity = experiment_result["posterior_log_ap_identity"]
    posterior_log_ap_test = experiment_result["posterior_log_ap_test"]
    posterior_log_contextual_qbc = experiment_result["posterior_log_contextual_qbc"]
    posterior_log_contextual_iwal = experiment_result["posterior_log_contextual_iwal"]
    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    box_budget_actual = eval["box_budget_actual"]

    eval_cumulative_loss = eval["cumulative_loss"]
    query_regardles_budget_detail = eval["query_regardles_budget_detail"]

    max_method=eval['max_method']
    max_budget_actual=eval['max_budget_actual']
    max_cumulative_loss=eval['max_cumulative_loss']
    max_method = rename_method_list(max_method)
    print(max_budget_actual)

    max_bar_query=[]
    for item in max_budget_actual:
        min_bar=0
        for j in budget_raw:
            if  item >= j:
                min_bar=j
        max_bar_query.append(min_bar)

    print(dataset_name,":", budget_raw)

    
    box_cumulative_loss = eval["box_cumulative_loss"]
    box_method = eval["box_method"]
    box_method = rename_method_list(box_method)
    print(dataset_name,":", box_method)

    box_df_shading = {"budget": box_budget_actual,"budget_fixed": box_budget, "c_regret": box_cumulative_loss, "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)

    reshape_budget=[]
    reshape_budget_fixed=[]

    for index, row in box_df_shading.iterrows():
        print(row['budget'],row['budget_fixed'], row['c_regret'], row['method'])
        reshape_budget.append(row['budget'])
        budget_w_max=np.concatenate((budget_raw,[max_budget_actual[max_method.index(row['method'])]]))
        round_value=find_nearest(budget_w_max ,row['budget'])

        if round_value == max_bar_query[max_method.index(row['method'])]:
            reshape_budget_fixed.append(max_budget_actual[max_method.index(row['method'])])
        else:
            reshape_budget_fixed.append(round_value)


    box_df_shading = {"budget": reshape_budget, "budget_fixed": reshape_budget_fixed, "c_regret": box_cumulative_loss, "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)


    for item in methods:
        for budget_ in budget_raw:
            print(item)
            x = np.where((box_df_shading["method"]==item) & (box_df_shading["budget_fixed"]== budget_))
            y = box_df_shading.loc[x]["budget"].mean()
            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("budget_fixed")]]=y

    shade_df_2=box_df_shading.filter(["budget_fixed","method","c_regret"],axis=1).drop_duplicates().reset_index(drop=True)
    cost_effective_table=shade_df_2.groupby(['budget_fixed','method'])["c_regret"].mean().reset_index().round(0)
    cost_effective_table.to_csv("./task1/task1_" +dataset_name+"_cost_effective.csv")
    plt.figure(figsize=(10, 10), dpi=300)
    #    sns.set(font_scale = 5)
    line_ = sns.lineplot(x="budget_fixed", y="c_regret", label = n_RS, data=shade_df_2[shade_df_2["method"]==n_RS],color=my_pal[n_RS],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_Oracle, data=shade_df_2[shade_df_2["method"]==n_Oracle],color=my_pal[n_Oracle], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_QBC, data=shade_df_2[shade_df_2["method"]==n_QBC],color=my_pal[n_QBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_IWAL, data=shade_df_2[shade_df_2["method"]==n_IWAL],color=my_pal[n_IWAL], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_MP, data=shade_df_2[shade_df_2["method"]==n_MP],color=my_pal[n_MP],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CQBC, data=shade_df_2[shade_df_2["method"]==n_CQBC],color=my_pal[n_CQBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CIWAL, data=shade_df_2[shade_df_2["method"]==n_CIWAL],color=my_pal[n_CIWAL],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CAMS, data=shade_df_2[shade_df_2["method"]==n_CAMS],color=my_pal[n_CAMS],  ci=63, linewidth=4)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_test, data=shade_df_2[shade_df_2["method"]==n_test],color=my_pal[n_test],  ci=63, linewidth=4)

    #generate plot for figure 2 (buttom)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlabel("Query cost", fontsize=30)
    plt.ylabel("", fontsize=30)
    plt.legend(loc=2)
    plt.legend(fontsize=18,title=None)
    plt.legend('')
    plt.savefig("./task1/"+dataset_name + "_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
    plt.savefig("./task1/"+dataset_name + "_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)


    #save legend
    fig = plt.figure(figsize=(10, 10), dpi=300)
    handles,labels= line_.get_legend_handles_labels()

    fig.legend(handles,labels,ncol=8, loc='center')
    fig.savefig("./task1/" +'legend.png', bbox_inches='tight', pad_inches=0)
    fig.savefig("./task1/" +'legend.pdf', bbox_inches='tight', pad_inches=0)

    regret_t = np.zeros((num_reals, len(methods), num_instances))
    sampled_regret_t = np.zeros((num_reals, len(methods), num_instances))
    cumulative_loss_t = np.zeros((num_reals, len(methods), num_instances))

    relative_shade_methods=[]
    relative_shade_regret=[]
    relative_shade_round=[]
    relative_shade_instance=[]


    for real_idx in range(num_reals):

        streaming_first_realizaiton = streaming_instances_log[:, real_idx]

        predictions = predictions_arr[streaming_first_realizaiton, :]
        oracle = oracle_arr[streaming_first_realizaiton]

        true_precisions = compute_precisions(predictions, oracle, num_models)
        true_winner = np.where(np.equal(true_precisions, np.max(true_precisions)))[0]
        winner_randint = np.random.randint(len(true_winner))
        true_winner_random = true_winner[winner_randint]

        winner_randint = np.random.randint(len(true_winner))
        true_winner_random = true_winner[winner_randint]

        for num in range(len(methods)):
            zt_real = idx_budget_log[:, real_idx, num]  # num method in first realization
            posterior_real = posterior_log[:, :, real_idx]
            posterior_real_ap = posterior_log_ap[:, :, real_idx]
            posterior_real_ap_identity = posterior_log_ap_identity[:, :, real_idx]
            posterior_real_ap_test = posterior_log_ap_test[:, :, real_idx]
            posterior_real_contextual_qbc = posterior_log_contextual_qbc[:, :, real_idx]
            posterior_real_contextual_iwal = posterior_log_contextual_iwal[:, :, real_idx]

            labelled_ins = np.ravel(np.asarray(zt_real.nonzero()))  # the indices whose labels are queried
            num_labelled = np.size(labelled_ins)  # number of queries for this realization ~budget in interest
            if num_labelled == 0:
                labelled_ins = 0
                num_labelled = 1

            cumulative_regrets = []
            sampled_regret_real = 0
            regret_real = 0
            cumulative_loss_real = 0

            method = methods[num]
            print(dataset_name,"method", method)
            for t in np.arange(num_instances):

                if method == "CAMS_best_policy":
                    posterior_t = posterior_real_ap[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                elif method == "CAMS_identity":
                    posterior_t = posterior_real_ap_identity[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                elif method == "CAMS_test":
                    posterior_t = posterior_real_ap_test[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                elif method == "contextual_qbc":
                    posterior_t = posterior_real_contextual_qbc[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                elif method == "contextual_iwal":
                    posterior_t = posterior_real_contextual_iwal[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                elif method == 'mp':  # If MP, use its own posterior
                    # print(method)
                    posterior_t = posterior_real[t, :]
                    arg_winners_t = np.where(np.equal(posterior_t, np.max(posterior_t)))[0]

                else:  # else, check the weighted losses
                    posterior_t = np.ones(num_models) / num_models
                    if num_labelled == 1:
                        labelled_instances_t = 0
                    else:
                        idx_labelled_instances_transient = np.where(labelled_ins.reshape(num_labelled, 1) < t)[
                            0]  # find the location of labelled points that are smaller than t
                        labelled_instances_t = labelled_ins[
                            idx_labelled_instances_transient]  # find all labelled points so far

                    weighted_losses_t = compute_loss(predictions[labelled_instances_t, :], oracle[labelled_instances_t],
                                                     num_models)
                    if np.size(labelled_instances_t) > 1:
                        if np.sum(weighted_losses_t) == 0:  # if no true positive yet, set the posterior uniform
                            arg_winners_t = np.arange(num_models)
                        else:
                            arg_winners_t = \
                                np.where(np.equal(weighted_losses_t.reshape(num_models, 1), np.min(weighted_losses_t)))[
                                    0]
                    else:
                        arg_winners_t = np.arange(num_models)

                # If multi winners, choose randomly
                len_winners = np.size(arg_winners_t)

                if len_winners > 1:
                    idx_winner_t = np.random.choice(len_winners, 1)
                    winner_t = arg_winners_t[idx_winner_t]
                else:
                    winner_t = arg_winners_t

                # Accumulate the error of returned model
                loss_winner = int((predictions[t, int(winner_t)] != oracle[t]) * 1)
                # Accumulate the error of true winner
                loss_true = int((predictions[t, int(true_winner_random)] != oracle[t]) * 1)

                # Sampled regret time
                m_star = np.random.choice(list(range(num_models)), p=posterior_t)
                # Incur hidden loss
                loss_sampled = (predictions[t, m_star] != oracle[t]) * 1

                orac_rep = np.repeat(int(oracle[t]), len(predictions[t, :]))
                val = (predictions[t, :] != orac_rep) * 1

                cumulative_loss_real += (loss_winner - np.min(val))
                regret_real += (loss_winner - loss_true)
                sampled_regret_real += (loss_sampled - loss_true)
                # print(regret_real)
                regret_t[real_idx, num, t] = regret_real
                sampled_regret_t[real_idx, num, t] = sampled_regret_real
                cumulative_loss_t[real_idx, num, t] = cumulative_loss_real

                relative_shade_methods.append(methods[num])
                relative_shade_regret.append(regret_real)
                relative_shade_instance.append(real_idx)
                relative_shade_round.append(t)


    #evaluate query probability
    query_shade_methods=[]
    query_shade_counts=[]
    query_shade_round=[]
    query_shade_instance=[]

    budget_idx=np.where(budget_raw == budget)[0][0]

    for num in range(len(methods)):
        for real_idx in range(num_reals):
            cnt=0
            for t in np.arange(num_instances):
                cnt = cnt+ query_regardles_budget_detail[budget_idx,t,real_idx,num]
                query_shade_methods.append(methods[num])
                query_shade_round.append(t)
                query_shade_instance.append(real_idx)
                query_shade_counts.append(cnt)

    relative_shade_methods=rename_method_list(relative_shade_methods)

    shade_relative_loss = {"method": relative_shade_methods,"relative_loss": relative_shade_regret, "round": relative_shade_round,"simulation":relative_shade_instance}
    shade_relative_loss = pd.DataFrame(shade_relative_loss)

    for item in methods:
        print(dataset_name,":", item)

    plt.figure(figsize=(10, 10), dpi=300)


    sns.lineplot(x="round", y="relative_loss", label = n_RS, data=shade_relative_loss[shade_relative_loss["method"]==n_RS],color=my_pal[n_RS],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_Oracle, data=shade_relative_loss[shade_relative_loss["method"]==n_Oracle],color=my_pal[n_Oracle],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_QBC, data=shade_relative_loss[shade_relative_loss["method"]==n_QBC],color=my_pal[n_QBC],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_IWAL, data=shade_relative_loss[shade_relative_loss["method"]==n_IWAL],color=my_pal[n_IWAL],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_MP, data=shade_relative_loss[shade_relative_loss["method"]==n_MP],color=my_pal[n_MP],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_CQBC, data=shade_relative_loss[shade_relative_loss["method"]==n_CQBC],color=my_pal[n_CQBC],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_CIWAL, data=shade_relative_loss[shade_relative_loss["method"]==n_CIWAL],color=my_pal[n_CIWAL],ci=63, linewidth=1)
    sns.lineplot(x="round", y="relative_loss", label = n_CAMS, data=shade_relative_loss[shade_relative_loss["method"]==n_CAMS],color=my_pal[n_CAMS],ci=63, linewidth=4)
    sns.lineplot(x="round", y="relative_loss", label = n_test, data=shade_relative_loss[shade_relative_loss["method"]==n_test],color=my_pal[n_test],ci=63, linewidth=4)

    #plot relative cumulative loss(figure 2(top))
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlabel("Round", fontsize=30)
    plt.ylabel("", fontsize=30)
    plt.legend(loc=2)
    plt.legend(fontsize=20,title=None)
    plt.legend('')
    plt.savefig("./task1/"+dataset_name + "_shade_relative_cumulative_losss_budget_" + str(budget) + ".png", bbox_inches='tight', pad_inches=0.01)
    plt.savefig("./task1/"+dataset_name + "_shade_relative_cumulative_losss_budget_" + str(budget) + ".pdf", bbox_inches='tight', pad_inches=0.01)



#generate plot fot query cost in figure 2(middle)
def query_cost_plot(dataset_name, budget, folder_name, my_pal=my_pal):

    path = os.getcwd() + "/resources/results/" + folder_name + "/"
    file_list = os.listdir(path)
    print(dataset_name, ":", file_list)

    # data output
    data = np.load(path + "data.npz")
    num_reals = data["num_reals"]
    print(dataset_name, ":", num_reals)
    num_instances = data["num_instances"]
    methods = data["methods"]

    eval = np.load(path + "eval_results.npz")
    query_regardles_budget_detail = eval["query_regardles_budget_detail"]
    box_method = eval["box_method"]

    budget_raw = data["budgets"]
    # evaluate query probability
    query_shade_methods = []
    query_shade_counts = []
    query_shade_round = []
    query_shade_instance = []

    budget_idx = np.where(budget_raw == budget)[0][0]

    for num in range(len(methods)):
        for real_idx in range(num_reals):
            cnt = 0
            for t in np.arange(num_instances):
                cnt = cnt + query_regardles_budget_detail[budget_idx, t, real_idx, num]
                query_shade_methods.append(methods[num])
                query_shade_round.append(t)
                query_shade_instance.append(real_idx)
                query_shade_counts.append(cnt)

    query_shade_methods = rename_method_list(query_shade_methods)


    shade_query = {"method": query_shade_methods, "counts": query_shade_counts, "round": query_shade_round,
                   "simulation": query_shade_instance}
    shade_query = pd.DataFrame(shade_query)

    plt.figure(figsize=(10, 10), dpi=300)

    sns.lineplot(x="round", y="counts", label=n_RS, data=shade_query[shade_query["method"] == n_RS], color=my_pal[n_RS],
                 ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_Oracle, data=shade_query[shade_query["method"] == n_Oracle],
                 color=my_pal[n_Oracle], ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_QBC, data=shade_query[shade_query["method"] == n_QBC],
                 color=my_pal[n_QBC], ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_IWAL, data=shade_query[shade_query["method"] == n_IWAL],
                 color=my_pal[n_IWAL], ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_MP, data=shade_query[shade_query["method"] == n_MP], color=my_pal[n_MP],
                 ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_CQBC, data=shade_query[shade_query["method"] == n_CQBC],
                 color=my_pal[n_CQBC], ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_CIWAL, data=shade_query[shade_query["method"] == n_CIWAL],
                 color=my_pal[n_CIWAL], ci=63, linewidth=1)
    sns.lineplot(x="round", y="counts", label=n_CAMS, data=shade_query[shade_query["method"] == n_CAMS],
                 color=my_pal[n_CAMS], ci=63, linewidth=4)
    sns.lineplot(x="round", y="counts", label=n_test, data=shade_query[shade_query["method"] == n_test],
                 color=my_pal[n_test], ci=63, linewidth=4)

    plt.xticks(fontsize=20)
    plt.yticks(fontsize=15)
    plt.xlabel("Round", fontsize=30)
    plt.ylabel("", fontsize=30)
    plt.legend(loc=2)
    plt.legend(fontsize=18, title=None)
    plt.legend('')
    plt.savefig("./task1/" + dataset_name + "_shade_query_budget_" + str(budget) + ".png", bbox_inches='tight',
                pad_inches=0.01)
    plt.savefig("./task1/" + dataset_name + "_shade_query_budget_" + str(budget) + ".pdf", bbox_inches='tight',
                pad_inches=0.01)

    print("budget:", budget)
    print("folder_name:", folder_name)






# ############
#generate relative loss and cost effective plot

dataset_name="DRIFT"
budget=400
folder_name="drift_contextual_streamsize3000_numreals100"+time+"_which_methods11011011011_policy[1]"
organize_plot( dataset_name, budget,folder_name)

dataset_name="VERTEBRAL"
budget=30
folder_name="VERTEBRAL_contextual_streamsize80_numreals300"+time+"_which_methods11011011011_policy[0]"
organize_plot( dataset_name, budget,folder_name)


dataset_name="HIV"
budget=400
folder_name="HIV_contextual_streamsize4000_numreals200"+time+"_which_methods11011011011_policy[0]"
organize_plot( dataset_name, budget,folder_name)

dataset_name="CIFAR10"
budget=200
folder_name="cifar_contextual_streamsize10000_numreals20"+time+"_which_methods11011011011_policy[11]"
organize_plot( dataset_name, budget,folder_name)


###########
#generate query cost plot

dataset_name="DRIFT"
budget=2000
folder_name="drift_contextual_streamsize3000_numreals100"+time+"_which_methods11011011011_policy[1]"
query_cost_plot( dataset_name, budget,folder_name)


dataset_name="HIV"
budget=2000
folder_name="HIV_contextual_streamsize4000_numreals200"+time+"_which_methods11011011011_policy[0]"
query_cost_plot( dataset_name, budget,folder_name)

dataset_name="CIFAR10"
budget=1200
folder_name="cifar_contextual_streamsize10000_numreals20"+time+"_which_methods11011011011_policy[11]"
query_cost_plot( dataset_name, budget,folder_name)

dataset_name="VERTEBRAL"
budget=80
folder_name="VERTEBRAL_contextual_streamsize80_numreals300"+time+"_which_methods11011011011_policy[0]"
query_cost_plot( dataset_name, budget,folder_name)
