# coding=utf-8
# Copyright 2023 The Soar Neurips2023 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions."""
import collections
import os
import tempfile
from typing import Optional

import numpy as np
import scann
from scann.partitioning import partitioner_pb2


def normalize(pts):
  """Returns pts, scaled to unit L2 norm."""
  norms = np.linalg.norm(pts, axis=1)
  norms[norms == 0] = 1  # Avoid divide-by-zero.
  return pts / norms[:, np.newaxis]


def compute_ground_truth(ds, qs, k):
  """Returns the ids of the k nearest neighbors to each query."""
  s = (
      scann.scann_ops_pybind.builder(ds, k, "dot_product")
      .score_brute_force()
      .build()
  )
  return s.search_batched_parallel(qs)[0]


def redo_assignment(
    centers, ds
):
  """Returns 2 length-n vectors: IDs of #1 and #2 center for each datapoint."""
  s = (
      scann.scann_ops_pybind.builder(centers, 2, "squared_l2")
      .score_brute_force()
      .build()
  )
  neighbors = s.search_batched_parallel(ds)[0]
  return neighbors[:, 0], neighbors[:, 1]


def compute_score_diffs(ds, centers, tokenization, qs, gt):
  """Computes <q, x> - <q, quant(x)> for the various x referenced in gt."""
  exact_scores = np.einsum("qkd,qd->qk", ds[gt], qs)
  approx_scores = np.einsum("qkd,qd->qk", centers[tokenization[gt]], qs)
  return exact_scores - approx_scores


def extract_partitioner_centers(
    partitioner,
):
  """Given a ScaNN SerializedPartitioner proto, return partition centroids."""
  centers = partitioner.kmeans.kmeans_tree.root.centers

  num_dims = len(centers[0].dimension)
  result = np.zeros(shape=(len(centers), num_dims))
  for i, center in enumerate(centers):
    result[i] = center.dimension
  return result


def extract_centroids(scann_artifacts_dir):
  """Returns partitioner centers as a c x d 2D Numpy array."""
  centroid_path = os.path.join(scann_artifacts_dir, "serialized_partitioner.pb")
  with open(centroid_path, "rb") as f:
    partitioner_bytes = f.read()
    partitioner = partitioner_pb2.SerializedPartitioner()
    partitioner.ParseFromString(partitioner_bytes)
    return extract_partitioner_centers(partitioner)


def extract_tokenization(scann_artifacts_dir):
  """Returns point -> centroid mapping as length n 1D Numpy array."""
  return np.load(os.path.join(scann_artifacts_dir, "datapoint_to_token.npy"))


def train_kmeans(ds, k):
  """Runs k-means and returns (centroids, tokenization) tuple."""
  s = (
      scann.scann_ops_pybind.builder(ds, 1, "dot_product")
      .tree(k, 1, training_sample_size=len(ds))
      .score_brute_force()
      .build()
  )
  with tempfile.TemporaryDirectory() as tmp:
    s.serialize(tmp)
    return extract_centroids(tmp), extract_tokenization(tmp)


def compute_avq_center(points, eta):
  """Computes center to minimize anisotropic loss on `points`."""
  norms = np.linalg.norm(points, axis=1)
  d = points.shape[1]
  norms_pow = norms ** (eta - 1)
  np.nan_to_num(norms_pow, copy=False)
  if sum(norms_pow) == 0:
    return np.zeros(d)

  points_renormed = points * (norms ** (0.5 * eta - 1.5))[:, np.newaxis]
  np.nan_to_num(points_renormed, copy=False)
  mat = np.identity(d) * sum(norms_pow) + (eta - 1) * np.dot(
      points_renormed.T, points_renormed
  )
  return eta * np.linalg.solve(mat, np.dot(norms_pow, points))


def compute_avq_centers(
    points,
    centers,
    tokenization,
    eta,
):
  """Returns c x d array of k-means centers, post-AVQ."""
  rev_toke = [[] for _ in range(len(centers))]
  for dp_index, center_index in enumerate(tokenization):
    rev_toke[center_index].append(dp_index)
  res = np.zeros_like(centers)
  for i, rev_piece in enumerate(rev_toke):
    res[i] = compute_avq_center(points[rev_piece], eta)
  return res


def get_centroid_ranks(
    centers,
    tokenization,
    qs,
    want,
):
  """For q x d qs and q x k want, returns q x k rank of each point in want."""
  ranks = np.argsort(-np.dot(qs, centers.T), axis=-1).argsort()
  return np.take_along_axis(ranks, tokenization[want], 1)


def soar_assign(ds, centers, toke1, l, dup_ok):
  """Returns tokenization minimizing orthogonality-amplified loss against toke1."""
  batch_size = 1000
  res = []
  for batch_start in range(0, len(ds), batch_size):
    cur_ds = ds[batch_start : batch_start + batch_size]
    cur_toke = toke1[batch_start : batch_start + batch_size]

    # len(cur_ds) x len(centers); [i][j] = ||cur_ds[i] - centers[j]||^2
    sq_l2_dists = (
        (np.linalg.norm(cur_ds, axis=1) ** 2)[:, np.newaxis]
        + (np.linalg.norm(centers, axis=1) ** 2)[np.newaxis, :]
        - 2 * np.dot(cur_ds, centers.T)
    )

    resids_normed = normalize(cur_ds - centers[cur_toke])
    oa_loss1 = np.sum(cur_ds * resids_normed, axis=1)
    oa_loss2 = np.dot(resids_normed, centers.T)
    tot_loss = sq_l2_dists + l * (oa_loss1[:, np.newaxis] - oa_loss2) ** 2
    # If we want to ensure this tokenization doesn't duplicate the original
    # toke1, we make all the toke1 loss infinite so argmin never picks it.
    if not dup_ok:
      for i, t in enumerate(cur_toke):
        tot_loss[i][t] = np.inf
    res.append(np.argmin(tot_loss, axis=1))
  return np.concatenate(res)


def group_rank_data(ranks, other, agg_func):
  """Returns x and y array for other, grouped and aggregated by rank."""
  d = collections.defaultdict(list)
  for rank, o in zip(ranks, other):
    d[rank].append(o)
  xs = list(sorted(d.keys()))
  ys = []
  for x in xs:
    ys.append(agg_func(d[x]))
  return xs, ys


def kmr(centers, toke1, toke2, qs, gt):
  """Datapoint-weighted k-means recall curve; pass toke2=None if no spilling."""
  # sizes[i]: how many datapoints belong to partition i, from toke1 + toke2?
  sizes = np.zeros(len(centers), int)
  np.add.at(sizes, toke1, 1)
  if toke2 is not None:
    # If toke1 == toke2, consider it not spilled (doesn't make sense to put one
    # datapoint into the same partition twice).
    np.add.at(sizes, toke2[toke2 != toke1], 1)
  all_ranks_inv = np.argsort(-np.dot(qs, centers.T), axis=-1)
  # For query i, center j has rank all_ranks[i][j] (0-indexed).
  all_ranks = all_ranks_inv.argsort()
  # gt[i][j] belongs to the ranks[i][j] highest ranked centroid for that query.
  ranks = np.take_along_axis(all_ranks, toke1[gt], 1)
  if toke2 is not None:
    ranks = np.minimum(ranks, np.take_along_axis(all_ranks, toke2[gt], 1))
  # cum_sizes[i][j] = total # datapoints in top j+1 partitions for query i.
  cum_sizes = np.cumsum(sizes[all_ranks_inv], axis=-1)
  add_locs = np.take_along_axis(cum_sizes, ranks, axis=1)
  res = np.zeros(sum(sizes) + 1, float)
  np.add.at(res, add_locs, 1)
  return np.cumsum(res) / gt.size
