1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-short-docstring-punctuation 16"""Asserts and Boolean Checks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24import numpy as np 25 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.util import compat 37from tensorflow.python.util import deprecation 38from tensorflow.python.util import dispatch 39from tensorflow.python.util.tf_export import tf_export 40 41NUMERIC_TYPES = frozenset( 42 [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, 43 dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, 44 dtypes.complex64]) 45 46__all__ = [ 47 'assert_negative', 48 'assert_positive', 49 'assert_proper_iterable', 50 'assert_non_negative', 51 'assert_non_positive', 52 'assert_equal', 53 'assert_none_equal', 54 'assert_near', 55 'assert_integer', 56 'assert_less', 57 'assert_less_equal', 58 'assert_greater', 59 'assert_greater_equal', 60 'assert_rank', 61 'assert_rank_at_least', 62 'assert_rank_in', 63 'assert_same_float_dtype', 64 'assert_scalar', 65 'assert_type', 66 'assert_shapes', 67 'is_non_decreasing', 68 'is_numeric_tensor', 69 'is_strictly_increasing', 70] 71 72 73def _maybe_constant_value_string(t): 74 if not isinstance(t, ops.Tensor): 75 return str(t) 76 const_t = tensor_util.constant_value(t) 77 if const_t is not None: 78 return str(const_t) 79 return t 80 81 82def _assert_static(condition, data): 83 """Raises a InvalidArgumentError with as much information as possible.""" 84 if not condition: 85 data_static = [_maybe_constant_value_string(x) for x in data] 86 raise errors.InvalidArgumentError(node_def=None, op=None, 87 message='\n'.join(data_static)) 88 89 90def _shape_and_dtype_str(tensor): 91 """Returns a string containing tensor's shape and dtype.""" 92 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 93 94 95def _unary_assert_doc(sym, sym_name): 96 """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor. 97 98 Args: 99 sym: Mathematical symbol for the check performed on each element, i.e. "> 0" 100 sym_name: English-language name for the op described by sym 101 102 Returns: 103 Decorator that adds the appropriate docstring to the function for symbol 104 `sym`. 105 """ 106 107 def _decorator(func): 108 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 109 110 Args: 111 func: Function for a TensorFlow op 112 113 Returns: 114 Version of `func` with documentation attached. 115 """ 116 opname = func.__name__ 117 cap_sym_name = sym_name.capitalize() 118 119 func.__doc__ = """ 120 Assert the condition `x {sym}` holds element-wise. 121 122 When running in graph mode, you should add a dependency on this operation 123 to ensure that it runs. Example of adding a dependency to an operation: 124 125 ```python 126 with tf.control_dependencies([tf.debugging.{opname}(x, y)]): 127 output = tf.reduce_sum(x) 128 ``` 129 130 {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`. 131 If `x` is empty this is trivially satisfied. 132 133 Args: 134 x: Numeric `Tensor`. 135 data: The tensors to print out if the condition is False. Defaults to 136 error message and first few entries of `x`. 137 summarize: Print this many entries of each tensor. 138 message: A string to prefix to the default message. 139 name: A name for this operation (optional). Defaults to "{opname}". 140 141 Returns: 142 Op that raises `InvalidArgumentError` if `x {sym}` is False. 143 @compatibility(eager) 144 returns None 145 @end_compatibility 146 147 Raises: 148 InvalidArgumentError: if the check can be performed immediately and 149 `x {sym}` is False. The check can be performed immediately during 150 eager execution or if `x` is statically known. 151 """.format( 152 sym=sym, sym_name=cap_sym_name, opname=opname) 153 return func 154 155 return _decorator 156 157 158def _binary_assert_doc(sym, test_var): 159 """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise. 160 161 Args: 162 sym: Binary operation symbol, i.e. "==" 163 test_var: a string that represents the variable in the right-hand side of 164 binary operator of the test case 165 166 Returns: 167 Decorator that adds the appropriate docstring to the function for 168 symbol `sym`. 169 """ 170 171 def _decorator(func): 172 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 173 174 Args: 175 func: Function for a TensorFlow op 176 177 Returns: 178 A version of `func` with documentation attached. 179 """ 180 opname = func.__name__ 181 182 func.__doc__ = """ 183 Assert the condition `x {sym} y` holds element-wise. 184 185 This condition holds if for every pair of (possibly broadcast) elements 186 `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`. 187 If both `x` and `y` are empty, this is trivially satisfied. 188 189 When running in graph mode, you should add a dependency on this operation 190 to ensure that it runs. Example of adding a dependency to an operation: 191 192 ```python 193 with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]): 194 output = tf.reduce_sum(x) 195 ``` 196 197 Args: 198 x: Numeric `Tensor`. 199 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 200 data: The tensors to print out if the condition is False. Defaults to 201 error message and first few entries of `x`, `y`. 202 summarize: Print this many entries of each tensor. 203 message: A string to prefix to the default message. 204 name: A name for this operation (optional). Defaults to "{opname}". 205 206 Returns: 207 Op that raises `InvalidArgumentError` if `x {sym} y` is False. 208 209 Raises: 210 InvalidArgumentError: if the check can be performed immediately and 211 `x {sym} y` is False. The check can be performed immediately during 212 eager execution or if `x` and `y` are statically known. 213 214 @compatibility(TF2) 215 `tf.compat.v1.{opname}` is compatible with eager execution and 216 `tf.function`. 217 Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart 218 from `data`, all arguments are supported with the same argument name. 219 220 If you want to ensure the assert statements run before the 221 potentially-invalid computation, please use `tf.control_dependencies`, 222 as tf.function auto-control dependencies are insufficient for assert 223 statements. 224 225 #### Structural Mapping to Native TF2 226 227 Before: 228 229 ```python 230 tf.compat.v1.{opname}( 231 x=x, y=y, data=data, summarize=summarize, 232 message=message, name=name) 233 ``` 234 235 After: 236 237 ```python 238 tf.debugging.{opname}( 239 x=x, y=y, message=message, 240 summarize=summarize, name=name) 241 ``` 242 243 #### TF1 & TF2 Usage Example 244 245 TF1: 246 247 >>> g = tf.Graph() 248 >>> with g.as_default(): 249 ... a = tf.compat.v1.placeholder(tf.float32, [2]) 250 ... b = tf.compat.v1.placeholder(tf.float32, [2]) 251 ... result = tf.compat.v1.{opname}(a, b, 252 ... message='"a {sym} b" does not hold for the given inputs') 253 ... with tf.compat.v1.control_dependencies([result]): 254 ... sum_node = a + b 255 >>> sess = tf.compat.v1.Session(graph=g) 256 >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}}) 257 258 259 TF2: 260 261 >>> a = tf.Variable([1, 2], dtype=tf.float32) 262 >>> b = tf.Variable({test_var}, dtype=tf.float32) 263 >>> assert_op = tf.debugging.{opname}(a, b, message= 264 ... '"a {sym} b" does not hold for the given inputs') 265 >>> # When working with tf.control_dependencies 266 >>> with tf.control_dependencies([assert_op]): 267 ... val = a + b 268 269 @end_compatibility 270 """.format( 271 sym=sym, opname=opname, test_var=test_var) 272 return func 273 274 return _decorator 275 276 277def _make_assert_msg_data(sym, x, y, summarize, test_op): 278 """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode. 279 280 Args: 281 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 282 i.e. "==" 283 x: First input to the assertion after applying `convert_to_tensor()` 284 y: Second input to the assertion 285 summarize: Value of the "summarize" parameter to the original assert_* call; 286 tells how many elements of each tensor to print. 287 test_op: TensorFlow op that returns a Boolean tensor with True in each 288 position where the assertion is satisfied. 289 290 Returns: 291 List of tensors and scalars that, when stringified and concatenated, 292 will produce the error message string. 293 """ 294 # Prepare a message with first elements of x and y. 295 data = [] 296 297 data.append('Condition x %s y did not hold.' % sym) 298 299 if summarize > 0: 300 if x.shape == y.shape and x.shape.as_list(): 301 # If the shapes of x and y are the same (and not scalars), 302 # Get the values that actually differed and their indices. 303 # If shapes are different this information is more confusing 304 # than useful. 305 mask = math_ops.logical_not(test_op) 306 indices = array_ops.where(mask) 307 indices_np = indices.numpy() 308 x_vals = array_ops.boolean_mask(x, mask) 309 y_vals = array_ops.boolean_mask(y, mask) 310 num_vals = min(summarize, indices_np.shape[0]) 311 data.append('Indices of first %d different values:' % num_vals) 312 data.append(indices_np[:num_vals]) 313 data.append('Corresponding x values:') 314 data.append(x_vals.numpy().reshape((-1,))[:num_vals]) 315 data.append('Corresponding y values:') 316 data.append(y_vals.numpy().reshape((-1,))[:num_vals]) 317 318 # reshape((-1,)) is the fastest way to get a flat array view. 319 x_np = x.numpy().reshape((-1,)) 320 y_np = y.numpy().reshape((-1,)) 321 x_sum = min(x_np.size, summarize) 322 y_sum = min(y_np.size, summarize) 323 data.append('First %d elements of x:' % x_sum) 324 data.append(x_np[:x_sum]) 325 data.append('First %d elements of y:' % y_sum) 326 data.append(y_np[:y_sum]) 327 328 return data 329 330 331def _pretty_print(data_item, summarize): 332 """Format a data item for use in an error message in eager mode. 333 334 Args: 335 data_item: One of the items in the "data" argument to an assert_* function. 336 Can be a Tensor or a scalar value. 337 summarize: How many elements to retain of each tensor-valued entry in data. 338 339 Returns: 340 An appropriate string representation of data_item 341 """ 342 if isinstance(data_item, ops.Tensor): 343 arr = data_item.numpy() 344 if np.isscalar(arr): 345 # Tensor.numpy() returns a scalar for zero-dimensional tensors 346 return str(arr) 347 else: 348 flat = arr.reshape((-1,)) 349 lst = [str(x) for x in flat[:summarize]] 350 if len(lst) < flat.size: 351 lst.append('...') 352 return str(lst) 353 else: 354 return str(data_item) 355 356 357def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, 358 message, name): 359 """Generic binary elementwise assertion. 360 361 Implements the behavior described in _binary_assert_doc() above. 362 Args: 363 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 364 i.e. "==" 365 opname: Name of the assert op in the public API, i.e. "assert_equal" 366 op_func: Function that, if passed the two Tensor inputs to the assertion (x 367 and y), will return the test to be passed to reduce_all() i.e. 368 static_func: Function that, if passed numpy ndarray versions of the two 369 inputs to the assertion, will return a Boolean ndarray with containing 370 True in all positions where the assertion PASSES. 371 i.e. np.equal for assert_equal() 372 x: Numeric `Tensor`. 373 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 374 data: The tensors to print out if the condition is False. Defaults to 375 error message and first few entries of `x`, `y`. 376 summarize: Print this many entries of each tensor. 377 message: A string to prefix to the default message. 378 name: A name for this operation (optional). Defaults to the value of 379 `opname`. 380 381 Returns: 382 See docstring template in _binary_assert_doc(). 383 """ 384 with ops.name_scope(name, opname, [x, y, data]): 385 x = ops.convert_to_tensor(x, name='x') 386 y = ops.convert_to_tensor(y, name='y') 387 388 if context.executing_eagerly(): 389 test_op = op_func(x, y) 390 condition = math_ops.reduce_all(test_op) 391 if condition: 392 return 393 394 # If we get here, the assertion has failed. 395 # Default to printing 3 elements like control_flow_ops.Assert (used 396 # by graph mode) does. Also treat negative values as "print 397 # everything" for consistency with Tensor::SummarizeValue(). 398 if summarize is None: 399 summarize = 3 400 elif summarize < 0: 401 summarize = 1e9 # Code below will find exact size of x and y. 402 403 if data is None: 404 data = _make_assert_msg_data(sym, x, y, summarize, test_op) 405 406 if message is not None: 407 data = [message] + list(data) 408 409 raise errors.InvalidArgumentError( 410 node_def=None, 411 op=None, 412 message=('\n'.join(_pretty_print(d, summarize) for d in data))) 413 414 else: # not context.executing_eagerly() 415 if data is None: 416 data = [ 417 'Condition x %s y did not hold element-wise:' % sym, 418 'x (%s) = ' % x.name, x, 419 'y (%s) = ' % y.name, y 420 ] 421 if message is not None: 422 data = [message] + list(data) 423 condition = math_ops.reduce_all(op_func(x, y)) 424 x_static = tensor_util.constant_value(x) 425 y_static = tensor_util.constant_value(y) 426 if x_static is not None and y_static is not None: 427 condition_static = np.all(static_func(x_static, y_static)) 428 _assert_static(condition_static, data) 429 return control_flow_ops.Assert(condition, data, summarize=summarize) 430 431 432@tf_export( 433 'debugging.assert_proper_iterable', 434 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) 435@dispatch.add_dispatch_support 436@deprecation.deprecated_endpoints('assert_proper_iterable') 437def assert_proper_iterable(values): 438 """Static assert that values is a "proper" iterable. 439 440 `Ops` that expect iterables of `Tensor` can call this to validate input. 441 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 442 443 Args: 444 values: Object to be checked. 445 446 Raises: 447 TypeError: If `values` is not iterable or is one of 448 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 449 """ 450 unintentional_iterables = ( 451 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 452 + compat.bytes_or_text_types 453 ) 454 if isinstance(values, unintentional_iterables): 455 raise TypeError( 456 'Expected argument "values" to be a "proper" iterable. Found: %s' % 457 type(values)) 458 459 if not hasattr(values, '__iter__'): 460 raise TypeError( 461 'Expected argument "values" to be iterable. Found: %s' % type(values)) 462 463 464@tf_export('debugging.assert_negative', v1=[]) 465@dispatch.add_dispatch_support 466def assert_negative_v2(x, message=None, summarize=None, name=None): 467 """Assert the condition `x < 0` holds element-wise. 468 469 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is 470 empty, this is trivially satisfied. 471 472 If `x` is not negative everywhere, `message`, as well as the first `summarize` 473 entries of `x` are printed, and `InvalidArgumentError` is raised. 474 475 Args: 476 x: Numeric `Tensor`. 477 message: A string to prefix to the default message. 478 summarize: Print this many entries of each tensor. 479 name: A name for this operation (optional). Defaults to "assert_negative". 480 481 Returns: 482 Op raising `InvalidArgumentError` unless `x` is all negative. This can be 483 used with `tf.control_dependencies` inside of `tf.function`s to block 484 followup computation until the check has executed. 485 @compatibility(eager) 486 returns None 487 @end_compatibility 488 489 Raises: 490 InvalidArgumentError: if the check can be performed immediately and 491 `x[i] < 0` is False. The check can be performed immediately during eager 492 execution or if `x` is statically known. 493 """ 494 return assert_negative(x=x, message=message, summarize=summarize, name=name) 495 496 497@tf_export(v1=['debugging.assert_negative', 'assert_negative']) 498@dispatch.add_dispatch_support 499@deprecation.deprecated_endpoints('assert_negative') 500@_unary_assert_doc('< 0', 'negative') 501def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 502 message = message or '' 503 with ops.name_scope(name, 'assert_negative', [x, data]): 504 x = ops.convert_to_tensor(x, name='x') 505 if data is None: 506 if context.executing_eagerly(): 507 name = _shape_and_dtype_str(x) 508 else: 509 name = x.name 510 data = [ 511 message, 512 'Condition x < 0 did not hold element-wise:', 513 'x (%s) = ' % name, x] 514 zero = ops.convert_to_tensor(0, dtype=x.dtype) 515 return assert_less(x, zero, data=data, summarize=summarize) 516 517 518@tf_export('debugging.assert_positive', v1=[]) 519@dispatch.add_dispatch_support 520def assert_positive_v2(x, message=None, summarize=None, name=None): 521 """Assert the condition `x > 0` holds element-wise. 522 523 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is 524 empty, this is trivially satisfied. 525 526 If `x` is not positive everywhere, `message`, as well as the first `summarize` 527 entries of `x` are printed, and `InvalidArgumentError` is raised. 528 529 Args: 530 x: Numeric `Tensor`. 531 message: A string to prefix to the default message. 532 summarize: Print this many entries of each tensor. 533 name: A name for this operation (optional). Defaults to "assert_positive". 534 535 Returns: 536 Op raising `InvalidArgumentError` unless `x` is all positive. This can be 537 used with `tf.control_dependencies` inside of `tf.function`s to block 538 followup computation until the check has executed. 539 @compatibility(eager) 540 returns None 541 @end_compatibility 542 543 Raises: 544 InvalidArgumentError: if the check can be performed immediately and 545 `x[i] > 0` is False. The check can be performed immediately during eager 546 execution or if `x` is statically known. 547 """ 548 return assert_positive(x=x, summarize=summarize, message=message, name=name) 549 550 551@tf_export(v1=['debugging.assert_positive', 'assert_positive']) 552@dispatch.add_dispatch_support 553@deprecation.deprecated_endpoints('assert_positive') 554@_unary_assert_doc('> 0', 'positive') 555def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 556 message = message or '' 557 with ops.name_scope(name, 'assert_positive', [x, data]): 558 x = ops.convert_to_tensor(x, name='x') 559 if data is None: 560 if context.executing_eagerly(): 561 name = _shape_and_dtype_str(x) 562 else: 563 name = x.name 564 data = [ 565 message, 'Condition x > 0 did not hold element-wise:', 566 'x (%s) = ' % name, x] 567 zero = ops.convert_to_tensor(0, dtype=x.dtype) 568 return assert_less(zero, x, data=data, summarize=summarize) 569 570 571@tf_export('debugging.assert_non_negative', v1=[]) 572@dispatch.add_dispatch_support 573def assert_non_negative_v2(x, message=None, summarize=None, name=None): 574 """Assert the condition `x >= 0` holds element-wise. 575 576 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is 577 empty, this is trivially satisfied. 578 579 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize` 580 entries of `x` are printed, and `InvalidArgumentError` is raised. 581 582 Args: 583 x: Numeric `Tensor`. 584 message: A string to prefix to the default message. 585 summarize: Print this many entries of each tensor. 586 name: A name for this operation (optional). Defaults to 587 "assert_non_negative". 588 589 Returns: 590 Op raising `InvalidArgumentError` unless `x` is all non-negative. This can 591 be used with `tf.control_dependencies` inside of `tf.function`s to block 592 followup computation until the check has executed. 593 @compatibility(eager) 594 returns None 595 @end_compatibility 596 597 Raises: 598 InvalidArgumentError: if the check can be performed immediately and 599 `x[i] >= 0` is False. The check can be performed immediately during eager 600 execution or if `x` is statically known. 601 """ 602 return assert_non_negative(x=x, summarize=summarize, message=message, 603 name=name) 604 605 606@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) 607@dispatch.add_dispatch_support 608@deprecation.deprecated_endpoints('assert_non_negative') 609@_unary_assert_doc('>= 0', 'non-negative') 610def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 611 message = message or '' 612 with ops.name_scope(name, 'assert_non_negative', [x, data]): 613 x = ops.convert_to_tensor(x, name='x') 614 if data is None: 615 if context.executing_eagerly(): 616 name = _shape_and_dtype_str(x) 617 else: 618 name = x.name 619 data = [ 620 message, 621 'Condition x >= 0 did not hold element-wise:', 622 'x (%s) = ' % name, x] 623 zero = ops.convert_to_tensor(0, dtype=x.dtype) 624 return assert_less_equal(zero, x, data=data, summarize=summarize) 625 626 627@tf_export('debugging.assert_non_positive', v1=[]) 628@dispatch.add_dispatch_support 629def assert_non_positive_v2(x, message=None, summarize=None, name=None): 630 """Assert the condition `x <= 0` holds element-wise. 631 632 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is 633 empty, this is trivially satisfied. 634 635 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize` 636 entries of `x` are printed, and `InvalidArgumentError` is raised. 637 638 Args: 639 x: Numeric `Tensor`. 640 message: A string to prefix to the default message. 641 summarize: Print this many entries of each tensor. 642 name: A name for this operation (optional). Defaults to 643 "assert_non_positive". 644 645 Returns: 646 Op raising `InvalidArgumentError` unless `x` is all non-positive. This can 647 be used with `tf.control_dependencies` inside of `tf.function`s to block 648 followup computation until the check has executed. 649 @compatibility(eager) 650 returns None 651 @end_compatibility 652 653 Raises: 654 InvalidArgumentError: if the check can be performed immediately and 655 `x[i] <= 0` is False. The check can be performed immediately during eager 656 execution or if `x` is statically known. 657 """ 658 return assert_non_positive(x=x, summarize=summarize, message=message, 659 name=name) 660 661 662@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) 663@dispatch.add_dispatch_support 664@deprecation.deprecated_endpoints('assert_non_positive') 665@_unary_assert_doc('<= 0', 'non-positive') 666def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 667 message = message or '' 668 with ops.name_scope(name, 'assert_non_positive', [x, data]): 669 x = ops.convert_to_tensor(x, name='x') 670 if data is None: 671 if context.executing_eagerly(): 672 name = _shape_and_dtype_str(x) 673 else: 674 name = x.name 675 data = [ 676 message, 677 'Condition x <= 0 did not hold element-wise:' 678 'x (%s) = ' % name, x] 679 zero = ops.convert_to_tensor(0, dtype=x.dtype) 680 return assert_less_equal(x, zero, data=data, summarize=summarize) 681 682 683@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) 684@dispatch.add_dispatch_support 685def assert_equal_v2(x, y, message=None, summarize=None, name=None): 686 """Assert the condition `x == y` holds element-wise. 687 688 This Op checks that `x[i] == y[i]` holds for every pair of (possibly 689 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 690 trivially satisfied. 691 692 If `x` and `y` are not equal, `message`, as well as the first `summarize` 693 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. 694 695 Args: 696 x: Numeric `Tensor`. 697 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 698 message: A string to prefix to the default message. 699 summarize: Print this many entries of each tensor. 700 name: A name for this operation (optional). Defaults to "assert_equal". 701 702 Returns: 703 Op that raises `InvalidArgumentError` if `x == y` is False. This can be 704 used with `tf.control_dependencies` inside of `tf.function`s to block 705 followup computation until the check has executed. 706 @compatibility(eager) 707 returns None 708 @end_compatibility 709 710 Raises: 711 InvalidArgumentError: if the check can be performed immediately and 712 `x == y` is False. The check can be performed immediately during eager 713 execution or if `x` and `y` are statically known. 714 """ 715 return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name) 716 717 718@tf_export(v1=['debugging.assert_equal', 'assert_equal']) 719@dispatch.add_dispatch_support 720@_binary_assert_doc('==', '[1, 2]') 721def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 722 with ops.name_scope(name, 'assert_equal', [x, y, data]): 723 # Short-circuit if x and y are the same tensor. 724 if x is y: 725 return None if context.executing_eagerly() else control_flow_ops.no_op() 726 return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, 727 data, summarize, message, name) 728 729 730@tf_export('debugging.assert_none_equal', v1=[]) 731@dispatch.add_dispatch_support 732def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): 733 """Assert the condition `x != y` holds for all elements. 734 735 This Op checks that `x[i] != y[i]` holds for every pair of (possibly 736 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 737 trivially satisfied. 738 739 If any elements of `x` and `y` are equal, `message`, as well as the first 740 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 741 is raised. 742 743 Args: 744 x: Numeric `Tensor`. 745 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 746 summarize: Print this many entries of each tensor. 747 message: A string to prefix to the default message. 748 name: A name for this operation (optional). Defaults to 749 "assert_none_equal". 750 751 Returns: 752 Op that raises `InvalidArgumentError` if `x != y` is ever False. This can 753 be used with `tf.control_dependencies` inside of `tf.function`s to block 754 followup computation until the check has executed. 755 @compatibility(eager) 756 returns None 757 @end_compatibility 758 759 Raises: 760 InvalidArgumentError: if the check can be performed immediately and 761 `x != y` is False for any pair of elements in `x` and `y`. The check can 762 be performed immediately during eager execution or if `x` and `y` are 763 statically known. 764 """ 765 return assert_none_equal(x=x, y=y, summarize=summarize, message=message, 766 name=name) 767 768 769@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) 770@dispatch.add_dispatch_support 771@deprecation.deprecated_endpoints('assert_none_equal') 772@_binary_assert_doc('!=', '[2, 1]') 773def assert_none_equal( 774 x, y, data=None, summarize=None, message=None, name=None): 775 return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, 776 np.not_equal, x, y, data, summarize, message, name) 777 778 779@tf_export('debugging.assert_near', v1=[]) 780@dispatch.add_dispatch_support 781def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, 782 name=None): 783 """Assert the condition `x` and `y` are close element-wise. 784 785 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every 786 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are 787 empty, this is trivially satisfied. 788 789 If any elements of `x` and `y` are not close, `message`, as well as the first 790 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 791 is raised. 792 793 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 794 representable positive number such that `1 + eps != 1`. This is about 795 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 796 See `numpy.finfo`. 797 798 Args: 799 x: Float or complex `Tensor`. 800 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. 801 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 802 The relative tolerance. Default is `10 * eps`. 803 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 804 The absolute tolerance. Default is `10 * eps`. 805 message: A string to prefix to the default message. 806 summarize: Print this many entries of each tensor. 807 name: A name for this operation (optional). Defaults to "assert_near". 808 809 Returns: 810 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 811 This can be used with `tf.control_dependencies` inside of `tf.function`s 812 to block followup computation until the check has executed. 813 @compatibility(eager) 814 returns None 815 @end_compatibility 816 817 Raises: 818 InvalidArgumentError: if the check can be performed immediately and 819 `x != y` is False for any pair of elements in `x` and `y`. The check can 820 be performed immediately during eager execution or if `x` and `y` are 821 statically known. 822 823 @compatibility(numpy) 824 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 825 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 826 `64bit`, and even `16bit` data. 827 @end_compatibility 828 """ 829 return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, 830 message=message, name=name) 831 832 833@tf_export(v1=['debugging.assert_near', 'assert_near']) 834@dispatch.add_dispatch_support 835@deprecation.deprecated_endpoints('assert_near') 836def assert_near( 837 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 838 name=None): 839 """Assert the condition `x` and `y` are close element-wise. 840 841 Example of adding a dependency to an operation: 842 843 ```python 844 with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): 845 output = tf.reduce_sum(x) 846 ``` 847 848 This condition holds if for every pair of (possibly broadcast) elements 849 `x[i]`, `y[i]`, we have 850 851 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 852 853 If both `x` and `y` are empty, this is trivially satisfied. 854 855 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 856 representable positive number such that `1 + eps != 1`. This is about 857 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 858 See `numpy.finfo`. 859 860 Args: 861 x: Float or complex `Tensor`. 862 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 863 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 864 The relative tolerance. Default is `10 * eps`. 865 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 866 The absolute tolerance. Default is `10 * eps`. 867 data: The tensors to print out if the condition is False. Defaults to 868 error message and first few entries of `x`, `y`. 869 summarize: Print this many entries of each tensor. 870 message: A string to prefix to the default message. 871 name: A name for this operation (optional). Defaults to "assert_near". 872 873 Returns: 874 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 875 876 @compatibility(numpy) 877 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 878 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 879 `64bit`, and even `16bit` data. 880 @end_compatibility 881 """ 882 message = message or '' 883 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 884 x = ops.convert_to_tensor(x, name='x') 885 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 886 887 dtype = x.dtype 888 if dtype.is_complex: 889 dtype = dtype.real_dtype 890 eps = np.finfo(dtype.as_numpy_dtype).eps 891 rtol = 10 * eps if rtol is None else rtol 892 atol = 10 * eps if atol is None else atol 893 894 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) 895 atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) 896 897 if context.executing_eagerly(): 898 x_name = _shape_and_dtype_str(x) 899 y_name = _shape_and_dtype_str(y) 900 else: 901 x_name = x.name 902 y_name = y.name 903 904 if data is None: 905 data = [ 906 message, 907 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 908 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 909 ] 910 tol = atol + rtol * math_ops.abs(y) 911 diff = math_ops.abs(x - y) 912 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 913 return control_flow_ops.Assert(condition, data, summarize=summarize) 914 915 916@tf_export('debugging.assert_less', 'assert_less', v1=[]) 917@dispatch.add_dispatch_support 918def assert_less_v2(x, y, message=None, summarize=None, name=None): 919 """Assert the condition `x < y` holds element-wise. 920 921 This Op checks that `x[i] < y[i]` holds for every pair of (possibly 922 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 923 trivially satisfied. 924 925 If `x` is not less than `y` element-wise, `message`, as well as the first 926 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 927 raised. 928 929 Args: 930 x: Numeric `Tensor`. 931 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 932 message: A string to prefix to the default message. 933 summarize: Print this many entries of each tensor. 934 name: A name for this operation (optional). Defaults to "assert_less". 935 936 Returns: 937 Op that raises `InvalidArgumentError` if `x < y` is False. 938 This can be used with `tf.control_dependencies` inside of `tf.function`s 939 to block followup computation until the check has executed. 940 @compatibility(eager) 941 returns None 942 @end_compatibility 943 944 Raises: 945 InvalidArgumentError: if the check can be performed immediately and 946 `x < y` is False. The check can be performed immediately during eager 947 execution or if `x` and `y` are statically known. 948 """ 949 return assert_less(x=x, y=y, summarize=summarize, message=message, name=name) 950 951 952@tf_export(v1=['debugging.assert_less', 'assert_less']) 953@dispatch.add_dispatch_support 954@_binary_assert_doc('<', '[2, 3]') 955def assert_less(x, y, data=None, summarize=None, message=None, name=None): 956 return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, 957 summarize, message, name) 958 959 960@tf_export('debugging.assert_less_equal', v1=[]) 961@dispatch.add_dispatch_support 962def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): 963 """Assert the condition `x <= y` holds element-wise. 964 965 This Op checks that `x[i] <= y[i]` holds for every pair of (possibly 966 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 967 trivially satisfied. 968 969 If `x` is not less or equal than `y` element-wise, `message`, as well as the 970 first `summarize` entries of `x` and `y` are printed, and 971 `InvalidArgumentError` is raised. 972 973 Args: 974 x: Numeric `Tensor`. 975 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 976 message: A string to prefix to the default message. 977 summarize: Print this many entries of each tensor. 978 name: A name for this operation (optional). Defaults to "assert_less_equal". 979 980 Returns: 981 Op that raises `InvalidArgumentError` if `x <= y` is False. This can be 982 used with `tf.control_dependencies` inside of `tf.function`s to block 983 followup computation until the check has executed. 984 @compatibility(eager) 985 returns None 986 @end_compatibility 987 988 Raises: 989 InvalidArgumentError: if the check can be performed immediately and 990 `x <= y` is False. The check can be performed immediately during eager 991 execution or if `x` and `y` are statically known. 992 """ 993 return assert_less_equal(x=x, y=y, 994 summarize=summarize, message=message, name=name) 995 996 997@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) 998@dispatch.add_dispatch_support 999@deprecation.deprecated_endpoints('assert_less_equal') 1000@_binary_assert_doc('<=', '[1, 3]') 1001def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 1002 return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, 1003 np.less_equal, x, y, data, summarize, message, name) 1004 1005 1006@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) 1007@dispatch.add_dispatch_support 1008def assert_greater_v2(x, y, message=None, summarize=None, name=None): 1009 """Assert the condition `x > y` holds element-wise. 1010 1011 This Op checks that `x[i] > y[i]` holds for every pair of (possibly 1012 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 1013 trivially satisfied. 1014 1015 If `x` is not greater than `y` element-wise, `message`, as well as the first 1016 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 1017 raised. 1018 1019 Args: 1020 x: Numeric `Tensor`. 1021 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 1022 message: A string to prefix to the default message. 1023 summarize: Print this many entries of each tensor. 1024 name: A name for this operation (optional). Defaults to "assert_greater". 1025 1026 Returns: 1027 Op that raises `InvalidArgumentError` if `x > y` is False. This can be 1028 used with `tf.control_dependencies` inside of `tf.function`s to block 1029 followup computation until the check has executed. 1030 @compatibility(eager) 1031 returns None 1032 @end_compatibility 1033 1034 Raises: 1035 InvalidArgumentError: if the check can be performed immediately and 1036 `x > y` is False. The check can be performed immediately during eager 1037 execution or if `x` and `y` are statically known. 1038 """ 1039 return assert_greater(x=x, y=y, summarize=summarize, message=message, 1040 name=name) 1041 1042 1043@tf_export(v1=['debugging.assert_greater', 'assert_greater']) 1044@dispatch.add_dispatch_support 1045@_binary_assert_doc('>', '[0, 1]') 1046def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 1047 return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, 1048 y, data, summarize, message, name) 1049 1050 1051@tf_export('debugging.assert_greater_equal', v1=[]) 1052@dispatch.add_dispatch_support 1053def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): 1054 """Assert the condition `x >= y` holds element-wise. 1055 1056 This Op checks that `x[i] >= y[i]` holds for every pair of (possibly 1057 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 1058 trivially satisfied. 1059 1060 If `x` is not greater or equal to `y` element-wise, `message`, as well as the 1061 first `summarize` entries of `x` and `y` are printed, and 1062 `InvalidArgumentError` is raised. 1063 1064 Args: 1065 x: Numeric `Tensor`. 1066 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 1067 message: A string to prefix to the default message. 1068 summarize: Print this many entries of each tensor. 1069 name: A name for this operation (optional). Defaults to 1070 "assert_greater_equal". 1071 1072 Returns: 1073 Op that raises `InvalidArgumentError` if `x >= y` is False. This can be 1074 used with `tf.control_dependencies` inside of `tf.function`s to block 1075 followup computation until the check has executed. 1076 @compatibility(eager) 1077 returns None 1078 @end_compatibility 1079 1080 Raises: 1081 InvalidArgumentError: if the check can be performed immediately and 1082 `x >= y` is False. The check can be performed immediately during eager 1083 execution or if `x` and `y` are statically known. 1084 """ 1085 return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, 1086 name=name) 1087 1088 1089@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) 1090@dispatch.add_dispatch_support 1091@deprecation.deprecated_endpoints('assert_greater_equal') 1092@_binary_assert_doc('>=', '[1, 0]') 1093def assert_greater_equal(x, y, data=None, summarize=None, message=None, 1094 name=None): 1095 return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, 1096 np.greater_equal, x, y, data, summarize, message, name) 1097 1098 1099def _assert_rank_condition( 1100 x, rank, static_condition, dynamic_condition, data, summarize): 1101 """Assert `x` has a rank that satisfies a given condition. 1102 1103 Args: 1104 x: Numeric `Tensor`. 1105 rank: Scalar `Tensor`. 1106 static_condition: A python function that takes `[actual_rank, given_rank]` 1107 and returns `True` if the condition is satisfied, `False` otherwise. 1108 dynamic_condition: An `op` that takes [actual_rank, given_rank] and return 1109 `True` if the condition is satisfied, `False` otherwise. 1110 data: The tensors to print out if the condition is false. Defaults to 1111 error message and first few entries of `x`. 1112 summarize: Print this many entries of each tensor. 1113 1114 Returns: 1115 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1116 1117 Raises: 1118 ValueError: If static checks determine `x` fails static_condition. 1119 """ 1120 assert_type(rank, dtypes.int32) 1121 1122 # Attempt to statically defined rank. 1123 rank_static = tensor_util.constant_value(rank) 1124 if rank_static is not None: 1125 if rank_static.ndim != 0: 1126 raise ValueError('Rank must be a scalar.') 1127 1128 x_rank_static = x.get_shape().ndims 1129 if x_rank_static is not None: 1130 if not static_condition(x_rank_static, rank_static): 1131 raise ValueError( 1132 'Static rank condition failed', x_rank_static, rank_static) 1133 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1134 1135 condition = dynamic_condition(array_ops.rank(x), rank) 1136 1137 # Add the condition that `rank` must have rank zero. Prevents the bug where 1138 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1139 if rank_static is None: 1140 this_data = ['Rank must be a scalar. Received rank: ', rank] 1141 rank_check = assert_rank(rank, 0, data=this_data) 1142 condition = control_flow_ops.with_dependencies([rank_check], condition) 1143 1144 return control_flow_ops.Assert(condition, data, summarize=summarize) 1145 1146 1147@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) 1148@dispatch.add_dispatch_support 1149def assert_rank_v2(x, rank, message=None, name=None): 1150 """Assert that `x` has rank equal to `rank`. 1151 1152 This Op checks that the rank of `x` is equal to `rank`. 1153 1154 If `x` has a different rank, `message`, as well as the shape of `x` are 1155 printed, and `InvalidArgumentError` is raised. 1156 1157 Args: 1158 x: `Tensor`. 1159 rank: Scalar integer `Tensor`. 1160 message: A string to prefix to the default message. 1161 name: A name for this operation (optional). Defaults to 1162 "assert_rank". 1163 1164 Returns: 1165 Op raising `InvalidArgumentError` unless `x` has specified rank. 1166 If static checks determine `x` has correct rank, a `no_op` is returned. 1167 This can be used with `tf.control_dependencies` inside of `tf.function`s 1168 to block followup computation until the check has executed. 1169 @compatibility(eager) 1170 returns None 1171 @end_compatibility 1172 1173 Raises: 1174 InvalidArgumentError: if the check can be performed immediately and 1175 `x` does not have rank `rank`. The check can be performed immediately 1176 during eager execution or if the shape of `x` is statically known. 1177 """ 1178 return assert_rank(x=x, rank=rank, message=message, name=name) 1179 1180 1181@tf_export(v1=['debugging.assert_rank', 'assert_rank']) 1182@dispatch.add_dispatch_support 1183def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 1184 """Assert `x` has rank equal to `rank`. 1185 1186 Example of adding a dependency to an operation: 1187 1188 ```python 1189 with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): 1190 output = tf.reduce_sum(x) 1191 ``` 1192 1193 Args: 1194 x: Numeric `Tensor`. 1195 rank: Scalar integer `Tensor`. 1196 data: The tensors to print out if the condition is False. Defaults to 1197 error message and the shape of `x`. 1198 summarize: Print this many entries of each tensor. 1199 message: A string to prefix to the default message. 1200 name: A name for this operation (optional). Defaults to "assert_rank". 1201 1202 Returns: 1203 Op raising `InvalidArgumentError` unless `x` has specified rank. 1204 If static checks determine `x` has correct rank, a `no_op` is returned. 1205 1206 Raises: 1207 ValueError: If static checks determine `x` has wrong rank. 1208 """ 1209 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 1210 if not isinstance(x, sparse_tensor.SparseTensor): 1211 x = ops.convert_to_tensor(x, name='x') 1212 rank = ops.convert_to_tensor(rank, name='rank') 1213 message = message or '' 1214 1215 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 1216 dynamic_condition = math_ops.equal 1217 1218 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1219 name = '' 1220 else: 1221 name = x.name 1222 1223 if data is None: 1224 data = [ 1225 message, 1226 'Tensor %s must have rank' % name, rank, 'Received shape: ', 1227 array_ops.shape(x) 1228 ] 1229 1230 try: 1231 assert_op = _assert_rank_condition(x, rank, static_condition, 1232 dynamic_condition, data, summarize) 1233 1234 except ValueError as e: 1235 if e.args[0] == 'Static rank condition failed': 1236 raise ValueError( 1237 '%s. Tensor %s must have rank %d. Received rank %d, shape %s' % 1238 (message, name, e.args[2], e.args[1], x.get_shape())) 1239 else: 1240 raise 1241 1242 return assert_op 1243 1244 1245@tf_export('debugging.assert_rank_at_least', v1=[]) 1246@dispatch.add_dispatch_support 1247def assert_rank_at_least_v2(x, rank, message=None, name=None): 1248 """Assert that `x` has rank of at least `rank`. 1249 1250 This Op checks that the rank of `x` is greater or equal to `rank`. 1251 1252 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` 1253 are printed, and `InvalidArgumentError` is raised. 1254 1255 Args: 1256 x: `Tensor`. 1257 rank: Scalar integer `Tensor`. 1258 message: A string to prefix to the default message. 1259 name: A name for this operation (optional). Defaults to 1260 "assert_rank_at_least". 1261 1262 Returns: 1263 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1264 If static checks determine `x` has correct rank, a `no_op` is returned. 1265 This can be used with `tf.control_dependencies` inside of `tf.function`s 1266 to block followup computation until the check has executed. 1267 @compatibility(eager) 1268 returns None 1269 @end_compatibility 1270 1271 Raises: 1272 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank 1273 cannot be statically determined. 1274 ValueError: If static checks determine `x` has mismatched rank. 1275 """ 1276 return assert_rank_at_least(x=x, rank=rank, message=message, name=name) 1277 1278 1279@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) 1280@dispatch.add_dispatch_support 1281@deprecation.deprecated_endpoints('assert_rank_at_least') 1282def assert_rank_at_least( 1283 x, rank, data=None, summarize=None, message=None, name=None): 1284 """Assert `x` has rank equal to `rank` or higher. 1285 1286 Example of adding a dependency to an operation: 1287 1288 ```python 1289 with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): 1290 output = tf.reduce_sum(x) 1291 ``` 1292 1293 Args: 1294 x: Numeric `Tensor`. 1295 rank: Scalar `Tensor`. 1296 data: The tensors to print out if the condition is False. Defaults to 1297 error message and first few entries of `x`. 1298 summarize: Print this many entries of each tensor. 1299 message: A string to prefix to the default message. 1300 name: A name for this operation (optional). 1301 Defaults to "assert_rank_at_least". 1302 1303 Returns: 1304 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1305 If static checks determine `x` has correct rank, a `no_op` is returned. 1306 1307 Raises: 1308 ValueError: If static checks determine `x` has wrong rank. 1309 """ 1310 with ops.name_scope( 1311 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 1312 x = ops.convert_to_tensor(x, name='x') 1313 rank = ops.convert_to_tensor(rank, name='rank') 1314 message = message or '' 1315 1316 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 1317 dynamic_condition = math_ops.greater_equal 1318 1319 if context.executing_eagerly(): 1320 name = '' 1321 else: 1322 name = x.name 1323 1324 if data is None: 1325 data = [ 1326 message, 1327 'Tensor %s must have rank at least' % name, rank, 1328 'Received shape: ', array_ops.shape(x) 1329 ] 1330 1331 try: 1332 assert_op = _assert_rank_condition(x, rank, static_condition, 1333 dynamic_condition, data, summarize) 1334 1335 except ValueError as e: 1336 if e.args[0] == 'Static rank condition failed': 1337 raise ValueError( 1338 '%s. Tensor %s must have rank at least %d. Received rank %d, ' 1339 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1340 else: 1341 raise 1342 1343 return assert_op 1344 1345 1346def _static_rank_in(actual_rank, given_ranks): 1347 return actual_rank in given_ranks 1348 1349 1350def _dynamic_rank_in(actual_rank, given_ranks): 1351 if len(given_ranks) < 1: 1352 return ops.convert_to_tensor(False) 1353 result = math_ops.equal(given_ranks[0], actual_rank) 1354 for given_rank in given_ranks[1:]: 1355 result = math_ops.logical_or( 1356 result, math_ops.equal(given_rank, actual_rank)) 1357 return result 1358 1359 1360def _assert_ranks_condition( 1361 x, ranks, static_condition, dynamic_condition, data, summarize): 1362 """Assert `x` has a rank that satisfies a given condition. 1363 1364 Args: 1365 x: Numeric `Tensor`. 1366 ranks: Scalar `Tensor`. 1367 static_condition: A python function that takes 1368 `[actual_rank, given_ranks]` and returns `True` if the condition is 1369 satisfied, `False` otherwise. 1370 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 1371 and return `True` if the condition is satisfied, `False` otherwise. 1372 data: The tensors to print out if the condition is false. Defaults to 1373 error message and first few entries of `x`. 1374 summarize: Print this many entries of each tensor. 1375 1376 Returns: 1377 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1378 1379 Raises: 1380 ValueError: If static checks determine `x` fails static_condition. 1381 """ 1382 for rank in ranks: 1383 assert_type(rank, dtypes.int32) 1384 1385 # Attempt to statically defined rank. 1386 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 1387 if not any(r is None for r in ranks_static): 1388 for rank_static in ranks_static: 1389 if rank_static.ndim != 0: 1390 raise ValueError('Rank must be a scalar.') 1391 1392 x_rank_static = x.get_shape().ndims 1393 if x_rank_static is not None: 1394 if not static_condition(x_rank_static, ranks_static): 1395 raise ValueError( 1396 'Static rank condition failed', x_rank_static, ranks_static) 1397 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1398 1399 condition = dynamic_condition(array_ops.rank(x), ranks) 1400 1401 # Add the condition that `rank` must have rank zero. Prevents the bug where 1402 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1403 for rank, rank_static in zip(ranks, ranks_static): 1404 if rank_static is None: 1405 this_data = ['Rank must be a scalar. Received rank: ', rank] 1406 rank_check = assert_rank(rank, 0, data=this_data) 1407 condition = control_flow_ops.with_dependencies([rank_check], condition) 1408 1409 return control_flow_ops.Assert(condition, data, summarize=summarize) 1410 1411 1412@tf_export('debugging.assert_rank_in', v1=[]) 1413@dispatch.add_dispatch_support 1414def assert_rank_in_v2(x, ranks, message=None, name=None): 1415 """Assert that `x` has a rank in `ranks`. 1416 1417 This Op checks that the rank of `x` is in `ranks`. 1418 1419 If `x` has a different rank, `message`, as well as the shape of `x` are 1420 printed, and `InvalidArgumentError` is raised. 1421 1422 Args: 1423 x: `Tensor`. 1424 ranks: `Iterable` of scalar `Tensor` objects. 1425 message: A string to prefix to the default message. 1426 name: A name for this operation (optional). Defaults to "assert_rank_in". 1427 1428 Returns: 1429 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1430 If static checks determine `x` has matching rank, a `no_op` is returned. 1431 This can be used with `tf.control_dependencies` inside of `tf.function`s 1432 to block followup computation until the check has executed. 1433 @compatibility(eager) 1434 returns None 1435 @end_compatibility 1436 1437 Raises: 1438 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot 1439 be statically determined. 1440 ValueError: If static checks determine `x` has mismatched rank. 1441 """ 1442 return assert_rank_in(x=x, ranks=ranks, message=message, name=name) 1443 1444 1445@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) 1446@dispatch.add_dispatch_support 1447@deprecation.deprecated_endpoints('assert_rank_in') 1448def assert_rank_in( 1449 x, ranks, data=None, summarize=None, message=None, name=None): 1450 """Assert `x` has rank in `ranks`. 1451 1452 Example of adding a dependency to an operation: 1453 1454 ```python 1455 with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): 1456 output = tf.reduce_sum(x) 1457 ``` 1458 1459 Args: 1460 x: Numeric `Tensor`. 1461 ranks: Iterable of scalar `Tensor` objects. 1462 data: The tensors to print out if the condition is False. Defaults to 1463 error message and first few entries of `x`. 1464 summarize: Print this many entries of each tensor. 1465 message: A string to prefix to the default message. 1466 name: A name for this operation (optional). 1467 Defaults to "assert_rank_in". 1468 1469 Returns: 1470 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1471 If static checks determine `x` has matching rank, a `no_op` is returned. 1472 1473 Raises: 1474 ValueError: If static checks determine `x` has mismatched rank. 1475 """ 1476 with ops.name_scope( 1477 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 1478 if not isinstance(x, sparse_tensor.SparseTensor): 1479 x = ops.convert_to_tensor(x, name='x') 1480 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 1481 message = message or '' 1482 1483 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1484 name = '' 1485 else: 1486 name = x.name 1487 1488 if data is None: 1489 data = [ 1490 message, 'Tensor %s must have rank in' % name 1491 ] + list(ranks) + [ 1492 'Received shape: ', array_ops.shape(x) 1493 ] 1494 1495 try: 1496 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 1497 _dynamic_rank_in, data, summarize) 1498 1499 except ValueError as e: 1500 if e.args[0] == 'Static rank condition failed': 1501 raise ValueError( 1502 '%s. Tensor %s must have rank in %s. Received rank %d, ' 1503 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1504 else: 1505 raise 1506 1507 return assert_op 1508 1509 1510@tf_export('debugging.assert_integer', v1=[]) 1511@dispatch.add_dispatch_support 1512def assert_integer_v2(x, message=None, name=None): 1513 """Assert that `x` is of integer dtype. 1514 1515 If `x` has a non-integer type, `message`, as well as the dtype of `x` are 1516 printed, and `InvalidArgumentError` is raised. 1517 1518 This can always be checked statically, so this method returns nothing. 1519 1520 Args: 1521 x: A `Tensor`. 1522 message: A string to prefix to the default message. 1523 name: A name for this operation (optional). Defaults to "assert_integer". 1524 1525 Raises: 1526 TypeError: If `x.dtype` is not a non-quantized integer type. 1527 """ 1528 assert_integer(x=x, message=message, name=name) 1529 1530 1531@tf_export(v1=['debugging.assert_integer', 'assert_integer']) 1532@dispatch.add_dispatch_support 1533@deprecation.deprecated_endpoints('assert_integer') 1534def assert_integer(x, message=None, name=None): 1535 """Assert that `x` is of integer dtype. 1536 1537 Example of adding a dependency to an operation: 1538 1539 ```python 1540 with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): 1541 output = tf.reduce_sum(x) 1542 ``` 1543 1544 Args: 1545 x: `Tensor` whose basetype is integer and is not quantized. 1546 message: A string to prefix to the default message. 1547 name: A name for this operation (optional). Defaults to "assert_integer". 1548 1549 Raises: 1550 TypeError: If `x.dtype` is anything other than non-quantized integer. 1551 1552 Returns: 1553 A `no_op` that does nothing. Type can be determined statically. 1554 """ 1555 message = message or '' 1556 with ops.name_scope(name, 'assert_integer', [x]): 1557 x = ops.convert_to_tensor(x, name='x') 1558 if not x.dtype.is_integer: 1559 if context.executing_eagerly(): 1560 name = 'tensor' 1561 else: 1562 name = x.name 1563 err_msg = ( 1564 '%s Expected "x" to be integer type. Found: %s of dtype %s' 1565 % (message, name, x.dtype)) 1566 raise TypeError(err_msg) 1567 1568 return control_flow_ops.no_op('statically_determined_was_integer') 1569 1570 1571@tf_export('debugging.assert_type', v1=[]) 1572@dispatch.add_dispatch_support 1573def assert_type_v2(tensor, tf_type, message=None, name=None): 1574 """Asserts that the given `Tensor` is of the specified type. 1575 1576 This can always be checked statically, so this method returns nothing. 1577 1578 Example: 1579 1580 >>> a = tf.Variable(1.0) 1581 >>> tf.debugging.assert_type(a, tf_type= tf.float32) 1582 1583 >>> b = tf.constant(21) 1584 >>> tf.debugging.assert_type(b, tf_type=tf.bool) 1585 Traceback (most recent call last): 1586 ... 1587 TypeError: ... 1588 1589 >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], 1590 ... dense_shape=[3, 4]) 1591 >>> tf.debugging.assert_type(c, tf_type= tf.int32) 1592 1593 Args: 1594 tensor: A `Tensor`, `SparseTensor` or `tf.Variable . 1595 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1596 etc). 1597 message: A string to prefix to the default message. 1598 name: A name for this operation. Defaults to "assert_type" 1599 1600 Raises: 1601 TypeError: If the tensor's data type doesn't match `tf_type`. 1602 """ 1603 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) 1604 1605 1606@tf_export(v1=['debugging.assert_type', 'assert_type']) 1607@dispatch.add_dispatch_support 1608@deprecation.deprecated_endpoints('assert_type') 1609def assert_type(tensor, tf_type, message=None, name=None): 1610 """Statically asserts that the given `Tensor` is of the specified type. 1611 1612 Args: 1613 tensor: A `Tensor` or `SparseTensor`. 1614 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1615 etc). 1616 message: A string to prefix to the default message. 1617 name: A name to give this `Op`. Defaults to "assert_type" 1618 1619 Raises: 1620 TypeError: If the tensors data type doesn't match `tf_type`. 1621 1622 Returns: 1623 A `no_op` that does nothing. Type can be determined statically. 1624 """ 1625 message = message or '' 1626 tf_type = dtypes.as_dtype(tf_type) 1627 with ops.name_scope(name, 'assert_type', [tensor]): 1628 if not isinstance(tensor, sparse_tensor.SparseTensor): 1629 tensor = ops.convert_to_tensor(tensor, name='tensor') 1630 if tensor.dtype != tf_type: 1631 if context.executing_eagerly(): 1632 raise TypeError('%s tensor must be of type %s' % (message, tf_type)) 1633 else: 1634 raise TypeError( 1635 '%s %s must be of type %s' % 1636 (message, tensor.name if hasattr(tensor, 'name') else '', tf_type)) 1637 1638 return control_flow_ops.no_op('statically_determined_correct_type') 1639 1640 1641def _dimension_sizes(x): 1642 """Gets the dimension sizes of a tensor `x`. 1643 1644 If a size can be determined statically it is returned as an integer, 1645 otherwise as a tensor. 1646 1647 If `x` is a scalar it is treated as rank 1 size 1. 1648 1649 Args: 1650 x: A `Tensor`. 1651 1652 Returns: 1653 Dimension sizes. 1654 """ 1655 dynamic_shape = array_ops.shape(x) 1656 rank = x.get_shape().rank 1657 rank_is_known = rank is not None 1658 if rank_is_known and rank == 0: 1659 return (1,) 1660 if rank_is_known and rank > 0: 1661 static_shape = x.get_shape().as_list() 1662 sizes = [ 1663 int(size) if size is not None else dynamic_shape[i] 1664 for i, size in enumerate(static_shape) 1665 ] 1666 return sizes 1667 has_rank_zero = math_ops.equal(array_ops.rank(x), 0) 1668 return control_flow_ops.cond( 1669 has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) 1670 1671 1672def _symbolic_dimension_sizes(symbolic_shape): 1673 # If len(symbolic_shape) == 0 construct a tuple 1674 if not symbolic_shape: 1675 return tuple([1]) 1676 1677 return symbolic_shape 1678 1679 1680def _has_known_value(dimension_size): 1681 not_none = dimension_size is not None 1682 try: 1683 int(dimension_size) 1684 can_be_parsed_as_int = True 1685 except (ValueError, TypeError): 1686 can_be_parsed_as_int = False 1687 return not_none and can_be_parsed_as_int 1688 1689 1690def _is_symbol_for_any_size(symbol): 1691 return symbol in [None, '.'] 1692 1693 1694_TensorDimSizes = collections.namedtuple( 1695 '_TensorDimSizes', 1696 ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) 1697 1698 1699@tf_export('debugging.assert_shapes', v1=[]) 1700@dispatch.add_dispatch_support 1701def assert_shapes_v2(shapes, data=None, summarize=None, message=None, 1702 name=None): 1703 """Assert tensor shapes and dimension size relationships between tensors. 1704 1705 This Op checks that a collection of tensors shape relationships 1706 satisfies given constraints. 1707 1708 Example: 1709 1710 >>> n = 10 1711 >>> q = 3 1712 >>> d = 7 1713 >>> x = tf.zeros([n,q]) 1714 >>> y = tf.ones([n,d]) 1715 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1716 >>> scalar = 1.0 1717 >>> tf.debugging.assert_shapes([ 1718 ... (x, ('N', 'Q')), 1719 ... (y, ('N', 'D')), 1720 ... (param, ('Q',)), 1721 ... (scalar, ()), 1722 ... ]) 1723 1724 >>> tf.debugging.assert_shapes([ 1725 ... (x, ('N', 'D')), 1726 ... (y, ('N', 'D')) 1727 ... ]) 1728 Traceback (most recent call last): 1729 ... 1730 ValueError: ... 1731 1732 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1733 all specified constraints, `message`, as well as the first `summarize` entries 1734 of the first encountered violating tensor are printed, and 1735 `InvalidArgumentError` is raised. 1736 1737 Size entries in the specified shapes are checked against other entries by 1738 their __hash__, except: 1739 - a size entry is interpreted as an explicit size if it can be parsed as an 1740 integer primitive. 1741 - a size entry is interpreted as *any* size if it is None or '.'. 1742 1743 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1744 a variable number of outer dimensions of unspecified size, i.e. the constraint 1745 applies to the inner-most dimensions only. 1746 1747 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1748 prefix) are both treated as having a single dimension of size one. 1749 1750 Args: 1751 shapes: dictionary with (`Tensor` to shape) items, or a list of 1752 (`Tensor`, shape) tuples. A shape must be an iterable. 1753 data: The tensors to print out if the condition is False. Defaults to error 1754 message and first few entries of the violating tensor. 1755 summarize: Print this many entries of the tensor. 1756 message: A string to prefix to the default message. 1757 name: A name for this operation (optional). Defaults to "assert_shapes". 1758 1759 Raises: 1760 ValueError: If static checks determine any shape constraint is violated. 1761 """ 1762 assert_shapes( 1763 shapes, data=data, summarize=summarize, message=message, name=name) 1764 1765 1766@tf_export(v1=['debugging.assert_shapes']) 1767@dispatch.add_dispatch_support 1768def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 1769 """Assert tensor shapes and dimension size relationships between tensors. 1770 1771 This Op checks that a collection of tensors shape relationships 1772 satisfies given constraints. 1773 1774 Example: 1775 1776 >>> n = 10 1777 >>> q = 3 1778 >>> d = 7 1779 >>> x = tf.zeros([n,q]) 1780 >>> y = tf.ones([n,d]) 1781 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1782 >>> scalar = 1.0 1783 >>> tf.debugging.assert_shapes([ 1784 ... (x, ('N', 'Q')), 1785 ... (y, ('N', 'D')), 1786 ... (param, ('Q',)), 1787 ... (scalar, ()), 1788 ... ]) 1789 1790 >>> tf.debugging.assert_shapes([ 1791 ... (x, ('N', 'D')), 1792 ... (y, ('N', 'D')) 1793 ... ]) 1794 Traceback (most recent call last): 1795 ... 1796 ValueError: ... 1797 1798 Example of adding a dependency to an operation: 1799 1800 ```python 1801 with tf.control_dependencies([tf.assert_shapes(shapes)]): 1802 output = tf.matmul(x, y, transpose_a=True) 1803 ``` 1804 1805 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1806 all specified constraints, `message`, as well as the first `summarize` entries 1807 of the first encountered violating tensor are printed, and 1808 `InvalidArgumentError` is raised. 1809 1810 Size entries in the specified shapes are checked against other entries by 1811 their __hash__, except: 1812 - a size entry is interpreted as an explicit size if it can be parsed as an 1813 integer primitive. 1814 - a size entry is interpreted as *any* size if it is None or '.'. 1815 1816 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1817 a variable number of outer dimensions of unspecified size, i.e. the constraint 1818 applies to the inner-most dimensions only. 1819 1820 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1821 prefix) are both treated as having a single dimension of size one. 1822 1823 Args: 1824 shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the 1825 expected shape of `Tensor`. See the example code above. The `shape` must 1826 be an iterable. Each element of the iterable can be either a concrete 1827 integer value or a string that abstractly represents the dimension. 1828 For example, 1829 - `('N', 'Q')` specifies a 2D shape wherein the first and second 1830 dimensions of shape may or may not be equal. 1831 - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second 1832 dimensions are equal. 1833 - `(1, 'N')` specifies a 2D shape wherein the first dimension is 1834 exactly 1 and the second dimension can be any value. 1835 Note that the abstract dimension letters take effect across different 1836 tuple elements of the list. For example, 1837 `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts 1838 that both `x` and `y` are rank-2 tensors and their first dimensions are 1839 equal (`N`). 1840 `shape` can also be a `tf.TensorShape`. 1841 data: The tensors to print out if the condition is False. Defaults to error 1842 message and first few entries of the violating tensor. 1843 summarize: Print this many entries of the tensor. 1844 message: A string to prefix to the default message. 1845 name: A name for this operation (optional). Defaults to "assert_shapes". 1846 1847 Returns: 1848 Op raising `InvalidArgumentError` unless all shape constraints are 1849 satisfied. 1850 If static checks determine all constraints are satisfied, a `no_op` is 1851 returned. 1852 1853 Raises: 1854 ValueError: If static checks determine any shape constraint is violated. 1855 """ 1856 # If the user manages to assemble a dict containing tensors (possible in 1857 # Graph mode only), make sure we still accept that. 1858 if isinstance(shapes, dict): 1859 shapes = shapes.items() 1860 1861 message = message or '' 1862 with ops.name_scope(name, 'assert_shapes', [shapes, data]): 1863 # Shape specified as None implies no constraint 1864 shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else 1865 ops.convert_to_tensor(x), s) 1866 for x, s in shapes if s is not None] 1867 1868 executing_eagerly = context.executing_eagerly() 1869 1870 def tensor_name(x): 1871 if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): 1872 return _shape_and_dtype_str(x) 1873 return x.name 1874 1875 tensor_dim_sizes = [] 1876 for tensor, symbolic_shape in shape_constraints: 1877 is_iterable = ( 1878 hasattr(symbolic_shape, '__iter__') or 1879 hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. 1880 ) 1881 if not is_iterable: 1882 raise ValueError( 1883 '%s. ' 1884 'Tensor %s. Specified shape must be an iterable. ' 1885 'An iterable has the attribute `__iter__` or `__getitem__`. ' 1886 'Received specified shape: %s' % 1887 (message, tensor_name(tensor), symbolic_shape)) 1888 1889 # We convert this into a tuple to handle strings, lists and numpy arrays 1890 symbolic_shape_tuple = tuple(symbolic_shape) 1891 1892 tensors_specified_innermost = False 1893 for i, symbol in enumerate(symbolic_shape_tuple): 1894 if symbol not in [Ellipsis, '*']: 1895 continue 1896 1897 if i != 0: 1898 raise ValueError( 1899 '%s. ' 1900 'Tensor %s specified shape index %d. ' 1901 'Symbol `...` or `*` for a variable number of ' 1902 'unspecified dimensions is only allowed as the first entry' % 1903 (message, tensor_name(tensor), i)) 1904 1905 tensors_specified_innermost = True 1906 1907 # Only include the size of the specified dimensions since the 0th symbol 1908 # is either ellipsis or * 1909 tensor_dim_sizes.append( 1910 _TensorDimSizes( 1911 tensor, tensors_specified_innermost, _dimension_sizes(tensor), 1912 _symbolic_dimension_sizes( 1913 symbolic_shape_tuple[1:] 1914 if tensors_specified_innermost else symbolic_shape_tuple))) 1915 1916 rank_assertions = [] 1917 for sizes in tensor_dim_sizes: 1918 rank = len(sizes.symbolic_sizes) 1919 rank_zero_or_one = rank in [0, 1] 1920 if sizes.unspecified_dim: 1921 if rank_zero_or_one: 1922 # No assertion of rank needed as `x` only need to have rank at least 1923 # 0. See elif rank_zero_or_one case comment. 1924 continue 1925 assertion = assert_rank_at_least( 1926 x=sizes.x, 1927 rank=rank, 1928 data=data, 1929 summarize=summarize, 1930 message=message, 1931 name=name) 1932 elif rank_zero_or_one: 1933 # Rank 0 is treated as rank 1 size 1, i.e. there is 1934 # no distinction between the two in terms of rank. 1935 # See _dimension_sizes. 1936 assertion = assert_rank_in( 1937 x=sizes.x, 1938 ranks=[0, 1], 1939 data=data, 1940 summarize=summarize, 1941 message=message, 1942 name=name) 1943 else: 1944 assertion = assert_rank( 1945 x=sizes.x, 1946 rank=rank, 1947 data=data, 1948 summarize=summarize, 1949 message=message, 1950 name=name) 1951 rank_assertions.append(assertion) 1952 1953 size_assertions = [] 1954 size_specifications = {} 1955 for sizes in tensor_dim_sizes: 1956 for i, size_symbol in enumerate(sizes.symbolic_sizes): 1957 1958 if _is_symbol_for_any_size(size_symbol): 1959 # Size specified as any implies no constraint 1960 continue 1961 1962 if sizes.unspecified_dim: 1963 tensor_dim = i - len(sizes.symbolic_sizes) 1964 else: 1965 tensor_dim = i 1966 1967 if size_symbol in size_specifications or _has_known_value(size_symbol): 1968 if _has_known_value(size_symbol): 1969 specified_size = int(size_symbol) 1970 size_check_message = 'Specified explicitly' 1971 else: 1972 specified_size, specified_by_y, specified_at_dim = \ 1973 size_specifications[size_symbol] 1974 size_check_message = ( 1975 'Specified by tensor %s dimension %d' % 1976 (tensor_name(specified_by_y), specified_at_dim)) 1977 1978 # This is extremely subtle. If actual_sizes is dynamic, we must 1979 # make sure a control dependency is inserted here so that this slice 1980 # can not execute until the rank is asserted to be enough for the 1981 # slice to not fail. 1982 with ops.control_dependencies(rank_assertions): 1983 actual_size = sizes.actual_sizes[tensor_dim] 1984 if _has_known_value(actual_size) and _has_known_value(specified_size): 1985 if int(actual_size) != int(specified_size): 1986 raise ValueError( 1987 '%s. %s. Tensor %s dimension %s must have size %d. ' 1988 'Received size %d, shape %s' % 1989 (message, size_check_message, tensor_name(sizes.x), 1990 tensor_dim, specified_size, actual_size, 1991 sizes.x.get_shape())) 1992 # No dynamic assertion needed 1993 continue 1994 1995 condition = math_ops.equal( 1996 ops.convert_to_tensor(actual_size), 1997 ops.convert_to_tensor(specified_size)) 1998 data_ = data 1999 if data is None: 2000 data_ = [ 2001 message, size_check_message, 2002 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 2003 'must have size', specified_size, 'Received shape: ', 2004 array_ops.shape(sizes.x) 2005 ] 2006 size_assertions.append( 2007 control_flow_ops.Assert(condition, data_, summarize=summarize)) 2008 else: 2009 # Not sure if actual_sizes is a constant, but for safety, guard 2010 # on rank. See explanation above about actual_sizes need for safety. 2011 with ops.control_dependencies(rank_assertions): 2012 size = sizes.actual_sizes[tensor_dim] 2013 size_specifications[size_symbol] = (size, sizes.x, tensor_dim) 2014 2015 # Ensure both assertions actually occur. 2016 with ops.control_dependencies(rank_assertions): 2017 shapes_assertion = control_flow_ops.group(size_assertions) 2018 2019 return shapes_assertion 2020 2021 2022# pylint: disable=line-too-long 2023def _get_diff_for_monotonic_comparison(x): 2024 """Gets the difference x[1:] - x[:-1].""" 2025 x = array_ops.reshape(x, [-1]) 2026 if not is_numeric_tensor(x): 2027 raise TypeError('Expected x to be numeric, instead found: %s' % x) 2028 2029 # If x has less than 2 elements, there is nothing to compare. So return []. 2030 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 2031 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 2032 2033 # With 2 or more elements, return x[1:] - x[:-1] 2034 s_len = array_ops.shape(x) - 1 2035 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 2036 return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 2037 2038 2039@tf_export( 2040 'debugging.is_numeric_tensor', 2041 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) 2042@deprecation.deprecated_endpoints('is_numeric_tensor') 2043def is_numeric_tensor(tensor): 2044 """Returns `True` if the elements of `tensor` are numbers. 2045 2046 Specifically, returns `True` if the dtype of `tensor` is one of the following: 2047 2048 * `tf.float32` 2049 * `tf.float64` 2050 * `tf.int8` 2051 * `tf.int16` 2052 * `tf.int32` 2053 * `tf.int64` 2054 * `tf.uint8` 2055 * `tf.qint8` 2056 * `tf.qint32` 2057 * `tf.quint8` 2058 * `tf.complex64` 2059 2060 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not 2061 a `tf.Tensor` object. 2062 """ 2063 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 2064 2065 2066@tf_export( 2067 'math.is_non_decreasing', 2068 v1=[ 2069 'math.is_non_decreasing', 'debugging.is_non_decreasing', 2070 'is_non_decreasing' 2071 ]) 2072@dispatch.add_dispatch_support 2073@deprecation.deprecated_endpoints('debugging.is_non_decreasing', 2074 'is_non_decreasing') 2075def is_non_decreasing(x, name=None): 2076 """Returns `True` if `x` is non-decreasing. 2077 2078 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2079 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 2080 If `x` has less than two elements, it is trivially non-decreasing. 2081 2082 See also: `is_strictly_increasing` 2083 2084 >>> x1 = tf.constant([1.0, 1.0, 3.0]) 2085 >>> tf.math.is_non_decreasing(x1) 2086 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2087 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2088 >>> tf.math.is_non_decreasing(x2) 2089 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2090 2091 Args: 2092 x: Numeric `Tensor`. 2093 name: A name for this operation (optional). Defaults to "is_non_decreasing" 2094 2095 Returns: 2096 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 2097 2098 Raises: 2099 TypeError: if `x` is not a numeric tensor. 2100 """ 2101 with ops.name_scope(name, 'is_non_decreasing', [x]): 2102 diff = _get_diff_for_monotonic_comparison(x) 2103 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 2104 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2105 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 2106 2107 2108@tf_export( 2109 'math.is_strictly_increasing', 2110 v1=[ 2111 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 2112 'is_strictly_increasing' 2113 ]) 2114@dispatch.add_dispatch_support 2115@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 2116 'is_strictly_increasing') 2117def is_strictly_increasing(x, name=None): 2118 """Returns `True` if `x` is strictly increasing. 2119 2120 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2121 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 2122 If `x` has less than two elements, it is trivially strictly increasing. 2123 2124 See also: `is_non_decreasing` 2125 2126 >>> x1 = tf.constant([1.0, 2.0, 3.0]) 2127 >>> tf.math.is_strictly_increasing(x1) 2128 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2129 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2130 >>> tf.math.is_strictly_increasing(x2) 2131 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2132 2133 Args: 2134 x: Numeric `Tensor`. 2135 name: A name for this operation (optional). 2136 Defaults to "is_strictly_increasing" 2137 2138 Returns: 2139 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 2140 2141 Raises: 2142 TypeError: if `x` is not a numeric tensor. 2143 """ 2144 with ops.name_scope(name, 'is_strictly_increasing', [x]): 2145 diff = _get_diff_for_monotonic_comparison(x) 2146 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 2147 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2148 return math_ops.reduce_all(math_ops.less(zero, diff)) 2149 2150 2151def _assert_same_base_type(items, expected_type=None): 2152 r"""Asserts all items are of the same base type. 2153 2154 Args: 2155 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 2156 `Operation`, or `IndexedSlices`). Can include `None` elements, which 2157 will be ignored. 2158 expected_type: Expected type. If not specified, assert all items are 2159 of the same base type. 2160 2161 Returns: 2162 Validated type, or none if neither expected_type nor items provided. 2163 2164 Raises: 2165 ValueError: If any types do not match. 2166 """ 2167 original_expected_type = expected_type 2168 mismatch = False 2169 for item in items: 2170 if item is not None: 2171 item_type = item.dtype.base_dtype 2172 if not expected_type: 2173 expected_type = item_type 2174 elif expected_type != item_type: 2175 mismatch = True 2176 break 2177 if mismatch: 2178 # Loop back through and build up an informative error message (this is very 2179 # slow, so we don't do it unless we found an error above). 2180 expected_type = original_expected_type 2181 original_item_str = None 2182 for item in items: 2183 if item is not None: 2184 item_type = item.dtype.base_dtype 2185 if not expected_type: 2186 expected_type = item_type 2187 original_item_str = item.name if hasattr(item, 'name') else str(item) 2188 elif expected_type != item_type: 2189 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 2190 item.name if hasattr(item, 'name') else str(item), 2191 item_type, expected_type, 2192 (' as %s' % original_item_str) if original_item_str else '')) 2193 return expected_type # Should be unreachable 2194 else: 2195 return expected_type 2196 2197 2198@tf_export( 2199 'debugging.assert_same_float_dtype', 2200 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) 2201@dispatch.add_dispatch_support 2202@deprecation.deprecated_endpoints('assert_same_float_dtype') 2203def assert_same_float_dtype(tensors=None, dtype=None): 2204 """Validate and return float type based on `tensors` and `dtype`. 2205 2206 For ops such as matrix multiplication, inputs and weights must be of the 2207 same float type. This function validates that all `tensors` are the same type, 2208 validates that type is `dtype` (if supplied), and returns the type. Type must 2209 be a floating point type. If neither `tensors` nor `dtype` is supplied, 2210 the function will return `dtypes.float32`. 2211 2212 Args: 2213 tensors: Tensors of input values. Can include `None` elements, which will be 2214 ignored. 2215 dtype: Expected type. 2216 2217 Returns: 2218 Validated type. 2219 2220 Raises: 2221 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 2222 float, or the common type of the inputs is not a floating point type. 2223 """ 2224 if tensors: 2225 dtype = _assert_same_base_type(tensors, dtype) 2226 if not dtype: 2227 dtype = dtypes.float32 2228 elif not dtype.is_floating: 2229 raise ValueError('Expected floating point type, got %s.' % dtype) 2230 return dtype 2231 2232 2233@tf_export('debugging.assert_scalar', v1=[]) 2234@dispatch.add_dispatch_support 2235def assert_scalar_v2(tensor, message=None, name=None): 2236 """Asserts that the given `tensor` is a scalar. 2237 2238 This function raises `ValueError` unless it can be certain that the given 2239 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2240 unknown. 2241 2242 This is always checked statically, so this method returns nothing. 2243 2244 Args: 2245 tensor: A `Tensor`. 2246 message: A string to prefix to the default message. 2247 name: A name for this operation. Defaults to "assert_scalar" 2248 2249 Raises: 2250 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2251 unknown. 2252 """ 2253 assert_scalar(tensor=tensor, message=message, name=name) 2254 2255 2256@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) 2257@dispatch.add_dispatch_support 2258@deprecation.deprecated_endpoints('assert_scalar') 2259def assert_scalar(tensor, name=None, message=None): 2260 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). 2261 2262 This function raises `ValueError` unless it can be certain that the given 2263 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2264 unknown. 2265 2266 Args: 2267 tensor: A `Tensor`. 2268 name: A name for this operation. Defaults to "assert_scalar" 2269 message: A string to prefix to the default message. 2270 2271 Returns: 2272 The input tensor (potentially converted to a `Tensor`). 2273 2274 Raises: 2275 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2276 unknown. 2277 """ 2278 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 2279 tensor = ops.convert_to_tensor(tensor, name=name_scope) 2280 shape = tensor.get_shape() 2281 if shape.ndims != 0: 2282 if context.executing_eagerly(): 2283 raise ValueError('%sExpected scalar shape, saw shape: %s.' 2284 % (message or '', shape,)) 2285 else: 2286 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' 2287 % (message or '', tensor.name, shape)) 2288 return tensor 2289 2290 2291@tf_export('ensure_shape') 2292@dispatch.add_dispatch_support 2293def ensure_shape(x, shape, name=None): 2294 """Updates the shape of a tensor and checks at runtime that the shape holds. 2295 2296 When executed, this operation asserts that the input tensor `x`'s shape 2297 is compatible with the `shape` argument. 2298 See `tf.TensorShape.is_compatible_with` for details. 2299 2300 >>> x = tf.constant([[1, 2, 3], 2301 ... [4, 5, 6]]) 2302 >>> x = tf.ensure_shape(x, [2, 3]) 2303 2304 Use `None` for unknown dimensions: 2305 2306 >>> x = tf.ensure_shape(x, [None, 3]) 2307 >>> x = tf.ensure_shape(x, [2, None]) 2308 2309 If the tensor's shape is not compatible with the `shape` argument, an error 2310 is raised: 2311 2312 >>> x = tf.ensure_shape(x, [5]) 2313 Traceback (most recent call last): 2314 ... 2315 tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not 2316 compatible with expected shape [5]. [Op:EnsureShape] 2317 2318 During graph construction (typically tracing a `tf.function`), 2319 `tf.ensure_shape` updates the static-shape of the **result** tensor by 2320 merging the two shapes. See `tf.TensorShape.merge_with` for details. 2321 2322 This is most useful when **you** know a shape that can't be determined 2323 statically by TensorFlow. 2324 2325 The following trivial `tf.function` prints the input tensor's 2326 static-shape before and after `ensure_shape` is applied. 2327 2328 >>> @tf.function 2329 ... def f(tensor): 2330 ... print("Static-shape before:", tensor.shape) 2331 ... tensor = tf.ensure_shape(tensor, [None, 3]) 2332 ... print("Static-shape after:", tensor.shape) 2333 ... return tensor 2334 2335 This lets you see the effect of `tf.ensure_shape` when the function is traced: 2336 >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) 2337 Static-shape before: (None, None) 2338 Static-shape after: (None, 3) 2339 2340 >>> cf(tf.zeros([3, 3])) # Passes 2341 >>> cf(tf.constant([1, 2, 3])) # fails 2342 Traceback (most recent call last): 2343 ... 2344 InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. 2345 2346 The above example raises `tf.errors.InvalidArgumentError`, because `x`'s 2347 shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` 2348 2349 Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and 2350 runtime shapes. This is stricter than `tf.Tensor.set_shape` which only 2351 checks the buildtime shape. 2352 2353 Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape 2354 of the resulting tensor and enforces it at runtime, raising an error if the 2355 tensor's runtime shape is incompatible with the specified shape. 2356 `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it 2357 at runtime, which may result in inconsistencies between the statically-known 2358 shape of tensors and the runtime value of tensors. 2359 2360 For example, of loading images of a known size: 2361 2362 >>> @tf.function 2363 ... def decode_image(png): 2364 ... image = tf.image.decode_png(png, channels=3) 2365 ... # the `print` executes during tracing. 2366 ... print("Initial shape: ", image.shape) 2367 ... image = tf.ensure_shape(image,[28, 28, 3]) 2368 ... print("Final shape: ", image.shape) 2369 ... return image 2370 2371 When tracing a function, no ops are being executed, shapes may be unknown. 2372 See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) 2373 for details. 2374 2375 >>> concrete_decode = decode_image.get_concrete_function( 2376 ... tf.TensorSpec([], dtype=tf.string)) 2377 Initial shape: (None, None, 3) 2378 Final shape: (28, 28, 3) 2379 2380 >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) 2381 >>> image = tf.cast(image,tf.uint8) 2382 >>> png = tf.image.encode_png(image) 2383 >>> image2 = concrete_decode(png) 2384 >>> print(image2.shape) 2385 (28, 28, 3) 2386 2387 >>> image = tf.concat([image,image], axis=0) 2388 >>> print(image.shape) 2389 (56, 28, 3) 2390 >>> png = tf.image.encode_png(image) 2391 >>> image2 = concrete_decode(png) 2392 Traceback (most recent call last): 2393 ... 2394 tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not 2395 compatible with expected shape [28,28,3]. 2396 2397 Caution: if you don't use the result of `tf.ensure_shape` the check may not 2398 run. 2399 2400 >>> @tf.function 2401 ... def bad_decode_image(png): 2402 ... image = tf.image.decode_png(png, channels=3) 2403 ... # the `print` executes during tracing. 2404 ... print("Initial shape: ", image.shape) 2405 ... # BAD: forgot to use the returned tensor. 2406 ... tf.ensure_shape(image,[28, 28, 3]) 2407 ... print("Final shape: ", image.shape) 2408 ... return image 2409 2410 >>> image = bad_decode_image(png) 2411 Initial shape: (None, None, 3) 2412 Final shape: (None, None, 3) 2413 >>> print(image.shape) 2414 (56, 28, 3) 2415 2416 Args: 2417 x: A `Tensor`. 2418 shape: A `TensorShape` representing the shape of this tensor, a 2419 `TensorShapeProto`, a list, a tuple, or None. 2420 name: A name for this operation (optional). Defaults to "EnsureShape". 2421 2422 Returns: 2423 A `Tensor`. Has the same type and contents as `x`. 2424 2425 Raises: 2426 tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape 2427 of `x`. 2428 """ 2429 if not isinstance(shape, tensor_shape.TensorShape): 2430 shape = tensor_shape.TensorShape(shape) 2431 2432 return array_ops.ensure_shape(x, shape, name=name) 2433 2434 2435@ops.RegisterGradient('EnsureShape') 2436def _ensure_shape_grad(op, grad): 2437 del op # Unused. 2438 return grad 2439