import math
import os

import torchvision.datasets
import torchvision.transforms as T

from data.transform import TransformFixMatch

in_mean = [0.485, 0.456, 0.406]
in_std = [0.229, 0.224, 0.225]
gan_mean = [0.5, 0.5, 0.5]
gan_std = [0.5, 0.5, 0.5]


def cls_test_transforms(size, mean, std):
    upper_size = int(math.pow(2, math.ceil(math.log2(size))))
    return T.Compose([
        T.Resize(upper_size),
        T.CenterCrop(size),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])


def cls_train_transforms(size, mean, std):
    return T.Compose([
        T.RandomResizedCrop(size),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])


class ImageNet(torchvision.datasets.ImageFolder):
    def __init__(self, root="/dataset/imagenet", test=False, size=224, gan_mean_std=False):
        train = not test
        mean, std = (gan_mean, gan_std) if gan_mean_std else (in_mean, in_std)
        if train:
            split_data_dir = os.path.join(root, "train")
            super(ImageNet, self).__init__(root=split_data_dir, transform=cls_test_transforms(size=size, mean=mean, std=std))
        else:
            split_data_dir = os.path.join(root, "val")
            super(ImageNet, self).__init__(root=split_data_dir, transform=cls_test_transforms(size=size, mean=mean, std=std))


class GenericDataset(torchvision.datasets.ImageFolder):
    def __init__(self, root, size=224, test=False, fixmatch=False, gan_mean_std=False):
        train = not test
        mean, std = (gan_mean, gan_std) if gan_mean_std else (in_mean, in_std)
        if train:
            split_data_dir = os.path.join(root, "train")
            transform = TransformFixMatch(size=size, mean=mean, std=std) if fixmatch else cls_train_transforms(size=size, mean=mean, std=std)
        else:
            split_data_dir = os.path.join(root, "test")
            transform = cls_test_transforms(size=size, mean=mean, std=std)
        super(GenericDataset, self).__init__(root=split_data_dir, transform=transform)


class FixMatchDataset(torchvision.datasets.ImageFolder):
    def __init__(self, root, size=224, test=False):
        train = not test
        if train:
            split_data_dir = os.path.join(root, "train")
            transform = TransformFixMatch(size=size, mean=in_mean, std=in_std)
        else:
            split_data_dir = os.path.join(root, "test")
            transform = cls_test_transforms(size=size, mean=in_mean, std=in_std)
        super(FixMatchDataset, self).__init__(root=split_data_dir, transform=transform)


class StanfordCars(GenericDataset):
    def __init__(self, root="/dataset/StanfordCars", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(StanfordCars, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Birds(GenericDataset):
    def __init__(self, root="/dataset/CUB-200-2011", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Birds, self).__init__(root, size, test, fixmatch, gan_mean_std)


class DTD(GenericDataset):
    def __init__(self, root="/dataset/DTD", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(DTD, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Aircraft(GenericDataset):
    def __init__(self, root="/dataset/FGVC-Aircraft", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Aircraft, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Indoor67(GenericDataset):
    def __init__(self, root="/dataset/Indoor67", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Indoor67, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Flower(GenericDataset):
    def __init__(self, root="/dataset/OxfordFlower102", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Flower, self).__init__(root, size, test, fixmatch, gan_mean_std)


class TinyImageNet(GenericDataset):
    def __init__(self, root="/dataset/tiny-imagenet-200", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(TinyImageNet, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Dogs(GenericDataset):
    def __init__(self, root="/dataset/StanfordDogs", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Dogs, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Caltech_256_60(GenericDataset):
    def __init__(self, root="/dataset/Caltech-256-60", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Caltech_256_60, self).__init__(root, size, test, fixmatch, gan_mean_std)


class Pets(GenericDataset):
    def __init__(self, root="/dataset/OxfordPets", size=224, test=False, fixmatch=False, gan_mean_std=False):
        super(Pets, self).__init__(root, size, test, fixmatch, gan_mean_std)
