import torch
import torchvision
import os
import csv
import config
import numpy as np
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image


class ToNumpy():
    def __call__(self, x):
        x = np.array(x)
        if(len(x.shape) == 2):
            x = np.expand_dims(x, axis=2)
        return x


def get_transform(opt, train=True):
    transforms_list = []
    transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
    if(train):
        transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.input_height // 8))
        transforms_list.append(transforms.RandomRotation(10))
        transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
    transforms_list.append(transforms.ToTensor())
    if(opt.dataset == 'cifar10'):
        transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
    elif(opt.dataset == 'mnist'):
        transforms_list.append(transforms.Normalize([0.5], [0.5]))
    elif(opt.dataset == 'gtsrb'):
        pass
    else:
        raise Exception("Invalid Dataset")
    return transforms.Compose(transforms_list)


class GTSRB(data.Dataset):
    def __init__(self, opt, train, transforms):
        super(GTSRB, self).__init__()
        if(train):
            self.data_folder = os.path.join(opt.data_root, 'GTSRB/Train')
            self.images, self.labels = self._get_data_train_list()
        else:
            self.data_folder = os.path.join(opt.data_root, 'GTSRB/Test')
            self.images, self.labels = self._get_data_test_list()
            
        self.transforms = transforms
        
    def _get_data_train_list(self):
        images = [] 
        labels = [] 
        for c in range(0,43):
            prefix = self.data_folder + '/' + format(c, '05d') + '/' 
            gtFile = open(prefix + 'GT-'+ format(c, '05d') + '.csv') 
            gtReader = csv.reader(gtFile, delimiter=';') 
            next(gtReader) 
            for row in gtReader:
                images.append(prefix + row[0]) 
                labels.append(int(row[7]))
            gtFile.close()
        return images, labels

    def _get_data_test_list(self):
        images = []
        labels = []
        prefix = os.path.join(self.data_folder, 'GT-final_test.csv')
        gtFile = open(prefix)
        gtReader = csv.reader(gtFile, delimiter=';')
        next(gtReader)
        for row in gtReader:
            images.append(self.data_folder + '/' + row[0])
            labels.append(int(row[7]))
        return images, labels
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index])
        image = self.transforms(image)
        label = self.labels[index]
        return image, label


def get_dataloader(opt, train=True):
    transform = get_transform(opt, train)
    if(opt.dataset == 'gtsrb'):
        dataset = GTSRB(opt, train, transform)
    elif(opt.dataset == 'mnist'):
        dataset = torchvision.datasets.MNIST(opt.data_root, train, transform, download=True)
    elif(opt.dataset == 'cifar10'):
        dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform, download=True)
    else:
        raise Exception('Invalid dataset')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True)
    return dataloader


def get_dataset(opt, train=True):
    if(opt.dataset == 'gtsrb'):
        dataset = GTSRB(opt, train, transforms=transforms.Compose([transforms.Resize((opt.input_height, opt.input_width)), ToNumpy()]))
    elif(opt.dataset == 'mnist'):
        dataset = torchvision.datasets.MNIST(opt.data_root, train, transform=ToNumpy(), download=True)
    elif(opt.dataset == 'cifar10'):
        dataset = torchvision.datasets.CIFAR10(opt.data_root, train, transform=ToNumpy(), download=True)
    else:
        raise Exception('Invalid dataset')
    return dataset 


def main():
    opt = config.get_arguments().parse_args()
    transforms = get_transform(opt, False)
    dataloader = get_dataloader(opt, False)
    for item in dataloader:
        images, labels = item
    

if(__name__ == '__main__'):
    main()
