# coding=utf-8
# Copyright 2020 The Attribution Gnn Benchmarking Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Attribution related metrics."""
import functools
import warnings
from typing import Callable, List

import numpy as np
import scipy
from absl import logging

import graph_nets
import graphs as graph_utils
import sklearn.metrics

# Typing aliases.
GraphsTuple = graph_nets.graphs.GraphsTuple


def silent_nan_np(f):
  """Decorator that silences np errors and returns nan if undefined.

  The np.nanmax and other numpy functions will log RuntimeErrors when the
  input is only nan (e.g. All-NaN axis encountered). This decorator silences
  these messages.

  Args:
    f: function to decorate.

  Returns:
    Variant of function that will be silent to invalid numpy errors, and with
    np.nan when metric is undefined.
  """

  @functools.wraps(f)
  def wrapper(*args, **kwargs):
    with np.errstate(invalid='ignore'), warnings.catch_warnings():
      warnings.simplefilter('ignore')
      try:
        return f(*args, **kwargs)
      except (ValueError, sklearn.exceptions.UndefinedMetricWarning):
        return np.nan
  return wrapper


accuracy_score = sklearn.metrics.accuracy_score
nanmax = silent_nan_np(np.nanmax)
nan_auroc_score = silent_nan_np(sklearn.metrics.roc_auc_score)
nan_f1_score = silent_nan_np(sklearn.metrics.f1_score)


def _nodewise_metric(
    f):
  """Wrapper to apply a metric computation to each nodes vector in a heatmap.

  For example, given a function `auroc` that computes AUROC over a pair
  (y_true, y_pred) for a binary classification task,
  '_heatmapwise_metric(auroc)' will compute a AUROC for each nodes heatmap in
  a list.

  Args:
    f: A function taking 1-D arrays `y_true` and `y_pred` of shape
      [num_examples] and returning some metric value.

  Returns:
    A function taking 2-D arrays `y_true` and `y_pred` of shape [num_examples,
    num_output_classes], returning an array of shape [num_output_classes].
  """

  def vectorized_f(y_true, y_pred, *args, **kwargs):
    n = len(y_true)
    values = [
        f(y_true[i].nodes, y_pred[i].nodes, *args, **kwargs) for i in range(n)
    ]
    return np.array(values)

  return vectorized_f


def _validate_attribution_inputs(y_true,
                                 y_pred):
  """Helper function to validate that attribution metric inputs are good."""
  if len(y_true) != len(y_pred):
    raise ValueError(
        f'Expected same number of graphs in y_true and y_pred, found {len(y_true)} and {len(y_pred)}'
    )
  for att_true in y_true:
    node_shape = att_true.nodes.shape
    if len(node_shape) != 2:
      raise ValueError(
          f'Expecting 2D nodes for true attribution, found at least one with shape {node_shape}'
      )


def _attribution_metric(
    f):
  """Wrapper to apply a 'attribution' style metric computation to each graph.

  For example, given a function `auroc` that computes AUROC over a pair
  (y_true, y_pred) for a binary classification task,
  '_attribution_metric(auroc)' will compute a AUROC for each graph in
  a list.

  Args:
    f: A function taking 1-D arrays `y_true` and `y_pred` of shape
      [num_examples] and returning some metric value.

  Returns:
    A function taking 2-D arrays `y_true` and `y_pred` of shape [num_examples,
    num_output_classes], returning an array of shape [num_output_classes].
  """

  def vectorized_f(y_true, y_pred, *args, **kwargs):
    _validate_attribution_inputs(y_true, y_pred)
    values = []
    for att_true, att_pred in zip(y_true, y_pred):
      many_values = [
          f(true_nodes, att_pred.nodes, *args, **kwargs)
          for true_nodes in att_true.nodes.T
      ]
      values.append(nanmax(many_values))
    return np.array(values)

  return vectorized_f


def kendall_tau_score(y_true, y_pred):
  """Kendall's tau rank correlation, used for relative orderings."""
  return scipy.stats.kendalltau(y_true, y_pred).correlation


def pearson_r_score(y_true, y_pred):
  """Pearson's r for linear correlation."""
  r, _ = scipy.stats.pearsonr(y_true, y_pred)
  return r[0] if hasattr(r, 'ndim') and r.ndim == 1 else r


def rmse(y_true, y_pred):
  """Root mean squared error."""
  return np.sqrt(sklearn.metrics.mean_squared_error(y_true, y_pred))


nodewise_f1_score = _nodewise_metric(nan_f1_score)
nodewise_kendall_tau_score = _nodewise_metric(kendall_tau_score)
nodewise_pearson_r_score = _nodewise_metric(pearson_r_score)

attribution_auroc = _attribution_metric(nan_auroc_score)
attribution_accuracy = _attribution_metric(accuracy_score)
attribution_f1 = _attribution_metric(nan_f1_score)


def get_optimal_threshold(y_true,
                          y_prob,
                          grid_spacing=0.01,
                          verbose=False):
  """For probabilities, find optimal threshold according to f1 score.

  For a set of groud truth labels and predicted probabilities of these labels,
  performs a grid search over several probability thresholds. For each threshold
  f1_score is computed and the threshold that maximizes said metric is returned.

  Arguments:
    y_true (np.array): 1D array with true labels.
    y_prob (np.array): 1D array with predicted probabilities.
    grid_spacing (float): controls the spacing for the grid search, should be a
      positive value lower than 1.0 . Defaults to 0.01.
    verbose (bool): flag to print values.

  Returns:
    p_threshold (float): Probability threshold.

  """
  with warnings.catch_warnings():
    warn_cat = sklearn.exceptions.UndefinedMetricWarning
    warnings.filterwarnings(action='ignore', category=warn_cat)
    thresholds = np.arange(0.0, 1.0 + grid_spacing, grid_spacing)
    scores = [
        sklearn.metrics.f1_score(y_true, y_prob >= thres)
        for thres in thresholds
    ]
    p_threshold = thresholds[np.argmax(scores)]
  if verbose:
    logging.info('Optimal p_threshold is %.2f', p_threshold)
  return p_threshold


def get_opt_binary_attributions(atts_true,
                                atts_pred,
                                metric=nodewise_f1_score,
                                n_steps=20):
  """Binarize attributions according to a threshold."""

  thresholds = np.linspace(0, 1, num=n_steps)
  scores = []
  for thres in thresholds:
    atts = [graph_utils.binarize_np_nodes(g, thres) for g in atts_pred]
    scores.append(np.nanmean(metric(atts_true, atts)))
  opt_threshold = thresholds[np.argmax(scores)]
  atts_pred = [
      graph_utils.binarize_np_nodes(g, opt_threshold) for g in atts_pred
  ]
  return atts_pred
