from spaghettini import quick_register
from typing import Sequence, Union, Callable

from torch.utils.data.dataloader import DataLoader

from src.data.datasets.dataset_utils import ConcatDatasetWithSplitIds

SUPPORTED_MIXTURE_MODES = ["mixture", "separate"]


@quick_register
def get_multi_split_dataloaders(dataset_getter: Callable, partition_name: str,
                                mixture_mode: str, batch_size: int, split_ids: Sequence[int],
                                num_workers_per_dataset: int = 0, shuffle: bool = True) \
        -> Union[DataLoader, Sequence[DataLoader]]:
    """Return dataloader(s) from in-distribution training and in-out of distribution evaluation."""
    if mixture_mode not in SUPPORTED_MIXTURE_MODES:
        raise ValueError(f"Mixture mode {mixture_mode} not amongst supported modes {SUPPORTED_MIXTURE_MODES}.")

    # Get the datasets (one per split).
    datasets_list = list()
    for split_id in split_ids:
        curr_dataset = dataset_getter(dataset_partition_name=partition_name, split_id=split_id)
        datasets_list.append(curr_dataset)

    # Construct the dataloaders and return.
    if mixture_mode == "mixture":
        # Create a mixture dataset for training.
        mixture_dataset = ConcatDatasetWithSplitIds(datasets_list)
        num_workers = num_workers_per_dataset * len(datasets_list)
        loader = DataLoader(dataset=mixture_dataset, batch_size=batch_size, shuffle=shuffle,
                            num_workers=num_workers, pin_memory=True)
        return loader
    else:  # "separate".
        assert mixture_mode in "separate"
        # Return (a list of) separate dataloaders (i.e. avoid mixing).
        loaders = [DataLoader(datasets_list[i], batch_size=batch_size, shuffle=shuffle,
                              num_workers=num_workers_per_dataset, pin_memory=True)
                   for i in range(len(split_ids))]
        return loaders


if __name__ == "__main__":
    """
    Run command:
    python -m src.data.data_loading.multi_split_loaders
    """
    dev_num = 0

    if dev_num == 0:
        pass
