import unittest
from unittest import TestCase

import numpy as np

from src.data.datasets.prefix_sum.prefix_sum import PrefixSumDataset

DATASETS_PATH = "./data"


class TestPrefixSumDataset(TestCase):
    def test_shape(self):
        num_bits = 16
        prefix_sum_dataset = PrefixSumDataset(root=DATASETS_PATH,
                                              partition_name="train",
                                              num_bits=num_bits,
                                              download=False)
        x, y = prefix_sum_dataset[0]
        self.assertEqual(tuple(x.shape), (1, num_bits))
        self.assertEqual(tuple(y.shape), (num_bits,))

    def test_partitions_are_disjoint(self):
        num_bits = 16

        def get_distinct_inputs(partition):
            print(f"Extracting unique inputs from the {partition} partition of the dataset. ")
            train_dataset = PrefixSumDataset(root=DATASETS_PATH,
                                             partition_name=partition,
                                             num_bits=num_bits,
                                             download=False)
            return set([np.array2string(n) for n in list(train_dataset.inputs.numpy())])

        train_inputs = get_distinct_inputs("train")
        valid_inputs = get_distinct_inputs("valid")
        test_inputs = get_distinct_inputs("test")

        self.assertEqual(len(train_inputs.intersection(valid_inputs)), 0)
        self.assertEqual(len(valid_inputs.intersection(test_inputs)), 0)
        self.assertEqual(len(test_inputs.intersection(train_inputs)), 0)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.data.datasets.prefix_sum.test_prefix_sum
    """
    unittest.main()
