import numpy as np
# from src.evaluation.evaluation_pipeline.evaluate_method import *
from src.evaluation.evaluation_pipeline.evaluate_realizations import *
from src.evaluation.aux.load_results import *

import matplotlib.pyplot as plt
import os
import argparse
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import config as config

time = "_Date-2022-05-25_Time-13-51"


'''this file plots figures for Appendix I.1.4 Comparing CAMS and Model Picker in a context-free environment'''

parser = argparse.ArgumentParser(
    description='task7: plotting figures for Appendix I.1.4 Comparing CAMS and Model Picker in a context-free environment')
parser.add_argument('-t', default="time", type=str,
                    help='please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45')

args = parser.parse_args()
time = args.t

if time == "time":
    print("please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45")
    exit()

# rename macros
my_pal = config.COLOR
n_RS = config.n_RS
n_Oracle = config.n_CAMS_best_policy
n_QBC = config.n_qbc
n_IWAL = config.n_iwal
n_MP = config.n_mp
n_CQBC = config.n_contextual_qbc
n_CIWAL = config.n_contextual_iwal
n_CAMS = config.n_CAMS_identity
n_test = config.n_CAMS_test


def rename_method_list(methods):
    arr = []
    for item in methods:
        if item == "rs":
            arr.append(n_RS)
        elif item == "qbc":
            arr.append(n_QBC)
        elif item == "iwal":
            arr.append(n_IWAL)
        elif item == "mp":
            arr.append(n_MP)
        elif item == "contextual_qbc":
            arr.append(n_CQBC)
        elif item == "contextual_iwal":
            arr.append(n_CIWAL)
        elif item == "CAMS_best_policy":
            arr.append(n_Oracle)
        elif item == "CAMS_identity":
            arr.append(n_CAMS)
        elif item == "CAMS_test":
            arr.append(n_test)
        else:
            print("error")
            print(item)
            exit()

    return arr


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def organize_plot(folder_name):
    # process and load data from experiment results
    path = os.getcwd() + "/resources/results/" + folder_name + "/"
    data = np.load(path + "data.npz")
    methods = data["methods"]
    budget_raw = data["budgets"]
    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    box_budget_actual = eval["box_budget_actual"]
    eval_cumulative_loss = eval["cumulative_loss"]
    eval_num_queries = eval["num_queries_under_budget"]
    query_regardles_budget = eval["query_regardles_budget"]
    eval_cumulative_loss_final = []
    eval_num_queries_final = []
    eval_cumulative_loss_tmp = numpy.transpose(eval_cumulative_loss)

    eval_query_arr = []

    for idx in range(len(eval_num_queries)):
        eval_query_arr.append(np.mean(query_regardles_budget[:, :, idx], axis=0))

        line = eval_num_queries[idx]
        loss_ = eval_cumulative_loss_tmp[idx]
        indicator1 = ((line >= budget_raw - 10) * 1)
        print(indicator1)
        loss_1 = np.nan_to_num(indicator1 * loss_)
        query_1 = np.nan_to_num(indicator1 * eval_num_queries[idx])
        print(query_1)

        idicator2 = ((line < budget_raw - 10) * 1)
        loss_2 = np.nan_to_num(idicator2 * np.sum(idicator2 * eval_cumulative_loss_tmp[idx]) / np.sum(idicator2))
        query_2 = np.nan_to_num(idicator2 * np.sum(idicator2 * eval_num_queries[idx]) / np.sum(idicator2))
        print(query_2)

        eval_num_queries_final.append(query_1 + query_2)
        eval_cumulative_loss_final.append(loss_1 + loss_2)

    eval_query_arr_final = []
    for method_ in eval_query_arr:
        counter = 0
        array_temp = []
        for item in method_:
            counter = counter + item
            array_temp.append(counter)
        eval_query_arr_final.append(array_temp)

    box_cumulative_loss = eval["box_cumulative_loss"]
    box_method = eval["box_method"]
    box_method = rename_method_list(box_method)

    box_df_shading = {"budget": box_budget_actual, "budget_fixed": box_budget, "c_regret": box_cumulative_loss,
                      "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)

    reshape_budget = []
    reshape_budget_fixed = []
    horizon_error_bar = []
    vertical_error_bar = []
    v_mean = []

    for index, row in box_df_shading.iterrows():
        print(row['budget'], row['budget_fixed'], row['c_regret'], row['method'])
        reshape_budget.append(row['budget'])
        horizon_error_bar.append(0)
        vertical_error_bar.append(0)
        v_mean.append(0)
        reshape_budget_fixed.append(find_nearest(budget_raw, row['budget']))

    box_df_shading = {"budget": reshape_budget, "budget_fixed": reshape_budget_fixed, "c_regret": box_cumulative_loss,
                      "method": box_method, "h_err_bar": horizon_error_bar, "v_err_bar": vertical_error_bar,
                      "v_mean": v_mean}
    box_df_shading = pd.DataFrame(box_df_shading)

    for item in methods:
        for budget_ in budget_raw:
            print(item)
            x = np.where((box_df_shading["method"] == item) & (box_df_shading["budget_fixed"] == budget_))
            y = box_df_shading.loc[x]["budget"].mean()
            y_min = box_df_shading.loc[x]["budget"].min()
            y_max = box_df_shading.loc[x]["budget"].max()

            h_bar = np.maximum(abs(y - y_min), abs(y - y_max))

            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("budget_fixed")]] = y
            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("h_err_bar")]] = h_bar

            y = box_df_shading.loc[x]["c_regret"].mean()
            std = box_df_shading.loc[x]["c_regret"].std()
            count = box_df_shading.loc[x]["c_regret"].count()
            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("v_err_bar")]] = 1.95 * std / np.sqrt(count)
            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("v_mean")]] = y

    shade_df_2 = box_df_shading.filter(["budget_fixed", "h_err_bar", "v_err_bar", "v_mean", "method", "c_regret"],
                                       axis=1).drop_duplicates().reset_index(drop=True)

    return shade_df_2


def organize_plot_classifier(folder_name, budgets_mp_ap):
    path = os.getcwd() + "/resources/results/" + folder_name + "/"
    data = np.load(path + "data.npz")
    budget_raw = data["budgets"]

    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    print(box_budget)

    eval_cumulative_loss = eval["cumulative_loss"]
    eval_cumulative_loss = numpy.transpose(eval_cumulative_loss)[0]
    box_ = np.repeat(n_Oracle, len(budget_raw))

    budgets_mp_ap = np.unique(budgets_mp_ap)
    mp_ap_len = len(budgets_mp_ap)

    shade_df_2 = {"budget": budgets_mp_ap, "c_regret": eval_cumulative_loss[0:mp_ap_len], "method": box_[0:mp_ap_len]}
    shade_df_2 = pd.DataFrame(shade_df_2)

    return shade_df_2


################################################
# VERTEBRAL

dataset_name = "VERTEBRAL"
budget = 30

shade_df_cams_mp = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods10000001000_policy[0]"
shade_df_0 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[0]"
shade_df_1 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[1]"
shade_df_2 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[2]"
shade_df_3 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[3]"
shade_df_4 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[4]"
shade_df_5 = "VERTEBRAL_contextual_streamsize80_numreals50" + time + "_which_methods00000010000_policy[5]"

shade_df_cams_mp = organize_plot(shade_df_cams_mp)
shade_df_1 = organize_plot_classifier(shade_df_1, shade_df_cams_mp["budget_fixed"])  # 20.5
shade_df_3 = organize_plot_classifier(shade_df_3, shade_df_cams_mp["budget_fixed"])  # 5.76
# shade_df_0 = organize_plot_classifier( shade_df_0,shade_df_cams_mp["budget_fixed"])#56.21
# shade_df_2 = organize_plot_classifier( shade_df_2,shade_df_cams_mp["budget_fixed"])#40.
shade_df_4 = organize_plot_classifier(shade_df_4, shade_df_cams_mp["budget_fixed"])  # 0.08
# shade_df_5 = organize_plot_classifier( shade_df_5,shade_df_cams_mp["budget_fixed"])#35.


fig, ax = plt.subplots(figsize=(10, 10), dpi=300)
line_ = sns.lineplot(x="budget_fixed", y="c_regret", label=n_MP,
                     data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_MP], color=my_pal[n_MP], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=4)
sns.lineplot(x="budget", y="c_regret", label="best classifier", data=shade_df_4[shade_df_4["method"] == n_Oracle],
             color=my_pal[n_Oracle], ci=63, linewidth=3)
sns.lineplot(x="budget", y="c_regret", data=shade_df_3[shade_df_3["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_0[shade_df_0["method"]==n_Oracle],color="grey",ci=63, linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_1[shade_df_1["method"]==n_Oracle],color="grey",ci=63,  linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_2[shade_df_2["method"]==n_Oracle],color="grey",ci=63,  linewidth=1)
sns.lineplot(x="budget", y="c_regret", data=shade_df_1[shade_df_1["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_5[shade_df_5["method"]==n_Oracle],color="grey",ci=63,  linewidth=1)

ax.set_ylim(10, 60)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=30)
# plt.ylabel("Cumulative loss", fontsize=30)
plt.ylabel("", fontsize=30)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)

print("completed ", dataset_name)

# save legend
fig = plt.figure(figsize=(10, 10), dpi=300)
handles, labels = line_.get_legend_handles_labels()
fig.legend(handles, labels, ncol=8, loc='center')
fig.savefig("./task7/" + 'legend_task7_no_policy.png', bbox_inches='tight', pad_inches=0)
fig.savefig("./task7/" + 'legend_task7_no_policy.pdf', bbox_inches='tight', pad_inches=0)

#############################################################################
# DRIFT

dataset_name = "DRIFT"
budget = 400

shade_df_cams_mp = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods10000001000_policy[1]"
shade_df_9 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[9]"
shade_df_3 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[3]"
shade_df_2 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[2]"
shade_df_8 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[8]"

shade_df_cams_mp = organize_plot(shade_df_cams_mp)
shade_df_9 = organize_plot_classifier(shade_df_9, shade_df_cams_mp["budget_fixed"])
shade_df_3 = organize_plot_classifier(shade_df_3, shade_df_cams_mp["budget_fixed"])
shade_df_2 = organize_plot_classifier(shade_df_2, shade_df_cams_mp["budget_fixed"])
shade_df_8 = organize_plot_classifier(shade_df_8, shade_df_cams_mp["budget_fixed"])

plt.figure(figsize=(10, 10), dpi=300)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_MP, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_MP],
             color=my_pal[n_MP], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=4)

sns.lineplot(x="budget", y="c_regret", label="best classifier", data=shade_df_9[shade_df_9["method"] == n_Oracle],
             color=my_pal[n_Oracle], ci=63, linewidth=3)
sns.lineplot(x="budget", y="c_regret", data=shade_df_3[shade_df_3["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
sns.lineplot(x="budget", y="c_regret", data=shade_df_2[shade_df_2["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_8[shade_df_8["method"]==n_Oracle],color="y",ci=63, linewidth=1)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
#    plt.title(dataset_name, fontsize=25)
plt.xlabel("Query cost", fontsize=22)
# plt.ylabel("Cumulative loss", fontsize=22)
plt.ylabel("", fontsize=22)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)

print("completed ", dataset_name)

#########################################
# HIV

dataset_name = "HIV"
budget = 400
shade_df_cams_mp = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods10000001000_policy[0]"

shade_df_0 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[0]"
shade_df_1 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[1]"
shade_df_2 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[2]"
shade_df_3 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[3]"

shade_df_cams_mp = organize_plot(shade_df_cams_mp)
shade_df_0 = organize_plot_classifier(shade_df_0, shade_df_cams_mp["budget_fixed"])
shade_df_1 = organize_plot_classifier(shade_df_1, shade_df_cams_mp["budget_fixed"])
shade_df_2 = organize_plot_classifier(shade_df_2, shade_df_cams_mp["budget_fixed"])
# shade_df_3 = organize_plot_classifier( shade_df_3, shade_df_cams_mp["budget_fixed"])

plt.figure(figsize=(10, 10), dpi=300)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_MP, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_MP],
             color=my_pal[n_MP], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=4)

sns.lineplot(x="budget", y="c_regret", label="best classifier", data=shade_df_0[shade_df_0["method"] == n_Oracle],
             color=my_pal[n_Oracle], ci=63, linewidth=3)
sns.lineplot(x="budget", y="c_regret", data=shade_df_1[shade_df_1["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_2[shade_df_2["method"]==n_Oracle],color="tab:brown",ci=63, linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_1[shade_df_1["method"]==n_Oracle],color="grey",ci=63, linewidth=1)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=22)
# plt.ylabel("Cumulative loss", fontsize=22)
plt.ylabel("", fontsize=22)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)

print("completed ", dataset_name)

#####################################################################
# CIFAR10

dataset_name = "CIFAR10"
budget = 100

shade_df_cams_mp = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods10000001000_policy[11]"

shade_df_20 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[20]"
shade_df_14 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[14]"
shade_df_17 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[17]"
shade_df_12 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[12]"
shade_df_4 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[4]"
shade_df_4 = "cifar_contextual_streamsize10000_numreals20" + time + "_which_methods00000010000_policy[4]"

shade_df_cams_mp = organize_plot(shade_df_cams_mp)
shade_df_20 = organize_plot_classifier(shade_df_20, shade_df_cams_mp["budget_fixed"])
shade_df_14 = organize_plot_classifier(shade_df_14, shade_df_cams_mp["budget_fixed"])
shade_df_17 = organize_plot_classifier(shade_df_17, shade_df_cams_mp["budget_fixed"])
shade_df_12 = organize_plot_classifier(shade_df_12, shade_df_cams_mp["budget_fixed"])
shade_df_4 = organize_plot_classifier(shade_df_4, shade_df_cams_mp["budget_fixed"])

plt.figure(figsize=(10, 10), dpi=300)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_MP, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_MP],
             color=my_pal[n_MP], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df_cams_mp[shade_df_cams_mp["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=4)

sns.lineplot(x="budget", y="c_regret", label="best classifier", data=shade_df_20[shade_df_20["method"] == n_Oracle],
             color=my_pal[n_Oracle], ci=63, linewidth=3)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_3[shade_df_3["method"]==n_Oracle],ci=63, linewidth=1)
sns.lineplot(x="budget", y="c_regret", label="top sub-optimal classifiers",
             data=shade_df_14[shade_df_14["method"] == n_Oracle], color="grey", ci=63, linewidth=1)
sns.lineplot(x="budget", y="c_regret", data=shade_df_17[shade_df_17["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
sns.lineplot(x="budget", y="c_regret", data=shade_df_12[shade_df_12["method"] == n_Oracle], color="grey", ci=63,
             linewidth=1)
# sns.lineplot(x="budget", y="c_regret", data=shade_df_4[shade_df_4["method"]==n_Oracle],ci=63, linewidth=1)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=22)
# plt.ylabel("Cumulative loss", fontsize=22)
plt.ylabel("", fontsize=22)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./task7/" + dataset_name + "_task7_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)

print("completed ", dataset_name)





