import torch
import math
import torch.nn as nn
from math import sqrt
import numpy as np


class ALOFT(nn.Module):
    """
    Frequency Distribution Uncertainty Module
        Args:
        p   (float): probabilty of foward distribution uncertainty module, p in [0,1].
    """

    def __init__(self, p=0.5, eps=1e-6, mask_size=0.5, factor=0.8, mask_or_model=0):
        super(ALOFT, self).__init__()
        self.eps = eps
        self.p = p

        self.mask_size = mask_size
        self.factor = factor
        self.mask_or_model = mask_or_model

    def forward(self, img_fft):
        if (not self.training) or (np.random.random()) > self.p:
            return img_fft

        # img = img.to(torch.float32)
        # img_fft = torch.fft.fft2(img, dim=(2, 3), norm='ortho')
        B, C, h_fft, w_fft = img_fft.shape
        img_abs, img_pha = torch.abs(img_fft), torch.angle(img_fft)

        h_crop = int(h_fft * sqrt(self.mask_size))
        w_crop = int(w_fft * sqrt(self.mask_size))
        h_start = h_fft // 2 - h_crop // 2
        # w_start = w_fft // 2 - w_crop // 2
        w_start = 0

        img_abs = torch.fft.fftshift(img_abs, dim=(2, ))
        img_abs_ = img_abs.clone()

        if self.mask_or_model == 0:
            masks = torch.ones_like(img_abs)
            masks[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 0
            img_abs = img_abs_ * masks.cuda()
            freq_avg = torch.mean(img_abs_[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop],
                                  dim=(1, 2, 3), keepdim=True)  # Bx1x1x1
            freq_avg_mask = torch.zeros_like(img_abs_)
            freq_avg_mask[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 1
            freq_avg_mask = freq_avg * freq_avg_mask.cuda()
            img_abs += freq_avg_mask
        else:
            var_of_elem = torch.var(img_abs_[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop], dim=0,
                                    keepdim=True)
            sig_of_elem = (var_of_elem + 1e-6).sqrt()  # 1xHxWxC

            epsilon_sig = torch.randn_like(
                img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop])  # BxHxWxC N(0,1)
            gamma = epsilon_sig * sig_of_elem * self.factor

            img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = \
                img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] + gamma

        img_abs = torch.fft.ifftshift(img_abs, dim=(2, ))
        img_stylized = img_abs * (np.e ** (1j * img_pha))

        return img_stylized


class ALOFT_image(nn.Module):
    """
    Frequency Distribution Uncertainty Module
        Args:
        p   (float): probabilty of foward distribution uncertainty module, p in [0,1].
    """

    def __init__(self, p=0.5, eps=1e-6, mask_size=0.5, factor=0.8, mask_or_model=0):
        super(ALOFT_image, self).__init__()
        self.eps = eps
        self.p = p

        self.mask_size = mask_size
        self.factor = factor
        self.mask_or_model = mask_or_model

    def forward(self, img):
        if (not self.training) or (np.random.random()) > self.p:
            return img

        # img: B K C L
        B, K, C, L = img.shape
        h = w = int(sqrt(L))
        img = img.view(B, K, C, h, w)

        img = img.to(torch.float32)
        img_fft = torch.fft.rfft2(img, dim=(3, 4), norm='ortho')

        B, K, C, h_fft, w_fft = img_fft.shape
        img_abs, img_pha = torch.abs(img_fft), torch.angle(img_fft)

        h_crop = int(h_fft * sqrt(self.mask_size))
        w_crop = int(w_fft * sqrt(self.mask_size))
        h_start = h_fft // 2 - h_crop // 2
        # w_start = w_fft // 2 - w_crop // 2
        w_start = 0

        img_abs = torch.fft.fftshift(img_abs, dim=(3, ))
        img_abs_ = img_abs.clone()

        if self.mask_or_model == 0:
            masks = torch.ones_like(img_abs)
            masks[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 0
            img_abs = img_abs_ * masks.cuda()
            freq_avg = torch.mean(img_abs_[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop],
                                  dim=(2, 3, 4), keepdim=True)  # Bx1x1x1
            freq_avg_mask = torch.zeros_like(img_abs_)
            freq_avg_mask[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 1
            freq_avg_mask = freq_avg * freq_avg_mask.cuda()
            img_abs += freq_avg_mask
        else:
            var_of_elem = torch.var(img_abs_[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop], dim=0,
                                    keepdim=True)
            sig_of_elem = (var_of_elem + 1e-6).sqrt()  # 1xHxWxC

            epsilon_sig = torch.randn_like(
                img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop])  # BxKxHxWxC N(0,1)
            gamma = epsilon_sig * sig_of_elem * self.factor

            img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = \
                img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] + gamma

        img_abs = torch.fft.ifftshift(img_abs, dim=(3, ))
        img_stylized = img_abs * (np.e ** (1j * img_pha))

        img_stylized = torch.fft.irfft2(img_stylized, s=(h, w), dim=(3, 4), norm='ortho')
        img_stylized = img_stylized.view(B, K, C, -1)
        return img_stylized



