################################################################################
# spectral/training/train_model.py
#
# 
# 
# 
# 2024
#
#

import json
import torch

from argparse          import ArgumentParser
from easy_to_hard_data import PrefixSumDataset, MazeDataset
from typing            import Any, Callable, Dict, List, Tuple, Type, Union

import modules.warmup as warmup_lib

from experilog.logger           import Logger, JSONType
from modules.chesspuzzles.data  import FlippedChessPuzzleDataset
from modules.chesspuzzles.model import load_from_json_dict as load_chess
from modules.mazes.model        import load_from_json_dict as load_mazes
from modules.prefixsums.loss    import metric_solution_accuracy
from modules.prefixsums.model   import load_from_json_dict as load_prefix_sums
from modules.tsp.data           import TSPRandomGrids
from modules.tsp.loss           import tsp_loss
from modules.tsp.model          import load_from_json_dict as load_tsp
from modules.trainer            import DeepThinkingTrainer

#torch.autograd.set_detect_anomaly(True)

# This is highly recommended for the Lipschitz regularizer. Our testing on both
# MPS and CUDA have shown slight differences, depending on PyTorch versions.
# If using MPS, ensure at least 2.4.0.
#torch.use_deterministic_algorithms(True)

DatasetType = Union[
  Type[PrefixSumDataset], Type[MazeDataset], Type[FlippedChessPuzzleDataset],
  Type[TSPRandomGrids]
]
Module = torch.nn.Module
Tensor = torch.Tensor

def cross_entropy_deterministic(
    # Arguments:
    input,
    target
  ):
  loss = torch.nn.functional.cross_entropy(input, target, reduction = "none")
  return torch.mean(loss)

DATASET_MAP = {
  "prefix_sums":   PrefixSumDataset,
  "mazes":         MazeDataset,
  "chess_puzzles": FlippedChessPuzzleDataset,
  "tsp":           TSPRandomGrids
}
MODEL_MAP = {
  "prefix_sums":   load_prefix_sums,
  "mazes":         load_mazes,
  "chess_puzzles": load_chess,
  "tsp":           load_tsp
}
LOSS_MAP = {
  "prefix_sums": (
    cross_entropy_deterministic,
    metric_solution_accuracy
  ),
  "mazes": (
    cross_entropy_deterministic,
    metric_solution_accuracy
  ),
  "chess_puzzles": (
    cross_entropy_deterministic,
    metric_solution_accuracy
  ),
  "tsp": (
    tsp_loss,
    tsp_loss
  )
}

def get_config_json(
    # Arguments:
    filename: str
  ) -> JSONType:
  """
  Loads the config JSON file into a dictionary.

  Args:
    filename (str):
      The filename (location) of the JSON file.

  Returns:
    JSONType:
      The config dictionary.
  """
  with open(filename, "r") as file:
    json_dict = json.loads(file.read())
  return json_dict

def get_dataset_type(
    # Arguments:
    problem: str
  ) -> DatasetType:
  """
  Gets the dataset type from the ``"problem"`` entry of the dictionary.

  Args:
    problem (str):
      The problem as a string.

  Returns:
    DatasetType:
      The class of the dataset.
  """
  try:
    return DATASET_MAP[problem]
  except:
    raise Exception(
      f"'{problem}' is an unrecognized problem type. Looking for one of:\n" +
      "\n".join(f"- '{x}'" for x in DATASET_MAP)
    )

def get_progress_loss_function(
    # Arguments:
    alpha:   float,
    loss_fn: Callable[[Tensor, Tensor], Tensor]
  ) -> Callable[[Tuple[Tensor, Tensor], Tensor], Tensor]:
  """
  Constructs a loss function for incremental progress training.

  Args:
    alpha (float):
      The alpha parameter for incremental progress training.
    loss_fn (Callable[[Tensor, Tensor], Tensor]):
      The loss function (applied to both).

  Returns:
    Callable[[Tensor, Tensor], Tensor]:
      The incremental progress loss function.
  """
  def temp(y_hat, y):
    y_hat_prog, y_hat_m = y_hat
    L_prog = loss_fn(y_hat_prog, y)
    L_m    = loss_fn(y_hat_m,    y)
    L = (1 - alpha) * L_m + alpha * L_prog
    return L
  return temp

def get_parameters(
    # Arguments:
    model: Module,
    lr:    float,
    # Keyword Arguments:
    do_lr_throttle:   bool  = False,
    wd_not_constrained: bool  = False,
    wd_only_weights:  bool  = False
  ) -> List[Dict[str, Any]]:
  """
  Gets the training parameters, specifically returning a modified set for
  cases where learning rate throttling is used.

  Args:
    model (Module):
      The model to train.
    lr (float):
      The learning rate used for training.
    do_lr_throttle (bool, optional):
      Whether to use learning rate throttling.
      Defaults to ``False``.
    wd_not_constrained (bool, optional):
      Whether to only apply the weight decay to weights that don't use
      constraints.
      Defaults to ``False``.
    wd_only_weights (bool, optional):
      Whether to only apply the weight decay to weights as part of layers,
      excluding bias and batchnorm.
      Defaults to ``False``.

  Returns:
    List[Dict[str, Any]]:
      The list of each parameters and associated learning rate, if different
      from the default.
  """
  throttle = lr / model.max_iterations
  groups = {}
  for name, parameter in model.named_parameters():
    group = [None, None]
    if do_lr_throttle and "thought" in name:
      group[0] = throttle
    if wd_not_constrained:
      exceptions = "ortho" not in name and "recall" not in name
      if "thought" in name and "weight" in name and exceptions:
        group[1] = 0.0
    if wd_only_weights:
      if "weight" not in name or "norm" in name:
        group[1] = 0.0
    group = tuple(group)
    print(group, name)
    groups[group] = groups.get(group, []) + [parameter]
  parameters = []
  for group in groups:
    args = {}
    if group[0] is not None:
      args["lr"] = group[0]
    if group[1] is not None:
      args["weight_decay"] = group[1]
    parameter = {"params": groups[group], **args}
    parameters.append(parameter)
  return parameters

def get_trainer(
    # Arguments:
    config:   JSONType,
    model:    Module,
    dataset:  DatasetType,
    log_dir:  str,
    save_dir: str,
    # Keyword Arguments:
    device:  str = "cpu",
    verbose: int = 0,
    name:    str = "model"
  ) -> DeepThinkingTrainer:
  """
  Constructs the ``DeepThinkingTrainer`` object from the config.

  Args:
    config (JSONType):
      The config dictionary.
    model (Module):
      The model for training.
    dataset (DatasetType):
      The dataset used for training.
    log_dir (str):
      The directory used for saving logs.
    save_dir (str):
      The directory used for saving models.
    device (str, optional):
      The device used for the training.
      Defaults to ``"cpu"``.
    verbose (int, optional):
      The verbosity of training.
      Defaults to ``0``.
    name (str, optional):
      The name of the model.
      Defaults to ``"model"``.

  Returns:
    DeepThinkingTrainer:
      The constructed trainer object.
  """
  # Loss functions.
  loss_fn = LOSS_MAP[config["problem"]][0]
  if config["model"]["use_incremental_progress"]:
    train_loss_fn = get_progress_loss_function(
      config["training"]["alpha"],
      loss_fn
    )
  else:
    train_loss_fn = loss_fn
  # Optimizer.
  optimizer_config = config["training"]["optimizer"]
  optimizer_type = optimizer_config["name"]
  try:
    optimizer_type = getattr(torch.optim, optimizer_type)
  except:
    raise Exception(
      f"'{optimizer_type}' is not a known optimizer in torch.optim."
    )
  # Throttling and weight decay settings are optional.
  lr_throttle = optimizer_config.get("lr_throttle", False)
  wd_not_constrained = optimizer_config.get("wd_not_constrained", False)
  wd_only_weights = optimizer_config.get("wd_only_weights", False)
  optimizer = optimizer_type(
    get_parameters(model,
      config["training"]["optimizer"]["args"]["lr"],
      do_lr_throttle   = lr_throttle,
      wd_not_constrained = wd_not_constrained,
      wd_only_weights  = wd_only_weights,
    ),
    **config["training"]["optimizer"]["args"]
  )
  # Scheduler.
  scheduler_type = config["training"]["scheduler"]["name"]
  try:
    scheduler_type = getattr(torch.optim.lr_scheduler, scheduler_type)
  except:
    raise Exception(
      f"'{scheduler_type}' is not a known scheduler in " + \
      "torch.optim.lr_scheduler."
    )
  scheduler = scheduler_type(
    optimizer,
    **config["training"]["scheduler"]["args"]
  )
  scheduler_callback = config["training"].get("scheduler_callback")
  # Warmup.
  warmup_type = config["training"].get("warmup")
  if warmup_type is not None:
    warmup_type = warmup_type["name"]
    try:
      warmup_type = getattr(warmup_lib, warmup_type)
    except:
      raise Exception(
        f"'{warmup_type}' is not a known warmup in modules.warmup."
      )
    warmup = warmup_type(
      optimizer,
      **config["training"]["warmup"]["args"]
    )
  else:
    warmup = None
  # Clipping.
  clip = config["training"].get("clip")
  if clip is None or isinstance(clip, float):
    clip_foreach = None
  else:
    clip_foreach = clip.get("foreach")
    clip         = clip.get("max")
  print("DEBUG:", clip, clip_foreach)
  # Data loaders.
  train_split = int(config["training"]["train_split"] * len(dataset))
  train_dataset, valid_dataset = torch.utils.data.random_split(
    dataset,
    [train_split, int(len(dataset) - train_split)]
  )
  train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    num_workers = config["training"].get("train_workers", 0),
    batch_size  = config["training"]["batch_size"],
    shuffle     = config["training"]["shuffle"],
    drop_last   = config["training"].get("drop_last", True)
  )
  valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    num_workers = config["training"].get("valid_workers", 0),
    batch_size  = config["training"]["batch_size"],
    shuffle     = False,
    drop_last   = config["training"].get("drop_last", False)
  )
  # Trainer.
  return DeepThinkingTrainer(
    model                 = model,
    loss_fn               = train_loss_fn,
    optimizer             = optimizer,
    train_dataloader      = train_dataloader,
    validation_dataloader = valid_dataloader,
    log_dir               = log_dir,
    config                = config,
    device                = device,
    verbose               = verbose,
    validation_loss_fn    = LOSS_MAP[config["problem"]][1],
    clip                  = clip,
    clip_foreach          = clip_foreach,
    save_best             = config["training"].get("save_best", "max"),
    save_dir              = save_dir,
    scheduler             = scheduler,
    scheduler_callback    = scheduler_callback,
    warmup                = warmup,
    format_loss_string    = "Loss: %.5f",
    name                  = name,
    log_name              = name
  )

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Trains one of the specified models.")
  parser.add_argument("config", help = "The JSON configuration file.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("save_dir", help = "The directory for saving models.")
  parser.add_argument("-d", "--dataset", default = None, type = str)
  parser.add_argument("-v", "--verbose", default = 0, type = int)
  parser.add_argument("-n", "--name", default = "model", type = str)
  args = parser.parse_args()
  # Get the config file.
  config = get_config_json(args.config)
  # Set up the dataset from the config.
  dataset_type = get_dataset_type(config["problem"])
  if args.dataset is not None:
    train_dataset = dataset_type(args.dataset, **config["training"]["data"])
  else:
    train_dataset = dataset_type(**config["training"]["data"])
  # Get the device.
  if torch.cuda.is_available():
    print("Using CUDA.")
    device = "cuda"
  elif torch.backends.mps.is_available():
    print("Using MPS.")
    device = "mps"
  else:
    print("Using CPU.")
    device = "cpu"
  # Construct the model.
  model_fn = MODEL_MAP[config["problem"]]
  model = model_fn(config["model"])
  model.to(device)
  # Construct the trainer.
  trainer = get_trainer(
    config,
    model,
    train_dataset,
    args.log_dir,
    args.save_dir,
    device = device,
    verbose = args.verbose,
    name = args.name
  )
  # Train.
  trainer(config["training"]["epochs"])