from torchvision import transforms, datasets
from typing import *
# import imageio
import torch
import pickle
import os
import glob
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import matplotlib.image as mpimg
import torch.nn.functional as F

def get_dim(name):
    if name == "cifar10":
        return 3 * 32 * 32
    if name == "mnist":
        return 28 * 28
    if name == "imagenet":
        return 3 * 224 * 224
    if name == "tiny_imagenet":
        return 3 * 64 * 64
    if name == "fashion":
        return 28 * 28

def get_num_classes(dataset: str):
    """Return the number of classes in the dataset. """
    if dataset.lower() == "imagenet":
        return 1000
    elif dataset.lower() == "cifar10":
        return 10
    elif dataset.lower() == "tiny_imagenet":
        return 200

def get_input_channels(dataset: str):
    """Return the number of channels in input images"""
    if dataset.lower() == "cifar10":
        return 3
    if dataset.lower() == "tiny_imagenet":
        return 3

def get_num_labels(name):
    return 1000 if "imagenet" in name else 10

# set this environment variable to the location of your imagenet directory if you want to read ImageNet data.
# make sure your test directory is preprocessed to look like the train directory, e.g. by running this script
# https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/testprep.sh
IMAGENET_LOC_ENV = "IMAGENET_DIR"

# list of all datasets
DATASETS = ["cifar10", "imagenet", "tiny_imagenet"]

class PadAndShift(object):
    print("using data utils file!")
    def __init__(self, transform_params: dict = {}, split: str = "train"):
        # add parameters of the location of the image
        self.pad_size = transform_params["pad_size"] # one sided pad_size; hence total padding = 2 * pad_size
        self.num_image_locations = transform_params["num_image_locations"]
        self.background = transform_params["background"] # options = ["black", "nature"]

        # ideally should be able to adjust the dimensions of background image as welll but right now they are fixed to 48*48
        if self.background == "nature":
            print("using nature background!")
            if split == "train":
                # path to bg-20k train dataset images only -- of shape (48,48)
                print("enter model arch")
                self.random_bg_image_paths = glob.glob("data path for 20k bg train split")
            elif split == "test":
                # path to bg-20k test dataset images only -- of shape (48,48)
                self.random_bg_image_paths = glob.glob("data path for 20k bg test split")

    def __call__(self, image):
        # assume the image is of PIL form, hence first convert it to a numpy array
        image = np.array(image)

        h, w, _ = image.shape
        new_h = h + 2 * self.pad_size
        new_w = w + 2 * self.pad_size

        # generate the padded image based on chosen background
        if self.background == "black" or self.background == None:
            padded_image = np.zeros((new_h, new_w, 3), dtype=np.uint8)
        elif self.background == "nature":
            random_bg_image_path = random.choice(self.random_bg_image_paths)
                        
            bg_img = Image.open(random_bg_image_path)
            bg_img = bg_img.resize((new_w, new_h))
            padded_image = np.array(bg_img)

        # if shifting, choose random locations or static location if not
        if self.pad_size > 0:
            if self.num_image_locations == "1":
                # center location; statically place the image in the center
                choices = [(self.pad_size, self.pad_size)]
                x_prime, y_prime = random.choice(choices)
            elif self.num_image_locations == "2":
                # top left or bottom right
                choices = [(0,0),
                           (new_h - h, new_w - w)]
                x_prime, y_prime = random.choice(choices)
            elif self.num_image_locations == "4":
                # top left, top right, bottom left, bottom right
                choices = [(0, 0),
                           (new_h - h, 0),
                           (0, new_w - w),
                           (new_h - h, new_w - w)]
                x_prime, y_prime = random.choice(choices)
            elif self.num_image_locations == "8":
                # top left, top right, bottom left, bottom right,
                # center top, center right, center left, center bottom
                choices = [
                           (0, 0),
                           (new_h - h, 0),
                           (0, new_w - w),
                           (new_h - h, new_w - w),
                           (0, self.pad_size),
                           (new_h - h, self.pad_size),
                           (self.pad_size, 0),
                           (self.pad_size, new_w - w),
                          ]
                x_prime, y_prime = random.choice(choices)
            elif self.num_image_locations == "edges":
                a = random.choice([0, 2*self.pad_size])
                b = random.choice(np.arange(0, 2*self.pad_size))
                flip = np.random.binomial(1,0.5)
                if flip == 0:
                    x_prime, y_prime = a, b
                else:
                    x_prime, y_prime = b, a
            elif self.num_image_locations == "random":
                # place image randomly anywhere within the padded image
                x_prime, y_prime = np.random.randint(low=0, high=2*self.pad_size, size=2)
            else:
                raise Exception("Choose between 1,2,4 and random")
        else:
            x_prime, y_prime = 0, 0
        
        # put the image as per shifting choice
        if self.background in ["black", "nature", None]:
            padded_image[x_prime:x_prime + h, y_prime:y_prime + w, :] = image


        return padded_image

class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None, target_transform=None):
        
        if not isinstance(data, np.ndarray):
            data = np.array(data)

        if not isinstance(targets, np.ndarray):
            targets = np.array(targets)

        self.data = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):

        img, target = self.data[idx], self.targets[idx]

        if self.transform is not None:
            img = np.rint(img)
            img = img.astype(np.uint8)
            pil_img = Image.fromarray(img)
            img = self.transform(pil_img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

# For TinyImagenet
class TinyImageNetDataset(Dataset):
  def __init__(self, root_dir, mode='train', transform=None, target_transform=None):
    
    self.root_dir = root_dir
    self.mode = mode
    self.transform = transform
    self.target_transform = target_transform
    # self.IMAGE_SHAPE = (64, 64, 3)

    # read the csv file containing a list of images
    anno_file = os.path.join(root_dir, self.mode, self.mode+"_annotations.txt")
    self.image_table = pd.read_csv(anno_file, sep="\t", header=None)
 
    # add columns to csv file
    if self.mode == "train":
        self.image_table.columns = ["images","x1","y1","x2","y2"]
    elif self.mode == "test":
        self.image_table.columns = ["images","nid","x1","y1","x2","y2"]

    # fetch the class nid to id mapping
    with open(os.path.join(root_dir, "class_ids"), "rb") as f:
        self.nid_to_id_mapping = pickle.load(f)

  def __len__(self):
    return len(self.image_table)

  def __getitem__(self, idx):

    if torch.is_tensor(idx):
        idx = idx.tolist()

    record = self.image_table.iloc[idx]
    image_file = record["images"]

    # get class nid
    if self.mode == "train":
        class_nid = image_file.split("_")[0]
    elif self.mode == "test":
        class_nid = record["nid"]

    # get the image
    image_path = os.path.join(self.root_dir, self.mode, class_nid, "images", image_file)
    # img = imageio.imread(image_path)
    # img = Image.open(image_path)
    img = mpimg.imread(image_path)

    # get class label
    class_label = self.nid_to_id_mapping[class_nid]

    # img = np.array(img)
    # img_tensor = torch.from_numpy(img)
    # img_tensor = img_tensor.permute(2, 0, 1)
    class_label = torch.tensor(class_label)
            
    if self.transform is not None:
        # Convert array to Image
        img = Image.fromarray(img)
        img = self.transform(img)
    if self.target_transform is not None:
        class_label = self.target_transform(class_label)
    return img, class_label

def get_dataset(dataset: str, split: str, path: str, transform_params: dict = {}) -> Dataset:
    """Return the dataset as a PyTorch Dataset object"""
    if dataset.lower() == "imagenet":
        return _imagenet(split)
    elif dataset.lower() == "cifar10":
        return _cifar10(split, path, transform_params)
    elif dataset.lower() == "tiny_imagenet":
        return _tiny_imagenet(split, path)

# CIFAR-10
_CIFAR10_MEAN = [0.4914, 0.4822, 0.4465]
_CIFAR10_STDDEV = [0.2023, 0.1994, 0.2010]

# TINY IMAGENET
_TINY_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_TINY_IMAGENET_STDDEV = [0.229, 0.224, 0.225]

# IMAGENET
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_IMAGENET_STDDEV = [0.229, 0.224, 0.225]

def get_normalize_layer(dataset: str) -> torch.nn.Module:
    """Return the dataset's normalization layer"""
    if dataset.lower() == "cifar10":
        return NormalizeLayer(_CIFAR10_MEAN, _CIFAR10_STDDEV)
    elif dataset.lower() == "tiny_imagenet":
        return NormalizeLayer(_TINY_IMAGENET_MEAN, _TINY_IMAGENET_STDDEV)
    elif dataset.lower() == "imagenet":
        return NormalizeLayer(_IMAGENET_MEAN, _IMAGENET_STDDEV)

# TINY_IMAGENET
TINYIMAGENET_TRAIN_TRANSFORM = transforms.Compose([            
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
])
TINYIMAGENET_TEST_TRANSFORM = transforms.ToTensor()

# CIFAR10
# CIFAR_TRAIN_TRANSFORM = transforms.Compose([
#             # transforms.RandomCrop(32, padding=4),
#             # transforms.RandomHorizontalFlip(),
#             transforms.ToTensor()
#         ])
# CIFAR_TEST_TRANSFORM = transforms.ToTensor()

# CIFAR_TRAIN_TRANSFORM = transforms.Compose([
#             # transforms.CenterCrop(crop_size),
#             transforms.Pad(pad_size, fill=0, padding_mode="constant"),
#             transforms.ToTensor()
#         ])

# CIFAR_TEST_TRANSFORM = transforms.Compose([
#             # transforms.CenterCrop(crop_size),
#             transforms.Pad(pad_size, fill=0, padding_mode="constant"),
#             transforms.ToTensor()
#         ])

# def _cifar10(split: str, path: str, transform_params: dict) -> Dataset:

#     if split.lower() == "train":
#         train_bool = True
#         CIFAR_TRAIN_TRANSFORM = transforms.Compose([
#             PadAndShift(transform_params), 
#             transforms.ToTensor(),
#         ])
#         dataset = datasets.CIFAR10(path, train=train_bool, download=True, transform=CIFAR_TRAIN_TRANSFORM)
#     elif split.lower() == "test":
#         train_bool = False
#         CIFAR_TEST_TRANSFORM = transforms.Compose([
#             PadAndShift(transform_params), 
#             transforms.ToTensor()
#         ])
#         dataset = datasets.CIFAR10(path, train=train_bool, download=True, transform=CIFAR_TEST_TRANSFORM)

#     return dataset

def _cifar10(split: str, path: str, transform_params: dict) -> Dataset:

    if split.lower() == "train":
        train_bool = True
        CIFAR_TRAIN_TRANSFORM = transforms.Compose([
            PadAndShift(transform_params, split=split), # the split param is only required for bg-20k nature bg dataset
            transforms.ToTensor(),
        ])
        dataset = datasets.CIFAR10(path, train=train_bool, download=True, transform=CIFAR_TRAIN_TRANSFORM)
    elif split.lower() == "test":
        train_bool = False
        CIFAR_TEST_TRANSFORM = transforms.Compose([
            PadAndShift(transform_params, split=split), # the split param is only required for bg-20k nature bg dataset
            transforms.ToTensor()
        ])
        dataset = datasets.CIFAR10(path, train=train_bool, download=True, transform=CIFAR_TEST_TRANSFORM)

    return dataset

def _tiny_imagenet(split: str, path: str) -> Dataset:
    """
    """
    if split.lower() == "train":
        transform = TINYIMAGENET_TRAIN_TRANSFORM
    elif split.lower() == "test":
        transform = TINYIMAGENET_TEST_TRANSFORM

    # data, targets = get_tiny_imagenet_data(root_dir=path, mode=split)
    # dataset = CustomDataset(data, targets, transform=transform)

    dataset = TinyImageNetDataset(root_dir=path,
                                  mode=split,
                                  transform=transform,
                                  target_transform=None)

    return dataset

def _imagenet(split: str) -> Dataset:
    if not IMAGENET_LOC_ENV in os.environ:
        raise RuntimeError("environment variable for ImageNet directory not set")

    dir = os.environ[IMAGENET_LOC_ENV]
    if split == "train":
        subdir = os.path.join(dir, "train")
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
    elif split == "test":
        subdir = os.path.join(dir, "val")
        transform = transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
    return datasets.ImageFolder(subdir, transform)

class AverageLayer(torch.nn.Module):
    """average images across pixels
    """

    def __init__(self):
        """
        :param kernel_avg: 
        :param stride_avg:
        """
        super(AverageLayer, self).__init__()

    def forward(self, kernel_avg: int, stride_avg: int, inputs: torch.tensor):

        self.kernel_avg = kernel_avg
        self.stride_avg = stride_avg
        average_layer = torch.nn.AvgPool2d(kernel_size=kernel_avg, stride=stride_avg)

        # check if you need to permute any order of data dimensions

        inputs = average_layer(inputs)

        return inputs

class NormalizeLayer(torch.nn.Module):
    """Standardize the channels of a batch of images by subtracting the dataset mean
      and dividing by the dataset standard deviation.

      In order to certify radii in original coordinates rather than standardized coordinates, we
      add the Gaussian noise _before_ standardizing, which is why we have standardization be the first
      layer of the classifier rather than as a part of preprocessing as is typical.
      """

    def __init__(self, means: List[float], sds: List[float]):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        """
        :param means: the channel means
        :param sds: the channel standard deviations
        """
        super(NormalizeLayer, self).__init__()
        self.means = torch.tensor(means).to(device)
        self.sds = torch.tensor(sds).to(device)

    def forward(self, input: torch.tensor):
        (batch_size, num_channels, height, width) = input.shape
        means = self.means.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)
        sds = self.sds.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)
        return (input - means) / sds

class PreProcessLayer(torch.nn.Module):
    """Transforms the input with the following transformations
    1) RandomHorizontalFlip
    """
   
    def __init__(self, prob_flip=0.5):
        """
        :param prob_flip: prob with which it is flipped
        """
        super(PreProcessLayer, self).__init__()
        self.pre_transforms = transforms.Compose([
                            transforms.RandomHorizontalFlip(p=prob_flip)
                        ])

    def forward(self, input: torch.tensor):
        input = self.pre_transforms(input)
        return input

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

