"""Dataloading utilities, for MNIST, FashionMNIST, CIFAR-10/100, and SVHN.
"""
import math

import numpy as np

from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms


def load_all_minibatches(dataloader):
  all_minibatches = []
  for minibatch in dataloader:
    all_minibatches.append(minibatch)
  return all_minibatches


def load_mnist(batch_size, val_split=True, normalize=True, preload=False):
  transformations = [transforms.ToTensor()]
  if normalize:
      transformations.append(transforms.Normalize((0.1307,), (0.3081,)))
  transform = transforms.Compose(transformations)

  if val_split:
    # Will split training set into 50,000 training and 10,000 validation images
    num_train = 50000

    # Train set
    trainset = datasets.MNIST(
        root='./data/mnist', train=True, download=True, transform=transform
    )
    trainset.train_data = trainset.train_data[:num_train, :, :]
    trainset.train_labels = trainset.train_labels[:num_train]
    # Validation set
    valset = datasets.MNIST(
        root='./data/mnist', train=True, download=True, transform=transform
    )
    valset.train_data = valset.train_data[num_train:, :, :]
    valset.train_labels = valset.train_labels[num_train:]
    # Test set
    testset = datasets.MNIST(
        root='./data/mnist', train=False, download=True, transform=transform
    )
    testset.test_data = testset.test_data

    # 50,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0
    )

    # 10,000 images
    val_dataloader = DataLoader(
      valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )

    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      val_dataloader = load_all_minibatches(val_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, val_dataloader, test_dataloader
  else:
    trainset = datasets.MNIST(
        root='./data/mnist', train=True, download=True, transform=transform
    )
    testset = datasets.MNIST(
        root='./data/mnist', train=False, download=True, transform=transform
    )

    # 50,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=0
    )

    # 10,000 images
    test_dataloader = DataLoader(
        testset, batch_size=batch_size, pin_memory=True, shuffle=False, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, None, test_dataloader


def load_fashion_mnist(batch_size, val_split=True, normalize=True, preload=False):
  transformations = [transforms.ToTensor()]
  if normalize:
    transformations.append(transforms.Normalize((0.1307,), (0.3081,)))
  transform = transforms.Compose(transformations)

  if val_split:
    # Will split training set into 50,000 training and 10,000 validation images
    num_train = 50000

    # Train set
    trainset = datasets.FashionMNIST(
        root='./data/fashion', train=True, download=True, transform=transform
    )
    trainset.train_data = trainset.train_data[:num_train, :, :]
    trainset.train_labels = trainset.train_labels[:num_train]

    # Validation set
    valset = datasets.FashionMNIST(
        root='./data/fashion', train=True, download=True, transform=transform
    )
    valset.train_data = valset.train_data[num_train:, :, :]
    valset.train_labels = valset.train_labels[num_train:]
    # Test set
    testset = datasets.FashionMNIST(
        root='./data/fashion', train=False, download=True, transform=transform
    )

    # 50,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, num_workers=0
    )

    # 10,000 images
    val_dataloader = DataLoader(
      valset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      val_dataloader = load_all_minibatches(val_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, val_dataloader, test_dataloader
  else:
    trainset = datasets.FashionMNIST(
        root='./data/fashion', train=True, download=True, transform=transform
    )
    testset = datasets.FashionMNIST(
        root='./data/fashion', train=False, download=True, transform=transform
    )

    # 60,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, None, test_dataloader


def load_cifar10(batch_size, val_split=True, normalize=True, augmentation=False, preload=False):
  train_transforms = []
  test_transforms = []

  if augmentation:
    train_transforms.append(transforms.RandomCrop(32, padding=4))
    train_transforms.append(transforms.RandomHorizontalFlip())

  train_transforms.append(transforms.ToTensor())
  test_transforms.append(transforms.ToTensor())

  if normalize:
    normalize = transforms.Normalize(
      mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
      std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
    train_transforms.append(normalize)
    test_transforms.append(normalize)

  train_transform = transforms.Compose(train_transforms)
  test_transform = transforms.Compose(test_transforms)

  if val_split:
    # Will split training set into 45,000 training and 5,000 validation images
    num_train = 45000

    # Train set
    trainset = datasets.CIFAR10(
        root='./data/cifar10', train=True, download=True, transform=train_transform
    )
    trainset.train_data = trainset.train_data[:num_train, :, :, :]
    trainset.train_labels = trainset.train_labels[:num_train]
    # Validation set
    valset = datasets.CIFAR10(
        root='./data/cifar10', train=True, download=True, transform=test_transform
    )
    valset.train_data = valset.train_data[num_train:, :, :, :]
    valset.train_labels = valset.train_labels[num_train:]
    # Test set
    testset = datasets.CIFAR10(
        root='./data/cifar10', train=False, download=True, transform=test_transform
    )

    # 45,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0
    )

    # 5,000 images
    val_dataloader = DataLoader(
      valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )

    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      val_dataloader = load_all_minibatches(val_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, val_dataloader, test_dataloader
  else:
    trainset = datasets.CIFAR10(
        root='./data/cifar10', train=True, download=True, transform=train_transform
    )
    testset = datasets.CIFAR10(
        root='./data/cifar10', train=False, download=True, transform=test_transform
    )

    # 50,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2
    )
    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, None, test_dataloader


def load_cifar100(batch_size, val_split=True, normalize=True, augmentation=False, preload=False):
  train_transforms = []
  test_transforms = []

  if augmentation:
    train_transforms.append(transforms.RandomCrop(32, padding=4))
    train_transforms.append(transforms.RandomHorizontalFlip())

  train_transforms.append(transforms.ToTensor())
  test_transforms.append(transforms.ToTensor())

  if normalize:
    normalize = transforms.Normalize(
      mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
      std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
    train_transforms.append(normalize)
    test_transforms.append(normalize)

  train_transform = transforms.Compose(train_transforms)
  test_transform = transforms.Compose(test_transforms)

  if val_split:
    # Will split training set into 45,000 training and 5,000 validation images
    num_train = 45000

    # Train set
    trainset = datasets.CIFAR100(
        root='./data/cifar100', train=True, download=True, transform=train_transform
    )
    trainset.train_data = trainset.train_data[:num_train, :, :, :]
    trainset.train_labels = trainset.train_labels[:num_train]
    # Validation set
    valset = datasets.CIFAR100(
        root='./data/cifar100', train=True, download=True, transform=test_transform
    )
    valset.train_data = valset.train_data[num_train:, :, :, :]
    valset.train_labels = valset.train_labels[num_train:]
    # Test set
    testset = datasets.CIFAR100(
        root='./data/cifar100', train=False, download=True, transform=test_transform
    )

    # 45,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    # 5,000 images
    val_dataloader = DataLoader(
      valset, batch_size=batch_size, shuffle=False, num_workers=0
    )
    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      val_dataloader = load_all_minibatches(val_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, val_dataloader, test_dataloader
  else:
    trainset = datasets.CIFAR100(
        root='./data/cifar100', train=True, download=True, transform=train_transform
    )
    testset = datasets.CIFAR100(
        root='./data/cifar100', train=False, download=True, transform=test_transform
    )

    # 50,000 images
    train_dataloader = DataLoader(
      trainset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    # 10,000 images
    test_dataloader = DataLoader(
      testset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, None, test_dataloader


def load_svhn(batch_size, val_split=True, normalize=True, use_extra_data=False, preload=False):
  transformations = [transforms.ToTensor()]
  if normalize:
    normalize = transforms.Normalize(
      mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
      std=[x / 255.0 for x in [50.1, 50.6, 50.8]]
    )
    transformations.append(normalize)
  transform = transforms.Compose(transformations)

  trainset = datasets.SVHN(
      root='./data/svhn', split='train', transform=transform, download=True
  )

  if use_extra_data:
    extra_dataset = datasets.SVHN(
        root='./data/svhn', split='extra', transform=transform, download=True
    )
    # Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf)
    data = np.concatenate([trainset.data, extra_dataset.data], axis=0)
    labels = np.concatenate([trainset.labels, extra_dataset.labels], axis=0)
    trainset.data = data
    trainset.labels = labels

  if val_split:
    # Use 90% of the training data as the train set and 10% as the validation set
    train_ratio = 0.9
    num_train = math.floor(len(trainset.data) * train_ratio)

    valset = datasets.SVHN(
        root='./data/svhn', split='train', transform=transform, download=True
    )
    valset.data = trainset.data[num_train:, :, :, :]
    valset.labels = trainset.labels[num_train:]

    trainset.data = trainset.data[:num_train, :, :, :]
    trainset.labels = trainset.labels[:num_train]

    # Test set
    testset = datasets.SVHN(
        root='./data/svhn', split='test', download=True, transform=transform
    )

    train_dataloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0
    )
    val_dataloader = DataLoader(
        valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )
    test_dataloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )

    if preload:
      train_dataloader = load_all_minibatches(train_dataloader)
      val_dataloader = load_all_minibatches(val_dataloader)
      test_dataloader = load_all_minibatches(test_dataloader)

    return train_dataloader, val_dataloader, test_dataloader
  else:
    testset = datasets.SVHN(
        root='./data/svhn', split='test', download=True, transform=transform
    )
    train_dataloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0
    )
    test_dataloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=0
    )
    return train_dataloader, None, test_dataloader
