################################################################################
# spectral/modules/nonlocalblock.py
#
# 
# 
# 
# 2024
#
# Implementation of a modified asymmetric pyramid non-local block [1]:
# https://arxiv.org/abs/1908.07678

import torch

from typing import Any, Callable, List

Tensor = torch.Tensor
Module = torch.nn.Module

class NonLocalBlock(Module):
  """
  Asymmetric non-local convolutional block as described in [1].
  """

  def __init__(self,
      # Arguments:
      width:    int,
      channels: int,
      *args,
      # Keyword Arguments:
      pool_sizes: List[int] = [1],
      pool_type:  str       = "max",
      **kwargs
    ):
    """
    Initializes ``NonLocalBlock``.

    Args:
      width (int):
        The number of input channels.
      channels (int):
        The number of attention channels.
      *args:
        Additional arguments for ``torch.nn.Module``.
      pool_sizes (list[int], optional):
        The sizes of the pools to use for pyramidal processing.
        Defaults to ``[1]``.
      pool_type (str, optional):
        The type of pool to perform, either ``"max"`` or ``"average"``.
        Defaults to ``"max"``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(NonLocalBlock, self).__init__(*args, **kwargs)
    self.channels = channels
    if pool_type == "average":
      pool = torch.nn.AdaptiveAvgPool2d
    else:
      pool = torch.nn.AdaptiveMaxPool2d
    self.pool_sizes = list(set(pool_sizes))
    self.conv_theta = torch.nn.Conv2d(
      in_channels  = width,
      out_channels = channels,
      kernel_size  = (1, 1),
      stride       = (1, 1),
      bias         = False
    )
    self.conv_phi = torch.nn.Conv2d(
      in_channels  = width,
      out_channels = channels,
      kernel_size  = (1, 1),
      stride       = (1, 1),
      bias         = False
    )
    self.conv_g = torch.nn.Conv2d(
      in_channels  = width,
      out_channels = channels,
      kernel_size  = (1, 1),
      stride       = (1, 1),
      bias         = False
    )
    self.softmax = torch.nn.Softmax(dim = -1)
    for s in self.pool_sizes:
      setattr(self, f"pool{s}", pool((s, s)))
  
  def _perform_pool(self,
      # Arguments:
      x: Tensor
    ) -> Tensor:
    """
    Helper function that performs each separate pool, then concatenates the
    results.

    Args:
      x (Tensor):
        Input for pooling.

    Returns:
      Tensor:
        Result of pooled operations.
    """
    B = x.size(0)
    x_pool = [getattr(self, f"pool{s}")(x) for s in self.pool_sizes]
    x_cat = torch.cat(
      [p.view(B, self.channels, -1) for p in x_pool],
      dim = -1
    )
    return x_cat
  
  def forward(self,
      # Arguments:
      x: Tensor,
    ) -> Tensor:
    """
    Forward function for ``NonLocalBlock``.

    Args:
      x (Tensor):
        Input tensor.

    Returns:
      Tensor:
        Output tensor - specifically the attention channels.
    """
    B = x.size(0)
    W = x.size(-1)
    theta_x = self._perform_pool(self.conv_theta(x))
    phi_x   = self.conv_phi(x).view(B, self.channels, -1).transpose(2, 1)
    g_x     = self._perform_pool(self.conv_g(x)).transpose(2, 1)
    v       = self.softmax(torch.bmm(phi_x, theta_x))
    y       = torch.bmm(v, g_x).transpose(2, 1).contiguous()
    y       = y.view(B, self.channels, -1, W)
    y       = torch.nn.functional.normalize(y, dim = 1)
    return y