################################################################################
# spectral/modules/lipschitz_regularizer.py
#
# 
# 
# 
# 2024
#
# Implements a regularizer that can update a power iterator for calculating the
# Lipschitz constant of a layer, then allows that value to be returned.

import torch

from typing import Callable, List, Tuple

Tensor = torch.Tensor
Module = torch.nn.Module

CONV_UPDATERS = {
  torch.nn.Conv1d: (
    torch.nn.functional.conv1d,
    torch.nn.functional.conv_transpose1d
  ),
  torch.nn.Conv2d: (
    torch.nn.functional.conv2d,
    torch.nn.functional.conv_transpose2d
  ),
  torch.nn.Conv3d: (
    torch.nn.functional.conv3d,
    torch.nn.functional.conv_transpose3d
  )
}

CONV_TRANSPOSE_MAP = {
  torch.nn.ConvTranspose1d: torch.nn.Conv1d,
  torch.nn.ConvTranspose2d: torch.nn.Conv2d,
  torch.nn.ConvTranspose3d: torch.nn.Conv3d
}

class _LipschitzRegularizer(Module):

  def __init__(self,
      # Arguments:
      module:           Module,
      name:             str,
      forward_updater:  Callable[[Tensor, Tensor], Tensor],
      backward_updater: Callable[[Tensor, Tensor], Tensor],
      u_size:           Tuple[int, ...],
      v_size:           Tuple[int, ...],
      # Keyword Arguments:
      n_power_iterations: int   = 1,
      eps:                float = 1e-12,
      distributed:        bool  = False
    ):
    super(_LipschitzRegularizer, self).__init__()
    self.module = module
    self.name = name
    assert n_power_iterations > 0, \
      "n_power_iterations must be greater than 0."
    self.n_power_iterations = n_power_iterations
    self.eps = eps
    self._forward_updater = forward_updater
    self._backward_updater = backward_updater
    u, v = torch.randn(u_size), torch.randn(v_size)
    u, v = self._normalize(u, u), self._normalize(v, v)
    self.register_buffer("_u", u)
    self.register_buffer("_v", v)
    self._power_method(15)

  @torch.autograd.no_grad()
  def _normalize(self,
      # Arguments:
      x:   Tensor,
      out: Tensor,
    ) -> Tensor:
    """
    Based on implementation of ``torch.nn.functional.normalize``, but allows for
    unspecified dims.

    Args:
      x (Tensor):
        Input tensor to normalize.
      out (Tensor):
        Output tensor.
    """
    denominator = torch.linalg.norm(x, keepdim = True).clamp_min_(self.eps)
    return torch.div(x, denominator.expand_as(x), out = out)

  def _get_weight(self) -> Tensor:
    """
    Gets the weight from the module the regularizer acts for.

    Returns:
      Tensor:
        The weight of the module.
    """
    return getattr(self.module, self.name)

  @torch.autograd.no_grad()
  def _power_method(self,
      n_power_iterations: int
    ) -> None:
    weight = self._get_weight()
    for _ in range(n_power_iterations):
      self._u = self._normalize(
        self._forward_updater(weight, self._v),
        self._u
      )
      self._v = self._normalize(
        self._backward_updater(weight, self._u),
        self._v
      )

  def forward(self) -> Tensor:
    weight = self._get_weight()
    if weight.ndim == 1:
      return torch.linalg.norm(weight, ord = 2)
    else:
      if self.training:
        self._power_method(self.n_power_iterations)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.linalg.norm(self._forward_updater(weight, v))
      return sigma

def calculate_sizes(
    kernel_size: List[int],
    padding:     List[int],
    stride:      List[int]
  ) -> List[int]:
  # Temporarily ignore stride.
  sizes = [
    -2 * p + 2 * k + 1
    for k, p, s in zip(kernel_size, padding, stride)
  ]
  return sizes

def lipschitz_regularizer(
    # Arguments:
    module: Module,
    # Keyword Arguments:
    name:               str   = "weight",
    n_power_iterations: int   = 1,
    eps:                float = 1e-12
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if isinstance(module, torch.nn.Linear):
    kind = "linear"
    forward_updater  = lambda w, x: torch.mv(w,     x)
    backward_updater = lambda w, x: torch.mv(w.t(), x)
    u_size, v_size = weight.size()
  elif CONV_TRANSPOSE_MAP.get(module.__class__) is not None:
    kind = "conv"
    bward, fward = CONV_UPDATERS[CONV_TRANSPOSE_MAP[module.__class__]]
  elif CONV_UPDATERS.get(module.__class__) is not None:
    kind = "conv"
    fward, bward = CONV_UPDATERS[module.__class__]
  else:
    raise Exception(f"Unsupported module '{module}'.")
  if kind == "conv":
    sizes = calculate_sizes(
      list(module.kernel_size),
      list(module.padding),
      list(module.stride)
    )
    u_size = tuple([1, module.out_channels] + sizes)
    v_size = tuple([1, module.in_channels ] + sizes)
    forward_updater = lambda w, x: fward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
    backward_updater = lambda w, x: bward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
  regularizer = _LipschitzRegularizer(
    module,
    name,
    forward_updater,
    backward_updater,
    u_size,
    v_size,
    n_power_iterations,
    eps
  )
  return regularizer

# https://github.com/henrygouk/keras-lipschitz-networks/blob/master/arch/lipschitz.py