from typing import Tuple, Dict
import pickle as pkl
import os
import tensorflow as tf
import numpy as np
from glob import glob
from functools import partial

from .vocabs import PFAM_VOCAB
from .pfam_protein_serializer import deserialize_pfam_sequence


def get_pfam_data(directory: str,
                  batch_size: int,
                  max_sequence_length: int,
                  add_cls_token: bool = False,
                  batch_by_sequence_length: bool = True) -> \
        Tuple[tf.data.Dataset,
              tf.data.Dataset]:

    if not os.path.isdir(directory):
        raise FileNotFoundError(directory)

    fam_file = os.path.join(directory, 'pfam31_whole', 'pfam31_whole_fams.pkl')
    clan_file = os.path.join(directory, 'pfam31_whole', 'pfam31_whole_clans.pkl')

    train_files = glob(os.path.join(directory, 'pfam31_whole', 'tfrecords', '*[0-9].tfrecord'))
    valid_files = glob(os.path.join(directory, 'pfam31_whole', 'tfrecords', '*valid.tfrecord'))
    # Explicitly error because the dataset building will work even for empty directories
    # This will later raise an error for training, but won't raise an error here.
    if len(train_files) == 0:
        raise FileNotFoundError("No training TFrecord files found in directory")
    if len(valid_files) == 0:
        raise FileNotFoundError("No validation TFrecord files found in directory")

    _holdout_clans = ['CL0635', 'CL0624', 'CL0355', 'CL0100', 'CL0417', 'CL0630']
    _holdout_families = ['PF18346', 'PF14604', 'PF18697', 'PF03577', 'PF01112', 'PF03417']

    with open(fam_file, 'rb') as f:
        fam_dict: Dict[str, int] = pkl.load(f)

    with open(clan_file, 'rb') as f:
        clan_dict: Dict[str, int] = pkl.load(f)

    holdout_clans = {clan_dict[k] for k in _holdout_clans}
    holdout_families = {fam_dict[k] for k in _holdout_families}

    print('Currently holding out the following families:', *_holdout_families, sep='\n-')
    print('Currently holding out the following clans: ', *_holdout_clans, sep='\n-')

    train_filenames = tf.data.Dataset.from_tensor_slices(tf.constant(train_files))
    valid_filenames = tf.data.Dataset.from_tensor_slices(tf.constant(valid_files))

    deserialize_example = partial(
        deserialize_pfam_sequence,
        add_cls_token=add_cls_token,
        cls_token=PFAM_VOCAB['<CLS>'])

    # Fun parallel interleave stuff. It takes a dataset of filenames, loads them as TFRecord files,
    # Which allows you to load things from different tfrecords in parallel.
    def prepare_pfam_dataset(filenames: tf.data.Dataset, shuffle: bool, is_holdout: bool) -> tf.data.Dataset:

        def _check_membership(tensor, array):
            iscontained = tf.py_func(lambda t: t in array, [tensor], tf.bool)
            iscontained.set_shape(())
            return iscontained

        def _filter_fn(example):
            is_holdout_example = \
                _check_membership(example['clan'], holdout_clans) | \
                _check_membership(example['family'], holdout_families)
            return ~ (is_holdout ^ is_holdout_example)

        def _load_records_and_preprocess(fname: tf.Tensor):
            dataset = tf.data.TFRecordDataset(fname)
            dataset = dataset.map(deserialize_example)
            # Hold out a prespecified set of families and clans
            dataset = dataset.filter(_filter_fn)
            return dataset

        dataset = filenames.apply(
            tf.data.experimental.parallel_interleave(
                _load_records_and_preprocess,
                sloppy=True,
                cycle_length=128,
                buffer_output_elements=32))

        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)

        if batch_by_sequence_length:
            # Best boundaries for transformer on hope
            # boundaries = [
                # (100, 64),
                # (200, 32),
                # (300, 24),
                # (400, 16),
                # (500, 12),
                # (700, 8),
                # (900, 6),
                # (1100, 4),
                # (1800, 2)]
            # Best boundaries for lstm on dubbel
            boundaries = [
                (100, 512),
                (200, 512),
                (500, 256),
                (1000, 128),
                (2000, 64)]
            buckets, batch_sizes = list(zip(*boundaries))
            batch_sizes = batch_sizes + (32,)

            batch_fun = tf.data.experimental.bucket_by_sequence_length(
                lambda example: example['protein_length'],
                buckets,
                batch_sizes)
            dataset = dataset.apply(batch_fun)
        else:
            dataset = dataset.padded_batch(batch_size, dataset.output_shapes)

        return dataset

    train_data = prepare_pfam_dataset(train_filenames, shuffle=True, is_holdout=False)
    holdout_data = prepare_pfam_dataset(valid_filenames, shuffle=False, is_holdout=True)

    return train_data, holdout_data
