################################################################################
# spectral/testing/test_model.py
#
# 
# 
# 
# 2024
#
# Performs testing (using the test dataset) of trained models.

import json
import numpy as np
import torch

from argparse          import ArgumentParser
from easy_to_hard_data import PrefixSumDataset, MazeDataset
from typing            import Any, Callable, Dict, Optional, Tuple, Type, Union

from experilog.logger           import Logger, JSONType
from modules.chesspuzzles.data  import FlippedChessPuzzleDataset
from modules.chesspuzzles.model import load_from_json_dict as load_chess_puzzles
from modules.mazes.model        import load_from_json_dict as load_mazes
from modules.prefixsums.model   import load_from_json_dict as load_prefix_sums

DataLoader = torch.utils.data.DataLoader
Device     = torch.device
Module     = torch.nn.Module
Tensor     = torch.Tensor

DatasetType = Union[
  Type[PrefixSumDataset], Type[MazeDataset], Type[FlippedChessPuzzleDataset]
]
MetricsDict = Dict[str, Callable[[Tensor, Tensor], Tensor]]

DATASET_MAP = {
  "prefix_sums":   PrefixSumDataset,
  "mazes":         MazeDataset,
  "chess_puzzles": FlippedChessPuzzleDataset
}
MODEL_MAP = {
  "prefix_sums":   load_prefix_sums,
  "mazes":         load_mazes,
  "chess_puzzles": load_chess_puzzles
}

def metric_solution_accuracy(
    # Arguments:
    y_hat: Tensor,
    y:     Tensor
  ) -> Tensor:
  """
  Computes the solution accuracy between the prediction and the target. The
  solution accuracy differs from the category accuracy in that an individual
  problem will only increase the accuracy if all of its classes are correct.

  Args:
    y_hat (Tensor):
      The predicted logits for each class.
    y (Tensor):
      The target classes.

  Returns:
    Tensor:
      The non-reduced tensor of each solution accuracy.
  """
  selected_classes = y_hat.argmax(dim = 1).long()
  equal_classes = selected_classes == y
  return torch.all(equal_classes.view((y.size(0), -1)), dim = -1).float()

def metric_category_accuracy(
    # Arguments:
    y_hat: Tensor,
    y:     Tensor
  ) -> Tensor:
  """
  Computes the category accuracy between the prediction and the target.

  Args:
    y_hat (Tensor):
      The predicted logits for each class.
    y (Tensor):
      The target classes.

  Returns:
    Tensor:
      The non-reduced tensor of each solution accuracy.
  """
  selected_classes = y_hat.argmax(dim = 1).long()
  equal_classes = (selected_classes == y).float()
  return torch.mean(equal_classes.view(y.size(0), -1), dim = -1)

def metric_cross_entropy(
    # Arguments:
    y_hat: Tensor,
    y:     Tensor
  ) -> Tensor:
  """
  Computes the cross-entropy loss between the prediction and the target.

  Args:
    y_hat (Tensor):
      The predicted logits for each class.
    y (Tensor):
      The target classes.

  Returns:
    Tensor:
      The non-reduced tensor of each cross-entropy loss.
  """
  return torch.nn.functional.cross_entropy(
    y_hat,
    y,
    reduction = "none"
  )

METRICS = {
  "solution_accuracy": metric_solution_accuracy,
  "category_accuracy": metric_category_accuracy,
  "cross_entropy":     metric_cross_entropy
}

def get_dataset_type(
    # Arguments:
    problem: str
  ) -> DatasetType:
  """
  Gets the dataset type from the ``"problem"`` entry of the dictionary.

  Args:
    problem (str):
      The problem as a string.

  Returns:
    DatasetType:
      The class of the dataset.
  """
  try:
    return DATASET_MAP[problem]
  except:
    raise Exception(
      f"'{problem}' is an unrecognized problem type. Looking for one of:\n" +
      "\n".join(f"- '{x}'" for x in DATASET_MAP)
    )

def load_model_for_inference(
    # Arguments:
    file: str
  ) -> Tuple[Module, JSONType, Any]:
  """
  Loads a model and its config for inference from the specified file.

  Args:
    file (str):
      The file location for the saved model.

  Returns:
    Tuple[Module, JSONType]:
      The model with the trained weights and the JSON-ready config dictionary.
  """
  contents = torch.load(file, map_location = "cpu")
  config = contents["config"]
  # Exists just to show the pretty error if needed.
  _ = get_dataset_type(config["problem"])
  model_fn = MODEL_MAP[config["problem"]]
  model = model_fn(config["model"])
  model.load_state_dict(contents["model_state"])
  return (model, config, contents)

def perform_test(
    # Arguments:
    model:      Module,
    dataloader: DataLoader,
    logger:     Logger,
    # Keyword Arguments:
    device:         Optional[Union[Device, str]] = None,
    max_iterations: Optional[int]                = None,
    metrics:        Optional[MetricsDict]        = None
  ) -> None:
  # Device.
  device = torch.device("cpu") if device is None else device
  device = torch.device(device) if isinstance(device, str) else device
  # Max iterations.
  max_iterations = model.max_iterations if max_iterations is None \
                                        else max_iterations
  # Metrics.
  metrics = METRICS if metrics is None else {m: METRICS[m] for m in metrics}
  # Model.
  model.eval()
  measurements = {m: [] for m in metrics}
  with torch.no_grad():
    for input_batch, target_batch in dataloader:
      input_batch  = input_batch.to(device)
      target_batch = target_batch.to(device)
      predicted_batch = model(input_batch, max_iterations = max_iterations)
      for metric in metrics:
        measurements[metric].append(
          metrics[metric](predicted_batch, target_batch).detach().cpu().numpy()
        )
  measurements = {m: np.concatenate(measurements[m]) for m in measurements}
  # To result.
  summaries = {m: logger.array_summary(measurements[m]) for m in measurements}
  results = {
    m: {
      #"data": logger.from_numpy(measurements[m]),
      "summary": summaries[m]
    }
    for m in measurements
  }
  logger.record_result(results)

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Tests a trained model.")
  parser.add_argument("model", help = "The model file.")
  parser.add_argument("dataset", help = "The root location of the dataset.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("-d", "--data", default = None, type = str)
  parser.add_argument("-t", "--title", default = None, type = str)
  parser.add_argument("-i", "--iterations", default = None, type = int)
  args = parser.parse_args()
  # Get the device.
  if torch.cuda.is_available():
    print("Using CUDA.")
    device = "cuda"
  elif torch.backends.mps.is_available():
    print("Using MPS.")
    device = "mps"
  else:
    print("Using CPU.")
    device = "cpu"
  # Get the model and config.
  model, config, contents = load_model_for_inference(args.model)
  model.to(device)
  # Get dataset.
  if args.data is not None:
    with open(args.data, "r") as file:
      dataset_config = json.loads(file.read())
  else:
    dataset_config = config["testing"]
  dataset_type = get_dataset_type(config["problem"])
  test_dataset = dataset_type(args.dataset, **dataset_config["data"])
  test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    num_workers = 0,
    batch_size  = dataset_config["batch_size"]
  )
  # Set up logger.
  logger = Logger(args.log_dir, args.title)
  logger.set_controls({
    "model_config": config,
    "data_config":  dataset_config,
    "iterations":   args.iterations,
    "device":       device,
    "train_info": {
      "timestamp":  contents["timestamp"],
      "epoch":      contents["epoch"],
      "train_loss": contents["train_loss"],
      "valid_loss": contents["valid_loss"]
    }
  })
  logger.start_experiment()
  # Get results.
  perform_test(
    model,
    test_dataloader,
    logger,
    device         = device,
    max_iterations = args.iterations,
    metrics        = dataset_config.get("metrics")
  )
  # End experiment.
  logger.stop_experiment()
  logger.write()