################################################################################
# spectral/modules/trainer.py
#
# 
# 
# 
# 2024
#
#

import torch

from os.path  import isdir, isfile, join
from typing   import Optional
from warnings import warn

from .warmup                   import BaseWarmup
from choochoo.training.logging import LoggingTrainer
from experilog.logger          import JSONType

from visual.temptsp import TSPDrawer
from PIL import ImageFont
from modules.tsp.loss import nodes_to_edges
from shutil import copy as copyfile

import numpy as np

TRAINING:   str = "training"
VALIDATION: str = "validation"

Scheduler = torch.optim.lr_scheduler._LRScheduler

class DeepThinkingTrainer(LoggingTrainer):

  def __init__(self,
      # Arguments:
      config: JSONType,
      *args,
      # Keyword Arguments:
      clip:               Optional[float]      = None,
      clip_foreach:       Optional[bool]       = None,
      save_best:          Optional[str]        = None,
      save_dir:           Optional[str]        = None,
      scheduler:          Optional[Scheduler]  = None,
      scheduler_callback: Optional[str]        = None,
      warmup:             Optional[BaseWarmup] = None,
      name:               str                  = "model",
      **kwargs
    ) -> None:
    """
    Initializes ``DeepThinkingTrainer``.

    Args:
      config (JSONType):
        The config dictionary used to construct the trainer and model.
      *args:
        Additional arguments for ``LoggingTrainer``.
      clip (float, optional):
        The float value to clip gradient norms to.
        Defaults to ``None``.
      save_best (str, optional):
        If specified, this is either ``"max"`` or ``"min"``, and determines the
        measurement for comparison on deciding whether to overwrite the current
        saved version.
        Defaults to ``None`` (no saving).
      save_dir (str, optional):
        Should be used if ``save_best`` is not ``None``. This is the directory
        to save models and configs to.
        Defaults to ``None``.
      scheduler (Scheduler, optional):
        The LR scheduler for training.
        Defaults to ``None``.
      scheduler_callback (str, optional):
        Either ``"training"`` or ``"validation"``. The type of loss specified
        will be fed into the scheduler as a step. Otherwise, no value will be
        used.
        Defaults to ``None``.
      warmup (BaseWarmup, optional):
        The warmup to use for training.
        Defaults to ``None``.
      name (str, optional):
        The model name.
        Defaults to ``"model"``.
    """
    super(DeepThinkingTrainer, self).__init__(*args, **kwargs)
    # Config.
    self.config = config
    # Clip.
    assert isinstance(clip, float) or clip is None, \
      "clip must be a float or None."
    self.clip = clip
    # Clip foreach.
    assert isinstance(clip_foreach, bool) or clip_foreach is None, \
      "clip_foreach must be a bool or None."
    if self.clip is not None:
      self.clip_foreach = False if clip_foreach is None else clip_foreach
    else:
      self.clip_foreach = None
    # Save best.
    assert isinstance(save_best, str) or save_best is None, \
      "save_best must be a string or None."
    self.save_best = save_best
    if self.save_best is not None:
      self._current_best = None
      self._is_better = lambda current, new: new <= current
      if self.save_best == "max":
        self._is_better = lambda current, new: new >= current
      elif self.save_best != "min":
        warn(
          f"'{self.save_best}' is not a recognized save mode. Defaulting " + \
          "to 'min'."
        )
        self.save_best = "min"
    # Save directory.
    assert isinstance(save_dir, str) or save_dir is None, \
      "save_dir must be a string or None."
    if save_dir is None and self.save_best is not None:
      raise Exception("save_dir must be specified if save_best is used.")
    assert isdir(save_dir), \
      f"{save_dir} is not a valid directory."
    self.save_dir = save_dir
    # Scheduler and warmup.
    self.scheduler          = scheduler
    self.scheduler_callback = scheduler_callback
    self.warmup             = warmup
    # Name.
    self._name = name
    self.tsp_drawer = TSPDrawer(
      image_size = (512, 512),
      point_color = (255, 0, 128, 255),
      font = ImageFont.truetype("arial.ttf", 55)
    )

  def _save_model(self):
    filename = self._name + self._train_start_stamp + ".tar"
    filename = join(self.save_dir, filename)
    torch.save(
      {
        "timestamp":  self._train_start_stamp,
        "config":     self.config,
        "epoch":      self._current_epoch,
        "train_loss": self._train_average(),
        "valid_loss": self._validation_average(),
        "model_state": self.model.state_dict(),
        "optim_state": self.optimizer.state_dict()
      },
      filename
    )

  def on_train_start(self):
    super(DeepThinkingTrainer, self).on_train_start()
    if self._current_best is None:
      self._current_best = -torch.inf if self.save_best == "max" else torch.inf

  def process_training_batch(self):
    if self._current_batch is None: pass
    input_batch, target_batch = self._current_batch
    self._current_batch = None
    input_batch  = input_batch.to(self.device)
    target_batch = target_batch.to(self.device)
    predicted_batch = self.model(input_batch)
    self._current_loss = self.loss_fn(predicted_batch, target_batch)
    for constraint in self.model.thought_module.constraints:
      constraint._update_enabled = True

  def optimization_step(self):
    if self._current_loss is None: pass
    # This is required to counteract a strange bug that exists in MPS, where
    # loss can become 1.0 after using .backward() unless it is evaluated in
    # some form:
    #_ = self._current_loss.item()
    self._current_loss.backward()
    self._batch_loss = self._current_loss.detach().item()
    self._current_loss = None
    if self.clip is not None:
      if self.clip_foreach:
        for p in self.model.parameters():
          torch.nn.utils.clip_grad_norm_(p, self.clip)
      else:
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
    self.optimizer.step()

  def on_epoch_end(self):
    super(DeepThinkingTrainer, self).on_epoch_end()
    if self.scheduler_callback is not None:
      scheduler_step_value = \
        self._train_average() if self.scheduler_callback == TRAINING \
                              else self._validation_average()
      self.scheduler.step(scheduler_step_value)
    else:
      self.scheduler.step()
    if self.warmup is not None:
      self.warmup.dampen()
    if self.save_best is not None:
      contender = self._validation_average()
      if self._is_better(self._current_best, contender):
        if self.verbose >= 1:
          print("New best model. Saving.")
        self._current_best = contender
        self._save_model()
    # TEMP STUFF
    solver = self.model
    solver.eval()
    with torch.no_grad():
      points = torch.rand((1, 15, 2))
      grid   = torch.cdist(points, points)
      # Normalize so the maximum distance is 1.
      grid = grid / torch.max(grid).clamp(1e-12)
      diag = torch.repeat_interleave(
        torch.eye(15).view(1, 1, 15, 15),
        grid.size(0),
        dim = 0
      )
      diag_flip = 1.0 - diag
      grid      = diag_flip * grid + diag
      inp = torch.concat([grid, diag_flip], dim = 1).to("cuda")
      thoughts = []
      outs = []
      inps = []
      infos = []
      for m in range(30):
        out, thought = solver(inp, max_iterations = m, return_thought = True)
        thoughts.append(thought[:, 1:].detach().cpu().numpy())
        temp = thought[:, [0]].detach().cpu()
        inps.append(temp.numpy())
        outs.append(nodes_to_edges(out.squeeze(0)).unsqueeze(0).detach().cpu().numpy())
        infos.append({
          "epoch": self._current_epoch,
          "nodes": 15,
          "iter": m
        })
      inps = np.array(inps)
      thoughts = np.array(thoughts)
      thoughts_shape = thoughts.shape
      thoughts = np.moveaxis(thoughts, 2, -1).reshape((-1, 64))
      mean, std = np.mean(thoughts), np.std(thoughts)
      bins = np.linspace(mean - 1 * std, mean + 1 * std, 4)
      thoughts = np.apply_along_axis(
        lambda x: np.histogram(x, bins = bins, density = False)[0],
        1,
        thoughts
      )
      thoughts = thoughts / (np.sum(thoughts, axis = -1)[:, None] + 1e-12)
      thoughts = thoughts.reshape(thoughts_shape[:2] + thoughts_shape[-2:] + (3,))
      thoughts = np.moveaxis(thoughts, -1, 2)
      outs = np.array(outs)
      images = self.tsp_drawer.draw_batch(
        points.numpy()[0],
        outs[:, 0],
        thoughts[:, 0],
        #np.clip(inp_mat[[0], [0]], 0.0, 1.0),
        np.clip(inps[:, 0], -1.0, 1.0),
        np.linalg.eigvals(outs[:, 0]),
        infos,
        grid.detach().cpu().numpy()[0, 0]
      )
      filename = f"temp/val{self._current_epoch}.gif"
      self.tsp_drawer.save_as_gif(
        images,
        filename = filename,
        duration = 10
      )
      copyfile(filename, "temp/valcurr.gif")
    solver.train()