import dataclasses
import functools
from typing import Any

import chex
import jax
import jax.numpy as jnp
import optax
from clu import metrics as clu_metrics
from flax.training import train_state

from tabular_mvdrl.kernels import Kernel
from tabular_mvdrl.mmd import mmd2
from tabular_mvdrl.models import EWPModel, TabularProbabilityModel
from tabular_mvdrl.state import WeightedParticleState
from tabular_mvdrl.trainer import MVDRLTransferTrainer
from tabular_mvdrl.types import MRPTransitionBatch
from tabular_mvdrl.utils import jitpp, support_init
from tabular_mvdrl.utils.discrete_distributions import DiscreteDistribution
from tabular_mvdrl.utils.jitpp import Bind, Donate, Static

LOSS_MMD = "loss__mmd"


def random_probability_init(rng: chex.PRNGKey, shape: Any, dtype: Any) -> chex.Array:
    logits = 0.01 * jax.random.normal(rng, shape=shape)
    return jax.nn.softmax(logits, axis=-1)


@dataclasses.dataclass(frozen=True, kw_only=True)
class CatTDTrainer(MVDRLTransferTrainer[WeightedParticleState]):
    optim: optax.GradientTransformation
    kernel: Kernel
    discount: float
    support_map_initializer: support_init.SupportMapInitializer
    signed: bool = False

    @functools.cached_property
    def num_atoms(self):
        support_map = self.support_map_initializer(jax.random.PRNGKey(0))
        return support_map.shape[1]

    @property
    def identifier(self):
        prefix = "Signed-" if self.signed else ""
        return f"{prefix}Cat-TD-{self.num_atoms}"

    @functools.cached_property
    def metrics(self) -> clu_metrics.Collection:
        metric_tags = [LOSS_MMD]
        metric_keepers = {
            tag: clu_metrics.Average.from_output(tag) for tag in metric_tags
        }
        return clu_metrics.Collection.create(**metric_keepers)

    @functools.cached_property
    def state(self) -> WeightedParticleState:
        locs_key, probs_key = jax.random.split(jax.random.PRNGKey(self.seed))
        locs_model = EWPModel(self.env.num_states, self.env.reward_dim, self.num_atoms)
        locs_params = locs_model.init_with_support(
            locs_key, jnp.int32(0), self.support_map_initializer
        )
        support_map_state = train_state.TrainState.create(
            apply_fn=locs_model.apply, params=locs_params, tx=self.optim
        )

        if self.signed:
            probs_model = TabularProbabilityModel(
                self.env.num_states,
                self.num_atoms,
                logits=False,
                initializer=random_probability_init,
            )
        else:
            probs_model = TabularProbabilityModel(
                self.env.num_states,
                self.num_atoms,
                logits=True,
            )
        probs_params = probs_model.init(probs_key, jnp.int32(0))
        return WeightedParticleState.create(
            params=probs_params,
            apply_fn=probs_model.apply,
            tx=self.optim,
            support_map=support_map_state,
            metrics=self.metrics.empty(),
        )

    @jitpp.jit
    @staticmethod
    def train_step(
        key: chex.PRNGKey,
        state: Donate[WeightedParticleState],
        batch: MRPTransitionBatch,
        *,
        kernel: Bind[Static[Kernel]],
        discount: Bind[float],
        num_atoms: Bind[int],
    ) -> WeightedParticleState:
        def _mmd_loss(
            eta_pred: DiscreteDistribution, eta_target: DiscreteDistribution
        ) -> chex.Scalar:
            return mmd2(
                kernel, eta_pred.locs, eta_target.locs, eta_pred.probs, eta_target.probs
            )

        @jax.value_and_grad
        def loss_fn(params: chex.ArrayTree, batch_: MRPTransitionBatch):
            locs_t = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
                state.support_map.params, batch.o_t
            )
            probs_t = jax.vmap(state.apply_fn, in_axes=(None, 0))(params, batch_.o_t)
            locs_tp1 = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
                state.support_map.params, batch_.o_tp1
            )
            probs_target = jax.vmap(state.apply_fn, in_axes=(None, 0))(
                state.params, batch_.o_tp1
            )
            locs_target = batch_.r_t[:, None, ...] + discount * locs_tp1
            eta_t = jax.vmap(DiscreteDistribution)(locs=locs_t, probs=probs_t)
            eta_target = jax.vmap(DiscreteDistribution)(
                locs=locs_target, probs=probs_target
            )
            return jnp.mean(jax.vmap(_mmd_loss)(eta_t, eta_target))

        loss, grads = loss_fn(state.params, batch)
        grads = jax.tree_util.tree_map(lambda g: g / num_atoms, grads)
        metrics = {LOSS_MMD: loss}
        return state.apply_gradients(
            grads, state.metrics.single_from_model_output(**metrics)
        )

    def return_distribution(
        self, state: WeightedParticleState, i: int
    ) -> DiscreteDistribution:
        i = jnp.int32(i)
        locs = state.support_map.apply_fn(state.support_map.params, i)
        probs = state.apply_fn(state.params, i)
        new_probs = jnp.clip(probs, 0, 1)
        new_probs /= jnp.sum(new_probs)
        return DiscreteDistribution(locs=locs, probs=probs)
