import math
import pdb

import h5py as h5
import torchvision.transforms as T
from torch.utils.data import Dataset

in_mean = [0.485, 0.456, 0.406]
in_std = [0.229, 0.224, 0.225]


def gan_transforms():
    transform_list = []
    transform_list += [T.ToPILImage()]

    transform_list += [T.ToTensor()]
    transform_list += [T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]

    return T.Compose(transform_list)


def cls_test_transforms(size, mean, std):
    upper_size = int(math.pow(2, math.ceil(math.log2(size))))
    return T.Compose([
        T.ToPILImage(),
        T.Resize(upper_size),
        T.CenterCrop(size),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])


def cls_train_transforms(size, mean, std):
    return T.Compose([
        T.ToPILImage(),
        T.RandomResizedCrop(size),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])


def get_transform(mode, mean, std, size, test):
    if mode == "gan":
        return gan_transforms()
    if mode == "cls":
        assert size is not None
        return cls_test_transforms(size, mean, std) if test else cls_train_transforms(size, mean, std)
    raise NotImplementedError


class HDF5(Dataset):
    def __init__(self, hdf5_path, mode="gan", load_data_in_memory=False, test=False, mean=in_mean, std=in_std, size=None, dataset_name=None):
        super(HDF5, self).__init__()
        self.dataset_name = dataset_name
        self.hdf5_path = hdf5_path
        self.load_data_in_memory = load_data_in_memory
        self.trsf = get_transform(mode, mean, std, size, test)
        self.load_dataset()

    def load_dataset(self):
        with h5.File(self.hdf5_path, "r") as f:
            data, labels = f["imgs"], f["labels"]
            self.num_dataset = data.shape[0]
            if self.load_data_in_memory:
                print("Load {path} into memory.".format(path=self.hdf5_path))
                self.data = data[:]
                self.labels = labels[:]

    def _get_hdf5(self, index):
        with h5.File(self.hdf5_path, "r") as f:
            return f["imgs"][index], f["labels"][index]

    def __len__(self):
        if self.hdf5_path is None:
            num_dataset = len(self.data)
        else:
            num_dataset = self.num_dataset
        return num_dataset

    def __getitem__(self, index):
        if self.load_data_in_memory:
            img, label = self.data[index], self.labels[index]
        else:
            img, label = self._get_hdf5(index)
        return self.trsf(img), int(label)
