################################################################################
# training/utils/validation.py
#
# 
# 
# 2023
#
# Implementations of the simple trainer with validation.

from collections.abc import Callable

import torch

from .utils.reduction_tools import *
from .utils.type_checking   import *
from .simple                import SimpleTrainer

DataLoader = torch.utils.data.DataLoader
Module     = torch.nn.Module
Optimizer  = torch.optim.Optimizer
Tensor     = torch.Tensor

# Verbosity enumerable:
# No output.
SILENT: int = 0
# Only epoch results.
QUIET: int = 1
# Epoch progress bars and live results.
FULL: int = 2

class ValidationTrainer(SimpleTrainer):
  """
  Implementation of ``SimpleTrainer`` which uses a validation dataset at the
  end of an epoch.
  """

  def __init__(self,
      # Arguments:
      validation_dataloader: DataLoader,
      *args,
      # Keyword Arguments:
      validation_loss_fn: Callable[[Tensor, Tensor], Tensor] = None,
      **kwargs
    ):
    """
    Initializes ``ValidationTrainer``.

    Args:
      validation_dataloader (DataLoader):
        The dataloader for the validation dataset.
      validation_loss_fn (Callable[[Tensor, Tensor], Tensor], optional):
        The loss function to use on the validation dataset.
        Defaults to ``None`` (uses ``loss_fn``).
    """
    super(ValidationTrainer, self).__init__(*args, **kwargs)
    check_if_type_or_none(
      validation_dataloader, DataLoader, "validation_dataloader"
    )
    self.validation_dataloader = validation_dataloader
    # Assume validation_loss_fn is callable.
    if validation_loss_fn is None:
      self.validation_loss_fn = self.loss_fn
    else:
      self.validation_loss_fn = validation_loss_fn

  def on_epoch_end(self):
    # Run inherited on_epoch_end first.
    super(ValidationTrainer, self).on_epoch_end()
    # Create a running average for validation.
    self._validation_average = RunningAverage()
    # Ensure model isn't training and isn't calculating gradients.
    self.model.eval()
    with torch.no_grad():
      # Process the validation batch.
      for validation_batch in self.validation_dataloader:
        input_batch, target_batch = validation_batch
        input_batch  = input_batch.to(self.device)
        target_batch = target_batch.to(self.device)
        predicted_batch = self.model(input_batch)
        loss = self.validation_loss_fn(predicted_batch, target_batch).item()
        # Add the validation loss to the running average.
        self._validation_average + loss
    self.model.train()
    # Print on the line below the output.
    if self.verbose >= QUIET:
      print(
        len(f"Epoch {self._current_epoch} ") * " " + "(valid.):  " + \
        self.format_loss_string % self._validation_average()
      )