################################################################################
# spectral/modules/spectral_bound.py
#
# 
# 
# 
# 2024
#
# Implements a modified version of the spectral normalization parametrization.
# In this case, the spectral norm is bound.

import torch

from typing import Optional

Tensor = torch.Tensor
Module = torch.nn.Module

class _BoundSpectralNorm(torch.nn.utils.parametrizations._SpectralNorm):

  def __init__(self,
      # Arguments:
      weight:      Tensor,
      lower_bound: float,
      # Keyword Arguments:
      n_power_iterations: int = 1,
      dim:                int = 0,
      eps:                float = 1e-12
    ):
    super(_BoundSpectralNorm, self).__init__(
      weight,
      n_power_iterations = n_power_iterations,
      dim                = dim,
      eps                = eps
    )
    self.lower_bound = lower_bound
    self._enabled = True

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    if weight.ndim == 1:
      sigma = torch.linalg.norm(weight)
      return torch.clamp(sigma, self.lower_bound, 1.0) * \
             (weight / (sigma + self.eps))
    else:
      weight_mat = self._reshape_weight_to_matrix(weight)
      if self.training and self._enabled:
        self._power_method(weight_mat, self.n_power_iterations)
      u = self._u.clone(memory_format = torch.contiguous_format)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.dot(u, torch.mv(weight_mat, v))
      return torch.clamp(sigma, self.lower_bound, 1.0) * (weight / sigma)

def bound_spectral_norm(
    # Arguments:
    module:      Module,
    lower_bound: float,
    # Keyword Arguments:
    name:                str           = "weight",
    n_power_iterations:  int           = 1,
    eps:                 float         = 1e-12,
    get_parametrization: bool          = False,
    dim:                 Optional[int] = None
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if dim is None:
    is_conv_transpose = isinstance(
      module,
      (
        torch.nn.ConvTranspose1d,
        torch.nn.ConvTranspose2d,
        torch.nn.ConvTranspose3d
      )
    )
    dim = 1 if is_conv_transpose else 0
  parametrization = _BoundSpectralNorm(
    weight,
    lower_bound,
    n_power_iterations,
    dim,
    eps
  )
  torch.nn.utils.parametrize.register_parametrization(
    module,
    name,
    parametrization
  )
  return (module, parametrization) if get_parametrization else module