################################################################################
# spectral/modules/warmup.py
#
# Original: https://github.com/ArneNx/pytorch_warmup/tree/warmup_fix
# Modified: July 2020 in https://github.com/aks2203/deep-thinking
#
# Adapted for this project (mostly stylistic changes) by...
# 
# 
# 
# 2024
#

import math

from torch.optim import Optimizer

class BaseWarmup():
  """
  Base class for all warmup schedules
  """

  def __init__(self,
      # Arguments:
      optimizer,
      warmup_params,
      # Keyword Arguments:
      last_step     = -1,
      warmup_period = 0
    ):
    """
    Initializes ``BaseWarmup``.

    Args:
      optimizer (Optimizer):
        An instance of a subclass of Optimizer.
      warmup_params (list):
        Warmup parameters.
      last_step (int, optional):
        The index of last step.
        Defaults to ``-1``.
      warmup_period (int | list, optional):
        Warmup period.
        Defaults to ``0``.
    """
    if not isinstance(optimizer, Optimizer):
      raise TypeError(f"{type(optimizer).__name__} is not an Optimizer.")
    self.optimizer     = optimizer
    self.warmup_params = warmup_params
    self.last_step     = last_step
    self.base_lrs      = [group["lr"] for group in self.optimizer.param_groups]
    self.warmup_period = warmup_period
    self.dampen()

  def state_dict(self):
    """
    Returns the state of the warmup scheduler as a dictionary. It contains an
    entry for every variable in ``self.__dict__`` which is not the optimizer.
    """
    return {
      key: value for key, value in self.__dict__.items() if key != "optimizer"
    }

  def load_state_dict(self,
      # Arguments:
      state_dict
    ):
    """
    Loads the warmup scheduler's state.

    Args:
      state_dict (dict):
        Warmup scheduler state. Should be an object returned from a call to
        ``state_dict``.
    """
    self.__dict__.update(state_dict)

  def dampen(self,
      # Keyword Arguments:
      step = None
    ):
    """
    Dampen the learning rates.

    Args:
      step (int, optional):
      The index of current step.
      Defaults to ``None``.
    """
    if step is None:
      step = self.last_step + 1
    self.last_step = step
    if isinstance(self.warmup_period, int) and step < self.warmup_period:
      z = zip(self.optimizer.param_groups, self.warmup_params)
      for i, (group, params) in enumerate(z):
        is_list = isinstance(self.warmup_period, list)
        if is_list and step >= self.warmup_period[i]:
          continue
        omega = self.warmup_factor(step, **params)
        group["lr"] = omega * self.base_lrs[i]

  def warmup_factor(self,
      # Arguments:
      step,
      warmup_period
    ):
    """
    Placeholder for objects that inherit BaseWarmup.
    """
    raise NotImplementedError

def get_warmup_params(
    # Arguments:
    warmup_period,
    group_count
  ):
  if type(warmup_period) == list:
    if len(warmup_period) != group_count:
      raise ValueError(f"Size of warmup_period does not equal {group_count}.")
    for x in warmup_period:
      if type(x) != int:
        raise ValueError(
          f"An element in warmup_period, {type(x).__name__}, is not an int."
        )
    warmup_params = [dict(warmup_period=x) for x in warmup_period]
  elif type(warmup_period) == int:
    warmup_params = [
      dict(warmup_period=warmup_period) for _ in range(group_count)
    ]
  else:
    raise TypeError(f"{type(warmup_period).__name__} is not a list nor an int.")
  return warmup_params

class LinearWarmup(BaseWarmup):
  """
  Linear warmup schedule.
  """

  def __init__(self,
      # Arguments:
      optimizer,
      warmup_period,
      last_step = -1
      # Keyword Arguments:
    ):
    """
    Initializes ``LinearWarmup``.

    Args:
      optimizer (Optimizer):
        An instance of a subclass of Optimizer.
      warmup_period (int | list):
        Warmup period.
      last_step (int, optional):
        The index of last step.
        Defaults to ``-1``.
    """
    group_count   = len(optimizer.param_groups)
    warmup_params = get_warmup_params(warmup_period, group_count)
    super().__init__(optimizer, warmup_params, last_step, warmup_period)

  def warmup_factor(self,
      # Arguments:
      step,
      warmup_period
    ):
    return min(1.0, (step + 1) / warmup_period)

class ExponentialWarmup(BaseWarmup):
  """
  Exponential warmup schedule.
  """

  def __init__(self,
      optimizer,
      warmup_period,
      last_step = -1
    ):
    """
    Initializes ``ExponentialWarmup``.

    Arguments:
      optimizer (Optimizer):
        An instance of a subclass of Optimizer.
      warmup_period (int | list):
        Effective warmup period.
      last_step (int, optional):
        The index of last step.
        Defaults to ``-1``.
    """
    group_count   = len(optimizer.param_groups)
    warmup_params = get_warmup_params(warmup_period, group_count)
    super().__init__(optimizer, warmup_params, last_step, warmup_period)

  def warmup_factor(self,
      # Arguments:
      step,
      warmup_period
    ):
    if step + 1 >= warmup_period:
      return 1.0
    else:
      return 1.0 - math.exp(-(step + 1) / warmup_period)