from typing import Tuple
import os
import tensorflow as tf
import numpy as np
from .secondary_structure_protein_serializer import deserialize_secondary_structure_sequence_eight_classes
from .secondary_structure_protein_serializer import deserialize_secondary_structure_sequence_three_classes


def get_secondary_structure_data(directory: str,
                                 batch_size: int,
                                 max_sequence_length: int,
                                 eight_class_version: bool = True) -> \
        Tuple[tf.data.Dataset, tf.data.Dataset]:

    train_file = os.path.join(directory, 'supervised', 'secondary_structure', 'secondary_structure_train.tfrecords')
    valid_file = os.path.join(directory, 'supervised', 'secondary_structure', 'secondary_structure_valid.tfrecords')
    casp_file = os.path.join(directory, 'supervised', 'secondary_structure', 'secondary_structure_casp12.tfrecords')
    ts_file = os.path.join(directory, 'supervised', 'secondary_structure', 'secondary_structure_ts115.tfrecords')
    cb_file = os.path.join(directory, 'supervised', 'secondary_structure', 'secondary_structure_cb513.tfrecords')

    if not os.path.exists(train_file):
        raise FileNotFoundError(train_file)
    if not os.path.exists(valid_file):
        raise FileNotFoundError(valid_file)
    if not os.path.exists(casp_file):
        raise FileNotFoundError(casp_file)
    if not os.path.exists(ts_file):
        raise FileNotFoundError(ts_file)
    if not os.path.exists(cb_file):
        raise FileNotFoundError(cb_file)

    train_data = tf.data.TFRecordDataset(train_file)
    valid_data = tf.data.TFRecordDataset(valid_file)
    casp_data = tf.data.TFRecordDataset(casp_file)
    ts_data = tf.data.TFRecordDataset(ts_file)
    cb_data = tf.data.TFRecordDataset(cb_file)

    def prepare_dataset(dataset: tf.data.Dataset, shuffle: bool) -> tf.data.Dataset:
        if eight_class_version:
            dataset = dataset.map(deserialize_secondary_structure_sequence_eight_classes, batch_size)
        else:
            dataset = dataset.map(deserialize_secondary_structure_sequence_three_classes, batch_size)
        # dataset = dataset.filter(lambda example: example['protein_length'] < max_sequence_length)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        # bucket_boundaries = [100, 200, 300, 400, 500, 600, 700]
        # ratio = [128, 64, 32, 16, 8, 4, 2, 1]
        bucket_boundaries = np.arange(100, 2000 + 100, 100)
        centers = np.arange(50, 2000 + 100, 100)
        ratio = (centers[-1]) / (centers)
        # ratio = ratio * batch_size
        ratio = np.asarray(ratio, np.int32)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            bucket_boundaries,
            ratio)
        dataset = dataset.apply(batch_fun)
        return dataset

    train_data = prepare_dataset(train_data, shuffle=True)
    valid_data = prepare_dataset(valid_data, shuffle=False)
    casp_data = prepare_dataset(casp_data, shuffle=False)
    ts_data = prepare_dataset(ts_data, shuffle=False)
    cb_data = prepare_dataset(cb_data, shuffle=False)

    return train_data, valid_data
