import torchvision
import torch
from pathlib import Path

DATASETS = ['FashionMNIST', 'CIFAR']


class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, list_IDs, labels):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs

  def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        X = self.list_IDs[index]
        y = self.labels[index]
        return X, y


def load_images(loader):
    # transform data to torch list
    image = []
    label = []
    for _, (data, target) in enumerate(loader):
        image.append(data)
        label.append(target)
    image = torch.cat(image, dim=0)
    label = torch.cat(label, dim=0)
    return image, label


def load_data(dataset, batchsize, ROOT_DIR):
    if dataset == 'FashionMNIST':
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.FashionMNIST(Path(ROOT_DIR, 'data'), train=True, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Lambda(lambda x: torch.flatten(x))
                                              ])), batch_size=batchsize, shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            torchvision.datasets.FashionMNIST(Path(ROOT_DIR, 'data'), train=False, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Lambda(lambda x: torch.flatten(x))
                                              ])), batch_size=1000, shuffle=False)

    elif dataset == 'CIFAR10':

        train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(Path(ROOT_DIR, 'data'), train=True, download=True,
                                             transform=torchvision.transforms.Compose([
                                                 torchvision.transforms.ToTensor()
                                             ])),batch_size=batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(Path(ROOT_DIR, 'data'), train=False, download=True,
                                             transform=torchvision.transforms.Compose([
                                                 torchvision.transforms.ToTensor()
                                             ])),batch_size=1000, shuffle=False)

    # load test images and split dataset for adversarial examples
    test_x, test_y = load_images(test_loader)
    test_data = Dataset(test_x, test_y)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batchsize, shuffle=False)
    adv_test_data = Dataset(test_x[:1000, :], test_y[:1000])
    test_for_adv_loader = torch.utils.data.DataLoader(adv_test_data, batch_size=1, shuffle=False)

    return train_loader, test_loader, test_for_adv_loader