import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 20

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

common_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/"
plot_dir = "plots/paper/varying_num_image_locations"

if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

# seeds
seed_arr = ["seed_1", "seed_2", "seed_3"]
# linf_pert_arr = ["linf_pert_0.00902"]
x_axis_arr = ["2", "4", "random"]

# width of the bars
barWidth = 0.3
 
# average across 3 seeds
avg_test_acc_dict = { 
    "vanilla": {}, # vanilla
    "static_mask": {}, # static mask (1 query)
    "adaptive": {}  # adaptive (2 query)
}

# error bars for vanilla
vanilla_error_bars = [ [],  # lower
                       [] ] # upper

# error bars for static mask
static_mask_error_bars = [ [], # lower
                           [] ] # upper

# error bars for adaptive
adaptive_error_bars = [ [], # lower
                        [] ] # upper

# vanilla bars
for x_var in x_axis_arr:
    max_test_acc_arr = []
    for seed in seed_arr:
        vanilla_path = os.path.join(common_path, "num_queries_1/mask_model_None/pad_size_8/", "num_image_locations_"+x_var, "background_black/mask_init_None/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_False/linf_pert_0.00902/average_queries_False/mask_output_None/")
        train_log_path = os.path.join(vanilla_path, seed, "vanilla", "train_log.txt")
        max_test_acc = pd.read_csv(train_log_path, delimiter=" \t ")['test_acc'].max()
        max_test_acc_arr.append(round(max_test_acc,2))
    avg_test_acc = np.mean(max_test_acc)
    
    # append the avgs
    avg_test_acc_dict["vanilla"]["num_image_locations_"+x_var] = round(avg_test_acc, 2)

    # append the min and max for error bars
    vanilla_error_bars[0].append(np.abs(np.min(max_test_acc_arr) - avg_test_acc))
    vanilla_error_bars[1].append(np.abs(np.max(max_test_acc_arr) - avg_test_acc))
    
# static mask bars
for x_var in x_axis_arr:
    max_test_acc_arr = []
    for seed in seed_arr:
        static_mask_path = os.path.join(common_path, "num_queries_1/mask_model_None/pad_size_8/", "num_image_locations_"+x_var, "background_black/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/")
        train_log_path = os.path.join(static_mask_path, seed, "static_learnt_mask_1_query", "train_log.txt")
        max_test_acc = pd.read_csv(train_log_path, delimiter=" \t ")['test_acc'].max()
        max_test_acc_arr.append(round(max_test_acc,2))
    avg_test_acc = np.mean(max_test_acc)
    
    # append the avgs
    avg_test_acc_dict["static_mask"]["num_image_locations_"+x_var] = round(avg_test_acc, 2)

    # append the min and max for error bars
    static_mask_error_bars[0].append(np.abs(np.min(max_test_acc_arr) - avg_test_acc))
    static_mask_error_bars[1].append(np.abs(np.max(max_test_acc_arr) - avg_test_acc))

# adaptive bars
for x_var in x_axis_arr:
    max_test_acc_arr = []
    for seed in seed_arr:
        adaptive_path = os.path.join(common_path, "num_queries_2/mask_model_modified_resnet/pad_size_8/", "num_image_locations_"+x_var, "background_black/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/")
        train_log_path = os.path.join(adaptive_path, seed, "fqbf_lr_0.0001_mom_0.9_wd_0_scheduler_off_t_scheduler_on", "train_log.txt")
        max_test_acc = pd.read_csv(train_log_path, delimiter=" \t ")['test_acc'].max()
        max_test_acc_arr.append(round(max_test_acc,2))
    avg_test_acc = np.mean(max_test_acc)
    
    # append the avgs
    avg_test_acc_dict["adaptive"]["num_image_locations_"+x_var] = round(avg_test_acc, 2)

    # append the min and max for error bars
    adaptive_error_bars[0].append(np.abs(np.min(max_test_acc_arr) - avg_test_acc))
    adaptive_error_bars[1].append(np.abs(np.max(max_test_acc_arr) - avg_test_acc))

# The x position of bars
r1 = np.arange(len(avg_test_acc_dict["vanilla"]))
r2 = [x + barWidth for x in r1]
r3 = [x + (2 * barWidth) for x in r1]
 
# plot vanilla bars
plt.bar(r1, avg_test_acc_dict["vanilla"].values(), width = barWidth, color = 'cornflowerblue', edgecolor = 'black', yerr=vanilla_error_bars, capsize=7, label='vanilla')
 
# plot static mask bars
plt.bar(r2, avg_test_acc_dict["static_mask"].values(), width = barWidth, color = 'tomato', edgecolor = 'black', yerr=static_mask_error_bars, capsize=7, label='static mask')
 
# plot adaptive bars
plt.bar(r3, avg_test_acc_dict["adaptive"].values(), width = barWidth, color = 'limegreen', edgecolor = 'black', yerr=adaptive_error_bars, capsize=7, label='adaptive')

plt.xticks([r + barWidth for r in range(len(avg_test_acc_dict["vanilla"]))], x_axis_arr)
plt.xlabel("image positions", labelpad=25)
plt.ylabel('Test accuracy', labelpad=25)
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=5)
plt.legend(loc='best')
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "mean_test_accuracy_bar_plot.png"))
# np.save(os.path.join(plot_dir, "avg_test_acc_dict.npy"), avg_test_acc_dict)
with open(os.path.join(plot_dir, "avg_test_acc_dict.txt"), "w+") as f:
    f.write(str(avg_test_acc_dict))