import unittest

import jax
import jax.numpy as jnp

from tabular_mvdrl import kernels
from tabular_mvdrl.utils.discrete_distributions import (
    DiscreteDistribution,
    SquaredMMDMetric,
)


class TestKernels(unittest.TestCase):
    def setUp(self):
        self.dim = 2
        self.num_samples = 10
        self.rng, x_key, y_key = jax.random.split(jax.random.PRNGKey(0), 3)
        self.x = jax.random.normal(x_key, shape=(self.dim,))
        self.y = 0.5 + jax.random.normal(y_key, shape=(self.dim,))
        self.p_samples = jax.random.normal(x_key, shape=(self.num_samples, self.dim))
        self.q_samples = 0.5 + jax.random.normal(
            y_key, shape=(2 * self.num_samples, self.dim)
        )
        self.p = DiscreteDistribution.empirical_from(self.p_samples)
        self.q = DiscreteDistribution.empirical_from(self.q_samples)

    def test_l1(self):
        out = kernels.l1(self.x, self.y).item()
        self.assertIsInstance(out, float)

    def test_l1_grads_normal(self):
        grads = jax.grad(kernels.l1)(self.x, self.y)
        self.assertListEqual(list(grads.shape), list(self.x.shape))
        self.assertListEqual(list(grads), list(-jnp.sign(self.x - self.y)))

    def test_l1_grads_kink(self):
        grads = jax.grad(kernels.l1)(self.x, self.x)
        self.assertListEqual(list(grads.shape), list(self.x.shape))
        self.assertEqual(jnp.sum(jnp.abs(grads)), 0.0)

    def test_l1_mmd(self):
        mmd = SquaredMMDMetric(kernels.l1)(self.p, self.q).item()
        self.assertIsInstance(mmd, float)
        mmd_self = SquaredMMDMetric(kernels.l1)(self.p, self.p).item()
        self.assertEqual(mmd_self, 0.0)

    def test_l1_mmd_grad(self):
        mmd_loss = SquaredMMDMetric(kernels.l1)
        grads = jax.grad(mmd_loss)(self.p, self.q)
        self.assertIsInstance(grads, DiscreteDistribution)
        grads_self = jax.grad(mmd_loss)(self.p, self.p)
        locs_grads = grads_self.locs
        self.assertEqual(jnp.sum(jnp.abs(locs_grads)), 0.0)


if __name__ == "__main__":
    unittest.main()
