from typing import Tuple, List
import os
import tensorflow as tf
import numpy as np
from .arnold_protein_serializer import deserialize_arnold_sequence


def get_localization_data(boundaries: Tuple[List[int], List[int]], directory: str, max_sequence_length: int) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:

    train_file = os.path.join(directory, 'supervised', 'localization', 'localization_train.tfrecords')
    valid_file = os.path.join(directory, 'supervised', 'localization', 'localization_valid.tfrecords')

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)

    def prepare_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        dataset = dataset.map(deserialize_arnold_sequence, 128)
        # dataset = dataset.filter(lambda example: example['protein_length'] < max_sequence_length)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            boundaries[0],
            boundaries[1])
        dataset = dataset.apply(batch_fun)
        return dataset

    train_data = prepare_dataset(train_data, shuffle=True)
    valid_data = prepare_dataset(valid_data, shuffle=False)

    return train_data, valid_data


def get_thermostability_data(directory: str, batch_size: int, max_sequence_length: int) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:

    train_file = os.path.join(directory, 'supervised', 'thermostability', 'thermostability_train.tfrecords')
    valid_file = os.path.join(directory, 'supervised', 'thermostability', 'thermostability_valid.tfrecords')

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)

    def prepare_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        dataset = dataset.map(deserialize_arnold_sequence, batch_size)
        # dataset = dataset.filter(lambda example: example['protein_length'] < max_sequence_length)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        bucket_boundaries = np.arange(100, max_sequence_length + 100, 100)
        centers = np.arange(50, max_sequence_length + 100, 100)
        ratio = (centers[-1]) / (centers)
        ratio = ratio * batch_size
        ratio = np.asarray(ratio, np.int32)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            bucket_boundaries,
            ratio)
        dataset = dataset.apply(batch_fun)
        return dataset

    train_data = prepare_dataset(train_data, shuffle=True)
    valid_data = prepare_dataset(valid_data, shuffle=False)

    return train_data, valid_data
