import numpy as np
from src.evaluation.aux.load_results import *
from matplotlib.ticker import LogFormatter
import matplotlib.ticker as ticker
from matplotlib.ticker import ScalarFormatter

import matplotlib.pyplot as plt
import os
import argparse
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import config as config

time = "_Date-2022-05-25_Time-13-46"

'''this file plots figures in I.1.5 Robustly recover from complete malicious experts environment'''

parser = argparse.ArgumentParser(
    description='task5 for plotting figures in I.1.5 Robustly recover from complete malicious experts environment')
parser.add_argument('-t', default="time", type=str,
                    help='please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45')

args = parser.parse_args()
time = args.t

if time == "time":
    print("please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45")
    exit()

my_pal = config.COLOR
n_RS = config.n_RS
n_Oracle = config.n_CAMS_best_policy
n_QBC = config.n_qbc
n_IWAL = config.n_iwal
n_MP = config.n_mp
n_CQBC = config.n_contextual_qbc
n_CIWAL = config.n_contextual_iwal
n_CAMS = config.n_CAMS_identity
n_test = config.n_CAMS_test
n_EXP4 = config.n_EXP4


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def rename_method_list(methods, policy_name):
    arr = []
    for item in methods:
        arr.append(policy_name)
    return arr


def organize_plot(dataset_name, folder_name, budget, policy_name="policy_name"):
    # Preprocess and loading data from experiments
    path = os.getcwd() + "/resources/results/" + folder_name + "/"
    data = np.load(path + "data.npz")
    num_reals = data["num_reals"]
    print(dataset_name, ":", num_reals)
    methods = data["methods"]
    budget_raw = data["budgets"]
    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    box_budget_actual = eval["box_budget_actual"]

    max_method = eval['max_method']
    max_budget_actual = eval['max_budget_actual']
    max_method = rename_method_list(max_method, policy_name)

    budget_arr = []
    for item in budget_raw:
        budget_arr.append(item)

    max_bar_query = []
    for item in max_budget_actual:
        # print(item)
        min_bar = 0
        for j in budget_raw:
            if item >= j:
                min_bar = j
        max_bar_query.append(min_bar)

    eval_query_arr = []
    eval_query_arr_final = []
    for method_ in eval_query_arr:
        counter = 0
        array_temp = []
        for item in method_:
            counter = counter + item
            array_temp.append(counter)
        eval_query_arr_final.append(array_temp)

    box_cumulative_loss = eval["box_cumulative_loss"]
    box_method = eval["box_method"]
    box_ = np.repeat(policy_name, len(box_method))
    box_df_shading = {"budget": box_budget_actual, "budget_fixed": box_budget, "c_regret": box_cumulative_loss,
                      "method": box_}
    box_df_shading = pd.DataFrame(box_df_shading)

    reshape_budget = []
    reshape_budget_fixed = []

    budget_raw = np.asarray(budget_raw)

    for index, row in box_df_shading.iterrows():

        reshape_budget.append(row['budget'])
        budget_w_max = np.concatenate((budget_raw, [max_budget_actual[max_method.index(row['method'])]]))
        round_value = find_nearest(budget_w_max, row['budget'])

        if round_value == max_bar_query[max_method.index(row['method'])]:
            reshape_budget_fixed.append(max_budget_actual[max_method.index(row['method'])])
        else:
            reshape_budget_fixed.append(round_value)

    box_df_shading = {"budget": reshape_budget, "budget_fixed": reshape_budget_fixed, "c_regret": box_cumulative_loss,
                      "method": box_}
    box_df_shading = pd.DataFrame(box_df_shading)

    for item in methods:
        for budget_ in budget_raw:
            print(item)
            x = np.where((box_df_shading["method"] == item) & (box_df_shading["budget_fixed"] == budget_))
            y = box_df_shading.loc[x]["budget"].mean()

            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("budget_fixed")]] = y

    shade_df_2 = box_df_shading.filter(["budget_fixed", "method", "c_regret"], axis=1).drop_duplicates().reset_index(
        drop=True)

    print(shade_df_2)
    return shade_df_2


################################

# VERTEBRAL
dataset_name = "VERTEBRAL"
budget = 60
classifier_10 = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000010000_policy[10]"
policy_12 = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000010000_policy[12]"
policy_15 = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000010000_policy[15]"
CAMS = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000010000_policy[6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]"
EXP4 = "VERTEBRAL_contextual_streamsize80_numreals200" + time + "_which_methods00000010000_policy[6]"

shade_df = []
shade_df_10 = organize_plot(dataset_name, classifier_10, budget, "classifier_10")
shade_df_12 = organize_plot(dataset_name, policy_12, budget, "policy_12")
shade_df_15 = organize_plot(dataset_name, policy_15, budget, "policy_15")
shade_df_cams = organize_plot(dataset_name, CAMS, budget, "CAMS")
shade_df_exp4 = organize_plot(dataset_name, EXP4, budget, "EXP4")

# shade_df_10, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     classifier_10, budget, predictions, oracle,my_pal,"classifier_10")

# shade_df_12, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_12, budget, predictions, oracle,my_pal,"policy_12")

# shade_df_15, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_15, budget, predictions, oracle,my_pal,"policy_15")

# shade_df_cams, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     CAMS, budget, predictions, oracle,my_pal,"CAMS")

# shade_df_exp4, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     EXP4, budget, predictions, oracle,my_pal,"EXP4")

shade_df = shade_df_cams

# shade_df=shade_df.append(shade_df_cams_max)
shade_df = shade_df.append(shade_df_10)
shade_df = shade_df.append(shade_df_12)
shade_df = shade_df.append(shade_df_15)
shade_df = shade_df.append(shade_df_exp4)

plt.figure(figsize=(10, 10), dpi=300)

# sns.lineplot(x="budget_fixed", y="c_regret", hue="method",data=shade_df_cams,ci=63, linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", label="best classifier",
             data=shade_df[shade_df["method"] == "classifier_10"], ci=63, color="purple", linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_12"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_15"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df[shade_df["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_EXP4, data=shade_df[shade_df["method"] == n_EXP4],
             color=my_pal[n_EXP4], ci=63, linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=30)
# plt.ylabel("Cumulative loss", fontsize=30)
plt.ylabel("", fontsize=30)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.png", bbox_inches='tight',
            pad_inches=0.01)
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.pdf", bbox_inches='tight',
            pad_inches=0.01)

print("completed !", dataset_name)

###########################################################################################

# HIV
dataset_name = "HIV"
budget = 100

classifier_11 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[11]"
policy_15 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[15]"
policy_16 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[16]"
policy_17 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[17]"
policy_18 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[18]"
policy_19 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[19]"
policy_20 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[20]"
policy_21 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[21]"
policy_22 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[22]"

CAMS = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]"
EXP4 = "HIV_contextual_streamsize4000_numreals200" + time + "_which_methods00000010000_policy[15, 16, 17, 18, 19, 20, 21, 22]"

shade_df = []
shade_df_11 = organize_plot(dataset_name, classifier_11, budget, "classifier_11")
shade_df_15 = organize_plot(dataset_name, policy_15, budget, "policy_15")
shade_df_16 = organize_plot(dataset_name, policy_16, budget, "policy_16")
shade_df_17 = organize_plot(dataset_name, policy_17, budget, "policy_17")
shade_df_18 = organize_plot(dataset_name, policy_18, budget, "policy_18")
shade_df_19 = organize_plot(dataset_name, policy_19, budget, "policy_19")
shade_df_20 = organize_plot(dataset_name, policy_20, budget, "policy_20")
shade_df_21 = organize_plot(dataset_name, policy_21, budget, "policy_21")
shade_df_22 = organize_plot(dataset_name, policy_22, budget, "policy_22")
shade_df_cams = organize_plot(dataset_name, CAMS, budget, "CAMS")
shade_df_exp4 = organize_plot(dataset_name, EXP4, budget, "EXP4")

shade_df = shade_df_cams

# shade_df=shade_df.append(shade_df_cams_max)
shade_df = shade_df.append(shade_df_exp4)
shade_df = shade_df.append(shade_df_11)
# shade_df=shade_df.append(shade_df_13)
# shade_df=shade_df.append(shade_df_14)
shade_df = shade_df.append(shade_df_15)
shade_df = shade_df.append(shade_df_16)
shade_df = shade_df.append(shade_df_17)
shade_df = shade_df.append(shade_df_18)
shade_df = shade_df.append(shade_df_19)
shade_df = shade_df.append(shade_df_20)
shade_df = shade_df.append(shade_df_21)
shade_df = shade_df.append(shade_df_22)
# shade_df=shade_df.append(shade_df_23)

plt.figure(figsize=(10, 10), dpi=300)

#    sns.set(font_scale = 5)
# sns.lineplot(x="budget_fixed", y="c_regret", hue="method",data=shade_df_cams,ci=63, linewidth=1)
ax = sns.lineplot(x="budget_fixed", y="c_regret", label="best classifier",
                  data=shade_df[shade_df["method"] == "classifier_11"], ci=63, color="purple", linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_13"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_14"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_15"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_16"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_17"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_18"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_19"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_20"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_21"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_22"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_23"], ci=63, color="gray",
             linewidth=1)

# sns.lineplot(x="budget_fixed", y="c_regret", label = "CAMS_MAX", data=shade_df[shade_df["method"]=="CAMS_MAX"],color="k",ci=63, linewidth=2)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df[shade_df["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_EXP4, data=shade_df[shade_df["method"] == n_EXP4],
             color=my_pal[n_EXP4], ci=63, linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=30)
# plt.ylabel("Cumulative loss", fontsize=30)
plt.ylabel("", fontsize=30)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.png", bbox_inches='tight',
            pad_inches=0.01)
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.pdf", bbox_inches='tight',
            pad_inches=0.01)

print("completed !", dataset_name)

###########################################################################################

# DRIFT
dataset_name = "DRIFT"
classifier_17 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[17]"

policy_18 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[18]"
policy_19 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[19]"
policy_20 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[20]"

# ap="drift_contextual_streamsize3000_numreals100_Date-2022-01-14_Time-12-53-32_which_methods00000001000_policy[6]"
CAMS = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]"
# CAMS_max="HIV_contextual_streamsize4000_numreals100_Date-2022-02-02_Time-03-26-46_which_methods000000000100_policy[6]"
EXP4 = "drift_contextual_streamsize3000_numreals100" + time + "_which_methods00000010000_policy[18, 19, 20]"

shade_df = []
shade_df_17 = organize_plot(dataset_name, classifier_17, budget, "classifier_17")
shade_df_18 = organize_plot(dataset_name, policy_18, budget, "policy_18")
shade_df_19 = organize_plot(dataset_name, policy_19, budget, "policy_19")
shade_df_20 = organize_plot(dataset_name, policy_17, budget, "policy_20")
shade_df_cams = organize_plot(dataset_name, CAMS, budget, "CAMS")
shade_df_exp4 = organize_plot(dataset_name, EXP4, budget, "EXP4")

# shade_df_17, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     classifier_17, budget, predictions, oracle,my_pal,"classifier_17")

# shade_df_18, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_18, budget, predictions, oracle,my_pal,"policy_18")

# shade_df_19, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_19, budget, predictions, oracle,my_pal,"policy_19")

# shade_df_20, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_20, budget, predictions, oracle,my_pal,"policy_20")

# shade_df_cams, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     CAMS, budget, predictions, oracle,my_pal,"CAMS")

# shade_df_EXP4, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     EXP4, budget, predictions, oracle,my_pal,"EXP4")


shade_df = shade_df_cams

# shade_df=shade_df.append(shade_df_cams_max)
shade_df = shade_df.append(shade_df_17)
shade_df = shade_df.append(shade_df_18)
shade_df = shade_df.append(shade_df_19)
shade_df = shade_df.append(shade_df_20)
shade_df = shade_df.append(shade_df_exp4)

plt.figure(figsize=(10, 10), dpi=300)
ax = sns.lineplot(x="budget_fixed", y="c_regret", label="best classifier",
                  data=shade_df[shade_df["method"] == "classifier_17"], ci=63, color="purple", linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_18"], ci=63, color="gray",
             linewidth=10)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_19"], ci=63, color="gray",
             linewidth=10)
# sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"]=="policy_20"],ci=63, color="gray", linewidth=5)

# sns.lineplot(x="budget_fixed", y="c_regret", label = "CAMS_MAX", data=shade_df[shade_df["method"]=="CAMS_MAX"],color="k",ci=63, linewidth=2)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df[shade_df["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label=n_EXP4, data=shade_df[shade_df["method"] == n_EXP4],
             color=my_pal[n_EXP4], ci=63, linewidth=5)

ax.set_ylim(500, 1600)
# ax.set_yscale('log')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
#    plt.title(dataset_name, fontsize=25)
plt.xlabel("Query cost", fontsize=30)
# plt.ylabel("Cumulative loss", fontsize=30)
plt.ylabel("", fontsize=30)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.png", bbox_inches='tight',
            pad_inches=0.01)
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.pdf", bbox_inches='tight',
            pad_inches=0.01)

print("completed !", dataset_name)

###########################################################################################
# plots for CIFAR10

dataset_name = "CIFAR10"
budget = 100
classifier_20 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[20]"

policy_80 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[80]"
policy_81 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[81]"
policy_82 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[82]"
policy_83 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[83]"
policy_84 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[84]"
policy_85 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[85]"
policy_86 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[86]"
policy_87 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[87]"
policy_88 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[88]"
policy_89 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[89]"
policy_90 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[90]"

CAMS = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000001000_policy[6]"
EXP4 = "cifar_contextual_streamsize10000_numreals10" + time + "_which_methods00000010000_policy[80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90]"

shade_df = []
shade_df_20 = organize_plot(dataset_name, classifier_20, budget, "classifier_20")
shade_df_80 = organize_plot(dataset_name, policy_80, budget, "policy_80")
shade_df_81 = organize_plot(dataset_name, policy_81, budget, "policy_81")
shade_df_82 = organize_plot(dataset_name, policy_82, budget, "policy_82")
shade_df_83 = organize_plot(dataset_name, policy_83, budget, "policy_83")
shade_df_84 = organize_plot(dataset_name, policy_84, budget, "policy_84")
shade_df_85 = organize_plot(dataset_name, policy_85, budget, "policy_85")
shade_df_86 = organize_plot(dataset_name, policy_86, budget, "policy_86")
shade_df_87 = organize_plot(dataset_name, policy_87, budget, "policy_87")
shade_df_88 = organize_plot(dataset_name, policy_88, budget, "policy_88")
shade_df_89 = organize_plot(dataset_name, policy_89, budget, "policy_89")
shade_df_90 = organize_plot(dataset_name, policy_90, budget, "policy_90")

shade_df_cams = organize_plot(dataset_name, CAMS, budget, "CAMS")
shade_df_exp4 = organize_plot(dataset_name, EXP4, budget, "EXP4")

# shade_df_105, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     classifier_105, budget, predictions, oracle,my_pal,"classifier_105")
# shade_df_166, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_166, budget, predictions, oracle,my_pal,"policy_166")
# shade_df_167, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_167, budget, predictions, oracle,my_pal,"policy_167")
# shade_df_168, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_168, budget, predictions, oracle,my_pal,"policy_168")
# shade_df_169, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_169, budget, predictions, oracle,my_pal,"policy_169")
# shade_df_170, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_170, budget, predictions, oracle,my_pal,"policy_170")
# shade_df_171, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_171, budget, predictions, oracle,my_pal,"policy_171")
# shade_df_172, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     policy_172, budget, predictions, oracle,my_pal,"policy_172")
# shade_df_cams, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     CAMS, budget, predictions, oracle,my_pal,"CAMS")
# shade_df_EXP4, eval_query_arr_final, eval_num_queries_final, eval_cumulative_loss_final, regret_t_mean, cumulative_loss_t_mean, sampled_regret_t_mean, num_instances, methods, eval_regret, eval_cumulative_loss, eval_sampled_regret, eval_num_queries,shade_relative_loss,shade_query = organize_plot(
#     EXP4, budget, predictions, oracle,my_pal,"EXP4")


shade_df = shade_df_cams

# shade_df=shade_df.append(shade_df_cams_max)
shade_df = shade_df.append(shade_df_exp4)
shade_df = shade_df.append(shade_df_20)
shade_df = shade_df.append(shade_df_80)
shade_df = shade_df.append(shade_df_81)
shade_df = shade_df.append(shade_df_82)
shade_df = shade_df.append(shade_df_83)
shade_df = shade_df.append(shade_df_84)
shade_df = shade_df.append(shade_df_85)
shade_df = shade_df.append(shade_df_86)
shade_df = shade_df.append(shade_df_87)
shade_df = shade_df.append(shade_df_88)
shade_df = shade_df.append(shade_df_89)
shade_df = shade_df.append(shade_df_90)

plt.figure(figsize=(10, 10), dpi=300)

ax = sns.lineplot(x="budget_fixed", y="c_regret", label="best classifier",
                  data=shade_df[shade_df["method"] == "classifier_20"], ci=63, color="purple", linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label="top sub-optimal policies",
             data=shade_df[shade_df["method"] == "policy_80"], ci=63, color="gray", linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_81"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_82"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_83"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_84"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_85"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_86"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_87"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_88"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_89"], ci=63, color="gray",
             linewidth=1)
sns.lineplot(x="budget_fixed", y="c_regret", data=shade_df[shade_df["method"] == "policy_90"], ci=63, color="gray",
             linewidth=1)

sns.lineplot(x="budget_fixed", y="c_regret", label=n_CAMS, data=shade_df[shade_df["method"] == n_CAMS],
             color=my_pal[n_CAMS], ci=63, linewidth=3)
# sns.lineplot(x="budget_fixed", y="c_regret", label = n_EXP4, data=shade_df[shade_df["method"]==n_EXP4],color=my_pal[n_EXP4],ci=63, linewidth=3)
sns.lineplot(x="budget_fixed", y="c_regret", label="CAMS-conventional", data=shade_df[shade_df["method"] == n_EXP4],
             color=my_pal[n_EXP4], ci=63, linewidth=3)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel("Query cost", fontsize=30)
plt.ylabel("", fontsize=30)
plt.legend(loc=2)
plt.legend(fontsize=21, title=None)
plt.legend("")
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.png", bbox_inches='tight',
            pad_inches=0.01)
plt.savefig("./task5/" + dataset_name + "_task5_malicious_policy_shade_line_plot.pdf", bbox_inches='tight',
            pad_inches=0.01)

# save legend
fig = plt.figure(figsize=(10, 10), dpi=300)
handles, labels = ax.get_legend_handles_labels()

fig.legend(handles, labels, ncol=8, loc='center')
fig.savefig("./task5/" + 'legend_task5_malicious_policies.png', bbox_inches='tight', pad_inches=0)
fig.savefig("./task5/" + 'legend_task5_malicious_policies.pdf', bbox_inches='tight', pad_inches=0)

print("completed !", dataset_name)





