################################################################################
# spectral/modules/tsp/loss.py
#
# 
# 
# 
# 2024
#
# Implements a loss function for TSP.

import torch

Tensor = torch.Tensor

TEMPERATURE = 0.1

def sinkhorn(
    # Arguments:
    x: Tensor
  ) -> Tensor:
  x = x / torch.sum(x, dim = -1, keepdim = True).clamp(1e-30)
  x = x / torch.sum(x, dim = -2, keepdim = True).clamp(1e-30)
  return x

def nodes_to_edges(
    # Arguments:
    t: Tensor
  ) -> Tensor:
  """
  Converts the model's "nodes-after-0th" ordering into the edges selected by
  that ordering.

  Args:
    t (Tensor):
      The tensor of the ordering.

  Returns:
    Tensor:
      The tensor of selected edges.
  """
  # Procedure for fixing top-left item.
  N = t.size(-1)
  t_expand = torch.repeat_interleave(
    torch.eye(N + 1, device = t.device).view(1, N + 1, N + 1),
    t.size(0),
    dim = 0
  )
  t_expand[:, 1:, 1:] = t
  t                   = t_expand
  # Generate single-cycle permutation matrix.
  const_cycle = torch.roll(
    torch.eye(N + 1, device = t.device),
    (-1, 0),
    dims = (0, 1)
  )
  # Calculate group conjugate.
  t = torch.matmul(
    t.mT, torch.matmul(const_cycle, t)
  ) / TEMPERATURE
  t = 0.5 * (torch.softmax(t, dim = -1) + torch.softmax(t, dim = -2))
  return t

def tsp_loss(
    # Arguments:
    t: Tensor,
    d_tilde:   Tensor
  ) -> Tensor:
  """
  Computes the TSP loss for a predicted ordering matrix (t), and a grid of
  distances (d_tilde).

  Args:
    t (Tensor):
      The selected tour matrix.
    d_tilde (Tensor):
      The grid of distances.
  """
  # Remove channel dimensions.
  t       = torch.squeeze(t,               dim = 1)
  d_tilde = torch.squeeze(d_tilde[:, [0]], dim = 1)
  # Lower edge lengths.
  d_tilde = d_tilde - 1.0
  # Compute edge matrix.
  t_hat = torch.pow(nodes_to_edges(t), 1.5)
  # Calculate path loss.
  L_path = torch.sum(
    (t_hat * d_tilde),
    dim = (-1, -2)
  ) / t.size(-1)
  # Calculate orthogonal loss.
  """identity = torch.eye(t.size(-1), device = t.device)
  L_ortho = torch.sum(
    torch.pow(
      torch.matmul(t.mT, t) - identity,
      2.0
    ).view(t.size(0), -1)
  ) / t.size(-1)
  L_ortho = 0.001 * L_ortho"""
  """L_perm = torch.sum(
    torch.pow(t, 2.0) * torch.pow(1 - t, 2.0),
    dim = (-1, -2)
  ) / t.size(-1)"""
  L = torch.mean(L_path + 1.0)
  return L
