from typing import Tuple, Dict, List

import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import rinokeras as rk

from tape.data_utils import deserialize_transmembrane_sequence
from .Task import SequenceToSequenceClassificationTask


class TransmembraneTask(SequenceToSequenceClassificationTask):

    def __init__(self):
        n_classes = 4
        super().__init__(
            key_metric='ACC',
            supervised=True,
            deserialization_func=deserialize_transmembrane_sequence,
            n_classes=n_classes,
            label_name='output_sequence',
            input_name='encoder_output',
            output_name='sequence_logits')

    def transmembrane_accuracy(self, labels, predictions, sequence_length):

        def check_signal(values):
            return values[0] == 1

        def get_regions(values):
            mask = np.array(values == 3, np.int)
            start = np.where(mask[1:] - mask[:-1] == 1)[0] + 1
            end = np.where(mask[1:] - mask[:-1] == -1)[0]

            if mask[0]:
                start = np.concatenate((np.array([0]), start))

            if mask[-1]:
                end = np.concatenate((end, np.array([len(values) - 1])))

            assert len(start) == len(end)

            return len(start), (start, end)

        def check_overlap(true_regions, pred_regions):
            starts = np.max([true_regions[0], pred_regions[0]], 0)
            ends = np.min([true_regions[1], pred_regions[1]], 0)

            overlap = ends - starts

            return np.all(overlap >= 5)

        def check_correct(labels_numpy, predictions_numpy, seqlen_numpy):
            acc = []
            for lab_numpy, pred_numpy, seqlen in zip(
                    labels_numpy, predictions_numpy, seqlen_numpy):
                lab_numpy = lab_numpy[:seqlen]
                pred_numpy = pred_numpy[:seqlen]
                if check_signal(lab_numpy) != check_signal(pred_numpy):
                    acc.append(False)
                    continue

                num_true_regions, true_regions = get_regions(lab_numpy)
                num_pred_regions, pred_regions = get_regions(pred_numpy)

                if num_true_regions != num_pred_regions:
                    acc.append(False)
                    continue

                if num_true_regions == 0:
                    acc.append(True)
                    continue

                acc.append(check_overlap(true_regions, pred_regions))
            return np.array(acc)

        acc = tf.py_func(check_correct, [labels, predictions, sequence_length], tf.bool)
        return tf.reduce_mean(K.cast(acc, K.floatx()))

    def loss_function(self,
                      inputs: Dict[str, tf.Tensor],
                      outputs: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        labels = inputs[self._label_name]
        logits = outputs[self._output_name]
        mask = rk.utils.convert_sequence_length_to_sequence_mask(
            labels, inputs['protein_length'])
        loss = tf.losses.sparse_softmax_cross_entropy(
            labels, logits, weights=tf.cast(mask, logits.dtype))

        predictions = tf.argmax(logits, -1, output_type=labels.dtype)
        accuracy = self.transmembrane_accuracy(labels, predictions, inputs['protein_length'])

        metrics = {self.key_metric: accuracy}
        return loss, metrics

    def prepare_dataset(self,
                        dataset: tf.data.Dataset,
                        buckets: List[int],
                        batch_sizes: List[int],
                        shuffle: bool = False) -> tf.data.Dataset:
        dataset = dataset.map(self._deserialization_func, num_parallel_calls=128)
        # if shuffle:
            # dataset = dataset.filter(lambda example: example['protein_length'] < 4000)
        dataset = dataset.shuffle(1024) if shuffle else dataset.prefetch(1024)
        batch_fun = tf.data.experimental.bucket_by_sequence_length(
            lambda example: example['protein_length'],
            buckets,
            batch_sizes)
        dataset = dataset.apply(batch_fun)
        return dataset
