from typing import Dict, Optional, Tuple, List
import string
import tensorflow as tf
from itertools import zip_longest
import pickle as pkl

from tqdm import tqdm

from .tf_data_utils import to_features, to_sequence_features


def grouper(iterable, n, fillvalue=None):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)


def read_transmembrane_data(filename: str) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    """
    Parses out transmembrane dataset. Adds additional parameter for which dataset.
    """
    names = []
    seqs = []
    tags = []

    with open(filename) as f:
        three_lines = grouper(f, 3)  # Iterate 3 lines at a time

        for entry in three_lines:
            name, seq, tag = entry
            name = name.strip()
            name = name.split('>')[-1]

            # Hack for some dangling lines
            if not seq:
                break

            assert len(seq.strip()) == len(tag.strip()), 'Tag and Sequence must have same length'
            names.append(name)
            seqs.append(seq.strip())
            tags.append(tag.strip())

    train_split = int(0.8 * len(seqs))
    train_seqs = seqs[:train_split]
    val_seqs = seqs[train_split:]

    train_names = names[:train_split]
    val_names = names[train_split:]

    train_tags = tags[:train_split]
    val_tags = tags[train_split:]

    train = {'seqs': train_seqs, 'names': train_names, 'tags': train_tags}
    val = {'seqs': val_seqs, 'names': val_names, 'tags': val_tags}

    return train, val


def convert_transmembrane_sequences_to_tfrecords(outfile: str,
                                                 vocab: Optional[Dict[str, int]] = None) -> None:
    outfile = outfile.rsplit('.')[0]

    if vocab is None:
        vocab = {"<PAD>": 0, "<MASK>": 1, "<CLS>": 2, "<SEP>": 3}

    tag_dict: Dict[str, int] = {}

    dataset_types = ['Globular', 'Globular+SP', 'TM', 'SP+TM']
    file_paths = ['data/transmembrane/TOPCONS2_datasets/' + name + '.3line' for name in dataset_types]

    # Open both val and train writer simultaneously because I'm lazy
    with tf.python_io.TFRecordWriter(outfile + '_train.tfrecords') as train_writer:

        with tf.python_io.TFRecordWriter(outfile + '_valid.tfrecords') as val_writer:

            # Go one data file at a time
            for filepath, dataset_type in zip(file_paths, dataset_types):
                # Load data from given file
                print(filepath)
                train, val = read_transmembrane_data(filepath)

                print(f'Dumping {dataset_type} Data Set...')

                # Create train set
                for example in tqdm(zip(train['seqs'], train['names'], train['tags'])):
                    seq, name, tag = example

                    serialized_example, vocab, tag_dict = serialize_transmembrane_sequence(
                        seq.strip(), tag, name, dataset_type, vocab, tag_dict)
                    train_writer.write(serialized_example)

                # Create val set
                for example in tqdm(zip(val['seqs'], val['names'], val['tags'])):
                    seq, name, tag = example

                    serialized_example, vocab, tag_dict = serialize_transmembrane_sequence(
                        seq.strip(), tag, name, dataset_type, vocab, tag_dict)
                    val_writer.write(serialized_example)

    with open(outfile.rsplit('.', maxsplit=1)[0] + '.vocab', 'wb') as f:
        pkl.dump(vocab, f)

    with open(outfile.rsplit('.', maxsplit=1)[0] + '_tag_dict.pkl', 'wb') as f:
        pkl.dump(tag_dict, f)


def serialize_transmembrane_sequence(sequence: str,
                                     tags: str,
                                     name: str,
                                     label: str,
                                     vocab: Dict[str, int],
                                     tag_dict: Dict[str, int]) -> \
        Tuple[bytes, Dict[str, int], Dict[str, int]]:
    int_sequence = []
    for aa in sequence:
        if aa in string.whitespace:
            raise ValueError("whitespace found in string")

        aa_idx = vocab.get(aa)
        if aa_idx is None:
            raise ValueError("Pfam vocab is incompatible")

        int_sequence.append(aa_idx)

    int_tag_sequence = []
    for pos in tags:
        if pos in string.whitespace:
            raise ValueError("whitespace found in string")

        pos_idx = tag_dict.get(pos)
        if pos_idx is None:
            tag_dict[pos] = len(tag_dict)  # Can't do this with defaultdict b/c it depends on the dictionary
            pos_idx = tag_dict[pos]

        int_tag_sequence.append(pos_idx)

    protein_context = to_features(
        name=name.encode('UTF-8'),
        protein_type=label.encode('UTF-8'),
        protein_length=len(int_sequence))
    protein_features = to_sequence_features(
        protein_sequence=int_sequence,
        output_sequence=int_tag_sequence)

    example = tf.train.SequenceExample(context=protein_context, feature_lists=protein_features)
    return example.SerializeToString(), vocab, tag_dict


def deserialize_transmembrane_sequence(example):
    context = {
        'name': tf.FixedLenFeature([], tf.string),
        'protein_length': tf.FixedLenFeature([1], tf.int64),
        'protein_type': tf.FixedLenFeature([], tf.string)
    }

    features = {
        'protein_sequence': tf.FixedLenSequenceFeature([1], tf.int64),
        'output_sequence': tf.FixedLenSequenceFeature([1], tf.int64)
    }

    context, features = tf.parse_single_sequence_example(
        example,
        context_features=context,
        sequence_features=features
    )

    name = context['name']
    protein_type = context['protein_type']
    protein_length = tf.to_int32(context['protein_length'][0])

    sequence = tf.to_int32(features['protein_sequence'][:, 0])
    membrane_tag = tf.to_int32(features['output_sequence'][:, 0])

    return {'sequence': sequence,
            'output_sequence': membrane_tag,
            'protein_length': protein_length,
            'protein_type': protein_type,
            'name': name}


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='convert protein sequences to tfrecords')
    parser.add_argument('--outfile', type=str, default=None, help='name of outfile')
    parser.add_argument('--vocab', type=str, default=None, help='path to existing vocab file')
    args = parser.parse_args()

    if args.vocab:
        with open(args.vocab, 'rb') as f:
            vocab = pkl.load(f)
    convert_transmembrane_sequences_to_tfrecords(args.outfile, vocab)
