import os
import numpy as np
import matplotlib
import argparse

# matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from typing import *
import pandas as pd
# import seaborn as sns
import math

# matplotlibrc params to set for better, bigger, clear plots
SMALLER_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 20

plt.rc('font', size=BIGGER_SIZE)   # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# sns.set()

# take in the directory path of numpy array files
parser = argparse.ArgumentParser(description='Plot some plots')
parser.add_argument('--outdir', type=str, help='dir path for file name')
parser.add_argument('--plot_file_name', type=str, help='plot file name')

class Accuracy(object):
    def at_radii(self, radii: np.ndarray):
        raise NotImplementedError()

class ApproximateAccuracy(Accuracy):
    def __init__(self, data_file_paths: str):
        self.data_file_path = data_file_path

    def at_radii(self, radii: np.ndarray, adv: str = "l2") -> np.ndarray:
        df = pd.read_csv(self.data_file_path, delimiter=" \t ")
        return np.array([self.at_radius(df, radius, adv) for radius in radii])

    def at_radius(self, df: pd.DataFrame, radius: float, adv: str = "l2"):
        if adv == "linf":
            return (df["correct"] & (df["radius"] * 255.0 >= radius)).mean()
        else:
            return (df["correct"] & (df["radius"] >= radius)).mean()

class HighProbAccuracy(Accuracy):
    def __init__(self, data_file_path: str, alpha: float, rho: float):
        self.data_file_path = data_file_path
        self.alpha = alpha
        self.rho = rho

    def at_radii(self, radii: np.ndarray) -> np.ndarray:
        df = pd.read_csv(self.data_file_path, delimiter="\t")
        return np.array([self.at_radius(df, radius) for radius in radii])

    def at_radius(self, df: pd.DataFrame, radius: float):
        mean = (df["correct"] & (df["radius"] >= radius)).mean()
        num_examples = len(df)
        return (mean - self.alpha - math.sqrt(self.alpha * (1 - self.alpha) * math.log(1 / self.rho) / num_examples)
                - math.log(1 / self.rho) / (3 * num_examples))

class Line(object):
    def __init__(self, 
                 quantity: Accuracy,
                 legend: str,
                 plot_fmt: str = "",
                 scale_x: float = 1,
                 color: str ="red",
                 style: str ="dotted"):
        
        self.quantity = quantity
        self.legend = legend
        self.plot_fmt = plot_fmt
        self.scale_x = scale_x
        self.color = color
        self.style = style

def plot_certified_accuracy(outdir: str,
                            plot_file_name: str,
                            title: str,
                            max_radius: float,
                            lines: List[Line],
                            radius_step: float = 0.001,
                            xlabel: str = "radius",
                            ylabel: str = "certified accuracy",
                            adv: str = "l2", figtext: str = "None") -> None:
     
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    radii = np.arange(0, max_radius + radius_step, radius_step)
    
    plt.figure()

    for line in lines:
        plt.plot(radii * line.scale_x, line.quantity.at_radii(radii, adv), line.plot_fmt)

    if adv == "linf":
        tick_frequency = 1
    else:
        tick_frequency = 0.5

    plt.ylim((0, 1))
    plt.xlim((0, max_radius))
    plt.tick_params()
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if figtext != "None":
        plt.figtext(0.05, 0.05, figtext)
    plt.xticks(np.arange(0, max_radius+0.01, tick_frequency))
    plt.legend([method.legend for method in lines], loc='upper right')
    plt.title(title)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, plot_file_name), dpi=300)
    plt.close()

def smallplot_certified_accuracy(outfile: str, title: str, max_radius: float,
                                 methods: List[Line], radius_step: float = 0.01, xticks=0.5) -> None:
    radii = np.arange(0, max_radius + radius_step, radius_step)
    plt.figure()
    for method in methods:
        plt.plot(radii, method.quantity.at_radii(radii), method.plot_fmt)

    plt.ylim((0, 1))
    plt.xlim((0, max_radius))
    plt.xlabel("Attack size", fontsize=22)
    plt.ylabel("certified accuracy", fontsize=22)
    plt.tick_params(labelsize=20)
    plt.gca().xaxis.set_major_locator(plt.MultipleLocator(xticks))
    plt.legend([method.legend for method in methods], loc='upper right', fontsize=20)
    plt.tight_layout()
    plt.savefig(outfile + ".pdf")
    plt.close()

def latex_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float,
                                   methods: List[Line]):
    radii = np.arange(radius_start, radius_stop + radius_step, radius_step)
    accuracies = np.zeros((len(methods), len(radii)))
    for i, method in enumerate(methods):
        accuracies[i, :] = method.quantity.at_radii(radii)

    f = open(outfile, 'w')

    for radius in radii:
        f.write("& $r = {:.3}$".format(radius))
    f.write("\\\\\n")

    f.write("\midrule\n")

    for i, method in enumerate(methods):
        f.write(method.legend)
        for j, radius in enumerate(radii):
            if i == accuracies[:, j].argmax():
                txt = r" & \textbf{" + "{:.2f}".format(accuracies[i, j]) + "}"
            else:
                txt = " & {:.2f}".format(accuracies[i, j])
            f.write(txt)
        f.write("\\\\\n")
    f.close()

def markdown_table_certified_accuracy(outfile: str, radius_start: float, radius_stop: float, radius_step: float,
                                      methods: List[Line]):
    radii = np.arange(radius_start, radius_stop + radius_step, radius_step)
    accuracies = np.zeros((len(methods), len(radii)))
    for i, method in enumerate(methods):
        accuracies[i, :] = method.quantity.at_radii(radii)

    f = open(outfile, 'w')
    f.write("|  | ")
    for radius in radii:
        f.write("r = {:.3} |".format(radius))
    f.write("\n")

    f.write("| --- | ")
    for i in range(len(radii)):
        f.write(" --- |")
    f.write("\n")

    for i, method in enumerate(methods):
        f.write("<b> {} </b>| ".format(method.legend))
        for j, radius in enumerate(radii):
            if i == accuracies[:, j].argmax():
                txt = "{:.2f}<b>*</b> |".format(accuracies[i, j])
            else:
                txt = "{:.2f} |".format(accuracies[i, j])
            f.write(txt)
        f.write("\n")
    f.close()

def at_radius(df: pd.DataFrame, radius: float, adv: str = "l2"):
    if adv == "linf":
        return (df["correct"] & (df["radius"] * 255.0 >= radius)).mean()
    else:
        return (df["correct"] & (df["radius"] >= radius)).mean()

def certified_acc_at_radii(data_file_path, radii, adv) -> np.ndarray:
    df = pd.read_csv(data_file_path, delimiter=" \t ")
    return np.array([at_radius(df, radius, adv) for radius in radii])
    
if __name__ == "__main__":

    ##----------------------------------------------------------------
    ## FIRST GENERATE MEAN CERTIFIED ACC (OVER N RANDOM SEEDS)
    ##----------------------------------------------------------------

    adv = "linf"
    
    if adv == "linf":
        x_label = r"$L_{\infty}$ radius"
        max_radius=16/255
    elif adv == "l2":
        x_label = r"$L_{2}$ radius"
        max_radius=3.0
    
    figtext = "None"

    radius_step=0.001
    radii = np.arange(0, max_radius + radius_step, radius_step)
    args = parser.parse_args()
    
    seed_arr = ["seed_1", "seed_2", "seed_3"]

    # vanilla
    vanilla_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_8/background_black/mask_init_None/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_False/linf_pert_0.00902/average_queries_False/mask_output_None/"
    ca_arr = []
    for seed in seed_arr:
        curr_path = os.path.join(vanilla_path, seed, "vanilla", "certification_log_50000.txt")
        curr_seed_ca_arr = certified_acc_at_radii(curr_path, radii, adv="l2")
        ca_arr.append(curr_seed_ca_arr)
    mean_vanilla_certified_acc = np.vstack(ca_arr).mean(axis=0)
    np.save(os.path.join(vanilla_path, "mean_certified_acc.npy"), mean_vanilla_certified_acc)
    print(mean_vanilla_certified_acc.mean())

    # static mask
    static_mask_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_1/mask_model_None/pad_size_8/num_image_locations_8/background_black/mask_init_identity/budget_split_None/first_query_budget_frac_1.0/first_query_with_mask_True/linf_pert_0.00902/average_queries_False/mask_output_None/"
    ca_arr = []
    for seed in seed_arr:
        curr_path = os.path.join(static_mask_path, seed, "static_learnt_mask_1_query", "certification_log_50000.txt")
        curr_seed_ca_arr = certified_acc_at_radii(curr_path, radii, adv="l2")
        ca_arr.append(curr_seed_ca_arr)
    mean_static_mask_certified_acc = np.vstack(ca_arr).mean(axis=0)
    np.save(os.path.join(static_mask_path, "mean_certified_acc.npy"), mean_static_mask_certified_acc)
    print(mean_static_mask_certified_acc.mean())

    # adaptive
    adaptive_path = "logs/mask_idea/runs/cifar10/base_classifier_cifar_resnet110/num_queries_2/mask_model_modified_resnet/pad_size_8/num_image_locations_8/background_black/mask_init_random/budget_split_learnt/first_query_budget_frac_0.5/first_query_with_mask_False/linf_pert_0.00902/average_queries_True/mask_output_sigmoid/"
    ca_arr = []
    for seed in seed_arr:
        curr_path = os.path.join(adaptive_path, seed, "fqbf_lr_0.0001_mom_0.9_wd_0_scheduler_off_t_scheduler_on", "certification_log_50000.txt")
        curr_seed_ca_arr = certified_acc_at_radii(curr_path, radii, adv="l2")
        ca_arr.append(curr_seed_ca_arr)
    mean_adaptive_certified_acc = np.vstack(ca_arr).mean(axis=0)
    np.save(os.path.join(adaptive_path, "mean_certified_acc.npy"), mean_adaptive_certified_acc)
    print(mean_adaptive_certified_acc.mean())

    # plot_certified_accuracy(outdir=args.outdir,
    #                         plot_file_name=args.plot_file_name,
    #                         title=r"",
    #                         max_radius=2.0,
    #                         lines = [Line(ApproximateAccuracy(args.line_files[0]), r"Vanilla", color="red", style="solid"),
    #                                  Line(ApproximateAccuracy(args.line_files[1]), r"Adaptive", color="green", style="solid")],
    #                         radius_step=0.001,
    #                         xlabel=r"$L_{2}$ radius",
    #                         ylabel="certified accuracy",
    #                         adv="l2")
    
    ##----------------------------------------------------------------
    ## THEN PLOT THE GENERATED MEAN CERTIFIED ACC
    ##----------------------------------------------------------------
    
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    
    plt.figure()
    
    plt.plot(radii, mean_vanilla_certified_acc, label="vanilla")
    plt.plot(radii, mean_static_mask_certified_acc, label="static mask")
    plt.plot(radii, mean_adaptive_certified_acc, label="adaptive")

    if adv == "linf":
        tick_frequency = 0.001
    else:
        tick_frequency = 0.5
                
    plt.ylim((0, 1))
    plt.xlim((0, max_radius))
    plt.tick_params()
    plt.xlabel(, labelpad=20)
    plt.ylabel("Certified Accuracy", labelpad=20)
    if figtext != "None":
        plt.figtext(0.05, 0.05, figtext)
    plt.xticks(np.arange(0, max_radius+0.01, tick_frequency))
    plt.legend(loc='upper right')
    # for legobj in leg.legendHandles:
    #     legobj.set_linewidth(2.0)
    # plt.title(title)
    plt.tight_layout()
    plt.savefig(os.path.join(args.outdir, args.plot_file_name), dpi=300)
    plt.close()