import torch

from easy_to_hard_data import PrefixSumDataset as OriginalPrefixSumDataset

ALLOWABLE_PARTITIONS = ["train", "valid", "test"]

TRAIN_SPLIT_PROPORTION = 0.8
VALID_SPLIT_PROPORTION = 0.1
TEST_SPLIT_PROPORTION = 0.1


class PrefixSumDataset(OriginalPrefixSumDataset):
    def __init__(self, root, partition_name, num_bits, download):
        if partition_name not in ALLOWABLE_PARTITIONS:
            message = f"Only partition names {ALLOWABLE_PARTITIONS} are allowed. "
            raise ValueError(message)

        self.lengths = list(range(16, 65))
        self.lengths.extend([72, 128, 256, 512, 1024, 2048, 3072, 4096])

        # Check if the requested number of bits is supported.
        if num_bits not in self.lengths:
            raise ValueError(f"num_bits = {num_bits} is not supported. Supported ones include: {self.lengths}")

        # Read the data and setup some other things.
        super().__init__(root=root, num_bits=num_bits, download=download)
        self.split_ids = {num_bits}

        # Shuffle the data.
        num_datapoints = self.inputs.shape[0]
        idx = torch.randperm(num_datapoints, generator=torch.Generator().manual_seed(0))
        self.inputs = self.inputs[idx]
        self.targets = self.targets[idx]

        # Divide into splits.
        if partition_name == "train":
            start, end = 0, int(TRAIN_SPLIT_PROPORTION * num_datapoints)
        elif partition_name == "valid":
            start = int(TRAIN_SPLIT_PROPORTION * num_datapoints)
            end = int((TRAIN_SPLIT_PROPORTION + VALID_SPLIT_PROPORTION) * num_datapoints)
        else:
            assert partition_name == "test"
            start, end = int((TRAIN_SPLIT_PROPORTION + VALID_SPLIT_PROPORTION) * num_datapoints), -1

        self.inputs = self.inputs[start:end]
        self.targets = self.targets[start:end]
