"""Subgrid generation routines for learning."""
import collections as col
import logging
import random

import numpy as np

import src.util.directions as di

# Recognized atoms types and the channel they are mapped to.
atom_to_channel = {'C': 0,
                   'O': 1,
                   'N': 2,
                   'S': 3}


def grid_size(model_params):
    """Get size of grid."""
    return int(round((model_params['radius_ang'] * 2 + 1) /
                     model_params['resolution']))


def grid_shape(model_params):
    """Return shape of grid."""
    size = grid_size(model_params)
    return (size, size, size, channel_size(model_params))


def channel_size(model_params):
    """Return number of channels to have."""
    return 4


def _recognized(x, dict):
    """If atom type is recognized, return it.  Else, return empty string."""
    if x in dict.keys():
        return x
    else:
        return ''


class TFSubgridGenerator(object):
    def __init__(self, model_params, num_directions, num_rolls):
        import tensorflow as tf
        self.model_params = model_params
        self.num_rolls = num_rolls
        self.num_directions = num_directions
        self.uvs = di.fibonacci(self.num_directions)
        self.ups = di.generate_all_up_vectors(self.uvs, self.num_rolls)
        self.rot_mats = \
            tf.cast(di.get_all_rot_mats(self.uvs, self.ups), tf.float32)

    def map_elements_to_int(self, elements):
        import tensorflow as tf
        all_locs, all_vals = [], []
        for (atom, channel) in atom_to_channel.items():
            # We add one since scatted_nd's default value is 0, so we don't
            # want ambiguity between that and our first channel.
            tmp_channel = tf.constant(channel + 1, dtype=tf.int32)
            locs = tf.cast(tf.where(
                tf.equal(elements, tf.constant(atom)))[:, 0], tf.int32)
            vals = tf.fill(tf.shape(locs), tmp_channel)
            all_locs.append(locs)
            all_vals.append(vals)
        indices = tf.expand_dims(tf.concat(all_locs, 0), 1)
        updates = tf.concat(all_vals, 0)
        encoding = tf.scatter_nd(indices, updates, tf.shape(elements))
        encoding = encoding - 1
        return encoding

    def get_gridded(self, center, positions, elements):
        import tensorflow as tf

        size = grid_size(self.model_params)
        true_radius = size * self.model_params['resolution'] / 2.0
        resolution = self.model_params['resolution']

        direction = tf.random_uniform(
            (), minval=0, maxval=self.num_directions, dtype=tf.int32)
        roll = tf.random_uniform(
            (), minval=0, maxval=self.num_rolls, dtype=tf.int32)
        rot_mat = self.rot_mats[direction][roll]
        centered_positions = positions - center
        rotated_positions = tf.transpose(
            tf.matmul(rot_mat, tf.transpose(centered_positions)))
        at = tf.cast(
            tf.round((rotated_positions + true_radius) / resolution - 0.5),
            tf.int32)
        big_enough = tf.reduce_all(at >= 0, axis=1)
        small_enough = tf.reduce_all(at < size, axis=1)
        sel = tf.where(tf.logical_and(big_enough, small_enough))
        lat = self.map_elements_to_int(tf.gather(elements, sel[:, 0]))
        at = tf.gather(at, sel[:, 0])

        # Make sure it is a recognized atom type
        recognized = tf.squeeze(
            tf.where(~tf.equal(lat, tf.constant(-1))), axis=1)
        lat = tf.gather(lat, recognized)
        at = tf.gather(at, recognized)
        pos = tf.concat((at, tf.expand_dims(lat, 1)), 1)
        grid = tf.cast(tf.sparse_to_dense(pos, (size, size, size, 4), 1,
                                          validate_indices=False), tf.float32)
        return grid, direction, roll

    def get_gridded_pair(self, region_pair):
        import tensorflow as tf
        grid0, direction0, roll0 = self.get_gridded(
            region_pair['center'][0],
            region_pair['positions0'],
            region_pair['elements0'])
        grid1, direction1, roll1 = self.get_gridded(
            region_pair['center'][1],
            region_pair['positions1'],
            region_pair['elements1'])
        uid0 = region_pair['pdb_name'][0] + '_' + region_pair['model'][0] + \
            '_' + region_pair['chain'][0] + '_' + region_pair['residue'][0]
        uid1 = region_pair['pdb_name'][1] + '_' + region_pair['model'][1] + \
            '_' + region_pair['chain'][1] + '_' + region_pair['residue'][1]
        grid = tf.stack((grid0, grid1))
        direction = tf.stack((direction0, direction1))
        roll = tf.stack((roll0, roll1))
        uid = tf.stack((uid0, uid1))
        region_pair['roll'] = roll
        region_pair['direction'] = direction
        region_pair['grid'] = grid
        region_pair['uid'] = uid
        del region_pair['positions0']
        del region_pair['positions1']
        del region_pair['elements0']
        del region_pair['elements1']
        return region_pair
