import cv2
import inspect
import numpy as np
from PIL import Image, ImageFilter


import torch
from torchvision import transforms as _transforms
from torchvision.transforms import functional

from openselfsup.utils import build_from_cfg

from ..registry import PIPELINES

# register all existing transforms in torchvision
_EXCLUDED_TRANSFORMS = ['GaussianBlur']
for m in inspect.getmembers(_transforms, inspect.isclass):
    if m[0] not in _EXCLUDED_TRANSFORMS:
        PIPELINES.register_module(m[1])


@PIPELINES.register_module
class RandomAppliedTrans(torch.nn.Module):
    """Randomly applied transformations.

    Args:
        transforms (list[dict]): List of transformations in dictionaries.
        p (float): Probability.
    """

    def __init__(self, transforms, p=0.5):
        super().__init__()
        t = [build_from_cfg(t, PIPELINES) for t in transforms]
        self.trans = _transforms.RandomApply(t, p=p)

    def __call__(self, img):
        return self.trans(img)

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module
class MultiTransform(torch.nn.Module):
    def __init__(
            self, num_imgs_one_sp, transforms):
        super().__init__()
        self.num_imgs_one_sp = num_imgs_one_sp
        t = [build_from_cfg(t, PIPELINES) for t in transforms]
        self.transform = _transforms.Compose(t)

    def forward(self, img):
        for sta_idx in range(0, img.size(0), self.num_imgs_one_sp):
            end_idx = min(sta_idx + self.num_imgs_one_sp, img.size(0))
            img[sta_idx:end_idx] = self.transform(img[sta_idx:end_idx])
        return img


# custom transforms
@PIPELINES.register_module
class Lighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)."""

    _IMAGENET_PCA = {
        'eigval':
        torch.Tensor([0.2175, 0.0188, 0.0045]),
        'eigvec':
        torch.Tensor([
            [-0.5675, 0.7192, 0.4009],
            [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203],
        ])
    }

    def __init__(self):
        self.alphastd = 0.1
        self.eigval = self._IMAGENET_PCA['eigval']
        self.eigvec = self._IMAGENET_PCA['eigvec']

    def __call__(self, img):
        assert isinstance(img, torch.Tensor), \
            "Expect torch.Tensor, got {}".format(type(img))
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone()\
            .mul(alpha.view(1, 3).expand(3, 3))\
            .mul(self.eigval.view(1, 3).expand(3, 3))\
            .sum(1).squeeze()

        return img.add(rgb.view(3, 1, 1).expand_as(img))

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709."""
    def __init__(self, sigma_min, sigma_max, **kwargs):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def __call__(self, img):
        sigma = np.random.uniform(self.sigma_min, self.sigma_max)
        img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
        return img

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module
class GaussianBlurTorch(_transforms.GaussianBlur):
    pass


@PIPELINES.register_module
class Solarization(object):
    """Solarization augmentation in BYOL https://arxiv.org/abs/2006.07733."""

    def __init__(self, threshold=128):
        self.threshold = threshold

    def __call__(self, img):
        img = np.array(img)
        img = np.where(img < self.threshold, img, 255 -img)
        return Image.fromarray(img.astype(np.uint8))

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module
class ResizeCenterPad(object):
    """Random Resize then pad with constant value"""
    def __init__(
            self, min_edge=50, pad_value=127,
            interpolation=Image.BILINEAR,
            ):
        self.min_edge = 50
        self.interpolation = interpolation
        self.pad_value = pad_value

    def __call__(self, img):
        img_size = img.size[-1]
        rnd_size = np.random.randint(self.min_edge, img_size)
        img = functional.resize(img, rnd_size, self.interpolation)
        left = (img_size - rnd_size) // 2
        right = (img_size - rnd_size) - left
        img = functional.pad(
                img, (left, left, right, right),
                fill=self.pad_value)
        return img

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str
