from typing import Dict, List
import string

import tensorflow as tf

from .tf_data_utils import to_features, to_sequence_features


def serialize_remote_homology_sequence(sequence: str,
                                       seq_id: str,
                                       class_label: int,
                                       fold_label: int,
                                       superfamily_label: int,
                                       family_label: int,
                                       pssm: List[List[int]],
                                       ss: List[int],
                                       sa: List[int],
                                       vocab: Dict[str, int]):
    int_sequence = []
    for aa in sequence:
        if aa in string.whitespace:
            raise ValueError("whitespace found in string")

        aa_idx = vocab.get(aa)
        if aa_idx is None:
            raise ValueError(f'{aa} not in vocab')

        int_sequence.append(aa_idx)

    protein_context = {}
    protein_context = to_features(
        sequence_id=seq_id.encode('UTF-8'),
        protein_length=len(int_sequence),
        class_label=class_label,
        fold_label=fold_label,
        superfamily_label=superfamily_label,
        family_label=family_label)

    protein_features = to_sequence_features(sequence=int_sequence,
                                            secondary_structure=ss,
                                            solvent_accessibility=sa,
                                            pssm=pssm)

    example = tf.train.SequenceExample(context=protein_context, feature_lists=protein_features)
    return example.SerializeToString()


def deserialize_remote_homology_sequence(example):
    context = {
        'sequence_id': tf.FixedLenFeature([], tf.string),
        'protein_length': tf.FixedLenFeature([], tf.int64),
        'class_label': tf.FixedLenFeature([], tf.int64),
        'fold_label': tf.FixedLenFeature([], tf.int64),
        'superfamily_label': tf.FixedLenFeature([], tf.int64),
        'family_label': tf.FixedLenFeature([], tf.int64)
    }

    features = {
        'sequence': tf.FixedLenSequenceFeature([], tf.int64),
        'secondary_structure': tf.FixedLenSequenceFeature([], tf.int64),
        'solvent_accessibility': tf.FixedLenSequenceFeature([], tf.int64),
        'pssm': tf.FixedLenSequenceFeature([20], tf.int64)
    }

    context, features = tf.parse_single_sequence_example(
        example,
        context_features=context,
        sequence_features=features
    )

    sequence = tf.to_int32(features['sequence'])
    sequence_id = context['sequence_id']
    protein_length = tf.to_int32(context['protein_length'])
    class_label = tf.cast(context['class_label'], tf.int32)
    fold_label = tf.cast(context['fold_label'], tf.int32)
    superfamily_label = tf.cast(context['superfamily_label'], tf.int32)
    family_label = tf.cast(context['family_label'], tf.int32)

    secondary_structure = tf.one_hot(features['secondary_structure'], 3)
    solvent_accessibility = tf.one_hot(features['solvent_accessibility'], 2)
    pssm = tf.cast(features['pssm'], tf.float32)  # floats since that's what downstream model expects

    # profile_features = tf.concat((secondary_structure, solvent_accessibility, pssm), -1)
    profile_features = pssm

    return {'sequence_id': sequence_id,
            'sequence': sequence,
            'protein_length': protein_length,
            'class_label': class_label,
            'fold_label': fold_label,
            'superfamily_label': superfamily_label,
            'family_label': family_label,
            'secondary_structure': secondary_structure,
            'solvent_accessibility': solvent_accessibility,
            'hmm_profile': profile_features}  # this is what we're calling alignment-based features for now I guess.
