################################################################################
# spectral/testing/get_spectral.py
#
# 
# 
# 
# 2024
#
# Gets the spectral norm of weight matrices of a DT model.

import json
import numpy as np
import torch

from argparse     import ArgumentParser
from numpy.typing import NDArray
from typing       import Any, Tuple

from experilog.logger         import Logger, JSONType
# from modules.chesspuzzles.model import load_from_json_dict
from modules.mazes.model      import load_from_json_dict as load_mazes
from modules.prefixsums.model import load_from_json_dict as load_prefix_sums

DataLoader = torch.utils.data.DataLoader
Device     = torch.device
Module     = torch.nn.Module
Tensor     = torch.Tensor

MODEL_MAP = {
  "prefix_sums": load_prefix_sums,
  "mazes":       load_mazes
}

WEIGHT_NAMES = [
  "thought_module.conv1.weight",
  "thought_module.conv2.weight",
  "thought_module.conv3.weight",
  "thought_module.conv4.weight"
]

def load_model_for_inference(
    # Arguments:
    file: str
  ) -> Tuple[Module, JSONType, Any]:
  """
  Loads a model and its config for inference from the specified file.

  Args:
    file (str):
      The file location for the saved model.

  Returns:
    Tuple[Module, JSONType]:
      The model with the trained weights and the JSON-ready config dictionary.
  """
  contents = torch.load(file, map_location = "cpu")
  config = contents["config"]
  model_fn = MODEL_MAP[config["problem"]]
  model = model_fn(config["model"])
  model.load_state_dict(contents["model_state"])
  return (model, config, contents)

def get_spectral_norm(
    # Arguments:
    weight: NDArray
  ) -> float:
  """
  Gets the spectral norm of a weight tensor as a NumPy array.

  Args:
    weight (NDArray):
      The weight tensor as a NumPy array.

  Returns:
    float:
      The spectral norm.
  """
  return float(np.linalg.norm(weight.reshape((weight.shape[0], -1)), ord = 2))

def perform_test(
    # Arguments:
    model:      Module,
    logger:     Logger,
    # Keyword Arguments:
  ) -> None:
  # Model.
  model.eval()
  state_dict = model.state_dict()
  measurements = {
    m: get_spectral_norm(state_dict[m].detach().cpu().numpy())
    for m in WEIGHT_NAMES
  }
  # To result.
  results = {
    m: {
      #"data": logger.from_numpy(measurements[m]),
      "data": measurements[m]
    }
    for m in measurements
  }
  logger.record_result(results)

if __name__ == "__main__":
  # Construct an argument parser.
  parser = ArgumentParser(description = "Gets the spectral norm of weights.")
  parser.add_argument("model", help = "The model file.")
  parser.add_argument("log_dir", help = "The directory for saving logs.")
  parser.add_argument("-t", "--title", default = None, type = str)
  args = parser.parse_args()
  # Get the model and config.
  model, config, contents = load_model_for_inference(args.model)
  model.to("cpu")
  # Set up logger.
  logger = Logger(args.log_dir, args.title)
  logger.set_controls({
    "model_config": config,
    "train_info": {
      "timestamp":  contents["timestamp"],
      "epoch":      contents["epoch"],
      "train_loss": contents["train_loss"],
      "valid_loss": contents["valid_loss"]
    }
  })
  logger.start_experiment()
  # Get results.
  perform_test(
    model,
    logger
  )
  # End experiment.
  logger.stop_experiment()
  logger.write()