from typing import Any

import fiddle as fdl
import jax

from tabular_mvdrl.kernels import energy_distance
from tabular_mvdrl.trainer import MVDRLTransferTrainer
from tabular_mvdrl.utils import support_init
from tabular_mvdrl.utils.discrete_distributions import (
    SquaredMMDMetric,
    SupremalMetric,
    Wasserstein2Metric,
)


def beta_rewards(base: fdl.Buildable[MVDRLTransferTrainer[Any]]):
    env = fdl.build(base.env)
    base.env.cumulant_prior = fdl.Partial(
        jax.random.beta, a=0.1, b=0.1, shape=(env.reward_dim,)
    )


def bins_per_dim(
    base: fdl.Buildable[MVDRLTransferTrainer[Any]], bins_per_dim: int = 10
):
    if hasattr(base, "support_map_initializer"):
        base.support_map_initializer.support_init.bins_per_dim = bins_per_dim
    else:
        base.num_atoms = bins_per_dim**2


def num_atoms(base: fdl.Buildable[MVDRLTransferTrainer[Any]], num_atoms: int = 100):
    if hasattr(base, "support_map_initializer"):
        env = fdl.build(base.env)
        base.support_map_initializer = fdl.Config(
            support_init.independent_map,
            support_init=fdl.Config(
                support_init.uniform_random_support,
                d=env.reward_dim,
                num_atoms=num_atoms,
                minval=base.support_map_initializer.support_init.minval,
                maxval=base.support_map_initializer.support_init.maxval,
            ),
            n=env.num_states,
        )
    else:
        base.num_atoms = num_atoms


def eval_by_wasserstein(base: fdl.Buildable[MVDRLTransferTrainer[Any]]):
    base.return_metric = fdl.Config(
        SupremalMetric, base_metric=fdl.Config(Wasserstein2Metric, epsilon=1e-4)
    )


def eval_by_cramer(base: fdl.Buildable[MVDRLTransferTrainer[Any]]):
    base.return_metric = fdl.Config(
        SupremalMetric,
        base_metric=fdl.Config(
            SquaredMMDMetric, kernel=fdl.Partial(energy_distance, alpha=1.0)
        ),
    )
