################################################################################
# training/utils/trainer.py
#
# 
# 
# 2023
#
# High-level abstract training environment for models.
# [Object orientated programming makes zero sense to me as a functional
#  programmer, so this is the best I can do.]

from abc import ABC, abstractmethod

class Trainer(ABC):
  """
  Abstract class that defines a high-level trainer for PyTorch ``Module``s.
  """

  @abstractmethod
  def on_train_start(self):
    """
    Runs once at the beginning of training, before the loop.
    """
    pass

  @abstractmethod
  def on_epoch_start(self):
    """
    Runs once per epoch before any batches have been processed.
    """
    pass

  @abstractmethod
  def on_batch_start(self):
    """
    Runs at the start of each batch.
    """
    pass

  @abstractmethod
  def process_training_batch(self):
    """
    Runs after ``on_batch_start``, but before ``optimization_step``, ideally
    computing losses.
    """
    pass

  @abstractmethod
  def optimization_step(self):
    """
    Runs after ``process_training_batch``, ideally performing an optimizer step.
    """
    pass

  @abstractmethod
  def on_batch_end(self):
    """
    Runs after ``on_batch_end``, at the end of a batch.
    """
    pass

  @abstractmethod
  def on_epoch_end(self):
    """
    Runs after every batch of the current epoch has finished.
    """
    pass

  @abstractmethod
  def on_train_end(self):
    """
    Runs once, after the training loop has completed (or been broken).
    """
    pass

  @abstractmethod
  def _full_batch(self):
    """
    Runs all the batch processes. Implemented in ``SimpleTrainer``.
    Only modify if absolutely necessary.
    """
    pass

  @abstractmethod
  def _full_epoch(self):
    """
    Runs all the epoch processes. Implemented in ``SimpleTrainer``.
    Only modify if absolutely necessary.
    """
    pass

  @abstractmethod
  def _full_train(self):
    """
    Runs all the training processes. Implemented in ``SimpleTrainer``.
    Only modify if absolutely necessary.
    """
    pass

  @abstractmethod
  def __call__(self):
    """
    Calls on ``Trainer`` types should run the training process from start to
    end.
    """
    pass