import chex
import fiddle as fdl
import jax
import jax.numpy as jnp

from tabular_mvdrl.agents.cat_projected_td import CatProjectedTDTrainer

from . import cat_projected_td as cat_projected_td_config


def dirichlet_prior(key: chex.PRNGKey, n: int, alpha: float) -> chex.Array:
    return jax.random.dirichlet(key, alpha * jnp.ones(n))


def base(**kwargs) -> fdl.Config[CatProjectedTDTrainer]:
    cfg = cat_projected_td_config.base(**kwargs)
    cfg.signed = True
    return cfg


def rowland(**kwargs) -> fdl.Config[CatProjectedTDTrainer]:
    cfg = cat_projected_td_config.rowland(**kwargs)
    cfg.signed = True
    return cfg


def rowland_multivariate(**kwargs) -> fdl.Config[CatProjectedTDTrainer]:
    cfg = cat_projected_td_config.rowland_multivariate(**kwargs)
    cfg.signed = True
    return cfg


def extrapolated_rowland() -> fdl.Config[CatProjectedTDTrainer]:
    cfg = cat_projected_td_config.extrapolated_rowland()
    cfg.signed = True
    return cfg


### FIDDLERS


def finite_horizon(cfg: fdl.Config[CatProjectedTDTrainer], horizon=4):
    cat_projected_td_config.finite_horizon(cfg)


def terminal_reward(**kwargs):
    cfg = cat_projected_td_config.terminal_reward(**kwargs)
    cfg.signed = True
    return cfg


def l1_kernel(cfg: fdl.Config[CatProjectedTDTrainer]):
    cat_projected_td_config.l1_kernel(cfg)
