tensorflow.python.ops.check_ops 源代码

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""Asserts and Boolean Checks."""

import collections

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export

NUMERIC_TYPES = frozenset([
    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16,
    dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32,
    dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8,
    dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16
])

__all__ = [
    'assert_negative',
    'assert_positive',
    'assert_proper_iterable',
    'assert_non_negative',
    'assert_non_positive',
    'assert_equal',
    'assert_none_equal',
    'assert_near',
    'assert_integer',
    'assert_less',
    'assert_less_equal',
    'assert_greater',
    'assert_greater_equal',
    'assert_rank',
    'assert_rank_at_least',
    'assert_rank_in',
    'assert_same_float_dtype',
    'assert_scalar',
    'assert_type',
    'assert_shapes',
    'is_non_decreasing',
    'is_numeric_tensor',
    'is_strictly_increasing',
]


def _maybe_constant_value_string(t):
  if not isinstance(t, ops.Tensor):
    return str(t)
  const_t = tensor_util.constant_value(t)
  if const_t is not None:
    return str(const_t)
  return t


def _assert_static(condition, data):
  """Raises a InvalidArgumentError with as much information as possible."""
  if not condition:
    data_static = [_maybe_constant_value_string(x) for x in data]
    raise errors.InvalidArgumentError(node_def=None, op=None,
                                      message='\n'.join(data_static))


def _shape_and_dtype_str(tensor):
  """Returns a string containing tensor's shape and dtype."""
  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)


def _unary_assert_doc(sym, sym_name):
  """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.

  Args:
    sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
    sym_name: English-language name for the op described by sym

  Returns:
    Decorator that adds the appropriate docstring to the function for symbol
    `sym`.
  """

  def _decorator(func):
    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.

    Args:
      func: Function for a TensorFlow op

    Returns:
      Version of `func` with documentation attached.
    """
    opname = func.__name__
    cap_sym_name = sym_name.capitalize()

    func.__doc__ = """
    Assert the condition `x {sym}` holds element-wise.

    When running in graph mode, you should add a dependency on this operation
    to ensure that it runs. Example of adding a dependency to an operation:

    ```python
    with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
      output = tf.reduce_sum(x)
    ```

    {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
    If `x` is empty this is trivially satisfied.

    Args:
      x:  Numeric `Tensor`.
      data:  The tensors to print out if the condition is False.  Defaults to
        error message and first few entries of `x`.
      summarize: Print this many entries of each tensor.
      message: A string to prefix to the default message.
      name: A name for this operation (optional).  Defaults to "{opname}".

    Returns:
      Op that raises `InvalidArgumentError` if `x {sym}` is False.
      @compatibility(eager)
        returns None
      @end_compatibility

    Raises:
      InvalidArgumentError: if the check can be performed immediately and
        `x {sym}` is False. The check can be performed immediately during
        eager execution or if `x` is statically known.
    """.format(
        sym=sym, sym_name=cap_sym_name, opname=opname)
    return func

  return _decorator


def _binary_assert_doc(sym, test_var):
  """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.

  Args:
    sym: Binary operation symbol, i.e. "=="
    test_var: a string that represents the variable in the right-hand side of
      binary operator of the test case

  Returns:
    Decorator that adds the appropriate docstring to the function for
  symbol `sym`.
  """

  def _decorator(func):
    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.

    Args:
      func: Function for a TensorFlow op

    Returns:
      A version of `func` with documentation attached.
    """
    opname = func.__name__

    func.__doc__ = """
    Assert the condition `x {sym} y` holds element-wise.

    This condition holds if for every pair of (possibly broadcast) elements
    `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
    If both `x` and `y` are empty, this is trivially satisfied.

    When running in graph mode, you should add a dependency on this operation
    to ensure that it runs. Example of adding a dependency to an operation:

    ```python
    with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
      output = tf.reduce_sum(x)
    ```

    Args:
      x:  Numeric `Tensor`.
      y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
      data:  The tensors to print out if the condition is False.  Defaults to
        error message and first few entries of `x`, `y`.
      summarize: Print this many entries of each tensor.
      message: A string to prefix to the default message.
      name: A name for this operation (optional).  Defaults to "{opname}".

    Returns:
      Op that raises `InvalidArgumentError` if `x {sym} y` is False.

    Raises:
      InvalidArgumentError: if the check can be performed immediately and
        `x {sym} y` is False. The check can be performed immediately during
        eager execution or if `x` and `y` are statically known.

    @compatibility(TF2)
    `tf.compat.v1.{opname}` is compatible with eager execution and
    `tf.function`.
    Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart
    from `data`, all arguments are supported with the same argument name.

    If you want to ensure the assert statements run before the
    potentially-invalid computation, please use `tf.control_dependencies`,
    as tf.function auto-control dependencies are insufficient for assert
    statements.

    #### Structural Mapping to Native TF2

    Before:

    ```python
    tf.compat.v1.{opname}(
      x=x, y=y, data=data, summarize=summarize,
      message=message, name=name)
    ```

    After:

    ```python
    tf.debugging.{opname}(
      x=x, y=y, message=message,
      summarize=summarize, name=name)
    ```

    #### TF1 & TF2 Usage Example

    TF1:

    >>> g = tf.Graph()
    >>> with g.as_default():
    ...   a = tf.compat.v1.placeholder(tf.float32, [2])
    ...   b = tf.compat.v1.placeholder(tf.float32, [2])
    ...   result = tf.compat.v1.{opname}(a, b,
    ...     message='"a {sym} b" does not hold for the given inputs')
    ...   with tf.compat.v1.control_dependencies([result]):
    ...     sum_node = a + b
    >>> sess = tf.compat.v1.Session(graph=g)
    >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}})


    TF2:

    >>> a = tf.Variable([1, 2], dtype=tf.float32)
    >>> b = tf.Variable({test_var}, dtype=tf.float32)
    >>> assert_op = tf.debugging.{opname}(a, b, message=
    ...   '"a {sym} b" does not hold for the given inputs')
    >>> # When working with tf.control_dependencies
    >>> with tf.control_dependencies([assert_op]):
    ...   val = a + b

    @end_compatibility
    """.format(
        sym=sym, opname=opname, test_var=test_var)
    return func

  return _decorator


def _make_assert_msg_data(sym, x, y, summarize, test_op):
  """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.

  Args:
    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
      i.e. "=="
    x: First input to the assertion after applying `convert_to_tensor()`
    y: Second input to the assertion
    summarize: Value of the "summarize" parameter to the original assert_* call;
      tells how many elements of each tensor to print.
    test_op: TensorFlow op that returns a Boolean tensor with True in each
      position where the assertion is satisfied.

  Returns:
    List of tensors and scalars that, when stringified and concatenated,
    will produce the error message string.
  """
  # Prepare a message with first elements of x and y.
  data = []

  data.append('Condition x %s y did not hold.' % sym)

  if summarize > 0:
    if x.shape == y.shape and x.shape.as_list():
      # If the shapes of x and y are the same (and not scalars),
      # Get the values that actually differed and their indices.
      # If shapes are different this information is more confusing
      # than useful.
      mask = math_ops.logical_not(test_op)
      indices = array_ops.where(mask)
      indices_np = indices.numpy()
      x_vals = array_ops.boolean_mask(x, mask)
      y_vals = array_ops.boolean_mask(y, mask)
      num_vals = min(summarize, indices_np.shape[0])
      data.append('Indices of first %d different values:' % num_vals)
      data.append(indices_np[:num_vals])
      data.append('Corresponding x values:')
      data.append(x_vals.numpy().reshape((-1,))[:num_vals])
      data.append('Corresponding y values:')
      data.append(y_vals.numpy().reshape((-1,))[:num_vals])

    # reshape((-1,)) is the fastest way to get a flat array view.
    x_np = x.numpy().reshape((-1,))
    y_np = y.numpy().reshape((-1,))
    x_sum = min(x_np.size, summarize)
    y_sum = min(y_np.size, summarize)
    data.append('First %d elements of x:' % x_sum)
    data.append(x_np[:x_sum])
    data.append('First %d elements of y:' % y_sum)
    data.append(y_np[:y_sum])

  return data


def _pretty_print(data_item, summarize):
  """Format a data item for use in an error message in eager mode.

  Args:
    data_item: One of the items in the "data" argument to an assert_* function.
      Can be a Tensor or a scalar value.
    summarize: How many elements to retain of each tensor-valued entry in data.

  Returns:
    An appropriate string representation of data_item
  """
  if isinstance(data_item, ops.Tensor):
    arr = data_item.numpy()
    if np.isscalar(arr):
      # Tensor.numpy() returns a scalar for zero-dimensional tensors
      return str(arr)
    else:
      flat = arr.reshape((-1,))
      lst = [str(x) for x in flat[:summarize]]
      if len(lst) < flat.size:
        lst.append('...')
      return str(lst)
  else:
    return str(data_item)


def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
                   message, name):
  """Generic binary elementwise assertion.

  Implements the behavior described in _binary_assert_doc() above.
  Args:
    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
      i.e. "=="
    opname: Name of the assert op in the public API, i.e. "assert_equal"
    op_func: Function that, if passed the two Tensor inputs to the assertion (x
      and y), will return the test to be passed to reduce_all() i.e.
    static_func: Function that, if passed numpy ndarray versions of the two
      inputs to the assertion, will return a Boolean ndarray with containing
      True in all positions where the assertion PASSES.
      i.e. np.equal for assert_equal()
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to the value of
      `opname`.

  Returns:
    See docstring template in _binary_assert_doc().
  """
  with ops.name_scope(name, opname, [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')

    if context.executing_eagerly():
      test_op = op_func(x, y)
      condition = math_ops.reduce_all(test_op)
      if condition:
        return

      # If we get here, the assertion has failed.
      # Default to printing 3 elements like control_flow_ops.Assert (used
      # by graph mode) does. Also treat negative values as "print
      # everything" for consistency with Tensor::SummarizeValue().
      if summarize is None:
        summarize = 3
      elif summarize < 0:
        summarize = 1e9  # Code below will find exact size of x and y.

      if data is None:
        data = _make_assert_msg_data(sym, x, y, summarize, test_op)

      if message is not None:
        data = [message] + list(data)

      raise errors.InvalidArgumentError(
          node_def=None,
          op=None,
          message=('\n'.join(_pretty_print(d, summarize) for d in data)))

    else:  # not context.executing_eagerly()
      if data is None:
        data = [
            'Condition x %s y did not hold element-wise:' % sym,
            'x (%s) = ' % x.name, x,
            'y (%s) = ' % y.name, y
        ]
      if message is not None:
        data = [message] + list(data)
      condition = math_ops.reduce_all(op_func(x, y))
      x_static = tensor_util.constant_value(x)
      y_static = tensor_util.constant_value(y)
      if x_static is not None and y_static is not None:
        condition_static = np.all(static_func(x_static, y_static))
        _assert_static(condition_static, data)
      return control_flow_ops.Assert(condition, data, summarize=summarize)


@tf_export(
    'debugging.assert_proper_iterable',
    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
  """Static assert that values is a "proper" iterable.

  `Ops` that expect iterables of `Tensor` can call this to validate input.
  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.

  Args:
    values:  Object to be checked.

  Raises:
    TypeError:  If `values` is not iterable or is one of
      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
  """
  unintentional_iterables = (
      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
      + compat.bytes_or_text_types
  )
  if isinstance(values, unintentional_iterables):
    raise TypeError(
        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
        type(values))

  if not hasattr(values, '__iter__'):
    raise TypeError(
        'Expected argument "values" to be iterable.  Found: %s' % type(values))


@tf_export('debugging.assert_negative', v1=[])
@dispatch.add_dispatch_support
def assert_negative_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x < 0` holds element-wise.

  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not negative everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to "assert_negative".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all negative. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] < 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_negative(x=x, message=message, summarize=summarize, name=name)


@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_negative')
@_unary_assert_doc('< 0', 'negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = _message_prefix(message)
  with ops.name_scope(name, 'assert_negative', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x < 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less(x, zero, data=data, summarize=summarize)


@tf_export('debugging.assert_positive', v1=[])
@dispatch.add_dispatch_support
def assert_positive_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x > 0` holds element-wise.

  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not positive everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional). Defaults to "assert_positive".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all positive. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] > 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_positive(x=x, summarize=summarize, message=message, name=name)


@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_positive')
@_unary_assert_doc('> 0', 'positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = _message_prefix(message)
  with ops.name_scope(name, 'assert_positive', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message, 'Condition x > 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less(zero, x, data=data, summarize=summarize)


@tf_export('debugging.assert_non_negative', v1=[])
@dispatch.add_dispatch_support
def assert_non_negative_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x >= 0` holds element-wise.

  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to
      "assert_non_negative".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
      be used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] >= 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_non_negative(x=x, summarize=summarize, message=message,
                             name=name)


@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_negative')
@_unary_assert_doc('>= 0', 'non-negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = _message_prefix(message)
  with ops.name_scope(name, 'assert_non_negative', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x >= 0 did not hold element-wise:',
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less_equal(zero, x, data=data, summarize=summarize)


@tf_export('debugging.assert_non_positive', v1=[])
@dispatch.add_dispatch_support
def assert_non_positive_v2(x, message=None, summarize=None, name=None):
  """Assert the condition `x <= 0` holds element-wise.

  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
  empty, this is trivially satisfied.

  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
  entries of `x` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to
      "assert_non_positive".

  Returns:
    Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
      be used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x[i] <= 0` is False. The check can be performed immediately during eager
      execution or if `x` is statically known.
  """
  return assert_non_positive(x=x, summarize=summarize, message=message,
                             name=name)


@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('assert_non_positive')
@_unary_assert_doc('<= 0', 'non-positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
  message = _message_prefix(message)
  with ops.name_scope(name, 'assert_non_positive', [x, data]):
    x = ops.convert_to_tensor(x, name='x')
    if data is None:
      if context.executing_eagerly():
        name = _shape_and_dtype_str(x)
      else:
        name = x.name
      data = [
          message,
          'Condition x <= 0 did not hold element-wise:'
          'x (%s) = ' % name, x]
    zero = ops.convert_to_tensor(0, dtype=x.dtype)
    return assert_less_equal(x, zero, data=data, summarize=summarize)


@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
@dispatch.add_dispatch_support
def assert_equal_v2(x, y, message=None, summarize=None, name=None):
  """Assert the condition `x == y` holds element-wise.

  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
  trivially satisfied.

  If `x` and `y` are not equal, `message`, as well as the first `summarize`
  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    message: A string to prefix to the default message.
    summarize: Print this many entries of each tensor.
    name: A name for this operation (optional).  Defaults to "assert_equal".

  Returns:
    Op that raises `InvalidArgumentError` if `x == y` is False. This can be
      used with `tf.control_dependencies` inside of `tf.function`s to block
      followup computation until the check has executed.
    @compatibility(eager)
    returns None
    @end_compatibility

  Raises:
    InvalidArgumentError: if the check can be performed immediately and
      `x == y` is False. The check can be performed immediately during eager
      execution or if `x` and `y` are statically known.
  """
  return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)


[文档]@tf_export(v1=['debugging.assert_equal', 'assert_equal']) @dispatch.add_dispatch_support @_binary_assert_doc('==', '[1, 2]') def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring with ops.name_scope(name, 'assert_equal', [x, y, data]): # Short-circuit if x and y are the same tensor. if x is y: return None if context.executing_eagerly() else control_flow_ops.no_op() return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, data, summarize, message, name)
@tf_export('debugging.assert_none_equal', v1=[]) @dispatch.add_dispatch_support def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): """Assert the condition `x != y` holds for all elements. This Op checks that `x[i] != y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If any elements of `x` and `y` are equal, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_none_equal". Returns: Op that raises `InvalidArgumentError` if `x != y` is ever False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x != y` is False for any pair of elements in `x` and `y`. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_none_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_none_equal') @_binary_assert_doc('!=', '[2, 1]') def assert_none_equal( x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, np.not_equal, x, y, data, summarize, message, name) @tf_export('debugging.assert_near', v1=[]) @dispatch.add_dispatch_support def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, name=None): """Assert the condition `x` and `y` are close element-wise. This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If any elements of `x` and `y` are not close, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest representable positive number such that `1 + eps != 1`. This is about `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. See `numpy.finfo`. Args: x: Float or complex `Tensor`. y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The relative tolerance. Default is `10 * eps`. atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The absolute tolerance. Default is `10 * eps`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_near". Returns: Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x != y` is False for any pair of elements in `x` and `y`. The check can be performed immediately during eager execution or if `x` and `y` are statically known. @compatibility(numpy) Similar to `numpy.testing.assert_allclose`, except tolerance depends on data type. This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, and even `16bit` data. @end_compatibility """ return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_near', 'assert_near']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_near') def assert_near( x, y, rtol=None, atol=None, data=None, summarize=None, message=None, name=None): """Assert the condition `x` and `y` are close element-wise. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): output = tf.reduce_sum(x) ``` This condition holds if for every pair of (possibly broadcast) elements `x[i]`, `y[i]`, we have ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. If both `x` and `y` are empty, this is trivially satisfied. The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest representable positive number such that `1 + eps != 1`. This is about `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. See `numpy.finfo`. Args: x: Float or complex `Tensor`. y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The relative tolerance. Default is `10 * eps`. atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. The absolute tolerance. Default is `10 * eps`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`, `y`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_near". Returns: Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. @compatibility(numpy) Similar to `numpy.testing.assert_allclose`, except tolerance depends on data type. This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, and even `16bit` data. @end_compatibility """ message = _message_prefix(message) with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) dtype = x.dtype if dtype.is_complex: dtype = dtype.real_dtype eps = np.finfo(dtype.as_numpy_dtype).eps rtol = 10 * eps if rtol is None else rtol atol = 10 * eps if atol is None else atol rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) if context.executing_eagerly(): x_name = _shape_and_dtype_str(x) y_name = _shape_and_dtype_str(y) else: x_name = x.name y_name = y.name if data is None: data = [ message, 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y ] tol = atol + rtol * math_ops.abs(y) diff = math_ops.abs(x - y) condition = math_ops.reduce_all(math_ops.less(diff, tol)) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_less', 'assert_less', v1=[]) @dispatch.add_dispatch_support def assert_less_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x < y` holds element-wise. This Op checks that `x[i] < y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not less than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_less". Returns: Op that raises `InvalidArgumentError` if `x < y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x < y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_less', 'assert_less']) @dispatch.add_dispatch_support @_binary_assert_doc('<', '[2, 3]') def assert_less(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, summarize, message, name)
@tf_export('debugging.assert_less_equal', v1=[]) @dispatch.add_dispatch_support def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x <= y` holds element-wise. This Op checks that `x[i] <= y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not less or equal than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_less_equal". Returns: Op that raises `InvalidArgumentError` if `x <= y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x <= y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_less_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_less_equal') @_binary_assert_doc('<=', '[1, 3]') def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, np.less_equal, x, y, data, summarize, message, name) @tf_export('debugging.assert_greater', 'assert_greater', v1=[]) @dispatch.add_dispatch_support def assert_greater_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x > y` holds element-wise. This Op checks that `x[i] > y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not greater than `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_greater". Returns: Op that raises `InvalidArgumentError` if `x > y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x > y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_greater(x=x, y=y, summarize=summarize, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_greater', 'assert_greater']) @dispatch.add_dispatch_support @_binary_assert_doc('>', '[0, 1]') def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, y, data, summarize, message, name)
@tf_export('debugging.assert_greater_equal', v1=[]) @dispatch.add_dispatch_support def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x >= y` holds element-wise. This Op checks that `x[i] >= y[i]` holds for every pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is trivially satisfied. If `x` is not greater or equal to `y` element-wise, `message`, as well as the first `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. Args: x: Numeric `Tensor`. y: Numeric `Tensor`, same dtype as and broadcastable to `x`. message: A string to prefix to the default message. summarize: Print this many entries of each tensor. name: A name for this operation (optional). Defaults to "assert_greater_equal". Returns: Op that raises `InvalidArgumentError` if `x >= y` is False. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x >= y` is False. The check can be performed immediately during eager execution or if `x` and `y` are statically known. """ return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_greater_equal') @_binary_assert_doc('>=', '[1, 0]') def assert_greater_equal(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, np.greater_equal, x, y, data, summarize, message, name) def _assert_rank_condition( x, rank, static_condition, dynamic_condition, data, summarize): """Assert `x` has a rank that satisfies a given condition. Args: x: Numeric `Tensor`. rank: Scalar `Tensor`. static_condition: A python function that takes `[actual_rank, given_rank]` and returns `True` if the condition is satisfied, `False` otherwise. dynamic_condition: An `op` that takes [actual_rank, given_rank] and return `True` if the condition is satisfied, `False` otherwise. data: The tensors to print out if the condition is false. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. Returns: Op raising `InvalidArgumentError` if `x` fails dynamic_condition. Raises: ValueError: If static checks determine `x` fails static_condition. """ assert_type(rank, dtypes.int32) # Attempt to statically defined rank. rank_static = tensor_util.constant_value(rank) if rank_static is not None: if rank_static.ndim != 0: raise ValueError('Rank must be a scalar.') x_rank_static = x.get_shape().ndims if x_rank_static is not None: if not static_condition(x_rank_static, rank_static): raise ValueError( 'Static rank condition failed', x_rank_static, rank_static) return control_flow_ops.no_op(name='static_checks_determined_all_ok') condition = dynamic_condition(array_ops.rank(x), rank) # Add the condition that `rank` must have rank zero. Prevents the bug where # someone does assert_rank(x, [n]), rather than assert_rank(x, n). if rank_static is None: this_data = ['Rank must be a scalar. Received rank: ', rank] rank_check = assert_rank(rank, 0, data=this_data) condition = control_flow_ops.with_dependencies([rank_check], condition) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_rank', 'assert_rank', v1=[]) @dispatch.add_dispatch_support def assert_rank_v2(x, rank, message=None, name=None): """Assert that `x` has rank equal to `rank`. This Op checks that the rank of `x` is equal to `rank`. If `x` has a different rank, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. rank: Scalar integer `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank. If static checks determine `x` has correct rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: if the check can be performed immediately and `x` does not have rank `rank`. The check can be performed immediately during eager execution or if the shape of `x` is statically known. """ return assert_rank(x=x, rank=rank, message=message, name=name)
[文档]@tf_export(v1=['debugging.assert_rank', 'assert_rank']) @dispatch.add_dispatch_support def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank`. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. rank: Scalar integer `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and the shape of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): if not isinstance(x, sparse_tensor.SparseTensor): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = _message_prefix(message) static_condition = lambda actual_rank, given_rank: actual_rank == given_rank dynamic_condition = math_ops.equal if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank' % name, rank, 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_rank_condition(x, rank, static_condition, dynamic_condition, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%sTensor %s must have rank %d. Received rank %d, shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise ValueError(e.args[0]) return assert_op
@tf_export('debugging.assert_rank_at_least', v1=[]) @dispatch.add_dispatch_support def assert_rank_at_least_v2(x, rank, message=None, name=None): """Assert that `x` has rank of at least `rank`. This Op checks that the rank of `x` is greater or equal to `rank`. If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. rank: Scalar integer `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_at_least". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or higher. If static checks determine `x` has correct rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: `x` does not have rank at least `rank`, but the rank cannot be statically determined. ValueError: If static checks determine `x` has mismatched rank. """ return assert_rank_at_least(x=x, rank=rank, message=message, name=name) @tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_rank_at_least') def assert_rank_at_least( x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank` or higher. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. rank: Scalar `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_at_least". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or higher. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with ops.name_scope( name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = _message_prefix(message) static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank dynamic_condition = math_ops.greater_equal if context.executing_eagerly(): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank at least' % name, rank, 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_rank_condition(x, rank, static_condition, dynamic_condition, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%sTensor %s must have rank at least %d. Received rank %d, ' 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise return assert_op def _static_rank_in(actual_rank, given_ranks): return actual_rank in given_ranks def _dynamic_rank_in(actual_rank, given_ranks): if len(given_ranks) < 1: return ops.convert_to_tensor(False) result = math_ops.equal(given_ranks[0], actual_rank) for given_rank in given_ranks[1:]: result = math_ops.logical_or( result, math_ops.equal(given_rank, actual_rank)) return result def _assert_ranks_condition( x, ranks, static_condition, dynamic_condition, data, summarize): """Assert `x` has a rank that satisfies a given condition. Args: x: Numeric `Tensor`. ranks: Scalar `Tensor`. static_condition: A python function that takes `[actual_rank, given_ranks]` and returns `True` if the condition is satisfied, `False` otherwise. dynamic_condition: An `op` that takes [actual_rank, given_ranks] and return `True` if the condition is satisfied, `False` otherwise. data: The tensors to print out if the condition is false. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. Returns: Op raising `InvalidArgumentError` if `x` fails dynamic_condition. Raises: ValueError: If static checks determine `x` fails static_condition. """ for rank in ranks: assert_type(rank, dtypes.int32) # Attempt to statically defined rank. ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) if not any(r is None for r in ranks_static): for rank_static in ranks_static: if rank_static.ndim != 0: raise ValueError('Rank must be a scalar.') x_rank_static = x.get_shape().ndims if x_rank_static is not None: if not static_condition(x_rank_static, ranks_static): raise ValueError( 'Static rank condition failed', x_rank_static, ranks_static) return control_flow_ops.no_op(name='static_checks_determined_all_ok') condition = dynamic_condition(array_ops.rank(x), ranks) # Add the condition that `rank` must have rank zero. Prevents the bug where # someone does assert_rank(x, [n]), rather than assert_rank(x, n). for rank, rank_static in zip(ranks, ranks_static): if rank_static is None: this_data = ['Rank must be a scalar. Received rank: ', rank] rank_check = assert_rank(rank, 0, data=this_data) condition = control_flow_ops.with_dependencies([rank_check], condition) return control_flow_ops.Assert(condition, data, summarize=summarize) @tf_export('debugging.assert_rank_in', v1=[]) @dispatch.add_dispatch_support def assert_rank_in_v2(x, ranks, message=None, name=None): """Assert that `x` has a rank in `ranks`. This Op checks that the rank of `x` is in `ranks`. If `x` has a different rank, `message`, as well as the shape of `x` are printed, and `InvalidArgumentError` is raised. Args: x: `Tensor`. ranks: `Iterable` of scalar `Tensor` objects. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_in". Returns: Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. If static checks determine `x` has matching rank, a `no_op` is returned. This can be used with `tf.control_dependencies` inside of `tf.function`s to block followup computation until the check has executed. @compatibility(eager) returns None @end_compatibility Raises: InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot be statically determined. ValueError: If static checks determine `x` has mismatched rank. """ return assert_rank_in(x=x, ranks=ranks, message=message, name=name) @tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_rank_in') def assert_rank_in( x, ranks, data=None, summarize=None, message=None, name=None): """Assert `x` has rank in `ranks`. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): output = tf.reduce_sum(x) ``` Args: x: Numeric `Tensor`. ranks: Iterable of scalar `Tensor` objects. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_rank_in". Returns: Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. If static checks determine `x` has matching rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has mismatched rank. """ with ops.name_scope( name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): if not isinstance(x, sparse_tensor.SparseTensor): x = ops.convert_to_tensor(x, name='x') ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) message = _message_prefix(message) if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): name = '' else: name = x.name if data is None: data = [ message, 'Tensor %s must have rank in' % name ] + list(ranks) + [ 'Received shape: ', array_ops.shape(x) ] try: assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, _dynamic_rank_in, data, summarize) except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( '%sTensor %s must have rank in %s. Received rank %d, ' 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) else: raise return assert_op @tf_export('debugging.assert_integer', v1=[]) @dispatch.add_dispatch_support def assert_integer_v2(x, message=None, name=None): """Assert that `x` is of integer dtype. If `x` has a non-integer type, `message`, as well as the dtype of `x` are printed, and `InvalidArgumentError` is raised. This can always be checked statically, so this method returns nothing. Args: x: A `Tensor`. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_integer". Raises: TypeError: If `x.dtype` is not a non-quantized integer type. """ assert_integer(x=x, message=message, name=name) @tf_export(v1=['debugging.assert_integer', 'assert_integer']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_integer') def assert_integer(x, message=None, name=None): """Assert that `x` is of integer dtype. Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): output = tf.reduce_sum(x) ``` Args: x: `Tensor` whose basetype is integer and is not quantized. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_integer". Raises: TypeError: If `x.dtype` is anything other than non-quantized integer. Returns: A `no_op` that does nothing. Type can be determined statically. """ with ops.name_scope(name, 'assert_integer', [x]): x = ops.convert_to_tensor(x, name='x') if not x.dtype.is_integer: if context.executing_eagerly(): name = 'tensor' else: name = x.name err_msg = ( '%sExpected "x" to be integer type. Found: %s of dtype %s' % (_message_prefix(message), name, x.dtype)) raise TypeError(err_msg) return control_flow_ops.no_op('statically_determined_was_integer') @tf_export('debugging.assert_type', v1=[]) @dispatch.add_dispatch_support def assert_type_v2(tensor, tf_type, message=None, name=None): """Asserts that the given `Tensor` is of the specified type. This can always be checked statically, so this method returns nothing. Example: >>> a = tf.Variable(1.0) >>> tf.debugging.assert_type(a, tf_type= tf.float32) >>> b = tf.constant(21) >>> tf.debugging.assert_type(b, tf_type=tf.bool) Traceback (most recent call last): ... TypeError: ... >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], ... dense_shape=[3, 4]) >>> tf.debugging.assert_type(c, tf_type= tf.int32) Args: tensor: A `Tensor`, `SparseTensor` or `tf.Variable` . tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. name: A name for this operation. Defaults to "assert_type" Raises: TypeError: If the tensor's data type doesn't match `tf_type`. """ assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) @tf_export(v1=['debugging.assert_type', 'assert_type']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_type') def assert_type(tensor, tf_type, message=None, name=None): """Statically asserts that the given `Tensor` is of the specified type. Args: tensor: A `Tensor` or `SparseTensor`. tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. name: A name to give this `Op`. Defaults to "assert_type" Raises: TypeError: If the tensors data type doesn't match `tf_type`. Returns: A `no_op` that does nothing. Type can be determined statically. """ tf_type = dtypes.as_dtype(tf_type) with ops.name_scope(name, 'assert_type', [tensor]): if not isinstance(tensor, sparse_tensor.SparseTensor): tensor = ops.convert_to_tensor(tensor, name='tensor') if tensor.dtype != tf_type: raise TypeError( f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}' f' must be of type {tf_type!r}; got {tensor.dtype!r}') return control_flow_ops.no_op('statically_determined_correct_type') def _dimension_sizes(x): """Gets the dimension sizes of a tensor `x`. If a size can be determined statically it is returned as an integer, otherwise as a tensor. If `x` is a scalar it is treated as rank 1 size 1. Args: x: A `Tensor`. Returns: Dimension sizes. """ dynamic_shape = array_ops.shape(x) rank = x.get_shape().rank rank_is_known = rank is not None if rank_is_known and rank == 0: return (1,) if rank_is_known and rank > 0: static_shape = x.get_shape().as_list() sizes = [ int(size) if size is not None else dynamic_shape[i] for i, size in enumerate(static_shape) ] return sizes has_rank_zero = math_ops.equal(array_ops.rank(x), 0) return control_flow_ops.cond( has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) def _symbolic_dimension_sizes(symbolic_shape): # If len(symbolic_shape) == 0 construct a tuple if not symbolic_shape: return tuple([1]) return symbolic_shape def _has_known_value(dimension_size): not_none = dimension_size is not None try: int(dimension_size) can_be_parsed_as_int = True except (ValueError, TypeError): can_be_parsed_as_int = False return not_none and can_be_parsed_as_int def _is_symbol_for_any_size(symbol): return symbol in [None, '.'] _TensorDimSizes = collections.namedtuple( '_TensorDimSizes', ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) @tf_export('debugging.assert_shapes', v1=[]) @dispatch.add_dispatch_support def assert_shapes_v2(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. This Op checks that a collection of tensors shape relationships satisfies given constraints. Example: >>> n = 10 >>> q = 3 >>> d = 7 >>> x = tf.zeros([n,q]) >>> y = tf.ones([n,d]) >>> param = tf.Variable([1.0, 2.0, 3.0]) >>> scalar = 1.0 >>> tf.debugging.assert_shapes([ ... (x, ('N', 'Q')), ... (y, ('N', 'D')), ... (param, ('Q',)), ... (scalar, ()), ... ]) >>> tf.debugging.assert_shapes([ ... (x, ('N', 'D')), ... (y, ('N', 'D')) ... ]) Traceback (most recent call last): ... ValueError: ... If `x`, `y`, `param` or `scalar` does not have a shape that satisfies all specified constraints, `message`, as well as the first `summarize` entries of the first encountered violating tensor are printed, and `InvalidArgumentError` is raised. Size entries in the specified shapes are checked against other entries by their __hash__, except: - a size entry is interpreted as an explicit size if it can be parsed as an integer primitive. - a size entry is interpreted as *any* size if it is None or '.'. If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates a variable number of outer dimensions of unspecified size, i.e. the constraint applies to the inner-most dimensions only. Scalar tensors and specified shapes of length zero (excluding the 'inner-most' prefix) are both treated as having a single dimension of size one. Args: shapes: dictionary with (`Tensor` to shape) items, or a list of (`Tensor`, shape) tuples. A shape must be an iterable. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_shapes". Raises: ValueError: If static checks determine any shape constraint is violated. """ assert_shapes( shapes, data=data, summarize=summarize, message=message, name=name) @tf_export(v1=['debugging.assert_shapes']) @dispatch.add_dispatch_support def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. This Op checks that a collection of tensors shape relationships satisfies given constraints. Example: >>> n = 10 >>> q = 3 >>> d = 7 >>> x = tf.zeros([n,q]) >>> y = tf.ones([n,d]) >>> param = tf.Variable([1.0, 2.0, 3.0]) >>> scalar = 1.0 >>> tf.debugging.assert_shapes([ ... (x, ('N', 'Q')), ... (y, ('N', 'D')), ... (param, ('Q',)), ... (scalar, ()), ... ]) >>> tf.debugging.assert_shapes([ ... (x, ('N', 'D')), ... (y, ('N', 'D')) ... ]) Traceback (most recent call last): ... ValueError: ... Example of adding a dependency to an operation: ```python with tf.control_dependencies([tf.assert_shapes(shapes)]): output = tf.matmul(x, y, transpose_a=True) ``` If `x`, `y`, `param` or `scalar` does not have a shape that satisfies all specified constraints, `message`, as well as the first `summarize` entries of the first encountered violating tensor are printed, and `InvalidArgumentError` is raised. Size entries in the specified shapes are checked against other entries by their __hash__, except: - a size entry is interpreted as an explicit size if it can be parsed as an integer primitive. - a size entry is interpreted as *any* size if it is None or '.'. If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates a variable number of outer dimensions of unspecified size, i.e. the constraint applies to the inner-most dimensions only. Scalar tensors and specified shapes of length zero (excluding the 'inner-most' prefix) are both treated as having a single dimension of size one. Args: shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the expected shape of `Tensor`. See the example code above. The `shape` must be an iterable. Each element of the iterable can be either a concrete integer value or a string that abstractly represents the dimension. For example, - `('N', 'Q')` specifies a 2D shape wherein the first and second dimensions of shape may or may not be equal. - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second dimensions are equal. - `(1, 'N')` specifies a 2D shape wherein the first dimension is exactly 1 and the second dimension can be any value. Note that the abstract dimension letters take effect across different tuple elements of the list. For example, `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts that both `x` and `y` are rank-2 tensors and their first dimensions are equal (`N`). `shape` can also be a `tf.TensorShape`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of the violating tensor. summarize: Print this many entries of the tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_shapes". Returns: Op raising `InvalidArgumentError` unless all shape constraints are satisfied. If static checks determine all constraints are satisfied, a `no_op` is returned. Raises: ValueError: If static checks determine any shape constraint is violated. """ # If the user manages to assemble a dict containing tensors (possible in # Graph mode only), make sure we still accept that. if isinstance(shapes, dict): shapes = shapes.items() message_prefix = _message_prefix(message) with ops.name_scope(name, 'assert_shapes', [shapes, data]): # Shape specified as None implies no constraint shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else ops.convert_to_tensor(x), s) for x, s in shapes if s is not None] executing_eagerly = context.executing_eagerly() def tensor_name(x): if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): return _shape_and_dtype_str(x) return x.name tensor_dim_sizes = [] for tensor, symbolic_shape in shape_constraints: is_iterable = ( hasattr(symbolic_shape, '__iter__') or hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. ) if not is_iterable: raise ValueError( '%s' 'Tensor %s. Specified shape must be an iterable. ' 'An iterable has the attribute `__iter__` or `__getitem__`. ' 'Received specified shape: %s' % (message_prefix, tensor_name(tensor), symbolic_shape)) # We convert this into a tuple to handle strings, lists and numpy arrays symbolic_shape_tuple = tuple(symbolic_shape) tensors_specified_innermost = False for i, symbol in enumerate(symbolic_shape_tuple): if symbol not in [Ellipsis, '*']: continue if i != 0: raise ValueError( '%s' 'Tensor %s specified shape index %d. ' 'Symbol `...` or `*` for a variable number of ' 'unspecified dimensions is only allowed as the first entry' % (message_prefix, tensor_name(tensor), i)) tensors_specified_innermost = True # Only include the size of the specified dimensions since the 0th symbol # is either ellipsis or * tensor_dim_sizes.append( _TensorDimSizes( tensor, tensors_specified_innermost, _dimension_sizes(tensor), _symbolic_dimension_sizes( symbolic_shape_tuple[1:] if tensors_specified_innermost else symbolic_shape_tuple))) rank_assertions = [] for sizes in tensor_dim_sizes: rank = len(sizes.symbolic_sizes) rank_zero_or_one = rank in [0, 1] if sizes.unspecified_dim: if rank_zero_or_one: # No assertion of rank needed as `x` only need to have rank at least # 0. See elif rank_zero_or_one case comment. continue assertion = assert_rank_at_least( x=sizes.x, rank=rank, data=data, summarize=summarize, message=message, name=name) elif rank_zero_or_one: # Rank 0 is treated as rank 1 size 1, i.e. there is # no distinction between the two in terms of rank. # See _dimension_sizes. assertion = assert_rank_in( x=sizes.x, ranks=[0, 1], data=data, summarize=summarize, message=message, name=name) else: assertion = assert_rank( x=sizes.x, rank=rank, data=data, summarize=summarize, message=message, name=name) rank_assertions.append(assertion) size_assertions = [] size_specifications = {} for sizes in tensor_dim_sizes: for i, size_symbol in enumerate(sizes.symbolic_sizes): if _is_symbol_for_any_size(size_symbol): # Size specified as any implies no constraint continue if sizes.unspecified_dim: tensor_dim = i - len(sizes.symbolic_sizes) else: tensor_dim = i if size_symbol in size_specifications or _has_known_value(size_symbol): if _has_known_value(size_symbol): specified_size = int(size_symbol) size_check_message = 'Specified explicitly' else: specified_size, specified_by_y, specified_at_dim = \ size_specifications[size_symbol] size_check_message = ( 'Specified by tensor %s dimension %d' % (tensor_name(specified_by_y), specified_at_dim)) # This is extremely subtle. If actual_sizes is dynamic, we must # make sure a control dependency is inserted here so that this slice # can not execute until the rank is asserted to be enough for the # slice to not fail. with ops.control_dependencies(rank_assertions): actual_size = sizes.actual_sizes[tensor_dim] if _has_known_value(actual_size) and _has_known_value(specified_size): if int(actual_size) != int(specified_size): raise ValueError( '%s%s. Tensor %s dimension %s must have size %d. ' 'Received size %d, shape %s' % (message_prefix, size_check_message, tensor_name(sizes.x), tensor_dim, specified_size, actual_size, sizes.x.get_shape())) # No dynamic assertion needed continue condition = math_ops.equal( ops.convert_to_tensor(actual_size), ops.convert_to_tensor(specified_size)) data_ = data if data is None: data_ = [ message_prefix, size_check_message, 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 'must have size', specified_size, 'Received shape: ', array_ops.shape(sizes.x) ] size_assertions.append( control_flow_ops.Assert(condition, data_, summarize=summarize)) else: # Not sure if actual_sizes is a constant, but for safety, guard # on rank. See explanation above about actual_sizes need for safety. with ops.control_dependencies(rank_assertions): size = sizes.actual_sizes[tensor_dim] size_specifications[size_symbol] = (size, sizes.x, tensor_dim) # Ensure both assertions actually occur. with ops.control_dependencies(rank_assertions): shapes_assertion = control_flow_ops.group(size_assertions) return shapes_assertion # pylint: disable=line-too-long def _get_diff_for_monotonic_comparison(x): """Gets the difference x[1:] - x[:-1].""" x = array_ops.reshape(x, [-1]) if not is_numeric_tensor(x): raise TypeError('Expected x to be numeric, instead found: %s' % x) # If x has less than 2 elements, there is nothing to compare. So return []. is_shorter_than_two = math_ops.less(array_ops.size(x), 2) short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) # With 2 or more elements, return x[1:] - x[:-1] s_len = array_ops.shape(x) - 1 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) return control_flow_ops.cond(is_shorter_than_two, short_result, diff) @tf_export( 'debugging.is_numeric_tensor', v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) @deprecation.deprecated_endpoints('is_numeric_tensor') def is_numeric_tensor(tensor): """Returns `True` if the elements of `tensor` are numbers. Specifically, returns `True` if the dtype of `tensor` is one of the following: * `tf.float16` * `tf.float32` * `tf.float64` * `tf.int8` * `tf.int16` * `tf.int32` * `tf.int64` * `tf.uint8` * `tf.uint16` * `tf.uint32` * `tf.uint64` * `tf.qint8` * `tf.qint16` * `tf.qint32` * `tf.quint8` * `tf.quint16` * `tf.complex64` * `tf.complex128` * `tf.bfloat16` Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not a `tf.Tensor` object. """ return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES @tf_export( 'math.is_non_decreasing', v1=[ 'math.is_non_decreasing', 'debugging.is_non_decreasing', 'is_non_decreasing' ]) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('debugging.is_non_decreasing', 'is_non_decreasing') def is_non_decreasing(x, name=None): """Returns `True` if `x` is non-decreasing. Elements of `x` are compared in row-major order. The tensor `[x[0],...]` is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. If `x` has less than two elements, it is trivially non-decreasing. See also: `is_strictly_increasing` >>> x1 = tf.constant([1.0, 1.0, 3.0]) >>> tf.math.is_non_decreasing(x1) <tf.Tensor: shape=(), dtype=bool, numpy=True> >>> x2 = tf.constant([3.0, 1.0, 2.0]) >>> tf.math.is_non_decreasing(x2) <tf.Tensor: shape=(), dtype=bool, numpy=False> Args: x: Numeric `Tensor`. name: A name for this operation (optional). Defaults to "is_non_decreasing" Returns: Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. Raises: TypeError: if `x` is not a numeric tensor. """ with ops.name_scope(name, 'is_non_decreasing', [x]): diff = _get_diff_for_monotonic_comparison(x) # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. zero = ops.convert_to_tensor(0, dtype=diff.dtype) return math_ops.reduce_all(math_ops.less_equal(zero, diff)) @tf_export( 'math.is_strictly_increasing', v1=[ 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 'is_strictly_increasing' ]) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 'is_strictly_increasing') def is_strictly_increasing(x, name=None): """Returns `True` if `x` is strictly increasing. Elements of `x` are compared in row-major order. The tensor `[x[0],...]` is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. If `x` has less than two elements, it is trivially strictly increasing. See also: `is_non_decreasing` >>> x1 = tf.constant([1.0, 2.0, 3.0]) >>> tf.math.is_strictly_increasing(x1) <tf.Tensor: shape=(), dtype=bool, numpy=True> >>> x2 = tf.constant([3.0, 1.0, 2.0]) >>> tf.math.is_strictly_increasing(x2) <tf.Tensor: shape=(), dtype=bool, numpy=False> Args: x: Numeric `Tensor`. name: A name for this operation (optional). Defaults to "is_strictly_increasing" Returns: Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. Raises: TypeError: if `x` is not a numeric tensor. """ with ops.name_scope(name, 'is_strictly_increasing', [x]): diff = _get_diff_for_monotonic_comparison(x) # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. zero = ops.convert_to_tensor(0, dtype=diff.dtype) return math_ops.reduce_all(math_ops.less(zero, diff)) def _assert_same_base_type(items, expected_type=None): r"""Asserts all items are of the same base type. Args: items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, `Operation`, or `IndexedSlices`). Can include `None` elements, which will be ignored. expected_type: Expected type. If not specified, assert all items are of the same base type. Returns: Validated type, or none if neither expected_type nor items provided. Raises: ValueError: If any types do not match. """ original_expected_type = expected_type mismatch = False for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type elif expected_type != item_type: mismatch = True break if mismatch: # Loop back through and build up an informative error message (this is very # slow, so we don't do it unless we found an error above). expected_type = original_expected_type original_item_str = None for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type original_item_str = item.name if hasattr(item, 'name') else str(item) elif expected_type != item_type: raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( item.name if hasattr(item, 'name') else str(item), item_type, expected_type, (' as %s' % original_item_str) if original_item_str else '')) return expected_type # Should be unreachable else: return expected_type @tf_export( 'debugging.assert_same_float_dtype', v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_same_float_dtype') def assert_same_float_dtype(tensors=None, dtype=None): """Validate and return float type based on `tensors` and `dtype`. For ops such as matrix multiplication, inputs and weights must be of the same float type. This function validates that all `tensors` are the same type, validates that type is `dtype` (if supplied), and returns the type. Type must be a floating point type. If neither `tensors` nor `dtype` is supplied, the function will return `dtypes.float32`. Args: tensors: Tensors of input values. Can include `None` elements, which will be ignored. dtype: Expected type. Returns: Validated type. Raises: ValueError: if neither `tensors` nor `dtype` is supplied, or result is not float, or the common type of the inputs is not a floating point type. """ if tensors: dtype = _assert_same_base_type(tensors, dtype) if not dtype: dtype = dtypes.float32 elif not dtype.is_floating: raise ValueError('Expected floating point type, got %s.' % dtype) return dtype @tf_export('debugging.assert_scalar', v1=[]) @dispatch.add_dispatch_support def assert_scalar_v2(tensor, message=None, name=None): """Asserts that the given `tensor` is a scalar. This function raises `ValueError` unless it can be certain that the given `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is unknown. This is always checked statically, so this method returns nothing. Args: tensor: A `Tensor`. message: A string to prefix to the default message. name: A name for this operation. Defaults to "assert_scalar" Raises: ValueError: If the tensor is not scalar (rank 0), or if its shape is unknown. """ assert_scalar(tensor=tensor, message=message, name=name) @tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) @dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_scalar') def assert_scalar(tensor, name=None, message=None): """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). This function raises `ValueError` unless it can be certain that the given `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is unknown. Args: tensor: A `Tensor`. name: A name for this operation. Defaults to "assert_scalar" message: A string to prefix to the default message. Returns: The input tensor (potentially converted to a `Tensor`). Raises: ValueError: If the tensor is not scalar (rank 0), or if its shape is unknown. """ with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: tensor = ops.convert_to_tensor(tensor, name=name_scope) shape = tensor.get_shape() message = _message_prefix(message) if shape.ndims != 0: if context.executing_eagerly(): raise ValueError('%sExpected scalar shape, saw shape: %s.' % (message, shape,)) else: raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' % (message, tensor.name, shape)) return tensor def _message_prefix(message): if message: return '%s. ' % message return ''
[文档]@tf_export('ensure_shape') @dispatch.add_dispatch_support def ensure_shape(x, shape, name=None): """Updates the shape of a tensor and checks at runtime that the shape holds. When executed, this operation asserts that the input tensor `x`'s shape is compatible with the `shape` argument. See `tf.TensorShape.is_compatible_with` for details. >>> x = tf.constant([[1, 2, 3], ... [4, 5, 6]]) >>> x = tf.ensure_shape(x, [2, 3]) Use `None` for unknown dimensions: >>> x = tf.ensure_shape(x, [None, 3]) >>> x = tf.ensure_shape(x, [2, None]) If the tensor's shape is not compatible with the `shape` argument, an error is raised: >>> x = tf.ensure_shape(x, [5]) Traceback (most recent call last): ... tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not compatible with expected shape [5]. [Op:EnsureShape] During graph construction (typically tracing a `tf.function`), `tf.ensure_shape` updates the static-shape of the **result** tensor by merging the two shapes. See `tf.TensorShape.merge_with` for details. This is most useful when **you** know a shape that can't be determined statically by TensorFlow. The following trivial `tf.function` prints the input tensor's static-shape before and after `ensure_shape` is applied. >>> @tf.function ... def f(tensor): ... print("Static-shape before:", tensor.shape) ... tensor = tf.ensure_shape(tensor, [None, 3]) ... print("Static-shape after:", tensor.shape) ... return tensor This lets you see the effect of `tf.ensure_shape` when the function is traced: >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) Static-shape before: (None, None) Static-shape after: (None, 3) >>> cf(tf.zeros([3, 3])) # Passes >>> cf(tf.constant([1, 2, 3])) # fails Traceback (most recent call last): ... InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. The above example raises `tf.errors.InvalidArgumentError`, because `x`'s shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and runtime shapes. This is stricter than `tf.Tensor.set_shape` which only checks the buildtime shape. Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape. `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors. For example, of loading images of a known size: >>> @tf.function ... def decode_image(png): ... image = tf.image.decode_png(png, channels=3) ... # the `print` executes during tracing. ... print("Initial shape: ", image.shape) ... image = tf.ensure_shape(image,[28, 28, 3]) ... print("Final shape: ", image.shape) ... return image When tracing a function, no ops are being executed, shapes may be unknown. See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) for details. >>> concrete_decode = decode_image.get_concrete_function( ... tf.TensorSpec([], dtype=tf.string)) Initial shape: (None, None, 3) Final shape: (28, 28, 3) >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) >>> image = tf.cast(image,tf.uint8) >>> png = tf.image.encode_png(image) >>> image2 = concrete_decode(png) >>> print(image2.shape) (28, 28, 3) >>> image = tf.concat([image,image], axis=0) >>> print(image.shape) (56, 28, 3) >>> png = tf.image.encode_png(image) >>> image2 = concrete_decode(png) Traceback (most recent call last): ... tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not compatible with expected shape [28,28,3]. Caution: if you don't use the result of `tf.ensure_shape` the check may not run. >>> @tf.function ... def bad_decode_image(png): ... image = tf.image.decode_png(png, channels=3) ... # the `print` executes during tracing. ... print("Initial shape: ", image.shape) ... # BAD: forgot to use the returned tensor. ... tf.ensure_shape(image,[28, 28, 3]) ... print("Final shape: ", image.shape) ... return image >>> image = bad_decode_image(png) Initial shape: (None, None, 3) Final shape: (None, None, 3) >>> print(image.shape) (56, 28, 3) Args: x: A `Tensor`. shape: A `TensorShape` representing the shape of this tensor, a `TensorShapeProto`, a list, a tuple, or None. name: A name for this operation (optional). Defaults to "EnsureShape". Returns: A `Tensor`. Has the same type and contents as `x`. Raises: tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape of `x`. """ if not isinstance(shape, tensor_shape.TensorShape): shape = tensor_shape.TensorShape(shape) return array_ops.ensure_shape(x, shape, name=name)
@ops.RegisterGradient('EnsureShape') def _ensure_shape_grad(op, grad): del op # Unused. return grad