""" Convert dataset to HDF5
    This script preprocesses a dataset and saves it (images and labels) to
    an HDF5 file for improved I/O. """
import pdb
from argparse import ArgumentParser

import h5py as h5
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from tqdm import tqdm

imsize_dict = {'I32': 32, 'I32_hdf5': 32,
               'I64': 64, 'I64_hdf5': 64,
               'I128': 128, 'I128_hdf5': 128,
               'I256': 256, 'I256_hdf5': 256,
               'C10': 32, 'C100': 32}

root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5',
             'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5',
             'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5',
             'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5',
             'C10': 'cifar', 'C100': 'cifar'}


class RandomCropLongEdge(object):
    """Crops the given PIL Image on the long edge with a random start point.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped image.
        """
        size = (min(img.size), min(img.size))
        # Only step forward along this edge if it's the long edge
        i = (0 if size[0] == img.size[0]
             else np.random.randint(low=0, high=img.size[0] - size[0]))
        j = (0 if size[1] == img.size[1]
             else np.random.randint(low=0, high=img.size[1] - size[1]))
        return transforms.functional.crop(img, i, j, size[0], size[1])

    def __repr__(self):
        return self.__class__.__name__


class CenterCropLongEdge(object):
    """Crops the given PIL Image on the long edge.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped image.
        """
        return transforms.functional.center_crop(img, min(img.size))

    def __repr__(self):
        return self.__class__.__name__


def get_data_loaders(dataset, data_root=None, img_size=256, augment=False, batch_size=64,
                     num_workers=8, shuffle=True, load_in_mem=False, hdf5=False,
                     pin_memory=True, drop_last=True, start_itr=0,
                     num_epochs=500, use_multiepoch_sampler=False, train_test="train", **kwargs):

    data_root += f"/{train_test}"
    print('Using dataset root location %s' % data_root)

    norm_mean = [0.5, 0.5, 0.5]
    norm_std = [0.5, 0.5, 0.5]
    image_size = img_size

    # HDF5 datasets have their own inbuilt transform, no need to train_transform
    if 'hdf5' in dataset:
        train_transform = None
    else:
        if dataset in ['C10', 'C100']:
            train_transform = []
        else:
            train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)]
        # train_transform = [transforms.Resize(image_size), transforms.CenterCrop]
        train_transform = transforms.Compose(train_transform + [
            transforms.ToTensor(),
            transforms.Normalize(norm_mean, norm_std)])

    torch.manual_seed(42)  # Ensure fixed seed to randomly split datasets
    train_set = ImageFolder(root=data_root, transform=train_transform)
    if kwargs["use_ratio"] < 1.0:
        full_size = len(train_set)
        use_size = int(full_size * kwargs["use_ratio"])
        train_set, _ = random_split(train_set, [use_size, full_size - use_size])
        print(f"## Reduced dataset size from {full_size} to {len(train_set)}")
    torch.manual_seed(torch.initial_seed())

    # Prepare loader; the loaders list is for forward compatibility with
    # using validation / test splits.
    loaders = []
    loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory,
                     'drop_last': drop_last}  # Default, drop last incomplete batch
    train_loader = DataLoader(train_set, batch_size=batch_size,
                              shuffle=shuffle, **loader_kwargs)
    loaders.append(train_loader)
    return loaders


def prepare_parser():
    usage = "Parser for ImageNet HDF5 scripts."
    parser = ArgumentParser(description=usage)
    parser.add_argument(
        "--img_size", type=int, default=256, help="Input image size to train (default: %(default)s)"
    )
    parser.add_argument(
        "--data_root", type=str, required=True, help="Default location where data is stored (default: %(default)s)"
    )
    parser.add_argument(
        "--data_name", type=str, required=True, help="Default location where data is stored (default: %(default)s)"
    )
    parser.add_argument("--batch_size", type=int, default=256, help="Default overall batchsize (default: %(default)s)")
    parser.add_argument(
        "--num_workers", type=int, default=16, help="Number of dataloader workers (default: %(default)s)"
    )
    parser.add_argument("--chunk_size", type=int, default=500, help="Default overall batchsize (default: %(default)s)")
    parser.add_argument("--use_ratio", type=float, default=1.0, help="Dataset size to be used (default: %(default)s)")
    parser.add_argument(
        "--compression", action="store_true", default=False, help="Use LZF compression? (default: %(default)s)"
    )
    return parser


def run(config):
    if "hdf5" in config["data_root"]:
        raise ValueError(
            "Reading from an HDF5 file which you will probably be "
            "about to overwrite! Override this error only if you know "
            "what you"
            "re doing!"
        )
    # Get image size
    config["image_size"] = config["img_size"]

    # Update compression entry
    config["compression"] = "lzf" if config["compression"] else None  # No compression; can also use 'lzf'

    # Get dataset
    kwargs = {"num_workers": config["num_workers"], "pin_memory": False, "drop_last": False, "use_ratio": config["use_ratio"]}
    train_loader = get_data_loaders(
        dataset=config["data_name"],
        data_root=config["data_root"],
        img_size=config["img_size"],
        batch_size=config["batch_size"],
        shuffle=False,
        use_multiepoch_sampler=False,
        **kwargs
    )[0]

    # HDF5 supports chunking and compression. You may want to experiment
    # with different chunk sizes to see how it runs on your machines.
    # Chunk Size/compression     Read speed @ 256x256   Read speed @ 128x128  Filesize @ 128x128    Time to write @128x128
    # 1 / None                   20/s
    # 500 / None                 ramps up to 77/s       102/s                 61GB                  23min
    # 500 / LZF                                         8/s                   56GB                  23min
    # 1000 / None                78/s
    # 5000 / None                81/s
    # auto:(125,1,16,32) / None                         11/s                  61GB

    print(
        "Starting to load %s into an HDF5 file with chunk size %i and compression %s..."
        % (config["data_name"], config["chunk_size"], config["compression"])
    )
    # Loop over train loader
    print(type(train_loader), len(train_loader))
    for i, (x, y) in enumerate(tqdm(train_loader)):
        # Stick X into the range [0, 255] since it's coming from the train loader
        # pdb.set_trace()
        x = (255 * ((x + 1) / 2.0)).byte().numpy()
        x = np.transpose(x, (0, 2, 3, 1))
        # Numpyify y
        y = y.numpy()
        # If we're on the first batch, prepare the hdf5
        if i == 0:
            with h5.File(config["data_root"] + f"/{config['data_name']}{config['image_size']}.hdf5", "w") as f:
                print("Producing dataset of len %d" % len(train_loader.dataset))
                imgs_dset = f.create_dataset(
                    "imgs",
                    x.shape,
                    dtype="uint8",
                    maxshape=(len(train_loader.dataset), config["image_size"], config["image_size"], 3),
                    chunks=(config["chunk_size"], config["image_size"], config["image_size"], 3),
                    compression=config["compression"],
                )
                print("Image chunks chosen as " + str(imgs_dset.chunks))
                imgs_dset[...] = x
                labels_dset = f.create_dataset(
                    "labels",
                    y.shape,
                    dtype="int64",
                    maxshape=(len(train_loader.dataset),),
                    chunks=(config["chunk_size"],),
                    compression=config["compression"],
                )
                print("Label chunks chosen as " + str(labels_dset.chunks))
                labels_dset[...] = y
        # Else append to the hdf5
        else:
            with h5.File(config["data_root"] + f"/{config['data_name']}{config['image_size']}.hdf5", "a") as f:
                f["imgs"].resize(f["imgs"].shape[0] + x.shape[0], axis=0)
                f["imgs"][-x.shape[0]:] = x
                f["labels"].resize(f["labels"].shape[0] + y.shape[0], axis=0)
                f["labels"][-y.shape[0]:] = y


def main():
    # parse command line and run
    parser = prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    run(config)


if __name__ == "__main__":
    main()
