import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import json

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 16

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# example command:
"""
python3 scripts/plotting/structured/plot_test_accuracy.py --file_path_label_dict '{
"logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_edges/background_black/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/seed_1/static_learnt_mask_1_query/train_log.txt": "static mask",
"logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_2/mask_model_modified_resnet/pad_size_8/num_image_locations_edges/background_black/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/seed_1/adaptive/train_log.txt": "adaptive"}' \
--plot_save_dir "plots/static_mask_vs_adaptive_test_acc_pad_size_8_num_image_locations_edges_bg_black" \
--plot_save_file_name "test_accuracy.png" \
"""

root_dir = os.environ['PROJECT_DIR']

argparser = ArgumentParser()
argparser.add_argument("--file_path_label_dict", type=str, help="dict with key as the train log file and value as the label for that line in the plot", required=True)
argparser.add_argument("--plot_save_dir", type=str, help="save directory of the plot", required=True)
argparser.add_argument("--plot_save_file_name", type=str, help="file name of the plot", required=True)
# argparser.add_argument("--plot_title", type=str, help="title of the plot", required=True)
args = argparser.parse_args()

if not os.path.exists(args.plot_save_dir):
    os.makedirs(args.plot_save_dir)

file_path_label_dict = json.loads(args.file_path_label_dict)

# read those files in a csv format
for file_path, label in file_path_label_dict.items():
    train_log_file = os.path.join(root_dir, file_path)
    train_log_csv = pd.read_csv(train_log_file, delimiter=" \t ")
    plt.plot(train_log_csv['epoch'], train_log_csv['test_acc'], label=label)

# set other configs of plot
plt.xticks(np.arange(0, len(train_log_csv['epoch'])+1, 10))
plt.xlabel("epochs")
plt.legend()
# plt.title(args.plot_title)
plt.savefig(os.path.join(args.plot_save_dir, args.plot_save_file_name))
plt.close()
