import unittest
from unittest import TestCase

import numpy as np
import torch

from src.utils.misc import is_scalar


class Test(TestCase):
    def test_is_scalar(self):
        """Check if is_scalar accepts the intended objects."""
        # List different objects, and the ground truth for whether they are scalar or not.
        test_cases = dict()
        test_cases["a"] = (int(1), True)
        test_cases["b"] = (float(1), True)
        test_cases["c"] = (bool(1), True)
        test_cases["d"] = (np.array([1]), True)
        test_cases["e"] = (np.array([1])[None], True)
        test_cases["f"] = (np.array([[1, 1]]), False)
        test_cases["g"] = (torch.tensor([1]), True)
        test_cases["h"] = (torch.tensor([1])[None], True)
        test_cases["i"] = (torch.tensor([1, 1]), False)
        test_cases["j"] = (list([1, 1]), False)
        test_cases["k"] = (dict({"a": 1, "b": 2}), False)

        # Check that the predicted outputs are correct.
        for v in test_cases.values():
            test_obj, ground_truth = v
            self.assertEqual(is_scalar(test_obj), ground_truth)

        # Check that all of the accepted values can be converted to floats.
        for v in test_cases.values():
            try:
                if is_scalar(v):
                    float(v)
            except TypeError:
                self.fail(f"An accepted object cannot be turned into float. ")


if __name__ == "__main__":
    """
    Run from root. 
    python -m unittest -v src.utils.test_misc
    """
    unittest.main()
