################################################################################
# spectral/modules/spectral_constraint.py
#
# 
# 
# 
# 2024
#
# Implements spectral normalization, but with some modifications.

import torch

from typing import Optional

Tensor = torch.Tensor
Module = torch.nn.Module

class _SpectralNormConstraint(torch.nn.utils.parametrizations._SpectralNorm):

  def __init__(self,
      # Arguments:
      weight:      Tensor,
      # Keyword Arguments:
      n_power_iterations: int = 1,
      dim:                int = 0,
      eps:                float = 1e-12
    ):
    super(_SpectralNormConstraint, self).__init__(
      weight,
      n_power_iterations = n_power_iterations,
      dim                = dim,
      eps                = eps
    )
    self._update_enabled = True

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    if weight.ndim == 1:
      sigma = torch.linalg.norm(weight)
      return weight / (sigma + self.eps)
    else:
      weight_mat = self._reshape_weight_to_matrix(weight)
      if self.training and self._update_enabled:
        self._power_method(weight_mat, self.n_power_iterations)
        self._update_enabled = False
      u = self._u.clone(memory_format = torch.contiguous_format)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.dot(u, torch.mv(weight_mat, v))
      return weight / (sigma + self.eps)

def spectral_norm_constraint(
    # Arguments:
    module: Module,
    # Keyword Arguments:
    name:               str           = "weight",
    n_power_iterations: int           = 1,
    eps:                float         = 1e-12,
    dim:                Optional[int] = None
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if dim is None:
    is_conv_transpose = isinstance(
      module,
      (
        torch.nn.ConvTranspose1d,
        torch.nn.ConvTranspose2d,
        torch.nn.ConvTranspose3d
      )
    )
    dim = 1 if is_conv_transpose else 0
  parametrization = _SpectralNormConstraint(
    weight,
    n_power_iterations,
    dim,
    eps
  )
  torch.nn.utils.parametrize.register_parametrization(
    module,
    name,
    parametrization
  )
  return parametrization