import unittest
from unittest import TestCase

import torch
from torch.nn.functional import relu

from src.dl.models.cell_types.fully_connected import OneLayerFullyConnectedCell

BATCH_SIZE = 32
Z_DIM = 100


class TestOneLayerFullyConnectedCell(TestCase):
    def setUp(self) -> None:
        self.cell = OneLayerFullyConnectedCell(z_dim=Z_DIM, activation=relu)

    def test_forward(self):
        """Check that the forward pass doesn't crash."""
        # Make sure that the forward pass runs without any trouble.
        zs = torch.zeros((BATCH_SIZE, Z_DIM))
        xs = torch.zeros((BATCH_SIZE, Z_DIM))
        self.cell(z=zs, x=xs)

    def test_parameter_count(self) -> None:
        """Check that the model has the expected number of parameters."""
        # Check that the number of parameters is as expected.
        total_trainable_params = sum(p.numel() for p in self.cell.parameters() if p.requires_grad)
        total_nontrainable_params = sum(p.numel() for p in self.cell.parameters() if not p.requires_grad)

        # Check number of trainable parameters.
        expected_num_trainable_params = (Z_DIM + 1) * Z_DIM
        self.assertEqual(total_trainable_params, expected_num_trainable_params)

        # Check that there are no nontrainable parameters.
        self.assertEqual(total_nontrainable_params, 0)

    def test_check_all_params_are_used(self):
        """Make sure that all parameters are used during forward pass."""
        # Zero out the gradients.
        self.cell.zero_grad()

        # Run forward pass on a random input.
        xs = torch.randn(size=(BATCH_SIZE, Z_DIM))
        zs = torch.randn(size=(BATCH_SIZE, Z_DIM))
        ys = self.cell(zs, xs)

        # Compute a loss and backprop.
        loss = torch.sum(ys**2)
        loss.backward()

        # Check that none of the parameters have exactly zero gradients.
        contains_zero_grad = any([torch.allclose(p.grad, torch.zeros_like(p)) for p in self.cell.parameters()])
        self.assertTrue(not contains_zero_grad)


if __name__ == "__main__":
    """
    Run following command to launch the test:
    python -m unittest -v src.dl.models.cell_types.test_fully_connected
    """
    unittest.main()
