"""
Copyright (c) ANONYMOUS
All rights reserved.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import jax
import jax.numpy as jnp


def get_idx(curr_obs, observations):
    match = jax.vmap(lambda x, y: jnp.all(jnp.abs(x - y) < 0.00001), in_axes=(0, None))(
        jnp.stack(observations), curr_obs
    )
    return jax.lax.cond(
        jnp.sum(match) > 0,
        lambda _: jnp.argmax(match, axis=0),
        lambda _: jnp.ones((), dtype=int) * (-1),
        None,
    )


def dfs(env, observe, step_mdp, curr_state, observations, transition, reward, reward_probs):
    curr_obs = observe(curr_state)
    curr_obs_idx = get_idx(curr_obs, observations) if len(observations) > 0 else -1

    if curr_obs_idx > 0:
        return
    curr_obs_idx = len(observations)
    observations.append(curr_obs)
    reward.append([0] * env.num_actions)
    transition.append([[0] * env.num_actions for _ in range(curr_obs_idx)])
    reward_probs.append([[] for _ in range(env.num_actions)])

    for i in range(curr_obs_idx + 1):
        transition[i].append([0] * env.num_actions)

    if curr_state.done:
        print("done reached!")
        transition[curr_obs_idx][curr_obs_idx] = [0] * env.num_actions
        reward_probs[curr_obs_idx] = [[1.] + [0.]*(len(env.reward_values)-1) for _ in range(env.num_actions)]

        return

    for action in range(env.num_actions):
        next_state, next_transition, next_reward_probs = step_mdp(curr_state, action)
        dfs(env, observe, step_mdp, next_state, observations, transition, reward, reward_probs)

        next_obs_idx = get_idx(next_transition.observation, observations)
        assert next_obs_idx >= 0
        reward[curr_obs_idx][action] = next_transition.reward
        reward_probs[curr_obs_idx][action] = list(next_reward_probs)
        transition[curr_obs_idx][next_obs_idx][action] = 1
    return


# one hot env with deterministic dynamic
def get_mdp(env):
    print("Building MDP")
    observations = []
    init_state, _ = env.reset(None)
    num_action = env.num_actions
    transition = []
    reward = [] # average rewards
    reward_probs = [] # probabilities for each reward
    observe = jax.jit(env._observe)
    step_mdp = jax.jit(env.step_mdp)
    dfs(env, observe, step_mdp, init_state, observations, transition, reward, reward_probs)
    print("DONE!")
    return (
        jnp.stack(observations),
        jnp.transpose(jnp.array(transition), (0, 2, 1)),
        jnp.array(reward),
        jnp.array(reward_probs),
    )


class MDP_DP:
    def __init__(self, env):
        self.max_trial = env.length
        observation, transition, reward, reward_probs = get_mdp(env)
        self.mdp_observation = observation
        self.mdp_transition = transition
        self.mdp_reward = reward # average rewards of state-action pairs
        self.mdp_reward_probs=reward_probs
        self.mdp_reward_values = jnp.array(env.reward_values)
        self.num_state = observation.shape[0]
        self.num_actions = env.num_actions
        self.obs_shape = env.observation_shape
        self.init_state = jax.nn.one_hot(0, self.num_state)

        # self.dummy_mdp_reward = jnp.zeros_like(self.mdp_reward) # dirty hack for getting the reward-based
        # # contribution coeff when stochastic rewards are used. The idea is to store the unique
        # # reward values that the environment can return, in a shape that is compatible with later functions
        # for i in range(len(env.reward_values)):
        #     self.dummy_mdp_reward = self.dummy_mdp_reward.at[i//self.mdp_reward.shape[1], i%self.mdp_reward.shape[
        #         1]].set(env.reward_values[i])

    def observation_to_state(self, observation):
        return jax.nn.one_hot(get_idx(observation, self.mdp_observation), self.num_state)

    def hindsight_object_to_hindsight_state(self, hindsight_object, abstract_fn):
        """
        Gets a hindsight object as input, and return the unique one hot encoding associated to it.
        """
        all_hindsight_objects = jax.vmap(
            jax.vmap(jax.vmap(abstract_fn, in_axes=(None, None, 0)), in_axes=(None, 0, None)), in_axes=(0, None, None)
        )(self.mdp_observation, jnp.arange(self.num_actions), self.mdp_reward_values)

        all_hindsight_objects = jnp.reshape(
            all_hindsight_objects, (-1, all_hindsight_objects.shape[-1])
        )
        hs_idx = get_idx(hindsight_object, all_hindsight_objects)

        return jax.nn.one_hot(hs_idx, self.num_state * self.num_actions * len(self.mdp_reward_values))

    def get_all_hindsight_object_state(self, abstract_fn):
        """
        Computes all (s,a,r) to the corresponding hindsight one-hot feature, then outputs the corresponding objects.
        Output dimension is (s,a,h) where h is the number of hindsight objects. The dimension in r is marginalized.
        """
        all_hindsight_objects = jax.vmap(
            jax.vmap(jax.vmap(abstract_fn, in_axes=(None, None, 0)), in_axes=(None, 0, None)), in_axes=(0, None, None)
        )(self.mdp_observation, jnp.arange(self.num_actions), self.mdp_reward_values)

        # This has shape SxAxRxH
        all_hindsight_states = jax.vmap(
                jax.vmap(
                jax.vmap(self.hindsight_object_to_hindsight_state, in_axes=(0, None)),
                in_axes=(0, None)),
                in_axes=(0, None)
        )(all_hindsight_objects, abstract_fn)

        marginalized_states = jax.vmap(jax.vmap(lambda a,b: a@b))(self.mdp_reward_probs, all_hindsight_states)
        return marginalized_states

    def _get_state_value(self, state, policy_prob, horizon):
        """
        state: 1d tensor of state occupancy
        policy_prob: 2d tensor of (s, a) probability
        """
        policy_transition = jax.vmap(lambda a, b: a @ b)(policy_prob, self.mdp_transition)

        batch_inner_prod = jax.vmap(lambda a, b: a @ b)
        av_rewards = batch_inner_prod(policy_prob, self.mdp_reward)

        def get_summed_reward(curr_state, timestep):
            # curr_reward = curr_state @ batch_inner_prod(policy_prob, self.mdp_reward)
            curr_reward = curr_state @ av_rewards
            next_state = curr_state @ policy_transition
            return next_state, curr_reward

        carry, timestep_reward = jax.lax.scan(get_summed_reward, state, jnp.arange(horizon))
        return timestep_reward.sum(0)

    def _get_state_action_value(self, state, policy_prob, horizon):
        """
        state: 1d tensor of state occupancy
        policy_prob: 2d tensor of (s, a) probability

        returns: the action-value for state
        """

        def _get_action_value(one_hot_action):
            action_transition = jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(
                one_hot_action, self.mdp_transition
            )

            curr_reward = state @ self.mdp_reward @ one_hot_action
            next_state = state @ action_transition

            value = self._get_state_value(next_state, policy_prob, horizon - 1)

            return value + curr_reward

        return jax.vmap(_get_action_value)(jnp.eye(self.num_actions))

    def _get_state_successor(self, curr_state, policy_prob, horizon):
        """
        curr_state: one hot encoding
        policy: 2d tensor of (s, a) logits

        returns: sum_k=t+1 P(s_k=curr_state|s_t)
        """
        policy_transition = jax.vmap(lambda a, b: a @ b)(policy_prob, self.mdp_transition)
        contribution = jnp.zeros(self.num_state)

        def _get_contribution(curr_contrib, timestep):
            next_contrib = policy_transition @ (curr_state + curr_contrib)
            return next_contrib, None

        contribution, _ = jax.lax.scan(_get_contribution, contribution, jnp.arange(horizon))

        return contribution

    def _get_state_action_successor(self, curr_state, policy_prob, horizon):
        """
        curr_state: one hot encoding
        policy: 2d tensor of (s, a) logits

        returns: sum_k=t+1 P(s_k=curr_state|s_t,a_t)
        """
        contribution = jnp.zeros((self.num_state, self.num_actions))

        batch_inner_prod = jax.vmap(lambda a, b: a @ b)

        def _get_contribution(curr_contrib, timestep):
            next_contrib = self.mdp_transition @ (
                curr_state + batch_inner_prod(curr_contrib, policy_prob)
            )
            return next_contrib, None

        contribution, _ = jax.lax.scan(_get_contribution, contribution, jnp.arange(horizon))

        return contribution

    def get_advantage(self, policy_prob):
        value = self.get_value(policy_prob, None, states=jnp.eye(self.num_state))
        return self.get_advantage_from_value(policy_prob, value)

    def get_advantage_from_value(self, policy_prob, value):
        action_value = self.get_action_value(policy_prob, None, states=jnp.eye(self.num_state))
        return action_value - jnp.expand_dims(value, -1)

    def get_advantage_from_hindsight(self, policy_prob, contribution, g_trick):
        # contribution has shape past_state, curr_state, curr_action, past_action
        policy_transition = jax.vmap(lambda a, b: a @ b)(policy_prob, self.mdp_transition)
        batch_inner_prod = jax.vmap((lambda a, b: a@b))
        reward_coeff = contribution * jnp.expand_dims(
            policy_prob, -1
        )  # * jnp.expand_dims(self.mdp_reward, -1)

        reward_coeff = reward_coeff.sum(-2)

        def get_summed_reward(curr_state, timestep):
            curr_adv = batch_inner_prod(
                curr_state , reward_coeff
            )
            next_state = curr_state @ policy_transition
            return next_state, curr_adv

        carry, timestep_adv = jax.lax.scan(
            get_summed_reward, (jnp.eye(self.num_state)), jnp.arange(self.max_trial)
        )

        if g_trick:
            reward_adv = self.mdp_reward - jnp.expand_dims(
                jax.vmap(lambda a, b: a @ b)(policy_prob, self.mdp_reward), -1
            )
            return timestep_adv[1:].sum(0) + reward_adv

        return timestep_adv.sum(0)

    def get_value(self, policy_prob, observations, states=None):
        if states is None:
            states = jax.vmap(self.observation_to_state)(observations)
        state_value = jax.vmap(self._get_state_value, in_axes=(0, None, None))(
            states, policy_prob, self.max_trial
        )
        return state_value

    def get_action_value(self, policy_prob, observations, states=None):
        if states is None:
            states = jax.vmap(self.observation_to_state)(observations)
        state_action_value = jax.vmap(self._get_state_action_value, in_axes=(0, None, None))(
            states, policy_prob, self.max_trial
        )
        return state_action_value

    def get_reward(self, observations, states=None):
        if states is None:
            states = jax.vmap(self.observation_to_state)(observations)
        return states @ self.mdp_reward

    # def get_contribution(self, policy_prob, observations, next_observations):
    #     state_successor = jax.vmap(self._get_state_successor, in_axes=(0, None, None))(
    #         jnp.eye(self.num_state), policy_prob, self.max_trial
    #     )
    #     state_action_successor = jax.vmap(
    #         self._get_state_action_successor, in_axes=(0, None, None)
    #     )(jnp.eye(self.num_state), policy_prob, self.max_trial)
    #
    #     state_contribution = jnp.where(
    #         state_action_successor == 0, 0, state_action_successor / jnp.expand_dims(state_successor, -1)
    #     )
    #
    #     states = jax.vmap(self.observation_to_state)(observations)
    #     next_states = jax.vmap(self.observation_to_state)(next_observations)
    #     state_contribution = jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(states, state_contribution)
    #     state_contribution = jnp.transpose(
    #         jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(
    #             next_states, jnp.transpose(state_contribution, (2, 0, 1))
    #         ),
    #         (2, 1, 0),
    #     )
    #
    #     return state_contribution

    def get_contribution(self, policy_prob, observations, hindsight_objects, abstract_fn, g_trick):
        batch_inner_prod = jax.vmap(lambda a, b: a @ b, in_axes=(0, 0))
        state_successor = jax.vmap(self._get_state_successor, in_axes=(0, None, None))(
            jnp.eye(self.num_state), policy_prob, self.max_trial
        )
        state_action_successor = jax.vmap(
            self._get_state_action_successor, in_axes=(0, None, None)
        )(jnp.eye(self.num_state), policy_prob, self.max_trial)

        states = jax.vmap(self.observation_to_state)(observations)
        next_states = jax.vmap(self.hindsight_object_to_hindsight_state, in_axes=(0, None))(
            hindsight_objects, abstract_fn
        )

        hindsight_objects_table = self.get_all_hindsight_object_state(abstract_fn)
        state_contrib_reward = jnp.tensordot(
            jnp.transpose(batch_inner_prod(policy_prob, hindsight_objects_table), (1, 0)),
            jnp.eye(self.num_state) * (1 - g_trick) + state_successor,
            axes=1,
        )
        state_action_contrib_reward = jnp.tensordot(
            jnp.transpose(batch_inner_prod(policy_prob, hindsight_objects_table), (1, 0)),
            state_action_successor,
            axes=1,
        ) + jnp.transpose(hindsight_objects_table, axes=(2, 0, 1)) * (1 - g_trick)

        coeff = jnp.where(
            state_action_contrib_reward == 0,
            0,
            state_action_contrib_reward / jnp.expand_dims(state_contrib_reward, -1),
        )
        coeff = jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(states, coeff)
        coeff = jnp.transpose(
            jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(
                next_states, jnp.transpose(coeff, (2, 0, 1))
            ),
            (2, 1, 0),
        )

        return coeff

    def get_contribution_hindsight_object_from_state_contribution(
        self, policy_prob, observations, hindsight_objects, abstract_fn, coeff
    ):
        # This only supports g_trick based contribution

        state_contrib = jax.vmap(self._get_state_successor, in_axes=(0, None, None))(
            jnp.eye(self.num_state), policy_prob, self.max_trial
        )

        hindsight_objects_table = self.get_all_hindsight_object_state(abstract_fn)
        batch_inner_prod = jax.vmap(lambda a, b: a.T @ b)
        hindsight_objects_collapsed = batch_inner_prod(policy_prob, hindsight_objects_table)
        marginal = jnp.expand_dims(state_contrib.T, axis=-1) * hindsight_objects_collapsed
        # hindsight: (s',a',h')
        # hindsight collapsed: (s', h')=P(h'|s')
        # state_contrib: (s', s)=P(s'|s)
        marginal = jnp.where(
            marginal == 0, 0, marginal / jnp.expand_dims(marginal.sum(axis=1), axis=1)
        )

        # marginal: (s, s', h')
        # coeff: (s, s', a)
        states = jax.vmap(self.observation_to_state)(observations)
        next_states = jax.vmap(self.hindsight_object_to_hindsight_state, in_axes=(0, None))(
            hindsight_objects, abstract_fn
        )

        reward_coeff = jnp.transpose(batch_inner_prod(marginal, coeff), (1, 0, 2))

        coeff = jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(states, reward_coeff)

        coeff = jnp.transpose(
            jax.vmap(lambda a, b: a @ b, in_axes=(None, 0))(
                next_states, jnp.transpose(coeff, (2, 0, 1))
            ),
            (2, 1, 0),
        )
        return coeff
