import abc

import tensorflow as tf
import tensorflow_probability as tfp
import tree


class BasePolicy:
    def __init__(self, input_shapes, output_shape, name='policy'):
        self._input_shapes = input_shapes
        self._output_shape = output_shape

        self._name = name

    @abc.abstractmethod
    def actions(self, inputs):
        """Compute actions for given inputs (e.g. observations)."""
        raise NotImplementedError

    def action(self, *args, **kwargs):
        """Compute an action for a single input, (e.g. observation)."""
        args_, kwargs_ = tree.map_structure(
            lambda x: x[None, ...], (args, kwargs))
        actions = self.actions(*args_, **kwargs_)
        action = tree.map_structure(lambda x: x[0], actions)
        return action

    @abc.abstractmethod
    def log_probs(self, inputs, actions):
        """Compute log probabilities for given actions."""
        raise NotImplementedError

    def log_prob(self, *args, **kwargs):
        """Compute the log probability for a single action."""
        args_, kwargs_ = tree.map_structure(
            lambda x: x[None, ...], (args, kwargs))
        log_probs = self.log_probs(*args_, **kwargs_)
        log_prob = tree.map_structure(lambda x: x[0], log_probs)
        return log_prob

    @abc.abstractmethod
    def probs(self, inputs, actions):
        """Compute probabilities for given actions."""
        raise NotImplementedError

    def prob(self, *args, **kwargs):
        """Compute the probability for a single action."""
        args_, kwargs_ = tree.map_structure(
            lambda x: x[None, ...], (args, kwargs))
        probs = self.probs(*args_, **kwargs_)
        prob = tree.map_structure(lambda x: x[0], probs)
        return prob

    def get_config(self):
        return {}


class ContinuousPolicy(BasePolicy):
    def __init__(self,
                 action_range,
                 *args,
                 squash=True,
                 **kwargs):
        assert (tf.reduce_all(action_range == tf.constant([[-1], [1]]))), (
            "The action space should be scaled to (-1, 1)."
            " TODO: We should support non-scaled actions spaces.")
        self._action_range = action_range
        self._squash = squash
        self._action_post_processor = {
            True: tfp.bijectors.Tanh(),
            False: tfp.bijectors.Identity(),
        }[squash]

        return super(ContinuousPolicy, self).__init__(*args, **kwargs)

    def get_config(self):
        base_config = super(ContinuousPolicy, self).get_config()
        config = {
            **base_config,
            'action_range': self._action_range,
            'squash': self._squash,
        }
        return config


class DiscretePolicy(BasePolicy):
    def __init__(self,
                 probability_table,
                 *args,
                 input_shapes=(),
                 output_shape=(),
                 **kwargs):
        self.probability_table = tf.convert_to_tensor(probability_table)
        return super(DiscretePolicy, self).__init__(
            *args,
            input_shapes=input_shapes,
            output_shape=output_shape,
            **kwargs)

    def actions(self, inputs):
        probabilities = tf.gather(self.probability_table, inputs)
        logits = tf.math.log(probabilities)
        actions = tf.random.categorical(logits, 1)[..., 0]
        return actions

    def probs(self, inputs, actions):
        """Compute probabilities for given actions."""
        probabilities = tf.gather_nd(
            self.probability_table, tf.stack((inputs, actions), axis=-1))
        return probabilities

    def get_config(self):
        base_config = super(DiscretePolicy, self).get_config()
        config = {
            **base_config,
            'probability_table': self.probability_table,
        }
        return config
