• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-short-docstring-punctuation
16"""Asserts and Boolean Checks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24import numpy as np
25
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.util import compat
37from tensorflow.python.util import deprecation
38from tensorflow.python.util import dispatch
39from tensorflow.python.util.tf_export import tf_export
40
41NUMERIC_TYPES = frozenset(
42    [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
43     dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
44     dtypes.complex64])
45
46__all__ = [
47    'assert_negative',
48    'assert_positive',
49    'assert_proper_iterable',
50    'assert_non_negative',
51    'assert_non_positive',
52    'assert_equal',
53    'assert_none_equal',
54    'assert_near',
55    'assert_integer',
56    'assert_less',
57    'assert_less_equal',
58    'assert_greater',
59    'assert_greater_equal',
60    'assert_rank',
61    'assert_rank_at_least',
62    'assert_rank_in',
63    'assert_same_float_dtype',
64    'assert_scalar',
65    'assert_type',
66    'assert_shapes',
67    'is_non_decreasing',
68    'is_numeric_tensor',
69    'is_strictly_increasing',
70]
71
72
73def _maybe_constant_value_string(t):
74  if not isinstance(t, ops.Tensor):
75    return str(t)
76  const_t = tensor_util.constant_value(t)
77  if const_t is not None:
78    return str(const_t)
79  return t
80
81
82def _assert_static(condition, data):
83  """Raises a InvalidArgumentError with as much information as possible."""
84  if not condition:
85    data_static = [_maybe_constant_value_string(x) for x in data]
86    raise errors.InvalidArgumentError(node_def=None, op=None,
87                                      message='\n'.join(data_static))
88
89
90def _shape_and_dtype_str(tensor):
91  """Returns a string containing tensor's shape and dtype."""
92  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
93
94
95def _unary_assert_doc(sym, sym_name):
96  """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
97
98  Args:
99    sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
100    sym_name: English-language name for the op described by sym
101
102  Returns:
103    Decorator that adds the appropriate docstring to the function for symbol
104    `sym`.
105  """
106
107  def _decorator(func):
108    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
109
110    Args:
111      func: Function for a TensorFlow op
112
113    Returns:
114      Version of `func` with documentation attached.
115    """
116    opname = func.__name__
117    cap_sym_name = sym_name.capitalize()
118
119    func.__doc__ = """
120    Assert the condition `x {sym}` holds element-wise.
121
122    When running in graph mode, you should add a dependency on this operation
123    to ensure that it runs. Example of adding a dependency to an operation:
124
125    ```python
126    with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
127      output = tf.reduce_sum(x)
128    ```
129
130    {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
131    If `x` is empty this is trivially satisfied.
132
133    Args:
134      x:  Numeric `Tensor`.
135      data:  The tensors to print out if the condition is False.  Defaults to
136        error message and first few entries of `x`.
137      summarize: Print this many entries of each tensor.
138      message: A string to prefix to the default message.
139      name: A name for this operation (optional).  Defaults to "{opname}".
140
141    Returns:
142      Op that raises `InvalidArgumentError` if `x {sym}` is False.
143      @compatibility(eager)
144        returns None
145      @end_compatibility
146
147    Raises:
148      InvalidArgumentError: if the check can be performed immediately and
149        `x {sym}` is False. The check can be performed immediately during
150        eager execution or if `x` is statically known.
151    """.format(
152        sym=sym, sym_name=cap_sym_name, opname=opname)
153    return func
154
155  return _decorator
156
157
158def _binary_assert_doc(sym, test_var):
159  """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
160
161  Args:
162    sym: Binary operation symbol, i.e. "=="
163    test_var: a string that represents the variable in the right-hand side of
164      binary operator of the test case
165
166  Returns:
167    Decorator that adds the appropriate docstring to the function for
168  symbol `sym`.
169  """
170
171  def _decorator(func):
172    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
173
174    Args:
175      func: Function for a TensorFlow op
176
177    Returns:
178      A version of `func` with documentation attached.
179    """
180    opname = func.__name__
181
182    func.__doc__ = """
183    Assert the condition `x {sym} y` holds element-wise.
184
185    This condition holds if for every pair of (possibly broadcast) elements
186    `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
187    If both `x` and `y` are empty, this is trivially satisfied.
188
189    When running in graph mode, you should add a dependency on this operation
190    to ensure that it runs. Example of adding a dependency to an operation:
191
192    ```python
193    with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
194      output = tf.reduce_sum(x)
195    ```
196
197    Args:
198      x:  Numeric `Tensor`.
199      y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
200      data:  The tensors to print out if the condition is False.  Defaults to
201        error message and first few entries of `x`, `y`.
202      summarize: Print this many entries of each tensor.
203      message: A string to prefix to the default message.
204      name: A name for this operation (optional).  Defaults to "{opname}".
205
206    Returns:
207      Op that raises `InvalidArgumentError` if `x {sym} y` is False.
208
209    Raises:
210      InvalidArgumentError: if the check can be performed immediately and
211        `x {sym} y` is False. The check can be performed immediately during
212        eager execution or if `x` and `y` are statically known.
213
214    @compatibility(TF2)
215    `tf.compat.v1.{opname}` is compatible with eager execution and
216    `tf.function`.
217    Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart
218    from `data`, all arguments are supported with the same argument name.
219
220    If you want to ensure the assert statements run before the
221    potentially-invalid computation, please use `tf.control_dependencies`,
222    as tf.function auto-control dependencies are insufficient for assert
223    statements.
224
225    #### Structural Mapping to Native TF2
226
227    Before:
228
229    ```python
230    tf.compat.v1.{opname}(
231      x=x, y=y, data=data, summarize=summarize,
232      message=message, name=name)
233    ```
234
235    After:
236
237    ```python
238    tf.debugging.{opname}(
239      x=x, y=y, message=message,
240      summarize=summarize, name=name)
241    ```
242
243    #### TF1 & TF2 Usage Example
244
245    TF1:
246
247    >>> g = tf.Graph()
248    >>> with g.as_default():
249    ...   a = tf.compat.v1.placeholder(tf.float32, [2])
250    ...   b = tf.compat.v1.placeholder(tf.float32, [2])
251    ...   result = tf.compat.v1.{opname}(a, b,
252    ...     message='"a {sym} b" does not hold for the given inputs')
253    ...   with tf.compat.v1.control_dependencies([result]):
254    ...     sum_node = a + b
255    >>> sess = tf.compat.v1.Session(graph=g)
256    >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}})
257
258
259    TF2:
260
261    >>> a = tf.Variable([1, 2], dtype=tf.float32)
262    >>> b = tf.Variable({test_var}, dtype=tf.float32)
263    >>> assert_op = tf.debugging.{opname}(a, b, message=
264    ...   '"a {sym} b" does not hold for the given inputs')
265    >>> # When working with tf.control_dependencies
266    >>> with tf.control_dependencies([assert_op]):
267    ...   val = a + b
268
269    @end_compatibility
270    """.format(
271        sym=sym, opname=opname, test_var=test_var)
272    return func
273
274  return _decorator
275
276
277def _make_assert_msg_data(sym, x, y, summarize, test_op):
278  """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
279
280  Args:
281    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
282      i.e. "=="
283    x: First input to the assertion after applying `convert_to_tensor()`
284    y: Second input to the assertion
285    summarize: Value of the "summarize" parameter to the original assert_* call;
286      tells how many elements of each tensor to print.
287    test_op: TensorFlow op that returns a Boolean tensor with True in each
288      position where the assertion is satisfied.
289
290  Returns:
291    List of tensors and scalars that, when stringified and concatenated,
292    will produce the error message string.
293  """
294  # Prepare a message with first elements of x and y.
295  data = []
296
297  data.append('Condition x %s y did not hold.' % sym)
298
299  if summarize > 0:
300    if x.shape == y.shape and x.shape.as_list():
301      # If the shapes of x and y are the same (and not scalars),
302      # Get the values that actually differed and their indices.
303      # If shapes are different this information is more confusing
304      # than useful.
305      mask = math_ops.logical_not(test_op)
306      indices = array_ops.where(mask)
307      indices_np = indices.numpy()
308      x_vals = array_ops.boolean_mask(x, mask)
309      y_vals = array_ops.boolean_mask(y, mask)
310      num_vals = min(summarize, indices_np.shape[0])
311      data.append('Indices of first %d different values:' % num_vals)
312      data.append(indices_np[:num_vals])
313      data.append('Corresponding x values:')
314      data.append(x_vals.numpy().reshape((-1,))[:num_vals])
315      data.append('Corresponding y values:')
316      data.append(y_vals.numpy().reshape((-1,))[:num_vals])
317
318    # reshape((-1,)) is the fastest way to get a flat array view.
319    x_np = x.numpy().reshape((-1,))
320    y_np = y.numpy().reshape((-1,))
321    x_sum = min(x_np.size, summarize)
322    y_sum = min(y_np.size, summarize)
323    data.append('First %d elements of x:' % x_sum)
324    data.append(x_np[:x_sum])
325    data.append('First %d elements of y:' % y_sum)
326    data.append(y_np[:y_sum])
327
328  return data
329
330
331def _pretty_print(data_item, summarize):
332  """Format a data item for use in an error message in eager mode.
333
334  Args:
335    data_item: One of the items in the "data" argument to an assert_* function.
336      Can be a Tensor or a scalar value.
337    summarize: How many elements to retain of each tensor-valued entry in data.
338
339  Returns:
340    An appropriate string representation of data_item
341  """
342  if isinstance(data_item, ops.Tensor):
343    arr = data_item.numpy()
344    if np.isscalar(arr):
345      # Tensor.numpy() returns a scalar for zero-dimensional tensors
346      return str(arr)
347    else:
348      flat = arr.reshape((-1,))
349      lst = [str(x) for x in flat[:summarize]]
350      if len(lst) < flat.size:
351        lst.append('...')
352      return str(lst)
353  else:
354    return str(data_item)
355
356
357def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
358                   message, name):
359  """Generic binary elementwise assertion.
360
361  Implements the behavior described in _binary_assert_doc() above.
362  Args:
363    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
364      i.e. "=="
365    opname: Name of the assert op in the public API, i.e. "assert_equal"
366    op_func: Function that, if passed the two Tensor inputs to the assertion (x
367      and y), will return the test to be passed to reduce_all() i.e.
368    static_func: Function that, if passed numpy ndarray versions of the two
369      inputs to the assertion, will return a Boolean ndarray with containing
370      True in all positions where the assertion PASSES.
371      i.e. np.equal for assert_equal()
372    x:  Numeric `Tensor`.
373    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
374    data:  The tensors to print out if the condition is False.  Defaults to
375      error message and first few entries of `x`, `y`.
376    summarize: Print this many entries of each tensor.
377    message: A string to prefix to the default message.
378    name: A name for this operation (optional).  Defaults to the value of
379      `opname`.
380
381  Returns:
382    See docstring template in _binary_assert_doc().
383  """
384  with ops.name_scope(name, opname, [x, y, data]):
385    x = ops.convert_to_tensor(x, name='x')
386    y = ops.convert_to_tensor(y, name='y')
387
388    if context.executing_eagerly():
389      test_op = op_func(x, y)
390      condition = math_ops.reduce_all(test_op)
391      if condition:
392        return
393
394      # If we get here, the assertion has failed.
395      # Default to printing 3 elements like control_flow_ops.Assert (used
396      # by graph mode) does. Also treat negative values as "print
397      # everything" for consistency with Tensor::SummarizeValue().
398      if summarize is None:
399        summarize = 3
400      elif summarize < 0:
401        summarize = 1e9  # Code below will find exact size of x and y.
402
403      if data is None:
404        data = _make_assert_msg_data(sym, x, y, summarize, test_op)
405
406      if message is not None:
407        data = [message] + list(data)
408
409      raise errors.InvalidArgumentError(
410          node_def=None,
411          op=None,
412          message=('\n'.join(_pretty_print(d, summarize) for d in data)))
413
414    else:  # not context.executing_eagerly()
415      if data is None:
416        data = [
417            'Condition x %s y did not hold element-wise:' % sym,
418            'x (%s) = ' % x.name, x,
419            'y (%s) = ' % y.name, y
420        ]
421      if message is not None:
422        data = [message] + list(data)
423      condition = math_ops.reduce_all(op_func(x, y))
424      x_static = tensor_util.constant_value(x)
425      y_static = tensor_util.constant_value(y)
426      if x_static is not None and y_static is not None:
427        condition_static = np.all(static_func(x_static, y_static))
428        _assert_static(condition_static, data)
429      return control_flow_ops.Assert(condition, data, summarize=summarize)
430
431
432@tf_export(
433    'debugging.assert_proper_iterable',
434    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
435@dispatch.add_dispatch_support
436@deprecation.deprecated_endpoints('assert_proper_iterable')
437def assert_proper_iterable(values):
438  """Static assert that values is a "proper" iterable.
439
440  `Ops` that expect iterables of `Tensor` can call this to validate input.
441  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
442
443  Args:
444    values:  Object to be checked.
445
446  Raises:
447    TypeError:  If `values` is not iterable or is one of
448      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
449  """
450  unintentional_iterables = (
451      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
452      + compat.bytes_or_text_types
453  )
454  if isinstance(values, unintentional_iterables):
455    raise TypeError(
456        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
457        type(values))
458
459  if not hasattr(values, '__iter__'):
460    raise TypeError(
461        'Expected argument "values" to be iterable.  Found: %s' % type(values))
462
463
464@tf_export('debugging.assert_negative', v1=[])
465@dispatch.add_dispatch_support
466def assert_negative_v2(x, message=None, summarize=None, name=None):
467  """Assert the condition `x < 0` holds element-wise.
468
469  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
470  empty, this is trivially satisfied.
471
472  If `x` is not negative everywhere, `message`, as well as the first `summarize`
473  entries of `x` are printed, and `InvalidArgumentError` is raised.
474
475  Args:
476    x:  Numeric `Tensor`.
477    message: A string to prefix to the default message.
478    summarize: Print this many entries of each tensor.
479    name: A name for this operation (optional).  Defaults to "assert_negative".
480
481  Returns:
482    Op raising `InvalidArgumentError` unless `x` is all negative. This can be
483      used with `tf.control_dependencies` inside of `tf.function`s to block
484      followup computation until the check has executed.
485    @compatibility(eager)
486    returns None
487    @end_compatibility
488
489  Raises:
490    InvalidArgumentError: if the check can be performed immediately and
491      `x[i] < 0` is False. The check can be performed immediately during eager
492      execution or if `x` is statically known.
493  """
494  return assert_negative(x=x, message=message, summarize=summarize, name=name)
495
496
497@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
498@dispatch.add_dispatch_support
499@deprecation.deprecated_endpoints('assert_negative')
500@_unary_assert_doc('< 0', 'negative')
501def assert_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
502  message = message or ''
503  with ops.name_scope(name, 'assert_negative', [x, data]):
504    x = ops.convert_to_tensor(x, name='x')
505    if data is None:
506      if context.executing_eagerly():
507        name = _shape_and_dtype_str(x)
508      else:
509        name = x.name
510      data = [
511          message,
512          'Condition x < 0 did not hold element-wise:',
513          'x (%s) = ' % name, x]
514    zero = ops.convert_to_tensor(0, dtype=x.dtype)
515    return assert_less(x, zero, data=data, summarize=summarize)
516
517
518@tf_export('debugging.assert_positive', v1=[])
519@dispatch.add_dispatch_support
520def assert_positive_v2(x, message=None, summarize=None, name=None):
521  """Assert the condition `x > 0` holds element-wise.
522
523  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
524  empty, this is trivially satisfied.
525
526  If `x` is not positive everywhere, `message`, as well as the first `summarize`
527  entries of `x` are printed, and `InvalidArgumentError` is raised.
528
529  Args:
530    x:  Numeric `Tensor`.
531    message: A string to prefix to the default message.
532    summarize: Print this many entries of each tensor.
533    name: A name for this operation (optional). Defaults to "assert_positive".
534
535  Returns:
536    Op raising `InvalidArgumentError` unless `x` is all positive. This can be
537      used with `tf.control_dependencies` inside of `tf.function`s to block
538      followup computation until the check has executed.
539    @compatibility(eager)
540    returns None
541    @end_compatibility
542
543  Raises:
544    InvalidArgumentError: if the check can be performed immediately and
545      `x[i] > 0` is False. The check can be performed immediately during eager
546      execution or if `x` is statically known.
547  """
548  return assert_positive(x=x, summarize=summarize, message=message, name=name)
549
550
551@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
552@dispatch.add_dispatch_support
553@deprecation.deprecated_endpoints('assert_positive')
554@_unary_assert_doc('> 0', 'positive')
555def assert_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
556  message = message or ''
557  with ops.name_scope(name, 'assert_positive', [x, data]):
558    x = ops.convert_to_tensor(x, name='x')
559    if data is None:
560      if context.executing_eagerly():
561        name = _shape_and_dtype_str(x)
562      else:
563        name = x.name
564      data = [
565          message, 'Condition x > 0 did not hold element-wise:',
566          'x (%s) = ' % name, x]
567    zero = ops.convert_to_tensor(0, dtype=x.dtype)
568    return assert_less(zero, x, data=data, summarize=summarize)
569
570
571@tf_export('debugging.assert_non_negative', v1=[])
572@dispatch.add_dispatch_support
573def assert_non_negative_v2(x, message=None, summarize=None, name=None):
574  """Assert the condition `x >= 0` holds element-wise.
575
576  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
577  empty, this is trivially satisfied.
578
579  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
580  entries of `x` are printed, and `InvalidArgumentError` is raised.
581
582  Args:
583    x:  Numeric `Tensor`.
584    message: A string to prefix to the default message.
585    summarize: Print this many entries of each tensor.
586    name: A name for this operation (optional).  Defaults to
587      "assert_non_negative".
588
589  Returns:
590    Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
591      be used with `tf.control_dependencies` inside of `tf.function`s to block
592      followup computation until the check has executed.
593    @compatibility(eager)
594    returns None
595    @end_compatibility
596
597  Raises:
598    InvalidArgumentError: if the check can be performed immediately and
599      `x[i] >= 0` is False. The check can be performed immediately during eager
600      execution or if `x` is statically known.
601  """
602  return assert_non_negative(x=x, summarize=summarize, message=message,
603                             name=name)
604
605
606@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
607@dispatch.add_dispatch_support
608@deprecation.deprecated_endpoints('assert_non_negative')
609@_unary_assert_doc('>= 0', 'non-negative')
610def assert_non_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
611  message = message or ''
612  with ops.name_scope(name, 'assert_non_negative', [x, data]):
613    x = ops.convert_to_tensor(x, name='x')
614    if data is None:
615      if context.executing_eagerly():
616        name = _shape_and_dtype_str(x)
617      else:
618        name = x.name
619      data = [
620          message,
621          'Condition x >= 0 did not hold element-wise:',
622          'x (%s) = ' % name, x]
623    zero = ops.convert_to_tensor(0, dtype=x.dtype)
624    return assert_less_equal(zero, x, data=data, summarize=summarize)
625
626
627@tf_export('debugging.assert_non_positive', v1=[])
628@dispatch.add_dispatch_support
629def assert_non_positive_v2(x, message=None, summarize=None, name=None):
630  """Assert the condition `x <= 0` holds element-wise.
631
632  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
633  empty, this is trivially satisfied.
634
635  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
636  entries of `x` are printed, and `InvalidArgumentError` is raised.
637
638  Args:
639    x:  Numeric `Tensor`.
640    message: A string to prefix to the default message.
641    summarize: Print this many entries of each tensor.
642    name: A name for this operation (optional).  Defaults to
643      "assert_non_positive".
644
645  Returns:
646    Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
647      be used with `tf.control_dependencies` inside of `tf.function`s to block
648      followup computation until the check has executed.
649    @compatibility(eager)
650    returns None
651    @end_compatibility
652
653  Raises:
654    InvalidArgumentError: if the check can be performed immediately and
655      `x[i] <= 0` is False. The check can be performed immediately during eager
656      execution or if `x` is statically known.
657  """
658  return assert_non_positive(x=x, summarize=summarize, message=message,
659                             name=name)
660
661
662@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
663@dispatch.add_dispatch_support
664@deprecation.deprecated_endpoints('assert_non_positive')
665@_unary_assert_doc('<= 0', 'non-positive')
666def assert_non_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
667  message = message or ''
668  with ops.name_scope(name, 'assert_non_positive', [x, data]):
669    x = ops.convert_to_tensor(x, name='x')
670    if data is None:
671      if context.executing_eagerly():
672        name = _shape_and_dtype_str(x)
673      else:
674        name = x.name
675      data = [
676          message,
677          'Condition x <= 0 did not hold element-wise:'
678          'x (%s) = ' % name, x]
679    zero = ops.convert_to_tensor(0, dtype=x.dtype)
680    return assert_less_equal(x, zero, data=data, summarize=summarize)
681
682
683@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
684@dispatch.add_dispatch_support
685def assert_equal_v2(x, y, message=None, summarize=None, name=None):
686  """Assert the condition `x == y` holds element-wise.
687
688  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
689  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
690  trivially satisfied.
691
692  If `x` and `y` are not equal, `message`, as well as the first `summarize`
693  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
694
695  Args:
696    x:  Numeric `Tensor`.
697    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
698    message: A string to prefix to the default message.
699    summarize: Print this many entries of each tensor.
700    name: A name for this operation (optional).  Defaults to "assert_equal".
701
702  Returns:
703    Op that raises `InvalidArgumentError` if `x == y` is False. This can be
704      used with `tf.control_dependencies` inside of `tf.function`s to block
705      followup computation until the check has executed.
706    @compatibility(eager)
707    returns None
708    @end_compatibility
709
710  Raises:
711    InvalidArgumentError: if the check can be performed immediately and
712      `x == y` is False. The check can be performed immediately during eager
713      execution or if `x` and `y` are statically known.
714  """
715  return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
716
717
718@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
719@dispatch.add_dispatch_support
720@_binary_assert_doc('==', '[1, 2]')
721def assert_equal(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
722  with ops.name_scope(name, 'assert_equal', [x, y, data]):
723    # Short-circuit if x and y are the same tensor.
724    if x is y:
725      return None if context.executing_eagerly() else control_flow_ops.no_op()
726  return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
727                        data, summarize, message, name)
728
729
730@tf_export('debugging.assert_none_equal', v1=[])
731@dispatch.add_dispatch_support
732def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
733  """Assert the condition `x != y` holds for all elements.
734
735  This Op checks that `x[i] != y[i]` holds for every pair of (possibly
736  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
737  trivially satisfied.
738
739  If any elements of `x` and `y` are equal, `message`, as well as the first
740  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
741  is raised.
742
743  Args:
744    x:  Numeric `Tensor`.
745    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
746    summarize: Print this many entries of each tensor.
747    message: A string to prefix to the default message.
748    name: A name for this operation (optional).  Defaults to
749    "assert_none_equal".
750
751  Returns:
752    Op that raises `InvalidArgumentError` if `x != y` is ever False. This can
753      be used with `tf.control_dependencies` inside of `tf.function`s to block
754      followup computation until the check has executed.
755    @compatibility(eager)
756    returns None
757    @end_compatibility
758
759  Raises:
760    InvalidArgumentError: if the check can be performed immediately and
761      `x != y` is False for any pair of elements in `x` and `y`. The check can
762      be performed immediately during eager execution or if `x` and `y` are
763      statically known.
764  """
765  return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
766                           name=name)
767
768
769@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
770@dispatch.add_dispatch_support
771@deprecation.deprecated_endpoints('assert_none_equal')
772@_binary_assert_doc('!=', '[2, 1]')
773def assert_none_equal(
774    x, y, data=None, summarize=None, message=None, name=None):
775  return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
776                        np.not_equal, x, y, data, summarize, message, name)
777
778
779@tf_export('debugging.assert_near', v1=[])
780@dispatch.add_dispatch_support
781def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
782                   name=None):
783  """Assert the condition `x` and `y` are close element-wise.
784
785  This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
786  pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
787  empty, this is trivially satisfied.
788
789  If any elements of `x` and `y` are not close, `message`, as well as the first
790  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
791  is raised.
792
793  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
794  representable positive number such that `1 + eps != 1`.  This is about
795  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
796  See `numpy.finfo`.
797
798  Args:
799    x: Float or complex `Tensor`.
800    y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
801    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
802      The relative tolerance.  Default is `10 * eps`.
803    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
804      The absolute tolerance.  Default is `10 * eps`.
805    message: A string to prefix to the default message.
806    summarize: Print this many entries of each tensor.
807    name: A name for this operation (optional).  Defaults to "assert_near".
808
809  Returns:
810    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
811      This can be used with `tf.control_dependencies` inside of `tf.function`s
812      to block followup computation until the check has executed.
813    @compatibility(eager)
814    returns None
815    @end_compatibility
816
817  Raises:
818    InvalidArgumentError: if the check can be performed immediately and
819      `x != y` is False for any pair of elements in `x` and `y`. The check can
820      be performed immediately during eager execution or if `x` and `y` are
821      statically known.
822
823  @compatibility(numpy)
824  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
825  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
826  `64bit`, and even `16bit` data.
827  @end_compatibility
828  """
829  return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
830                     message=message, name=name)
831
832
833@tf_export(v1=['debugging.assert_near', 'assert_near'])
834@dispatch.add_dispatch_support
835@deprecation.deprecated_endpoints('assert_near')
836def assert_near(
837    x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
838    name=None):
839  """Assert the condition `x` and `y` are close element-wise.
840
841  Example of adding a dependency to an operation:
842
843  ```python
844  with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
845    output = tf.reduce_sum(x)
846  ```
847
848  This condition holds if for every pair of (possibly broadcast) elements
849  `x[i]`, `y[i]`, we have
850
851  ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
852
853  If both `x` and `y` are empty, this is trivially satisfied.
854
855  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
856  representable positive number such that `1 + eps != 1`.  This is about
857  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
858  See `numpy.finfo`.
859
860  Args:
861    x:  Float or complex `Tensor`.
862    y:  Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
863    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
864      The relative tolerance.  Default is `10 * eps`.
865    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
866      The absolute tolerance.  Default is `10 * eps`.
867    data:  The tensors to print out if the condition is False.  Defaults to
868      error message and first few entries of `x`, `y`.
869    summarize: Print this many entries of each tensor.
870    message: A string to prefix to the default message.
871    name: A name for this operation (optional).  Defaults to "assert_near".
872
873  Returns:
874    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
875
876  @compatibility(numpy)
877  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
878  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
879  `64bit`, and even `16bit` data.
880  @end_compatibility
881  """
882  message = message or ''
883  with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
884    x = ops.convert_to_tensor(x, name='x')
885    y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
886
887    dtype = x.dtype
888    if dtype.is_complex:
889      dtype = dtype.real_dtype
890    eps = np.finfo(dtype.as_numpy_dtype).eps
891    rtol = 10 * eps if rtol is None else rtol
892    atol = 10 * eps if atol is None else atol
893
894    rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
895    atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
896
897    if context.executing_eagerly():
898      x_name = _shape_and_dtype_str(x)
899      y_name = _shape_and_dtype_str(y)
900    else:
901      x_name = x.name
902      y_name = y.name
903
904    if data is None:
905      data = [
906          message,
907          'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
908          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
909      ]
910    tol = atol + rtol * math_ops.abs(y)
911    diff = math_ops.abs(x - y)
912    condition = math_ops.reduce_all(math_ops.less(diff, tol))
913    return control_flow_ops.Assert(condition, data, summarize=summarize)
914
915
916@tf_export('debugging.assert_less', 'assert_less', v1=[])
917@dispatch.add_dispatch_support
918def assert_less_v2(x, y, message=None, summarize=None, name=None):
919  """Assert the condition `x < y` holds element-wise.
920
921  This Op checks that `x[i] < y[i]` holds for every pair of (possibly
922  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
923  trivially satisfied.
924
925  If `x` is not less than `y` element-wise, `message`, as well as the first
926  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
927  raised.
928
929  Args:
930    x:  Numeric `Tensor`.
931    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
932    message: A string to prefix to the default message.
933    summarize: Print this many entries of each tensor.
934    name: A name for this operation (optional).  Defaults to "assert_less".
935
936  Returns:
937    Op that raises `InvalidArgumentError` if `x < y` is False.
938    This can be used with `tf.control_dependencies` inside of `tf.function`s
939    to block followup computation until the check has executed.
940    @compatibility(eager)
941    returns None
942    @end_compatibility
943
944  Raises:
945    InvalidArgumentError: if the check can be performed immediately and
946      `x < y` is False. The check can be performed immediately during eager
947      execution or if `x` and `y` are statically known.
948  """
949  return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
950
951
952@tf_export(v1=['debugging.assert_less', 'assert_less'])
953@dispatch.add_dispatch_support
954@_binary_assert_doc('<', '[2, 3]')
955def assert_less(x, y, data=None, summarize=None, message=None, name=None):
956  return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
957                        summarize, message, name)
958
959
960@tf_export('debugging.assert_less_equal', v1=[])
961@dispatch.add_dispatch_support
962def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
963  """Assert the condition `x <= y` holds element-wise.
964
965  This Op checks that `x[i] <= y[i]` holds for every pair of (possibly
966  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
967  trivially satisfied.
968
969  If `x` is not less or equal than `y` element-wise, `message`, as well as the
970  first `summarize` entries of `x` and `y` are printed, and
971  `InvalidArgumentError` is raised.
972
973  Args:
974    x:  Numeric `Tensor`.
975    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
976    message: A string to prefix to the default message.
977    summarize: Print this many entries of each tensor.
978    name: A name for this operation (optional). Defaults to "assert_less_equal".
979
980  Returns:
981    Op that raises `InvalidArgumentError` if `x <= y` is False. This can be
982      used with `tf.control_dependencies` inside of `tf.function`s to block
983      followup computation until the check has executed.
984    @compatibility(eager)
985    returns None
986    @end_compatibility
987
988  Raises:
989    InvalidArgumentError: if the check can be performed immediately and
990      `x <= y` is False. The check can be performed immediately during eager
991      execution or if `x` and `y` are statically known.
992  """
993  return assert_less_equal(x=x, y=y,
994                           summarize=summarize, message=message, name=name)
995
996
997@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
998@dispatch.add_dispatch_support
999@deprecation.deprecated_endpoints('assert_less_equal')
1000@_binary_assert_doc('<=', '[1, 3]')
1001def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
1002  return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
1003                        np.less_equal, x, y, data, summarize, message, name)
1004
1005
1006@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
1007@dispatch.add_dispatch_support
1008def assert_greater_v2(x, y, message=None, summarize=None, name=None):
1009  """Assert the condition `x > y` holds element-wise.
1010
1011  This Op checks that `x[i] > y[i]` holds for every pair of (possibly
1012  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
1013  trivially satisfied.
1014
1015  If `x` is not greater than `y` element-wise, `message`, as well as the first
1016  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
1017  raised.
1018
1019  Args:
1020    x:  Numeric `Tensor`.
1021    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
1022    message: A string to prefix to the default message.
1023    summarize: Print this many entries of each tensor.
1024    name: A name for this operation (optional).  Defaults to "assert_greater".
1025
1026  Returns:
1027    Op that raises `InvalidArgumentError` if `x > y` is False. This can be
1028      used with `tf.control_dependencies` inside of `tf.function`s to block
1029      followup computation until the check has executed.
1030    @compatibility(eager)
1031    returns None
1032    @end_compatibility
1033
1034  Raises:
1035    InvalidArgumentError: if the check can be performed immediately and
1036      `x > y` is False. The check can be performed immediately during eager
1037      execution or if `x` and `y` are statically known.
1038  """
1039  return assert_greater(x=x, y=y, summarize=summarize, message=message,
1040                        name=name)
1041
1042
1043@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
1044@dispatch.add_dispatch_support
1045@_binary_assert_doc('>', '[0, 1]')
1046def assert_greater(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
1047  return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
1048                        y, data, summarize, message, name)
1049
1050
1051@tf_export('debugging.assert_greater_equal', v1=[])
1052@dispatch.add_dispatch_support
1053def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
1054  """Assert the condition `x >= y` holds element-wise.
1055
1056  This Op checks that `x[i] >= y[i]` holds for every pair of (possibly
1057  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
1058  trivially satisfied.
1059
1060  If `x` is not greater or equal to `y` element-wise, `message`, as well as the
1061  first `summarize` entries of `x` and `y` are printed, and
1062  `InvalidArgumentError` is raised.
1063
1064  Args:
1065    x:  Numeric `Tensor`.
1066    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
1067    message: A string to prefix to the default message.
1068    summarize: Print this many entries of each tensor.
1069    name: A name for this operation (optional).  Defaults to
1070    "assert_greater_equal".
1071
1072  Returns:
1073    Op that raises `InvalidArgumentError` if `x >= y` is False. This can be
1074      used with `tf.control_dependencies` inside of `tf.function`s to block
1075      followup computation until the check has executed.
1076    @compatibility(eager)
1077    returns None
1078    @end_compatibility
1079
1080  Raises:
1081    InvalidArgumentError: if the check can be performed immediately and
1082      `x >= y` is False. The check can be performed immediately during eager
1083      execution or if `x` and `y` are statically known.
1084  """
1085  return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
1086                              name=name)
1087
1088
1089@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
1090@dispatch.add_dispatch_support
1091@deprecation.deprecated_endpoints('assert_greater_equal')
1092@_binary_assert_doc('>=', '[1, 0]')
1093def assert_greater_equal(x, y, data=None, summarize=None, message=None,
1094                         name=None):
1095  return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
1096                        np.greater_equal, x, y, data, summarize, message, name)
1097
1098
1099def _assert_rank_condition(
1100    x, rank, static_condition, dynamic_condition, data, summarize):
1101  """Assert `x` has a rank that satisfies a given condition.
1102
1103  Args:
1104    x:  Numeric `Tensor`.
1105    rank:  Scalar `Tensor`.
1106    static_condition:   A python function that takes `[actual_rank, given_rank]`
1107      and returns `True` if the condition is satisfied, `False` otherwise.
1108    dynamic_condition:  An `op` that takes [actual_rank, given_rank] and return
1109      `True` if the condition is satisfied, `False` otherwise.
1110    data:  The tensors to print out if the condition is false.  Defaults to
1111      error message and first few entries of `x`.
1112    summarize: Print this many entries of each tensor.
1113
1114  Returns:
1115    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1116
1117  Raises:
1118    ValueError:  If static checks determine `x` fails static_condition.
1119  """
1120  assert_type(rank, dtypes.int32)
1121
1122  # Attempt to statically defined rank.
1123  rank_static = tensor_util.constant_value(rank)
1124  if rank_static is not None:
1125    if rank_static.ndim != 0:
1126      raise ValueError('Rank must be a scalar.')
1127
1128    x_rank_static = x.get_shape().ndims
1129    if x_rank_static is not None:
1130      if not static_condition(x_rank_static, rank_static):
1131        raise ValueError(
1132            'Static rank condition failed', x_rank_static, rank_static)
1133      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1134
1135  condition = dynamic_condition(array_ops.rank(x), rank)
1136
1137  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1138  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1139  if rank_static is None:
1140    this_data = ['Rank must be a scalar. Received rank: ', rank]
1141    rank_check = assert_rank(rank, 0, data=this_data)
1142    condition = control_flow_ops.with_dependencies([rank_check], condition)
1143
1144  return control_flow_ops.Assert(condition, data, summarize=summarize)
1145
1146
1147@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
1148@dispatch.add_dispatch_support
1149def assert_rank_v2(x, rank, message=None, name=None):
1150  """Assert that `x` has rank equal to `rank`.
1151
1152  This Op checks that the rank of `x` is equal to `rank`.
1153
1154  If `x` has a different rank, `message`, as well as the shape of `x` are
1155  printed, and `InvalidArgumentError` is raised.
1156
1157  Args:
1158    x: `Tensor`.
1159    rank: Scalar integer `Tensor`.
1160    message: A string to prefix to the default message.
1161    name: A name for this operation (optional). Defaults to
1162      "assert_rank".
1163
1164  Returns:
1165    Op raising `InvalidArgumentError` unless `x` has specified rank.
1166    If static checks determine `x` has correct rank, a `no_op` is returned.
1167    This can be used with `tf.control_dependencies` inside of `tf.function`s
1168    to block followup computation until the check has executed.
1169    @compatibility(eager)
1170    returns None
1171    @end_compatibility
1172
1173  Raises:
1174    InvalidArgumentError: if the check can be performed immediately and
1175      `x` does not have rank `rank`. The check can be performed immediately
1176      during eager execution or if the shape of `x` is statically known.
1177  """
1178  return assert_rank(x=x, rank=rank, message=message, name=name)
1179
1180
1181@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
1182@dispatch.add_dispatch_support
1183def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
1184  """Assert `x` has rank equal to `rank`.
1185
1186  Example of adding a dependency to an operation:
1187
1188  ```python
1189  with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
1190    output = tf.reduce_sum(x)
1191  ```
1192
1193  Args:
1194    x:  Numeric `Tensor`.
1195    rank:  Scalar integer `Tensor`.
1196    data:  The tensors to print out if the condition is False.  Defaults to
1197      error message and the shape of `x`.
1198    summarize: Print this many entries of each tensor.
1199    message: A string to prefix to the default message.
1200    name: A name for this operation (optional).  Defaults to "assert_rank".
1201
1202  Returns:
1203    Op raising `InvalidArgumentError` unless `x` has specified rank.
1204    If static checks determine `x` has correct rank, a `no_op` is returned.
1205
1206  Raises:
1207    ValueError:  If static checks determine `x` has wrong rank.
1208  """
1209  with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
1210    if not isinstance(x, sparse_tensor.SparseTensor):
1211      x = ops.convert_to_tensor(x, name='x')
1212    rank = ops.convert_to_tensor(rank, name='rank')
1213    message = message or ''
1214
1215    static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
1216    dynamic_condition = math_ops.equal
1217
1218    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1219      name = ''
1220    else:
1221      name = x.name
1222
1223    if data is None:
1224      data = [
1225          message,
1226          'Tensor %s must have rank' % name, rank, 'Received shape: ',
1227          array_ops.shape(x)
1228      ]
1229
1230    try:
1231      assert_op = _assert_rank_condition(x, rank, static_condition,
1232                                         dynamic_condition, data, summarize)
1233
1234    except ValueError as e:
1235      if e.args[0] == 'Static rank condition failed':
1236        raise ValueError(
1237            '%s.  Tensor %s must have rank %d.  Received rank %d, shape %s' %
1238            (message, name, e.args[2], e.args[1], x.get_shape()))
1239      else:
1240        raise
1241
1242  return assert_op
1243
1244
1245@tf_export('debugging.assert_rank_at_least', v1=[])
1246@dispatch.add_dispatch_support
1247def assert_rank_at_least_v2(x, rank, message=None, name=None):
1248  """Assert that `x` has rank of at least `rank`.
1249
1250  This Op checks that the rank of `x` is greater or equal to `rank`.
1251
1252  If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
1253  are printed, and `InvalidArgumentError` is raised.
1254
1255  Args:
1256    x: `Tensor`.
1257    rank: Scalar integer `Tensor`.
1258    message: A string to prefix to the default message.
1259    name: A name for this operation (optional).  Defaults to
1260      "assert_rank_at_least".
1261
1262  Returns:
1263    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1264    If static checks determine `x` has correct rank, a `no_op` is returned.
1265    This can be used with `tf.control_dependencies` inside of `tf.function`s
1266    to block followup computation until the check has executed.
1267    @compatibility(eager)
1268    returns None
1269    @end_compatibility
1270
1271  Raises:
1272    InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
1273      cannot be statically determined.
1274    ValueError: If static checks determine `x` has mismatched rank.
1275  """
1276  return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
1277
1278
1279@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
1280@dispatch.add_dispatch_support
1281@deprecation.deprecated_endpoints('assert_rank_at_least')
1282def assert_rank_at_least(
1283    x, rank, data=None, summarize=None, message=None, name=None):
1284  """Assert `x` has rank equal to `rank` or higher.
1285
1286  Example of adding a dependency to an operation:
1287
1288  ```python
1289  with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
1290    output = tf.reduce_sum(x)
1291  ```
1292
1293  Args:
1294    x:  Numeric `Tensor`.
1295    rank:  Scalar `Tensor`.
1296    data:  The tensors to print out if the condition is False.  Defaults to
1297      error message and first few entries of `x`.
1298    summarize: Print this many entries of each tensor.
1299    message: A string to prefix to the default message.
1300    name: A name for this operation (optional).
1301      Defaults to "assert_rank_at_least".
1302
1303  Returns:
1304    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1305    If static checks determine `x` has correct rank, a `no_op` is returned.
1306
1307  Raises:
1308    ValueError:  If static checks determine `x` has wrong rank.
1309  """
1310  with ops.name_scope(
1311      name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
1312    x = ops.convert_to_tensor(x, name='x')
1313    rank = ops.convert_to_tensor(rank, name='rank')
1314    message = message or ''
1315
1316    static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
1317    dynamic_condition = math_ops.greater_equal
1318
1319    if context.executing_eagerly():
1320      name = ''
1321    else:
1322      name = x.name
1323
1324    if data is None:
1325      data = [
1326          message,
1327          'Tensor %s must have rank at least' % name, rank,
1328          'Received shape: ', array_ops.shape(x)
1329      ]
1330
1331    try:
1332      assert_op = _assert_rank_condition(x, rank, static_condition,
1333                                         dynamic_condition, data, summarize)
1334
1335    except ValueError as e:
1336      if e.args[0] == 'Static rank condition failed':
1337        raise ValueError(
1338            '%s.  Tensor %s must have rank at least %d.  Received rank %d, '
1339            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1340      else:
1341        raise
1342
1343  return assert_op
1344
1345
1346def _static_rank_in(actual_rank, given_ranks):
1347  return actual_rank in given_ranks
1348
1349
1350def _dynamic_rank_in(actual_rank, given_ranks):
1351  if len(given_ranks) < 1:
1352    return ops.convert_to_tensor(False)
1353  result = math_ops.equal(given_ranks[0], actual_rank)
1354  for given_rank in given_ranks[1:]:
1355    result = math_ops.logical_or(
1356        result, math_ops.equal(given_rank, actual_rank))
1357  return result
1358
1359
1360def _assert_ranks_condition(
1361    x, ranks, static_condition, dynamic_condition, data, summarize):
1362  """Assert `x` has a rank that satisfies a given condition.
1363
1364  Args:
1365    x:  Numeric `Tensor`.
1366    ranks:  Scalar `Tensor`.
1367    static_condition:   A python function that takes
1368      `[actual_rank, given_ranks]` and returns `True` if the condition is
1369      satisfied, `False` otherwise.
1370    dynamic_condition:  An `op` that takes [actual_rank, given_ranks]
1371      and return `True` if the condition is satisfied, `False` otherwise.
1372    data:  The tensors to print out if the condition is false.  Defaults to
1373      error message and first few entries of `x`.
1374    summarize: Print this many entries of each tensor.
1375
1376  Returns:
1377    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1378
1379  Raises:
1380    ValueError:  If static checks determine `x` fails static_condition.
1381  """
1382  for rank in ranks:
1383    assert_type(rank, dtypes.int32)
1384
1385  # Attempt to statically defined rank.
1386  ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
1387  if not any(r is None for r in ranks_static):
1388    for rank_static in ranks_static:
1389      if rank_static.ndim != 0:
1390        raise ValueError('Rank must be a scalar.')
1391
1392    x_rank_static = x.get_shape().ndims
1393    if x_rank_static is not None:
1394      if not static_condition(x_rank_static, ranks_static):
1395        raise ValueError(
1396            'Static rank condition failed', x_rank_static, ranks_static)
1397      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1398
1399  condition = dynamic_condition(array_ops.rank(x), ranks)
1400
1401  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1402  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1403  for rank, rank_static in zip(ranks, ranks_static):
1404    if rank_static is None:
1405      this_data = ['Rank must be a scalar. Received rank: ', rank]
1406      rank_check = assert_rank(rank, 0, data=this_data)
1407      condition = control_flow_ops.with_dependencies([rank_check], condition)
1408
1409  return control_flow_ops.Assert(condition, data, summarize=summarize)
1410
1411
1412@tf_export('debugging.assert_rank_in', v1=[])
1413@dispatch.add_dispatch_support
1414def assert_rank_in_v2(x, ranks, message=None, name=None):
1415  """Assert that `x` has a rank in `ranks`.
1416
1417  This Op checks that the rank of `x` is in `ranks`.
1418
1419  If `x` has a different rank, `message`, as well as the shape of `x` are
1420  printed, and `InvalidArgumentError` is raised.
1421
1422  Args:
1423    x: `Tensor`.
1424    ranks: `Iterable` of scalar `Tensor` objects.
1425    message: A string to prefix to the default message.
1426    name: A name for this operation (optional). Defaults to "assert_rank_in".
1427
1428  Returns:
1429    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1430    If static checks determine `x` has matching rank, a `no_op` is returned.
1431    This can be used with `tf.control_dependencies` inside of `tf.function`s
1432    to block followup computation until the check has executed.
1433    @compatibility(eager)
1434    returns None
1435    @end_compatibility
1436
1437  Raises:
1438    InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
1439      be statically determined.
1440    ValueError: If static checks determine `x` has mismatched rank.
1441  """
1442  return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
1443
1444
1445@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
1446@dispatch.add_dispatch_support
1447@deprecation.deprecated_endpoints('assert_rank_in')
1448def assert_rank_in(
1449    x, ranks, data=None, summarize=None, message=None, name=None):
1450  """Assert `x` has rank in `ranks`.
1451
1452  Example of adding a dependency to an operation:
1453
1454  ```python
1455  with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
1456    output = tf.reduce_sum(x)
1457  ```
1458
1459  Args:
1460    x:  Numeric `Tensor`.
1461    ranks:  Iterable of scalar `Tensor` objects.
1462    data:  The tensors to print out if the condition is False.  Defaults to
1463      error message and first few entries of `x`.
1464    summarize: Print this many entries of each tensor.
1465    message: A string to prefix to the default message.
1466    name: A name for this operation (optional).
1467      Defaults to "assert_rank_in".
1468
1469  Returns:
1470    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1471    If static checks determine `x` has matching rank, a `no_op` is returned.
1472
1473  Raises:
1474    ValueError:  If static checks determine `x` has mismatched rank.
1475  """
1476  with ops.name_scope(
1477      name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
1478    if not isinstance(x, sparse_tensor.SparseTensor):
1479      x = ops.convert_to_tensor(x, name='x')
1480    ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
1481    message = message or ''
1482
1483    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1484      name = ''
1485    else:
1486      name = x.name
1487
1488    if data is None:
1489      data = [
1490          message, 'Tensor %s must have rank in' % name
1491      ] + list(ranks) + [
1492          'Received shape: ', array_ops.shape(x)
1493      ]
1494
1495    try:
1496      assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
1497                                          _dynamic_rank_in, data, summarize)
1498
1499    except ValueError as e:
1500      if e.args[0] == 'Static rank condition failed':
1501        raise ValueError(
1502            '%s.  Tensor %s must have rank in %s.  Received rank %d, '
1503            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1504      else:
1505        raise
1506
1507  return assert_op
1508
1509
1510@tf_export('debugging.assert_integer', v1=[])
1511@dispatch.add_dispatch_support
1512def assert_integer_v2(x, message=None, name=None):
1513  """Assert that `x` is of integer dtype.
1514
1515  If `x` has a non-integer type, `message`, as well as the dtype of `x` are
1516  printed, and `InvalidArgumentError` is raised.
1517
1518  This can always be checked statically, so this method returns nothing.
1519
1520  Args:
1521    x: A `Tensor`.
1522    message: A string to prefix to the default message.
1523    name: A name for this operation (optional). Defaults to "assert_integer".
1524
1525  Raises:
1526    TypeError:  If `x.dtype` is not a non-quantized integer type.
1527  """
1528  assert_integer(x=x, message=message, name=name)
1529
1530
1531@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
1532@dispatch.add_dispatch_support
1533@deprecation.deprecated_endpoints('assert_integer')
1534def assert_integer(x, message=None, name=None):
1535  """Assert that `x` is of integer dtype.
1536
1537  Example of adding a dependency to an operation:
1538
1539  ```python
1540  with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
1541    output = tf.reduce_sum(x)
1542  ```
1543
1544  Args:
1545    x: `Tensor` whose basetype is integer and is not quantized.
1546    message: A string to prefix to the default message.
1547    name: A name for this operation (optional).  Defaults to "assert_integer".
1548
1549  Raises:
1550    TypeError:  If `x.dtype` is anything other than non-quantized integer.
1551
1552  Returns:
1553    A `no_op` that does nothing.  Type can be determined statically.
1554  """
1555  message = message or ''
1556  with ops.name_scope(name, 'assert_integer', [x]):
1557    x = ops.convert_to_tensor(x, name='x')
1558    if not x.dtype.is_integer:
1559      if context.executing_eagerly():
1560        name = 'tensor'
1561      else:
1562        name = x.name
1563      err_msg = (
1564          '%s  Expected "x" to be integer type.  Found: %s of dtype %s'
1565          % (message, name, x.dtype))
1566      raise TypeError(err_msg)
1567
1568    return control_flow_ops.no_op('statically_determined_was_integer')
1569
1570
1571@tf_export('debugging.assert_type', v1=[])
1572@dispatch.add_dispatch_support
1573def assert_type_v2(tensor, tf_type, message=None, name=None):
1574  """Asserts that the given `Tensor` is of the specified type.
1575
1576  This can always be checked statically, so this method returns nothing.
1577
1578  Example:
1579
1580  >>> a = tf.Variable(1.0)
1581  >>> tf.debugging.assert_type(a, tf_type= tf.float32)
1582
1583  >>> b = tf.constant(21)
1584  >>> tf.debugging.assert_type(b, tf_type=tf.bool)
1585  Traceback (most recent call last):
1586  ...
1587  TypeError: ...
1588
1589  >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2],
1590  ...  dense_shape=[3, 4])
1591  >>> tf.debugging.assert_type(c, tf_type= tf.int32)
1592
1593  Args:
1594    tensor: A `Tensor`, `SparseTensor` or `tf.Variable .
1595    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1596      etc).
1597    message: A string to prefix to the default message.
1598    name:  A name for this operation. Defaults to "assert_type"
1599
1600  Raises:
1601    TypeError: If the tensor's data type doesn't match `tf_type`.
1602  """
1603  assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
1604
1605
1606@tf_export(v1=['debugging.assert_type', 'assert_type'])
1607@dispatch.add_dispatch_support
1608@deprecation.deprecated_endpoints('assert_type')
1609def assert_type(tensor, tf_type, message=None, name=None):
1610  """Statically asserts that the given `Tensor` is of the specified type.
1611
1612  Args:
1613    tensor: A `Tensor` or `SparseTensor`.
1614    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1615      etc).
1616    message: A string to prefix to the default message.
1617    name:  A name to give this `Op`.  Defaults to "assert_type"
1618
1619  Raises:
1620    TypeError: If the tensors data type doesn't match `tf_type`.
1621
1622  Returns:
1623    A `no_op` that does nothing.  Type can be determined statically.
1624  """
1625  message = message or ''
1626  tf_type = dtypes.as_dtype(tf_type)
1627  with ops.name_scope(name, 'assert_type', [tensor]):
1628    if not isinstance(tensor, sparse_tensor.SparseTensor):
1629      tensor = ops.convert_to_tensor(tensor, name='tensor')
1630    if tensor.dtype != tf_type:
1631      if context.executing_eagerly():
1632        raise TypeError('%s tensor must be of type %s' % (message, tf_type))
1633      else:
1634        raise TypeError(
1635            '%s  %s must be of type %s' %
1636            (message, tensor.name if hasattr(tensor, 'name') else '', tf_type))
1637
1638    return control_flow_ops.no_op('statically_determined_correct_type')
1639
1640
1641def _dimension_sizes(x):
1642  """Gets the dimension sizes of a tensor `x`.
1643
1644  If a size can be determined statically it is returned as an integer,
1645  otherwise as a tensor.
1646
1647  If `x` is a scalar it is treated as rank 1 size 1.
1648
1649  Args:
1650    x: A `Tensor`.
1651
1652  Returns:
1653    Dimension sizes.
1654  """
1655  dynamic_shape = array_ops.shape(x)
1656  rank = x.get_shape().rank
1657  rank_is_known = rank is not None
1658  if rank_is_known and rank == 0:
1659    return (1,)
1660  if rank_is_known and rank > 0:
1661    static_shape = x.get_shape().as_list()
1662    sizes = [
1663        int(size) if size is not None else dynamic_shape[i]
1664        for i, size in enumerate(static_shape)
1665    ]
1666    return sizes
1667  has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
1668  return control_flow_ops.cond(
1669      has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
1670
1671
1672def _symbolic_dimension_sizes(symbolic_shape):
1673  # If len(symbolic_shape) == 0 construct a tuple
1674  if not symbolic_shape:
1675    return tuple([1])
1676
1677  return symbolic_shape
1678
1679
1680def _has_known_value(dimension_size):
1681  not_none = dimension_size is not None
1682  try:
1683    int(dimension_size)
1684    can_be_parsed_as_int = True
1685  except (ValueError, TypeError):
1686    can_be_parsed_as_int = False
1687  return not_none and can_be_parsed_as_int
1688
1689
1690def _is_symbol_for_any_size(symbol):
1691  return symbol in [None, '.']
1692
1693
1694_TensorDimSizes = collections.namedtuple(
1695    '_TensorDimSizes',
1696    ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
1697
1698
1699@tf_export('debugging.assert_shapes', v1=[])
1700@dispatch.add_dispatch_support
1701def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
1702                     name=None):
1703  """Assert tensor shapes and dimension size relationships between tensors.
1704
1705  This Op checks that a collection of tensors shape relationships
1706  satisfies given constraints.
1707
1708  Example:
1709
1710  >>> n = 10
1711  >>> q = 3
1712  >>> d = 7
1713  >>> x = tf.zeros([n,q])
1714  >>> y = tf.ones([n,d])
1715  >>> param = tf.Variable([1.0, 2.0, 3.0])
1716  >>> scalar = 1.0
1717  >>> tf.debugging.assert_shapes([
1718  ...  (x, ('N', 'Q')),
1719  ...  (y, ('N', 'D')),
1720  ...  (param, ('Q',)),
1721  ...  (scalar, ()),
1722  ... ])
1723
1724  >>> tf.debugging.assert_shapes([
1725  ...   (x, ('N', 'D')),
1726  ...   (y, ('N', 'D'))
1727  ... ])
1728  Traceback (most recent call last):
1729  ...
1730  ValueError: ...
1731
1732  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1733  all specified constraints, `message`, as well as the first `summarize` entries
1734  of the first encountered violating tensor are printed, and
1735  `InvalidArgumentError` is raised.
1736
1737  Size entries in the specified shapes are checked against other entries by
1738  their __hash__, except:
1739    - a size entry is interpreted as an explicit size if it can be parsed as an
1740      integer primitive.
1741    - a size entry is interpreted as *any* size if it is None or '.'.
1742
1743  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1744  a variable number of outer dimensions of unspecified size, i.e. the constraint
1745  applies to the inner-most dimensions only.
1746
1747  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1748  prefix) are both treated as having a single dimension of size one.
1749
1750  Args:
1751    shapes: dictionary with (`Tensor` to shape) items, or a list of
1752      (`Tensor`, shape) tuples. A shape must be an iterable.
1753    data: The tensors to print out if the condition is False.  Defaults to error
1754      message and first few entries of the violating tensor.
1755    summarize: Print this many entries of the tensor.
1756    message: A string to prefix to the default message.
1757    name: A name for this operation (optional).  Defaults to "assert_shapes".
1758
1759  Raises:
1760    ValueError:  If static checks determine any shape constraint is violated.
1761  """
1762  assert_shapes(
1763      shapes, data=data, summarize=summarize, message=message, name=name)
1764
1765
1766@tf_export(v1=['debugging.assert_shapes'])
1767@dispatch.add_dispatch_support
1768def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
1769  """Assert tensor shapes and dimension size relationships between tensors.
1770
1771  This Op checks that a collection of tensors shape relationships
1772  satisfies given constraints.
1773
1774  Example:
1775
1776  >>> n = 10
1777  >>> q = 3
1778  >>> d = 7
1779  >>> x = tf.zeros([n,q])
1780  >>> y = tf.ones([n,d])
1781  >>> param = tf.Variable([1.0, 2.0, 3.0])
1782  >>> scalar = 1.0
1783  >>> tf.debugging.assert_shapes([
1784  ...  (x, ('N', 'Q')),
1785  ...  (y, ('N', 'D')),
1786  ...  (param, ('Q',)),
1787  ...  (scalar, ()),
1788  ... ])
1789
1790  >>> tf.debugging.assert_shapes([
1791  ...   (x, ('N', 'D')),
1792  ...   (y, ('N', 'D'))
1793  ... ])
1794  Traceback (most recent call last):
1795  ...
1796  ValueError: ...
1797
1798  Example of adding a dependency to an operation:
1799
1800  ```python
1801  with tf.control_dependencies([tf.assert_shapes(shapes)]):
1802    output = tf.matmul(x, y, transpose_a=True)
1803  ```
1804
1805  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1806  all specified constraints, `message`, as well as the first `summarize` entries
1807  of the first encountered violating tensor are printed, and
1808  `InvalidArgumentError` is raised.
1809
1810  Size entries in the specified shapes are checked against other entries by
1811  their __hash__, except:
1812    - a size entry is interpreted as an explicit size if it can be parsed as an
1813      integer primitive.
1814    - a size entry is interpreted as *any* size if it is None or '.'.
1815
1816  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1817  a variable number of outer dimensions of unspecified size, i.e. the constraint
1818  applies to the inner-most dimensions only.
1819
1820  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1821  prefix) are both treated as having a single dimension of size one.
1822
1823  Args:
1824    shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
1825      expected shape of `Tensor`. See the example code above. The `shape` must
1826      be an iterable. Each element of the iterable can be either a concrete
1827      integer value or a string that abstractly represents the dimension.
1828      For example,
1829        - `('N', 'Q')` specifies a 2D shape wherein the first and second
1830          dimensions of shape may or may not be equal.
1831        - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
1832          dimensions are equal.
1833        - `(1, 'N')` specifies a 2D shape wherein the first dimension is
1834          exactly 1 and the second dimension can be any value.
1835      Note that the abstract dimension letters take effect across different
1836      tuple elements of the list. For example,
1837      `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
1838      that both `x` and `y` are rank-2 tensors and their first dimensions are
1839      equal (`N`).
1840      `shape` can also be a `tf.TensorShape`.
1841    data: The tensors to print out if the condition is False.  Defaults to error
1842      message and first few entries of the violating tensor.
1843    summarize: Print this many entries of the tensor.
1844    message: A string to prefix to the default message.
1845    name: A name for this operation (optional).  Defaults to "assert_shapes".
1846
1847  Returns:
1848    Op raising `InvalidArgumentError` unless all shape constraints are
1849    satisfied.
1850    If static checks determine all constraints are satisfied, a `no_op` is
1851    returned.
1852
1853  Raises:
1854    ValueError:  If static checks determine any shape constraint is violated.
1855  """
1856  # If the user manages to assemble a dict containing tensors (possible in
1857  # Graph mode only), make sure we still accept that.
1858  if isinstance(shapes, dict):
1859    shapes = shapes.items()
1860
1861  message = message or ''
1862  with ops.name_scope(name, 'assert_shapes', [shapes, data]):
1863    # Shape specified as None implies no constraint
1864    shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
1865                          ops.convert_to_tensor(x), s)
1866                         for x, s in shapes if s is not None]
1867
1868    executing_eagerly = context.executing_eagerly()
1869
1870    def tensor_name(x):
1871      if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
1872        return _shape_and_dtype_str(x)
1873      return x.name
1874
1875    tensor_dim_sizes = []
1876    for tensor, symbolic_shape in shape_constraints:
1877      is_iterable = (
1878          hasattr(symbolic_shape, '__iter__') or
1879          hasattr(symbolic_shape, '__getitem__')  # For Python 2 compat.
1880      )
1881      if not is_iterable:
1882        raise ValueError(
1883            '%s.  '
1884            'Tensor %s.  Specified shape must be an iterable.  '
1885            'An iterable has the attribute `__iter__` or `__getitem__`.  '
1886            'Received specified shape: %s' %
1887            (message, tensor_name(tensor), symbolic_shape))
1888
1889      # We convert this into a tuple to handle strings, lists and numpy arrays
1890      symbolic_shape_tuple = tuple(symbolic_shape)
1891
1892      tensors_specified_innermost = False
1893      for i, symbol in enumerate(symbolic_shape_tuple):
1894        if symbol not in [Ellipsis, '*']:
1895          continue
1896
1897        if i != 0:
1898          raise ValueError(
1899              '%s.  '
1900              'Tensor %s specified shape index %d.  '
1901              'Symbol `...` or `*` for a variable number of '
1902              'unspecified dimensions is only allowed as the first entry' %
1903              (message, tensor_name(tensor), i))
1904
1905        tensors_specified_innermost = True
1906
1907      # Only include the size of the specified dimensions since the 0th symbol
1908      # is either ellipsis or *
1909      tensor_dim_sizes.append(
1910          _TensorDimSizes(
1911              tensor, tensors_specified_innermost, _dimension_sizes(tensor),
1912              _symbolic_dimension_sizes(
1913                  symbolic_shape_tuple[1:]
1914                  if tensors_specified_innermost else symbolic_shape_tuple)))
1915
1916    rank_assertions = []
1917    for sizes in tensor_dim_sizes:
1918      rank = len(sizes.symbolic_sizes)
1919      rank_zero_or_one = rank in [0, 1]
1920      if sizes.unspecified_dim:
1921        if rank_zero_or_one:
1922          # No assertion of rank needed as `x` only need to have rank at least
1923          # 0. See elif rank_zero_or_one case comment.
1924          continue
1925        assertion = assert_rank_at_least(
1926            x=sizes.x,
1927            rank=rank,
1928            data=data,
1929            summarize=summarize,
1930            message=message,
1931            name=name)
1932      elif rank_zero_or_one:
1933        # Rank 0 is treated as rank 1 size 1, i.e. there is
1934        # no distinction between the two in terms of rank.
1935        # See _dimension_sizes.
1936        assertion = assert_rank_in(
1937            x=sizes.x,
1938            ranks=[0, 1],
1939            data=data,
1940            summarize=summarize,
1941            message=message,
1942            name=name)
1943      else:
1944        assertion = assert_rank(
1945            x=sizes.x,
1946            rank=rank,
1947            data=data,
1948            summarize=summarize,
1949            message=message,
1950            name=name)
1951      rank_assertions.append(assertion)
1952
1953    size_assertions = []
1954    size_specifications = {}
1955    for sizes in tensor_dim_sizes:
1956      for i, size_symbol in enumerate(sizes.symbolic_sizes):
1957
1958        if _is_symbol_for_any_size(size_symbol):
1959          # Size specified as any implies no constraint
1960          continue
1961
1962        if sizes.unspecified_dim:
1963          tensor_dim = i - len(sizes.symbolic_sizes)
1964        else:
1965          tensor_dim = i
1966
1967        if size_symbol in size_specifications or _has_known_value(size_symbol):
1968          if _has_known_value(size_symbol):
1969            specified_size = int(size_symbol)
1970            size_check_message = 'Specified explicitly'
1971          else:
1972            specified_size, specified_by_y, specified_at_dim = \
1973                size_specifications[size_symbol]
1974            size_check_message = (
1975                'Specified by tensor %s dimension %d' %
1976                (tensor_name(specified_by_y), specified_at_dim))
1977
1978          # This is extremely subtle. If actual_sizes is dynamic, we must
1979          # make sure a control dependency is inserted here so that this slice
1980          # can not execute until the rank is asserted to be enough for the
1981          # slice to not fail.
1982          with ops.control_dependencies(rank_assertions):
1983            actual_size = sizes.actual_sizes[tensor_dim]
1984          if _has_known_value(actual_size) and _has_known_value(specified_size):
1985            if int(actual_size) != int(specified_size):
1986              raise ValueError(
1987                  '%s.  %s.  Tensor %s dimension %s must have size %d.  '
1988                  'Received size %d, shape %s' %
1989                  (message, size_check_message, tensor_name(sizes.x),
1990                   tensor_dim, specified_size, actual_size,
1991                   sizes.x.get_shape()))
1992            # No dynamic assertion needed
1993            continue
1994
1995          condition = math_ops.equal(
1996              ops.convert_to_tensor(actual_size),
1997              ops.convert_to_tensor(specified_size))
1998          data_ = data
1999          if data is None:
2000            data_ = [
2001                message, size_check_message,
2002                'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
2003                'must have size', specified_size, 'Received shape: ',
2004                array_ops.shape(sizes.x)
2005            ]
2006          size_assertions.append(
2007              control_flow_ops.Assert(condition, data_, summarize=summarize))
2008        else:
2009          # Not sure if actual_sizes is a constant, but for safety, guard
2010          # on rank. See explanation above about actual_sizes need for safety.
2011          with ops.control_dependencies(rank_assertions):
2012            size = sizes.actual_sizes[tensor_dim]
2013          size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
2014
2015  # Ensure both assertions actually occur.
2016  with ops.control_dependencies(rank_assertions):
2017    shapes_assertion = control_flow_ops.group(size_assertions)
2018
2019  return shapes_assertion
2020
2021
2022# pylint: disable=line-too-long
2023def _get_diff_for_monotonic_comparison(x):
2024  """Gets the difference x[1:] - x[:-1]."""
2025  x = array_ops.reshape(x, [-1])
2026  if not is_numeric_tensor(x):
2027    raise TypeError('Expected x to be numeric, instead found: %s' % x)
2028
2029  # If x has less than 2 elements, there is nothing to compare.  So return [].
2030  is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
2031  short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
2032
2033  # With 2 or more elements, return x[1:] - x[:-1]
2034  s_len = array_ops.shape(x) - 1
2035  diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
2036  return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
2037
2038
2039@tf_export(
2040    'debugging.is_numeric_tensor',
2041    v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
2042@deprecation.deprecated_endpoints('is_numeric_tensor')
2043def is_numeric_tensor(tensor):
2044  """Returns `True` if the elements of `tensor` are numbers.
2045
2046  Specifically, returns `True` if the dtype of `tensor` is one of the following:
2047
2048  * `tf.float32`
2049  * `tf.float64`
2050  * `tf.int8`
2051  * `tf.int16`
2052  * `tf.int32`
2053  * `tf.int64`
2054  * `tf.uint8`
2055  * `tf.qint8`
2056  * `tf.qint32`
2057  * `tf.quint8`
2058  * `tf.complex64`
2059
2060  Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
2061  a `tf.Tensor` object.
2062  """
2063  return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
2064
2065
2066@tf_export(
2067    'math.is_non_decreasing',
2068    v1=[
2069        'math.is_non_decreasing', 'debugging.is_non_decreasing',
2070        'is_non_decreasing'
2071    ])
2072@dispatch.add_dispatch_support
2073@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
2074                                  'is_non_decreasing')
2075def is_non_decreasing(x, name=None):
2076  """Returns `True` if `x` is non-decreasing.
2077
2078  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2079  is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
2080  If `x` has less than two elements, it is trivially non-decreasing.
2081
2082  See also:  `is_strictly_increasing`
2083
2084  >>> x1 = tf.constant([1.0, 1.0, 3.0])
2085  >>> tf.math.is_non_decreasing(x1)
2086  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2087  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2088  >>> tf.math.is_non_decreasing(x2)
2089  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2090
2091  Args:
2092    x: Numeric `Tensor`.
2093    name: A name for this operation (optional).  Defaults to "is_non_decreasing"
2094
2095  Returns:
2096    Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
2097
2098  Raises:
2099    TypeError: if `x` is not a numeric tensor.
2100  """
2101  with ops.name_scope(name, 'is_non_decreasing', [x]):
2102    diff = _get_diff_for_monotonic_comparison(x)
2103    # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
2104    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2105    return math_ops.reduce_all(math_ops.less_equal(zero, diff))
2106
2107
2108@tf_export(
2109    'math.is_strictly_increasing',
2110    v1=[
2111        'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
2112        'is_strictly_increasing'
2113    ])
2114@dispatch.add_dispatch_support
2115@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
2116                                  'is_strictly_increasing')
2117def is_strictly_increasing(x, name=None):
2118  """Returns `True` if `x` is strictly increasing.
2119
2120  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2121  is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
2122  If `x` has less than two elements, it is trivially strictly increasing.
2123
2124  See also:  `is_non_decreasing`
2125
2126  >>> x1 = tf.constant([1.0, 2.0, 3.0])
2127  >>> tf.math.is_strictly_increasing(x1)
2128  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2129  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2130  >>> tf.math.is_strictly_increasing(x2)
2131  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2132
2133  Args:
2134    x: Numeric `Tensor`.
2135    name: A name for this operation (optional).
2136      Defaults to "is_strictly_increasing"
2137
2138  Returns:
2139    Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
2140
2141  Raises:
2142    TypeError: if `x` is not a numeric tensor.
2143  """
2144  with ops.name_scope(name, 'is_strictly_increasing', [x]):
2145    diff = _get_diff_for_monotonic_comparison(x)
2146    # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
2147    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2148    return math_ops.reduce_all(math_ops.less(zero, diff))
2149
2150
2151def _assert_same_base_type(items, expected_type=None):
2152  r"""Asserts all items are of the same base type.
2153
2154  Args:
2155    items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
2156        `Operation`, or `IndexedSlices`). Can include `None` elements, which
2157        will be ignored.
2158    expected_type: Expected type. If not specified, assert all items are
2159        of the same base type.
2160
2161  Returns:
2162    Validated type, or none if neither expected_type nor items provided.
2163
2164  Raises:
2165    ValueError: If any types do not match.
2166  """
2167  original_expected_type = expected_type
2168  mismatch = False
2169  for item in items:
2170    if item is not None:
2171      item_type = item.dtype.base_dtype
2172      if not expected_type:
2173        expected_type = item_type
2174      elif expected_type != item_type:
2175        mismatch = True
2176        break
2177  if mismatch:
2178    # Loop back through and build up an informative error message (this is very
2179    # slow, so we don't do it unless we found an error above).
2180    expected_type = original_expected_type
2181    original_item_str = None
2182    for item in items:
2183      if item is not None:
2184        item_type = item.dtype.base_dtype
2185        if not expected_type:
2186          expected_type = item_type
2187          original_item_str = item.name if hasattr(item, 'name') else str(item)
2188        elif expected_type != item_type:
2189          raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
2190              item.name if hasattr(item, 'name') else str(item),
2191              item_type, expected_type,
2192              (' as %s' % original_item_str) if original_item_str else ''))
2193    return expected_type  # Should be unreachable
2194  else:
2195    return expected_type
2196
2197
2198@tf_export(
2199    'debugging.assert_same_float_dtype',
2200    v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
2201@dispatch.add_dispatch_support
2202@deprecation.deprecated_endpoints('assert_same_float_dtype')
2203def assert_same_float_dtype(tensors=None, dtype=None):
2204  """Validate and return float type based on `tensors` and `dtype`.
2205
2206  For ops such as matrix multiplication, inputs and weights must be of the
2207  same float type. This function validates that all `tensors` are the same type,
2208  validates that type is `dtype` (if supplied), and returns the type. Type must
2209  be a floating point type. If neither `tensors` nor `dtype` is supplied,
2210  the function will return `dtypes.float32`.
2211
2212  Args:
2213    tensors: Tensors of input values. Can include `None` elements, which will be
2214        ignored.
2215    dtype: Expected type.
2216
2217  Returns:
2218    Validated type.
2219
2220  Raises:
2221    ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
2222        float, or the common type of the inputs is not a floating point type.
2223  """
2224  if tensors:
2225    dtype = _assert_same_base_type(tensors, dtype)
2226  if not dtype:
2227    dtype = dtypes.float32
2228  elif not dtype.is_floating:
2229    raise ValueError('Expected floating point type, got %s.' % dtype)
2230  return dtype
2231
2232
2233@tf_export('debugging.assert_scalar', v1=[])
2234@dispatch.add_dispatch_support
2235def assert_scalar_v2(tensor, message=None, name=None):
2236  """Asserts that the given `tensor` is a scalar.
2237
2238  This function raises `ValueError` unless it can be certain that the given
2239  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2240  unknown.
2241
2242  This is always checked statically, so this method returns nothing.
2243
2244  Args:
2245    tensor: A `Tensor`.
2246    message: A string to prefix to the default message.
2247    name:  A name for this operation. Defaults to "assert_scalar"
2248
2249  Raises:
2250    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2251      unknown.
2252  """
2253  assert_scalar(tensor=tensor, message=message, name=name)
2254
2255
2256@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
2257@dispatch.add_dispatch_support
2258@deprecation.deprecated_endpoints('assert_scalar')
2259def assert_scalar(tensor, name=None, message=None):
2260  """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
2261
2262  This function raises `ValueError` unless it can be certain that the given
2263  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2264  unknown.
2265
2266  Args:
2267    tensor: A `Tensor`.
2268    name:  A name for this operation. Defaults to "assert_scalar"
2269    message: A string to prefix to the default message.
2270
2271  Returns:
2272    The input tensor (potentially converted to a `Tensor`).
2273
2274  Raises:
2275    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2276      unknown.
2277  """
2278  with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
2279    tensor = ops.convert_to_tensor(tensor, name=name_scope)
2280    shape = tensor.get_shape()
2281    if shape.ndims != 0:
2282      if context.executing_eagerly():
2283        raise ValueError('%sExpected scalar shape, saw shape: %s.'
2284                         % (message or '', shape,))
2285      else:
2286        raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
2287                         % (message or '', tensor.name, shape))
2288    return tensor
2289
2290
2291@tf_export('ensure_shape')
2292@dispatch.add_dispatch_support
2293def ensure_shape(x, shape, name=None):
2294  """Updates the shape of a tensor and checks at runtime that the shape holds.
2295
2296  When executed, this operation asserts that the input tensor `x`'s shape
2297  is compatible with the `shape` argument.
2298  See `tf.TensorShape.is_compatible_with` for details.
2299
2300  >>> x = tf.constant([[1, 2, 3],
2301  ...                  [4, 5, 6]])
2302  >>> x = tf.ensure_shape(x, [2, 3])
2303
2304  Use `None` for unknown dimensions:
2305
2306  >>> x = tf.ensure_shape(x, [None, 3])
2307  >>> x = tf.ensure_shape(x, [2, None])
2308
2309  If the tensor's shape is not compatible with the `shape` argument, an error
2310  is raised:
2311
2312  >>> x = tf.ensure_shape(x, [5])
2313  Traceback (most recent call last):
2314  ...
2315  tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
2316    compatible with expected shape [5]. [Op:EnsureShape]
2317
2318  During graph construction (typically tracing a `tf.function`),
2319  `tf.ensure_shape` updates the static-shape of the **result** tensor by
2320  merging the two shapes. See `tf.TensorShape.merge_with` for details.
2321
2322  This is most useful when **you** know a shape that can't be determined
2323  statically by TensorFlow.
2324
2325  The following trivial `tf.function` prints the input tensor's
2326  static-shape before and after `ensure_shape` is applied.
2327
2328  >>> @tf.function
2329  ... def f(tensor):
2330  ...   print("Static-shape before:", tensor.shape)
2331  ...   tensor = tf.ensure_shape(tensor, [None, 3])
2332  ...   print("Static-shape after:", tensor.shape)
2333  ...   return tensor
2334
2335  This lets you see the effect of `tf.ensure_shape` when the function is traced:
2336  >>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
2337  Static-shape before: (None, None)
2338  Static-shape after: (None, 3)
2339
2340  >>> cf(tf.zeros([3, 3])) # Passes
2341  >>> cf(tf.constant([1, 2, 3])) # fails
2342  Traceback (most recent call last):
2343  ...
2344  InvalidArgumentError:  Shape of tensor x [3] is not compatible with expected shape [3,3].
2345
2346  The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
2347  shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
2348
2349  Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
2350  runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
2351  checks the buildtime shape.
2352
2353  Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape
2354  of the resulting tensor and enforces it at runtime, raising an error if the
2355  tensor's runtime shape is incompatible with the specified shape.
2356  `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it
2357  at runtime, which may result in inconsistencies between the statically-known
2358  shape of tensors and the runtime value of tensors.
2359
2360  For example, of loading images of a known size:
2361
2362  >>> @tf.function
2363  ... def decode_image(png):
2364  ...   image = tf.image.decode_png(png, channels=3)
2365  ...   # the `print` executes during tracing.
2366  ...   print("Initial shape: ", image.shape)
2367  ...   image = tf.ensure_shape(image,[28, 28, 3])
2368  ...   print("Final shape: ", image.shape)
2369  ...   return image
2370
2371  When tracing a function, no ops are being executed, shapes may be unknown.
2372  See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)
2373  for details.
2374
2375  >>> concrete_decode = decode_image.get_concrete_function(
2376  ...     tf.TensorSpec([], dtype=tf.string))
2377  Initial shape:  (None, None, 3)
2378  Final shape:  (28, 28, 3)
2379
2380  >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
2381  >>> image = tf.cast(image,tf.uint8)
2382  >>> png = tf.image.encode_png(image)
2383  >>> image2 = concrete_decode(png)
2384  >>> print(image2.shape)
2385  (28, 28, 3)
2386
2387  >>> image = tf.concat([image,image], axis=0)
2388  >>> print(image.shape)
2389  (56, 28, 3)
2390  >>> png = tf.image.encode_png(image)
2391  >>> image2 = concrete_decode(png)
2392  Traceback (most recent call last):
2393  ...
2394  tf.errors.InvalidArgumentError:  Shape of tensor DecodePng [56,28,3] is not
2395    compatible with expected shape [28,28,3].
2396
2397  Caution: if you don't use the result of `tf.ensure_shape` the check may not
2398  run.
2399
2400  >>> @tf.function
2401  ... def bad_decode_image(png):
2402  ...   image = tf.image.decode_png(png, channels=3)
2403  ...   # the `print` executes during tracing.
2404  ...   print("Initial shape: ", image.shape)
2405  ...   # BAD: forgot to use the returned tensor.
2406  ...   tf.ensure_shape(image,[28, 28, 3])
2407  ...   print("Final shape: ", image.shape)
2408  ...   return image
2409
2410  >>> image = bad_decode_image(png)
2411  Initial shape:  (None, None, 3)
2412  Final shape:  (None, None, 3)
2413  >>> print(image.shape)
2414  (56, 28, 3)
2415
2416  Args:
2417    x: A `Tensor`.
2418    shape: A `TensorShape` representing the shape of this tensor, a
2419      `TensorShapeProto`, a list, a tuple, or None.
2420    name: A name for this operation (optional). Defaults to "EnsureShape".
2421
2422  Returns:
2423    A `Tensor`. Has the same type and contents as `x`.
2424
2425  Raises:
2426    tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape
2427    of `x`.
2428  """
2429  if not isinstance(shape, tensor_shape.TensorShape):
2430    shape = tensor_shape.TensorShape(shape)
2431
2432  return array_ops.ensure_shape(x, shape, name=name)
2433
2434
2435@ops.RegisterGradient('EnsureShape')
2436def _ensure_shape_grad(op, grad):
2437  del op  # Unused.
2438  return grad
2439