from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import torch


def load_data(dataset_name, batch_size, val_data_size, shuffle=True, data_augment=True):
    dataset_map = {
        "mnist": datasets.MNIST,
        "fashionmnist": datasets.FashionMNIST,
        "cifar10": datasets.CIFAR10,
        "cifar100": datasets.CIFAR100
    }

    assert (dataset_name in dataset_map.keys())

    if dataset_name == "mnist":
        transform_train = transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    elif dataset_name == "fashionmnist":
        transform_train = transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

    elif dataset_name == "cifar10":
        normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                         std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        if data_augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    elif dataset_name == "cifar100":
        if data_augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                # transforms.Normalize((0.5071, 0.4867, 0.4408),
                #                      (0.2675, 0.2565, 0.2761)),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408),
                                     (0.2675, 0.2565, 0.2761)),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761)),
        ])
    else:
        raise Exception()

    data = dataset_map[dataset_name]("./data", train=True, download=True,
                                     transform=transform_train)

    data_no_augment = dataset_map[dataset_name](root='data/', train=True, download=True,
                                                transform=transform_test)

    test_data = dataset_map[dataset_name]("./data", train=False, download=True,
                                          transform=transform_test)

    val_data = None
    if val_data_size > 0:
        discarded_data = val_data_size
        removed = np.random.choice(range(50000), discarded_data, replace=False)
        keep = list(set(range(50000)) - set(removed))
        data = torch.utils.data.Subset(data, keep)
        val_data = torch.utils.data.Subset(data_no_augment, removed)
        print("Train dataset size: ", len(data))
        print("Val dataset size: ", len(val_data))

    if val_data is None:
        val_loader = None
    else:
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(data, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


def next_batch(data_iter, data_loader):
    try:
        inputs, targets = data_iter.next()
    except StopIteration:
        data_iter = iter(data_loader)
        inputs, targets = data_iter.next()
    return inputs, targets, data_iter


class DataSampler:
    def __init__(self, loader, device):
        self.loader = loader
        self.data_iter = iter(self.loader)
        self.device = device

    def next(self):
        try:
            inputs, targets = self.data_iter.next()
        except StopIteration:
            self.data_iter = iter(self.loader)
            inputs, targets = self.data_iter.next()
        return inputs.to(self.device), targets.to(self.device)
