Source code for cleverhans.compat

"""
Wrapper functions for writing code that is compatible with many versions
of TensorFlow.
"""
import warnings
import tensorflow as tf
# The following 2 imports are not used in this module. They are imported so that users of cleverhans.compat can
# get access to device_lib, app, and flags. A pylint bug makes these imports cause errors when using python3+tf1.8.
# Doing the sanitized import here once makes it possible to do "from cleverhans.compat import flags" throughout the
# library without needing to repeat the pylint boilerplate.
from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module,unused-import
from tensorflow.python.platform import app, flags # pylint: disable=no-name-in-module,unused-import

def _wrap(f):
  """
  Wraps a callable `f` in a function that warns that the function is deprecated.
  """
  def wrapper(*args, **kwargs):
    """
    Issues a deprecation warning and passes through the arguments.
    """
    warnings.warn(str(f) + " is deprecated. Switch to calling the equivalent function in tensorflow. "
                  " This function was originally needed as a compatibility layer for old versions of tensorflow, "
                  " but support for those versions has now been dropped.")
    return f(*args, **kwargs)
  return wrapper

reduce_sum = _wrap(tf.reduce_sum)
reduce_max = _wrap(tf.reduce_max)
reduce_min = _wrap(tf.reduce_min)
reduce_mean = _wrap(tf.reduce_mean)
reduce_prod = _wrap(tf.reduce_prod)
reduce_any = _wrap(tf.reduce_any)

def reduce_function(op_func, input_tensor, axis=None, keepdims=None,
                    name=None, reduction_indices=None):
  """
  This function used to be needed to support tf 1.4 and early, but support for tf 1.4 and earlier is now dropped.
  :param op_func: expects the function to handle eg: tf.reduce_sum.
  :param input_tensor: The tensor to reduce. Should have numeric type.
  :param axis: The dimensions to reduce. If None (the default),
          reduces all dimensions. Must be in the range
          [-rank(input_tensor), rank(input_tensor)).
  :param keepdims: If true, retains reduced dimensions with length 1.
  :param name: A name for the operation (optional).
  :param reduction_indices: The old (deprecated) name for axis.
  :return: outputs same value as op_func.
  """

  warnings.warn("`reduce_function` is deprecated and may be removed on or after 2019-09-08.")

  out = op_func(input_tensor, axis=axis, keepdims=keepdims, name=name, reduction_indices=reduction_indices)

  return out

[docs]def softmax_cross_entropy_with_logits(sentinel=None, labels=None, logits=None, dim=-1): """ Wrapper around tf.nn.softmax_cross_entropy_with_logits_v2 to handle deprecated warning """ # Make sure that all arguments were passed as named arguments. if sentinel is not None: name = "softmax_cross_entropy_with_logits" raise ValueError("Only call `%s` with " "named arguments (labels=..., logits=..., ...)" % name) if labels is None or logits is None: raise ValueError("Both labels and logits must be provided.") try: f = tf.nn.softmax_cross_entropy_with_logits_v2 except AttributeError: raise RuntimeError("This version of TensorFlow is no longer supported. See cleverhans/README.md") labels = tf.stop_gradient(labels) loss = f(labels=labels, logits=logits, dim=dim) return loss