import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import json

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 15

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# example command:
"""
python3 scripts/plotting/structured/plot_sigma.py --file_path_label_dict '{
"cifar10/num_queries_2_pad_size_8_shifting_True_wm_init_identity_bs_fixed_fqbf_0.25_fqwm_False_linf_pert_0.01353/cifar_resnet110/sigma_log.txt": "fq_25_sq_75_fqwm_False",
"cifar10/num_queries_2_pad_size_8_shifting_True_wm_init_identity_bs_fixed_fqbf_0.5_fqwm_False_linf_pert_0.01353/cifar_resnet110/sigma_log.txt": "fq_50_sq_50_fqwm_False"}' \
--plot_file_name "vanilla_vs_2_query_sigma_fqwm_False_pad_size_8_shifting_True.png" --plot_title "sigma (=1.125)"
"""

"""
python3 scripts/plotting/structured/plot_sigma.py --file_path_label_dict '{
"cifar10/num_queries_2_pad_size_8_shifting_True_wm_init_identity_bs_fixed_fqbf_0.25_fqwm_False_linf_pert_0.01353/cifar_resnet110/sigma_log.txt": "fq_25_sq_75_fqwm_False",
"cifar10/num_queries_2_pad_size_8_shifting_True_wm_init_identity_bs_fixed_fqbf_0.5_fqwm_False_linf_pert_0.01353/cifar_resnet110/sigma_log.txt": "fq_50_sq_50_fqwm_False"}' \
--plot_save_dir "plots/vanilla_vs_2_query_fqwm_False_pad_size_8_shifting_True" \
--plot_save_file_name "sigma.png" \
--plot_title "vanilla sigma=1.125 (48*48*3 images)"
"""

root_dir = "your path to enter"

argparser = ArgumentParser()
argparser.add_argument("--file_path_label_dict", type=str, help="dict with key as the train log file and value as the label for that line in the plot", required=True)
argparser.add_argument("--plot_save_dir", type=str, help="save directory of the plot", required=True)
argparser.add_argument("--plot_save_file_name", type=str, help="file name of the plot", required=True)
argparser.add_argument("--plot_title", type=str, help="title of the plot", required=True)
args = argparser.parse_args()

if not os.path.exists(args.plot_save_dir):
    os.makedirs(args.plot_save_dir)

file_path_label_dict = json.loads(args.file_path_label_dict)

# read those files in a csv format
for file_path, label in file_path_label_dict.items():
    sigma_log_file = os.path.join(root_dir, file_path)
    sigma_log_csv = pd.read_csv(sigma_log_file, delimiter=" \t ")
    plt.plot(range(len(sigma_log_csv)), sigma_log_csv['sigma_1'], label=label+"_sigma_q_1", linestyle="solid")
    plt.plot(range(len(sigma_log_csv)), sigma_log_csv['sigma_2'], label=label+"_sigma_q_2", linestyle="dashed")

# set other configs of plot
plt.xticks(np.arange(0, len(sigma_log_csv), 2000))
plt.xlabel("iterations")
plt.legend()
plt.title(args.plot_title)
plt.savefig(os.path.join(args.plot_save_dir, args.plot_save_file_name))
plt.close()
