from typing import Optional, Tuple, List
import subprocess

import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.layers import CuDNNLSTM as LSTM
from tensorflow.keras.layers import Embedding, Dropout
from sacred import Ingredient

from rinokeras.layers import Stack


lstm_hparams = Ingredient('lstm')


@lstm_hparams.config
def configure_lstm():
    n_units = 1024  # noqa: F841
    n_layers = 3  # noqa: F841
    dropout = 0.1  # noqa: F841


class BidirectionalLSTM(Model):

    @lstm_hparams.capture
    def __init__(self,
                 n_symbols: int,
                 n_units: int = 1024,
                 n_layers: int = 3,
                 dropout: Optional[float] = 0.1) -> None:
        super().__init__()

        if dropout is None:
            dropout = 0

        self.embedding = Embedding(n_symbols, 128)

        self.forward_lstm = Stack([
            LSTM(n_units,
                 return_sequences=True) for _ in range(n_layers)],
            name='forward_lstm')

        self.reverse_lstm = Stack([
            LSTM(n_units,
                 return_sequences=True) for _ in range(n_layers)],
            name='reverse_lstm')

        self.dropout = Dropout(dropout)

    def call(self, inputs):
        sequence = inputs['sequence']
        protein_length = inputs['protein_length']

        sequence = self.embedding(sequence)
        tf.add_to_collection('checkpoints', sequence)

        forward_output = self.forward_lstm(sequence)
        tf.add_to_collection('checkpoints', forward_output)

        reversed_sequence = tf.reverse_sequence(sequence, protein_length, seq_axis=1)
        reverse_output = self.reverse_lstm(reversed_sequence)
        reverse_output = tf.reverse_sequence(reverse_output, protein_length, seq_axis=1)
        tf.add_to_collection('checkpoints', reverse_output)

        encoder_output = tf.concat((forward_output, reverse_output), -1)

        encoder_output = self.dropout(encoder_output)

        inputs['encoder_output'] = encoder_output
        return inputs

    @property
    def boundaries(self) -> Tuple[List[int], List[int]]:
        nvidia_smi = subprocess.check_output('nvidia-smi')
        memsize = list(filter(lambda word: 'MiB' in word, nvidia_smi.decode().split()))[1]
        memsize = int(memsize[:-3]) // 1000  # number of gigabytes on gpu
        boundaries = [
            (100, 5),
            (200, 5),
            (300, 5),
            (400, 5),
            (600, 5),
            (900, 4),
            (1000, 4),
            (1200, 3),
            (1300, 3),
            (2000, 2),
            (3000, 1)]

        bounds = [b[0] for b in boundaries]
        sizes = [b[1] for b in boundaries]
        sizes.append(0)

        bounds_array = np.array(bounds)
        sizes_array = np.array(sizes)
        sizes_array = sizes_array * memsize / 2

        sizes_array = np.asarray(sizes_array, np.int32)
        sizes_array[sizes_array <= 0] = 1

        return bounds_array, sizes_array
