from typing import Dict, List, Tuple
from collections import defaultdict
import operator
import numpy as np
from glob import glob
from tape.analysis import get_config, get_best_metric


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import seaborn as sns
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('directory')
    parser.add_argument('model')
    parser.add_argument('key1')
    parser.add_argument('key2')
    args = parser.parse_args()

    directories = glob(f'results/{args.directory}/secondary_structure_{args.model}*')

    key1_values: List[int] = []
    key2_values: List[int] = []
    data_values: Dict[Tuple[int, int], float] = defaultdict(lambda: 0)

    for directory in directories:
        config = get_config(directory)

        key1 = config[args.model][args.key1]
        key2 = config[args.model][args.key2]
        accuracy = get_best_metric(directory)

        key1_values.append(key1)
        key2_values.append(key2)
        data_values[(key1, key2)] = max(data_values[(key1, key2)], accuracy)

    key1_array = np.sort(np.array(list(set(key1_values))))
    key2_array = np.sort(np.array(list(set(key2_values))))

    data = np.zeros([key1_array.size, key2_array.size])

    for point, val in data_values.items():
        p1 = np.where(key1_array == point[0])[0][0]
        p2 = np.where(key2_array == point[1])[0][0]
        data[p1, p2] = val

    point, acc = max(data_values.items(), key=operator.itemgetter(1))
    print('Max Accuracy:', acc, f'{args.key1}:', point[0], f'{args.key2}:', point[1])

    data = data.T
    sns.heatmap(
        data, xticklabels=key1_array, yticklabels=key2_array, mask=data == 0,
        vmin=data[data != 0].min(), vmax=data[data != 0].max(), cmap='OrRd',
        annot=True)
    plt.xlabel(args.key1)
    plt.ylabel(args.key2)
    plt.title('Accuracy on secondary structure prediction gridsearch')

    plt.show()
