import pickle as pkl
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tape.data_utils.vocabs import SS_8_DICT, SS_3_DICT


def _consolidate_data(outputs, include_hidden: bool = False):
    data = defaultdict(list)

    for output in outputs:
        output = output[0]
        length = output['protein_length']
        for key, array in output.items():
            if array.ndim == 1:
                data[key].append(array)
            elif array.ndim == 2 and array.dtype in [np.float32, np.float64]:
                data[key].append(array)
            else:
                for l, arr in zip(length, array):
                    data[key].append(arr[:l])

    out_data = {key: np.concatenate(array, 0) for key, array in data.items()}

    if not include_hidden:
        del out_data['encoder_outputs']

    return out_data


def consolidate_data(outputs, include_hidden: bool = False):
    data = defaultdict(list)  # type: ignore

    for output in outputs:
        output = output[0]
        length = output['protein_length']
        for key, protein_batch in output.items():
            for protein_length, protein_data in zip(length, protein_batch):
                if np.isscalar(protein_data):
                    data[key].append(protein_data)
                elif protein_data.ndim == 1 and protein_data.dtype in [np.float32, np.float64]:
                    data[key].append(protein_data)
                else:
                    data[key].append(protein_data[:protein_length])

    data = dict(data)

    if not include_hidden:
        del data['encoder_output']

    return data


def make_confusion_matrix(labels, logits):
    labels = np.concatenate(labels, 0)
    logits = np.concatenate(logits, 0)
    predictions = np.argmax(logits, -1)
    num_classes = logits.shape[-1]

    confusion = np.zeros((num_classes, num_classes))

    for label, pred in zip(labels, predictions):
        confusion[label, pred] += 1

    ticklabels = list(SS_3_DICT.values())

    sns.heatmap(confusion, annot=True, xticklabels=ticklabels, yticklabels=ticklabels)
    plt.xlabel('Prediction')
    plt.ylabel('Label')
    plt.title('Confusion Matrix Secondary Structure')
    plt.yticks(rotation=0)

    plt.show()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('outfile')
    args = parser.parse_args()

    with open(args.outfile, 'rb') as f:
        outputs = pkl.load(f)

    data = consolidate_data(outputs)
    make_confusion_matrix(data['output_sequence'], data['sequence_logits'])
