from testing import compare_txt_files, compare_txt_g
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tabulate import tabulate 
import time
from multiprocessing import Pool
import itertools
import re
import glob
import seaborn
import os
import numpy as np
import json
from matplotlib import pyplot as plt


def score_experiment_g(
    checkpoints, experiment, json_file_name="score_dics.json",
):
    """
    Calculates bleu, substr, and exact for an experiment and checkpoint and 
    creates a score_dic.json file
    """
    bleus, exacts, substrs = [], [], []
    score_dics = {}
    json_path = f"experiments/{experiment}/{json_file_name}"

    if not os.path.exists(json_path):
        for checkpoint in checkpoints:
            print(f"Doing checkpoint {checkpoint} for {experiment}")
            score_dic = compare_txt_g(
                f"experiments/{experiment}/targets.txt",
                f"experiments/{experiment}/predictions.txt-{checkpoint}",
            )
            score_dics[str(checkpoint)] = score_dic
        with open(json_path, "w") as outfile:
            json.dump(score_dics, outfile)
    else:
        score_dics = json.load(open(json_path, "r"))
    return score_dics


def score_experiment_c(
    checkpoints, experiment, json_file_name="score_dics.json", n=100,
):
    """
    Calls compare_txt_files for targets.txt and predictions corresponding to 
    each ckpt in checkpoints. Inserts the bleu scores, exact scores, and substr
    scores into a dictionary to be saved as a json in the experiment directory.

    Arguments:
    
    checkpoints {list} -- The checkpoints one is interested in scores for. 
    experiment {str} -- The experiment one wants to score.
    json_file_name {str} -- The name of the file to save the score to. 

    Returns:

    The score_dics, whose keys are the checkpoints in string form.

    """
    bleus, exacts, substrs = [], [], []
    score_dics = {}
    json_path = f"experiments/{experiment}/{json_file_name}"

    if checkpoints == "glob":
        filename = f"experiments/{experiment}/predictions*"
        checkpointfiles = glob.glob(filename)
        checkpoints = [re.search("\d+$", cpf).group(0) for cpf in checkpointfiles]

    for ckpt in checkpoints:
        print(f"Doing checkpoint {ckpt}")
        score_dic = compare_txt_files(
            f"experiments/{experiment}/targets.txt",
            f"experiments/{experiment}/predictions.txt-{ckpt}",
            np.arange(2, 20),
            np.arange(2, 6),
            n=n,
        )
        bleus.append(score_dic["avg_bleu"])
        exacts.append(score_dic["avg_exact_score"])
        substrs.append(score_dic["avg_substr_match"])
        score_dics[str(ckpt)] = score_dic
    with open(json_path, "w") as outfile:
        json.dump(score_dics, outfile)
    return score_dics

    # if not os.path.exists(json_path):
    #    for ckpt in checkpoints:
    #        print(f'Doing checkpoint {ckpt}')
    #        score_dic = compare_txt_files(
    #            f"experiments/{experiment}/targets.txt",
    #            f"experiments/{experiment}/predictions.txt-{ckpt}",
    #            np.arange(2, 20),
    #            np.arange(2, 6),
    #            n=n
    #        )
    #        bleus.append(score_dic["avg_bleu_all"])
    #        exacts.append(score_dic["avg_exact_all"])
    #        substrs.append(score_dic["avg_substr_all"])
    #        score_dics[str(ckpt)] = score_dic
    #    with open(json_path, "w") as outfile:
    #        json.dump(score_dics, outfile)
    # else:
    #    score_dics=json.load(open(json_path,'r'))
    # return score_dics


def extract_series(checkpoints, score_dics):
    """
    Looks across the score_dics of a given experiment for the checkpoints 
    indicated and extracts the series of scores for each metric for each 
    of the checkpoints in order.

    Arguments:

    checkpoints {list} -- The checkpoints one is interested in scores for. These
                          are the keys of score_dics.
    score_dics {dic} -- The score_dics for an experiment.

    Returns:

    A dictionary whose keys indicate which metric's scores are housed in the
    value.
    """

    keys = "bleu exact substr".split()
    score_dic_keys = "avg_bleu_all avg_exact_all avg_substr_all".split()
    series_dic = {key: [] for key in keys}
    key_dic = {key: score_dic_key for key, score_dic_key in zip(keys, score_dic_keys)}

    for ckpt in checkpoints:
        for key in keys:
            series_dic[key].append(score_dics[str(ckpt)][key_dic[key]])
    return series_dic


def score_experiments(experiments, checkpoints, n):
    """
    Returns a dictionary whose values are sequences of metric scores for each
    experiment and whose keys are the names of the metric scores
    """
    print(f"Getting scores for {experiments}")
    keys = "bleu exact substr".split()
    # keys = 'avg_bleu_all avg_exact_all avg_substr_all'.split()
    scores_dic = {key: [] for key in keys}
    for experiment in experiments:
        score_dics = score_experiment_c(checkpoints, experiment, n=n)


def parallel_score_experiments(
    experiments, checkpoints, n,
):
    with Pool() as p:
        start = time.time()
        args = [
            ([experiment], [checkpoint], n,)
            for experiment, checkpoint in list(
                itertools.product(experiments, checkpoints)
            )
        ]
        print(args)
        p.starmap(generate_experiments, args)
        print(f"Took {time.time() - start} seconds")


def get_scores(experiments, checkpoints, n=100):
    """
    Returns a dictionary whose values are sequences of metric scores for each
    experiment and whose keys are the names of the metric scores
    """
    print(f"Getting scores for {experiments}")
    keys = "bleu exact substr".split()
    # keys = 'avg_bleu_all avg_exact_all avg_substr_all'.split()
    scores_dic = {key: [] for key in keys}
    for experiment in experiments:
        score_dics = score_experiment(checkpoints, experiment, n=n)
        series_dic = extract_series(checkpoints, score_dics)
        for key in keys:
            scores_dic[key].append(series_dic[key])
    return scores_dic


def plot_scores_dic(experiments, experiment_labels, scores_dic):
    keys = list(scores_dic.keys())
    for key in keys:
        plt.close()
        for ix, experiment_label in enumerate(experiment_labels):
            plt.plot(scores_dic[key][ix], label=experiment_label)
        plt.legend()
        plt.title(f"{key}")
        plt.xlabel("Checkpoints")
        plt.ylim(0, 1)
        plt.ylabel("Score")
        plt.savefig(f"experiments/{experiments[ix]}/accuracies_{key}.png")


def get_score_dic_from_experiment(experiment):
    with open(f"experiments/{experiment}/score_dics.json", "r") as f:
        score_dic = json.load(f)
    return score_dic


def get_checkpoint_from_score_dic(score_dic, checkpoint):
    return score_dic[checkpoint]


def get_score_from_score_dic(score_dic, score_type):
    """score_type should be 'avg_bleu_all, avg_substr_all, or avg_exact_all"""

    return score_dic[score_type]


def get_score_dics_from_experiments(experiments):
    score_dics = {}
    for experiment in experiments:
        score_dic = get_score_dic_from_experiment(experiment)
        score_dics[experiment] = score_dic
    return score_dics


def plot_bar_of_experiments():
    score_types = "bleus substrs exacts".split()
    score_dics = {score_type:{} for score_type in score_types}
    models = "20nouns 200nouns 2000nouns".split()
    datas = "nouns verbs random".split()
    checkpoint = '1001000'

    
    for data in datas:
        bleus, substrs, exacts = [], [], []
        for model in models:
            experiment = f"c/experiment-{model}-{data}"
            score_dic = get_score_dic_from_experiment(experiment)
            cpscore_dic = score_dic[checkpoint]
            bleus.append(cpscore_dic['avg_bleu_all'])
            substrs.append(cpscore_dic['avg_substr_all'])
            exacts.append(cpscore_dic['avg_exact_all'])
        score_dics['bleus'][data] = bleus
        score_dics['substrs'][data] = substrs
        score_dics['exacts'][data] = exacts

    for score_type in score_types:
        plt.clf()
        score_dic = score_dics[score_type]
        nouns = score_dic['nouns']
        verbs = score_dic['verbs']
        randoms = score_dic['random']

        x = np.arange(len(models))  # the label locations
        width = 0.15  # the width of the bars

        fig, ax = plt.subplots()
        rects1 = ax.bar(x - 1.1*width, nouns, width, label='Nouns')
        rects2 = ax.bar(x , verbs, width, label='Verbs')
        rects3 = ax.bar(x + 1.1*width, randoms, width, label='Random')

# Add some text for labels, title and custom x-axis tick labels, etc.
        ax.set_ylabel('%')
        ax.set_title(score_type)
        ax.set_xticks(x)
        ax.set_xticklabels(models)
        ax.legend()
        box = ax.get_position()
        ax.legend(loc='center left', bbox_to_anchor=(1,0.5))


        def autolabel(rects):
            """Attach a text label above each bar in *rects*, displaying its height."""
            for rect in rects:
                height = rect.get_height()
                ax.annotate('{}'.format(height),
                            xy=(rect.get_x() + rect.get_width() / 2, height),
                            xytext=(0, 3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom')


        #autolabel(rects1)
        #autolabel(rects2)
        #autolabel(rects3)

        fig.tight_layout()

        plt.savefig(f'results/final_results/fig2/{score_type}.png') 

    #plt.title(f"{model} Bleus")
    #plt.bar(datas, bleus)
    #plt.show()
    #plt.title(f"{model} Substring Matches")
    #plt.bar(datas, substrs)
    #plt.show()
    #plt.title(f"{model} Exact Matches")
    #plt.bar(datas, exacts)
    #plt.show()


def plot_gib(score_type):
    experiments = "c/experimentall2000concretecommon/experiment-gibberish-gibberish c/experimentall2000concretecommon/experiment-all-all-train".split()
    experiment_names="gibberish nouns".split()
    checkpoints = np.arange(1000000,1001001,100)
    checkpoints_list = [checkpoints,checkpoints]
    all_scores = []
    for experiment, name, checkpoints in zip(experiments, experiment_names, checkpoints_list):
        scores = get_scores_from_experiment(experiment, checkpoints, score_type)
        plt.plot(scores, label=name)
    plt.legend()
    plt.savefig('results/final_results/fig4/substr.png')
    plt.show()




def plot_comp(score_type):
    experiments = "cn5050comp cncomp nccomp n/ho".split()
    checkpoints = np.arange(1002000,1003001,100)
    checkpoints_list = [checkpoints, checkpoints, checkpoints, checkpoints - 2000]
    all_scores = []
    for experiment, checkpoints in zip(experiments, checkpoints_list):
        scores = get_scores_from_experiment(experiment, checkpoints, score_type)
        plt.plot(scores, label=experiment)
    plt.legend()
    plt.show()


def generate_grand_table(score_type):
    """score_type should be 'avg_bleu_all, avg_substr_all, or avg_exact_all"""
    checkpoint = "1001000"
    models = "all concrete common 2000".split()
    datas = "all-train concrete-train common-train 2000-train all-val concrete-val common-val verbs toverbs random gibberish".split()
    scores = np.zeros((len(models), len(datas)))
    for i, model in enumerate(models):
        for j, data in enumerate(datas):
            experiment = f"c/experimentall2000concretecommon/experiment-{model}-{data}"
            score_dic = get_score_dic_from_experiment(experiment)
            score = score_dic[checkpoint][score_type]
            scores[i, j] = score
    return scores

def save_grand_tables():
    results_dir = 'results/final_results/table2/'
    headers = "all-train concrete-train common-train 2000-train all-val concrete-val common-val verbs toverbs random gibberish".split()
    for score_type in 'avg_exact_all avg_substr_all avg_bleu_all'.split():
        table = generate_grand_table(score_type)
        with open(results_dir + f'table-{score_type}-scores-ckpt-1001000.txt', 'w') as f:
            f.write(tabulate(table, headers, tablefmt='latex'))





def accuracies_heatmaps(
    score_dic,
    checkpoint,
    n_objs_list,
    n_containers_list,
):
    score_types = 'avg_bleu avg_substr_match avg_exact_score'.split()
    score_names = 'Bleu Scores.% Substring Matches.% Exact Matches'.split('.')
    scores = \
        {score_type:np.ones((len(n_containers_list),len(n_objs_list))) * -1 \
        for score_type in score_types}
    score_dic = score_dic[checkpoint]
    for i, n_containers in enumerate(n_containers_list):
        for j, n_objs in enumerate(n_objs_list):
            ocscore_dic = score_dic[f'{n_objs}objs{n_containers}conts']
            for score_type in score_types:
                scores[score_type][i,j] = ocscore_dic[score_type]
    
    for score_type, score_name in zip(score_types, score_names):
        plt.clf()
        print(f'Score Type: {score_type}')
        print(scores[score_type])
        ax = sns.heatmap(
            scores[score_type][::-1,:],
            square=True,
            xticklabels=n_objs_list,
            yticklabels=n_containers_list[::-1],
            vmin=0.0,
            vmax=1.0,
        )

        # 2,9 2,3
        ax.add_patch(Rectangle((0, 2), 8, 2, fill=False, edgecolor="blue", lw=3))

        plt.xlabel("Number of Objects")
        plt.ylabel("Number of Containers")
        plt.title(score_name)
        plt.savefig(f'results/final_results/fig1/val/{score_type}.png')


def get_scores_from_experiment(experiment, checkpoints, score_type, ):
    score_dics = get_score_dic_from_experiment(experiment)
    scores = []
    print(score_dics)
    for ckpt in checkpoints:
        score_dic = score_dics[str(ckpt)]
        score = score_dic[score_type]
        scores.append(score)
    return scores






if __name__ == "__main__":
    #score_experiment_c(np.arange(1000000, 1001001, 100),'c/experimentall2000concretecommon_/experiment-all-all-train')
    #score_experiment_g(np.arange(1000000, 1001001, 100),'c/experimentall2000concretecommon_/experiment-gibberish-gibberish')
    plot_gib('avg_substr_match')
    #plot_bar_of_experiments()
    #checkpoints = np.arange(1002000,1003001,1000)
    #experiment = 'n/ho'
    #score_experiment_g(checkpoints, experiment, 'score_dics2000-3000.json')

    score_dic = get_score_dic_from_experiment('c/experimentall2000concretecommon/experiment-all-all-train')
    accuracies_heatmaps(score_dic, '1001000', np.arange(2,20), np.arange(2,6))

    score_dic = get_score_dic_from_experiment('c/experimentall2000concretecommon/experiment-all-all-val')
    accuracies_heatmaps(score_dic, '1001000', np.arange(2,20), np.arange(2,6))

    #models = "all concrete common 2000".split()
    #datas = "all-train concrete-train common-train 2000-train all-val concrete-val common-val".split()
    #experiments = [
    #    f"c/experimentall2000concretecommon_/experiment-{model}-{data}"
    #    for model, data in list(itertools.product(models, datas))
    #]
    #checkpoints = ["1001000"]
    #parallel_score_experiments(experiments, checkpoints, 100)
