1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-short-docstring-punctuation 16"""Asserts and Boolean Checks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24import numpy as np 25 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.util import compat 37from tensorflow.python.util import deprecation 38from tensorflow.python.util import dispatch 39from tensorflow.python.util.tf_export import tf_export 40 41NUMERIC_TYPES = frozenset( 42 [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, 43 dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, 44 dtypes.complex64]) 45 46__all__ = [ 47 'assert_negative', 48 'assert_positive', 49 'assert_proper_iterable', 50 'assert_non_negative', 51 'assert_non_positive', 52 'assert_equal', 53 'assert_none_equal', 54 'assert_near', 55 'assert_integer', 56 'assert_less', 57 'assert_less_equal', 58 'assert_greater', 59 'assert_greater_equal', 60 'assert_rank', 61 'assert_rank_at_least', 62 'assert_rank_in', 63 'assert_same_float_dtype', 64 'assert_scalar', 65 'assert_type', 66 'assert_shapes', 67 'is_non_decreasing', 68 'is_numeric_tensor', 69 'is_strictly_increasing', 70] 71 72 73def _maybe_constant_value_string(t): 74 if not isinstance(t, ops.Tensor): 75 return str(t) 76 const_t = tensor_util.constant_value(t) 77 if const_t is not None: 78 return str(const_t) 79 return t 80 81 82def _assert_static(condition, data): 83 """Raises a InvalidArgumentError with as much information as possible.""" 84 if not condition: 85 data_static = [_maybe_constant_value_string(x) for x in data] 86 raise errors.InvalidArgumentError(node_def=None, op=None, 87 message='\n'.join(data_static)) 88 89 90def _shape_and_dtype_str(tensor): 91 """Returns a string containing tensor's shape and dtype.""" 92 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 93 94 95def _unary_assert_doc(sym, sym_name): 96 """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor. 97 98 Args: 99 sym: Mathematical symbol for the check performed on each element, i.e. "> 0" 100 sym_name: English-language name for the op described by sym 101 102 Returns: 103 Decorator that adds the appropriate docstring to the function for symbol 104 `sym`. 105 """ 106 107 def _decorator(func): 108 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 109 110 Args: 111 func: Function for a TensorFlow op 112 113 Returns: 114 Version of `func` with documentation attached. 115 """ 116 opname = func.__name__ 117 cap_sym_name = sym_name.capitalize() 118 119 func.__doc__ = """ 120 Assert the condition `x {sym}` holds element-wise. 121 122 When running in graph mode, you should add a dependency on this operation 123 to ensure that it runs. Example of adding a dependency to an operation: 124 125 ```python 126 with tf.control_dependencies([tf.debugging.{opname}(x, y)]): 127 output = tf.reduce_sum(x) 128 ``` 129 130 {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`. 131 If `x` is empty this is trivially satisfied. 132 133 Args: 134 x: Numeric `Tensor`. 135 data: The tensors to print out if the condition is False. Defaults to 136 error message and first few entries of `x`. 137 summarize: Print this many entries of each tensor. 138 message: A string to prefix to the default message. 139 name: A name for this operation (optional). Defaults to "{opname}". 140 141 Returns: 142 Op that raises `InvalidArgumentError` if `x {sym}` is False. 143 @compatibility(eager) 144 returns None 145 @end_compatibility 146 147 Raises: 148 InvalidArgumentError: if the check can be performed immediately and 149 `x {sym}` is False. The check can be performed immediately during 150 eager execution or if `x` is statically known. 151 """.format( 152 sym=sym, sym_name=cap_sym_name, opname=opname) 153 return func 154 155 return _decorator 156 157 158def _binary_assert_doc(sym): 159 """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise. 160 161 Args: 162 sym: Binary operation symbol, i.e. "==" 163 164 Returns: 165 Decorator that adds the appropriate docstring to the function for 166 symbol `sym`. 167 """ 168 169 def _decorator(func): 170 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 171 172 Args: 173 func: Function for a TensorFlow op 174 175 Returns: 176 A version of `func` with documentation attached. 177 """ 178 opname = func.__name__ 179 180 func.__doc__ = """ 181 Assert the condition `x {sym} y` holds element-wise. 182 183 This condition holds if for every pair of (possibly broadcast) elements 184 `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`. 185 If both `x` and `y` are empty, this is trivially satisfied. 186 187 When running in graph mode, you should add a dependency on this operation 188 to ensure that it runs. Example of adding a dependency to an operation: 189 190 ```python 191 with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]): 192 output = tf.reduce_sum(x) 193 ``` 194 195 Args: 196 x: Numeric `Tensor`. 197 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 198 data: The tensors to print out if the condition is False. Defaults to 199 error message and first few entries of `x`, `y`. 200 summarize: Print this many entries of each tensor. 201 message: A string to prefix to the default message. 202 name: A name for this operation (optional). Defaults to "{opname}". 203 204 Returns: 205 Op that raises `InvalidArgumentError` if `x {sym} y` is False. 206 @compatibility(eager) 207 returns None 208 @end_compatibility 209 210 Raises: 211 InvalidArgumentError: if the check can be performed immediately and 212 `x {sym} y` is False. The check can be performed immediately during 213 eager execution or if `x` and `y` are statically known. 214 """.format( 215 sym=sym, opname=opname) 216 return func 217 218 return _decorator 219 220 221def _make_assert_msg_data(sym, x, y, summarize, test_op): 222 """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode. 223 224 Args: 225 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 226 i.e. "==" 227 x: First input to the assertion after applying `convert_to_tensor()` 228 y: Second input to the assertion 229 summarize: Value of the "summarize" parameter to the original assert_* call; 230 tells how many elements of each tensor to print. 231 test_op: TensorFlow op that returns a Boolean tensor with True in each 232 position where the assertion is satisfied. 233 234 Returns: 235 List of tensors and scalars that, when stringified and concatenated, 236 will produce the error message string. 237 """ 238 # Prepare a message with first elements of x and y. 239 data = [] 240 241 data.append('Condition x %s y did not hold.' % sym) 242 243 if summarize > 0: 244 if x.shape == y.shape and x.shape.as_list(): 245 # If the shapes of x and y are the same (and not scalars), 246 # Get the values that actually differed and their indices. 247 # If shapes are different this information is more confusing 248 # than useful. 249 mask = math_ops.logical_not(test_op) 250 indices = array_ops.where(mask) 251 indices_np = indices.numpy() 252 x_vals = array_ops.boolean_mask(x, mask) 253 y_vals = array_ops.boolean_mask(y, mask) 254 num_vals = min(summarize, indices_np.shape[0]) 255 data.append('Indices of first %d different values:' % num_vals) 256 data.append(indices_np[:num_vals]) 257 data.append('Corresponding x values:') 258 data.append(x_vals.numpy().reshape((-1,))[:num_vals]) 259 data.append('Corresponding y values:') 260 data.append(y_vals.numpy().reshape((-1,))[:num_vals]) 261 262 # reshape((-1,)) is the fastest way to get a flat array view. 263 x_np = x.numpy().reshape((-1,)) 264 y_np = y.numpy().reshape((-1,)) 265 x_sum = min(x_np.size, summarize) 266 y_sum = min(y_np.size, summarize) 267 data.append('First %d elements of x:' % x_sum) 268 data.append(x_np[:x_sum]) 269 data.append('First %d elements of y:' % y_sum) 270 data.append(y_np[:y_sum]) 271 272 return data 273 274 275def _pretty_print(data_item, summarize): 276 """Format a data item for use in an error message in eager mode. 277 278 Args: 279 data_item: One of the items in the "data" argument to an assert_* function. 280 Can be a Tensor or a scalar value. 281 summarize: How many elements to retain of each tensor-valued entry in data. 282 283 Returns: 284 An appropriate string representation of data_item 285 """ 286 if isinstance(data_item, ops.Tensor): 287 arr = data_item.numpy() 288 if np.isscalar(arr): 289 # Tensor.numpy() returns a scalar for zero-dimensional tensors 290 return str(arr) 291 else: 292 flat = arr.reshape((-1,)) 293 lst = [str(x) for x in flat[:summarize]] 294 if len(lst) < flat.size: 295 lst.append('...') 296 return str(lst) 297 else: 298 return str(data_item) 299 300 301def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, 302 message, name): 303 """Generic binary elementwise assertion. 304 305 Implements the behavior described in _binary_assert_doc() above. 306 Args: 307 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 308 i.e. "==" 309 opname: Name of the assert op in the public API, i.e. "assert_equal" 310 op_func: Function that, if passed the two Tensor inputs to the assertion (x 311 and y), will return the test to be passed to reduce_all() i.e. 312 static_func: Function that, if passed numpy ndarray versions of the two 313 inputs to the assertion, will return a Boolean ndarray with containing 314 True in all positions where the assertion PASSES. 315 i.e. np.equal for assert_equal() 316 x: Numeric `Tensor`. 317 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 318 data: The tensors to print out if the condition is False. Defaults to 319 error message and first few entries of `x`, `y`. 320 summarize: Print this many entries of each tensor. 321 message: A string to prefix to the default message. 322 name: A name for this operation (optional). Defaults to the value of 323 `opname`. 324 325 Returns: 326 See docstring template in _binary_assert_doc(). 327 """ 328 with ops.name_scope(name, opname, [x, y, data]): 329 x = ops.convert_to_tensor(x, name='x') 330 y = ops.convert_to_tensor(y, name='y') 331 332 if context.executing_eagerly(): 333 test_op = op_func(x, y) 334 condition = math_ops.reduce_all(test_op) 335 if condition: 336 return 337 338 # If we get here, the assertion has failed. 339 # Default to printing 3 elements like control_flow_ops.Assert (used 340 # by graph mode) does. Also treat negative values as "print 341 # everything" for consistency with Tensor::SummarizeValue(). 342 if summarize is None: 343 summarize = 3 344 elif summarize < 0: 345 summarize = 1e9 # Code below will find exact size of x and y. 346 347 if data is None: 348 data = _make_assert_msg_data(sym, x, y, summarize, test_op) 349 350 if message is not None: 351 data = [message] + list(data) 352 353 raise errors.InvalidArgumentError( 354 node_def=None, 355 op=None, 356 message=('\n'.join(_pretty_print(d, summarize) for d in data))) 357 358 else: # not context.executing_eagerly() 359 if data is None: 360 data = [ 361 'Condition x %s y did not hold element-wise:' % sym, 362 'x (%s) = ' % x.name, x, 363 'y (%s) = ' % y.name, y 364 ] 365 if message is not None: 366 data = [message] + list(data) 367 condition = math_ops.reduce_all(op_func(x, y)) 368 x_static = tensor_util.constant_value(x) 369 y_static = tensor_util.constant_value(y) 370 if x_static is not None and y_static is not None: 371 condition_static = np.all(static_func(x_static, y_static)) 372 _assert_static(condition_static, data) 373 return control_flow_ops.Assert(condition, data, summarize=summarize) 374 375 376@tf_export( 377 'debugging.assert_proper_iterable', 378 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) 379@dispatch.add_dispatch_support 380@deprecation.deprecated_endpoints('assert_proper_iterable') 381def assert_proper_iterable(values): 382 """Static assert that values is a "proper" iterable. 383 384 `Ops` that expect iterables of `Tensor` can call this to validate input. 385 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 386 387 Args: 388 values: Object to be checked. 389 390 Raises: 391 TypeError: If `values` is not iterable or is one of 392 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 393 """ 394 unintentional_iterables = ( 395 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 396 + compat.bytes_or_text_types 397 ) 398 if isinstance(values, unintentional_iterables): 399 raise TypeError( 400 'Expected argument "values" to be a "proper" iterable. Found: %s' % 401 type(values)) 402 403 if not hasattr(values, '__iter__'): 404 raise TypeError( 405 'Expected argument "values" to be iterable. Found: %s' % type(values)) 406 407 408@tf_export('debugging.assert_negative', v1=[]) 409@dispatch.add_dispatch_support 410def assert_negative_v2(x, message=None, summarize=None, name=None): 411 """Assert the condition `x < 0` holds element-wise. 412 413 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is 414 empty, this is trivially satisfied. 415 416 If `x` is not negative everywhere, `message`, as well as the first `summarize` 417 entries of `x` are printed, and `InvalidArgumentError` is raised. 418 419 Args: 420 x: Numeric `Tensor`. 421 message: A string to prefix to the default message. 422 summarize: Print this many entries of each tensor. 423 name: A name for this operation (optional). Defaults to "assert_negative". 424 425 Returns: 426 Op raising `InvalidArgumentError` unless `x` is all negative. This can be 427 used with `tf.control_dependencies` inside of `tf.function`s to block 428 followup computation until the check has executed. 429 @compatibility(eager) 430 returns None 431 @end_compatibility 432 433 Raises: 434 InvalidArgumentError: if the check can be performed immediately and 435 `x[i] < 0` is False. The check can be performed immediately during eager 436 execution or if `x` is statically known. 437 """ 438 return assert_negative(x=x, message=message, summarize=summarize, name=name) 439 440 441@tf_export(v1=['debugging.assert_negative', 'assert_negative']) 442@dispatch.add_dispatch_support 443@deprecation.deprecated_endpoints('assert_negative') 444@_unary_assert_doc('< 0', 'negative') 445def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 446 message = message or '' 447 with ops.name_scope(name, 'assert_negative', [x, data]): 448 x = ops.convert_to_tensor(x, name='x') 449 if data is None: 450 if context.executing_eagerly(): 451 name = _shape_and_dtype_str(x) 452 else: 453 name = x.name 454 data = [ 455 message, 456 'Condition x < 0 did not hold element-wise:', 457 'x (%s) = ' % name, x] 458 zero = ops.convert_to_tensor(0, dtype=x.dtype) 459 return assert_less(x, zero, data=data, summarize=summarize) 460 461 462@tf_export('debugging.assert_positive', v1=[]) 463@dispatch.add_dispatch_support 464def assert_positive_v2(x, message=None, summarize=None, name=None): 465 """Assert the condition `x > 0` holds element-wise. 466 467 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is 468 empty, this is trivially satisfied. 469 470 If `x` is not positive everywhere, `message`, as well as the first `summarize` 471 entries of `x` are printed, and `InvalidArgumentError` is raised. 472 473 Args: 474 x: Numeric `Tensor`. 475 message: A string to prefix to the default message. 476 summarize: Print this many entries of each tensor. 477 name: A name for this operation (optional). Defaults to "assert_positive". 478 479 Returns: 480 Op raising `InvalidArgumentError` unless `x` is all positive. This can be 481 used with `tf.control_dependencies` inside of `tf.function`s to block 482 followup computation until the check has executed. 483 @compatibility(eager) 484 returns None 485 @end_compatibility 486 487 Raises: 488 InvalidArgumentError: if the check can be performed immediately and 489 `x[i] > 0` is False. The check can be performed immediately during eager 490 execution or if `x` is statically known. 491 """ 492 return assert_positive(x=x, summarize=summarize, message=message, name=name) 493 494 495@tf_export(v1=['debugging.assert_positive', 'assert_positive']) 496@dispatch.add_dispatch_support 497@deprecation.deprecated_endpoints('assert_positive') 498@_unary_assert_doc('> 0', 'positive') 499def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 500 message = message or '' 501 with ops.name_scope(name, 'assert_positive', [x, data]): 502 x = ops.convert_to_tensor(x, name='x') 503 if data is None: 504 if context.executing_eagerly(): 505 name = _shape_and_dtype_str(x) 506 else: 507 name = x.name 508 data = [ 509 message, 'Condition x > 0 did not hold element-wise:', 510 'x (%s) = ' % name, x] 511 zero = ops.convert_to_tensor(0, dtype=x.dtype) 512 return assert_less(zero, x, data=data, summarize=summarize) 513 514 515@tf_export('debugging.assert_non_negative', v1=[]) 516@dispatch.add_dispatch_support 517def assert_non_negative_v2(x, message=None, summarize=None, name=None): 518 """Assert the condition `x >= 0` holds element-wise. 519 520 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is 521 empty, this is trivially satisfied. 522 523 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize` 524 entries of `x` are printed, and `InvalidArgumentError` is raised. 525 526 Args: 527 x: Numeric `Tensor`. 528 message: A string to prefix to the default message. 529 summarize: Print this many entries of each tensor. 530 name: A name for this operation (optional). Defaults to 531 "assert_non_negative". 532 533 Returns: 534 Op raising `InvalidArgumentError` unless `x` is all non-negative. This can 535 be used with `tf.control_dependencies` inside of `tf.function`s to block 536 followup computation until the check has executed. 537 @compatibility(eager) 538 returns None 539 @end_compatibility 540 541 Raises: 542 InvalidArgumentError: if the check can be performed immediately and 543 `x[i] >= 0` is False. The check can be performed immediately during eager 544 execution or if `x` is statically known. 545 """ 546 return assert_non_negative(x=x, summarize=summarize, message=message, 547 name=name) 548 549 550@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) 551@dispatch.add_dispatch_support 552@deprecation.deprecated_endpoints('assert_non_negative') 553@_unary_assert_doc('>= 0', 'non-negative') 554def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 555 message = message or '' 556 with ops.name_scope(name, 'assert_non_negative', [x, data]): 557 x = ops.convert_to_tensor(x, name='x') 558 if data is None: 559 if context.executing_eagerly(): 560 name = _shape_and_dtype_str(x) 561 else: 562 name = x.name 563 data = [ 564 message, 565 'Condition x >= 0 did not hold element-wise:', 566 'x (%s) = ' % name, x] 567 zero = ops.convert_to_tensor(0, dtype=x.dtype) 568 return assert_less_equal(zero, x, data=data, summarize=summarize) 569 570 571@tf_export('debugging.assert_non_positive', v1=[]) 572@dispatch.add_dispatch_support 573def assert_non_positive_v2(x, message=None, summarize=None, name=None): 574 """Assert the condition `x <= 0` holds element-wise. 575 576 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is 577 empty, this is trivially satisfied. 578 579 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize` 580 entries of `x` are printed, and `InvalidArgumentError` is raised. 581 582 Args: 583 x: Numeric `Tensor`. 584 message: A string to prefix to the default message. 585 summarize: Print this many entries of each tensor. 586 name: A name for this operation (optional). Defaults to 587 "assert_non_positive". 588 589 Returns: 590 Op raising `InvalidArgumentError` unless `x` is all non-positive. This can 591 be used with `tf.control_dependencies` inside of `tf.function`s to block 592 followup computation until the check has executed. 593 @compatibility(eager) 594 returns None 595 @end_compatibility 596 597 Raises: 598 InvalidArgumentError: if the check can be performed immediately and 599 `x[i] <= 0` is False. The check can be performed immediately during eager 600 execution or if `x` is statically known. 601 """ 602 return assert_non_positive(x=x, summarize=summarize, message=message, 603 name=name) 604 605 606@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) 607@dispatch.add_dispatch_support 608@deprecation.deprecated_endpoints('assert_non_positive') 609@_unary_assert_doc('<= 0', 'non-positive') 610def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 611 message = message or '' 612 with ops.name_scope(name, 'assert_non_positive', [x, data]): 613 x = ops.convert_to_tensor(x, name='x') 614 if data is None: 615 if context.executing_eagerly(): 616 name = _shape_and_dtype_str(x) 617 else: 618 name = x.name 619 data = [ 620 message, 621 'Condition x <= 0 did not hold element-wise:' 622 'x (%s) = ' % name, x] 623 zero = ops.convert_to_tensor(0, dtype=x.dtype) 624 return assert_less_equal(x, zero, data=data, summarize=summarize) 625 626 627@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) 628@dispatch.add_dispatch_support 629def assert_equal_v2(x, y, message=None, summarize=None, name=None): 630 """Assert the condition `x == y` holds element-wise. 631 632 This Op checks that `x[i] == y[i]` holds for every pair of (possibly 633 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 634 trivially satisfied. 635 636 If `x` and `y` are not equal, `message`, as well as the first `summarize` 637 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. 638 639 Args: 640 x: Numeric `Tensor`. 641 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 642 message: A string to prefix to the default message. 643 summarize: Print this many entries of each tensor. 644 name: A name for this operation (optional). Defaults to "assert_equal". 645 646 Returns: 647 Op that raises `InvalidArgumentError` if `x == y` is False. This can be 648 used with `tf.control_dependencies` inside of `tf.function`s to block 649 followup computation until the check has executed. 650 @compatibility(eager) 651 returns None 652 @end_compatibility 653 654 Raises: 655 InvalidArgumentError: if the check can be performed immediately and 656 `x == y` is False. The check can be performed immediately during eager 657 execution or if `x` and `y` are statically known. 658 """ 659 return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name) 660 661 662@tf_export(v1=['debugging.assert_equal', 'assert_equal']) 663@dispatch.add_dispatch_support 664@_binary_assert_doc('==') 665def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 666 with ops.name_scope(name, 'assert_equal', [x, y, data]): 667 # Short-circuit if x and y are the same tensor. 668 if x is y: 669 return None if context.executing_eagerly() else control_flow_ops.no_op() 670 return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, 671 data, summarize, message, name) 672 673 674@tf_export('debugging.assert_none_equal', v1=[]) 675@dispatch.add_dispatch_support 676def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): 677 """Assert the condition `x != y` holds for all elements. 678 679 This Op checks that `x[i] != y[i]` holds for every pair of (possibly 680 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 681 trivially satisfied. 682 683 If any elements of `x` and `y` are equal, `message`, as well as the first 684 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 685 is raised. 686 687 Args: 688 x: Numeric `Tensor`. 689 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 690 summarize: Print this many entries of each tensor. 691 message: A string to prefix to the default message. 692 name: A name for this operation (optional). Defaults to 693 "assert_none_equal". 694 695 Returns: 696 Op that raises `InvalidArgumentError` if `x != y` is ever False. This can 697 be used with `tf.control_dependencies` inside of `tf.function`s to block 698 followup computation until the check has executed. 699 @compatibility(eager) 700 returns None 701 @end_compatibility 702 703 Raises: 704 InvalidArgumentError: if the check can be performed immediately and 705 `x != y` is False for any pair of elements in `x` and `y`. The check can 706 be performed immediately during eager execution or if `x` and `y` are 707 statically known. 708 """ 709 return assert_none_equal(x=x, y=y, summarize=summarize, message=message, 710 name=name) 711 712 713@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) 714@dispatch.add_dispatch_support 715@deprecation.deprecated_endpoints('assert_none_equal') 716@_binary_assert_doc('!=') 717def assert_none_equal( 718 x, y, data=None, summarize=None, message=None, name=None): 719 return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, 720 np.not_equal, x, y, data, summarize, message, name) 721 722 723@tf_export('debugging.assert_near', v1=[]) 724@dispatch.add_dispatch_support 725def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, 726 name=None): 727 """Assert the condition `x` and `y` are close element-wise. 728 729 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every 730 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are 731 empty, this is trivially satisfied. 732 733 If any elements of `x` and `y` are not close, `message`, as well as the first 734 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 735 is raised. 736 737 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 738 representable positive number such that `1 + eps != 1`. This is about 739 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 740 See `numpy.finfo`. 741 742 Args: 743 x: Float or complex `Tensor`. 744 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. 745 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 746 The relative tolerance. Default is `10 * eps`. 747 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 748 The absolute tolerance. Default is `10 * eps`. 749 message: A string to prefix to the default message. 750 summarize: Print this many entries of each tensor. 751 name: A name for this operation (optional). Defaults to "assert_near". 752 753 Returns: 754 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 755 This can be used with `tf.control_dependencies` inside of `tf.function`s 756 to block followup computation until the check has executed. 757 @compatibility(eager) 758 returns None 759 @end_compatibility 760 761 Raises: 762 InvalidArgumentError: if the check can be performed immediately and 763 `x != y` is False for any pair of elements in `x` and `y`. The check can 764 be performed immediately during eager execution or if `x` and `y` are 765 statically known. 766 767 @compatibility(numpy) 768 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 769 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 770 `64bit`, and even `16bit` data. 771 @end_compatibility 772 """ 773 return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, 774 message=message, name=name) 775 776 777@tf_export(v1=['debugging.assert_near', 'assert_near']) 778@dispatch.add_dispatch_support 779@deprecation.deprecated_endpoints('assert_near') 780def assert_near( 781 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 782 name=None): 783 """Assert the condition `x` and `y` are close element-wise. 784 785 Example of adding a dependency to an operation: 786 787 ```python 788 with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): 789 output = tf.reduce_sum(x) 790 ``` 791 792 This condition holds if for every pair of (possibly broadcast) elements 793 `x[i]`, `y[i]`, we have 794 795 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 796 797 If both `x` and `y` are empty, this is trivially satisfied. 798 799 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 800 representable positive number such that `1 + eps != 1`. This is about 801 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 802 See `numpy.finfo`. 803 804 Args: 805 x: Float or complex `Tensor`. 806 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 807 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 808 The relative tolerance. Default is `10 * eps`. 809 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 810 The absolute tolerance. Default is `10 * eps`. 811 data: The tensors to print out if the condition is False. Defaults to 812 error message and first few entries of `x`, `y`. 813 summarize: Print this many entries of each tensor. 814 message: A string to prefix to the default message. 815 name: A name for this operation (optional). Defaults to "assert_near". 816 817 Returns: 818 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 819 820 @compatibility(numpy) 821 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 822 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 823 `64bit`, and even `16bit` data. 824 @end_compatibility 825 """ 826 message = message or '' 827 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 828 x = ops.convert_to_tensor(x, name='x') 829 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 830 831 dtype = x.dtype 832 if dtype.is_complex: 833 dtype = dtype.real_dtype 834 eps = np.finfo(dtype.as_numpy_dtype).eps 835 rtol = 10 * eps if rtol is None else rtol 836 atol = 10 * eps if atol is None else atol 837 838 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) 839 atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) 840 841 if context.executing_eagerly(): 842 x_name = _shape_and_dtype_str(x) 843 y_name = _shape_and_dtype_str(y) 844 else: 845 x_name = x.name 846 y_name = y.name 847 848 if data is None: 849 data = [ 850 message, 851 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 852 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 853 ] 854 tol = atol + rtol * math_ops.abs(y) 855 diff = math_ops.abs(x - y) 856 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 857 return control_flow_ops.Assert(condition, data, summarize=summarize) 858 859 860@tf_export('debugging.assert_less', 'assert_less', v1=[]) 861@dispatch.add_dispatch_support 862def assert_less_v2(x, y, message=None, summarize=None, name=None): 863 """Assert the condition `x < y` holds element-wise. 864 865 This Op checks that `x[i] < y[i]` holds for every pair of (possibly 866 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 867 trivially satisfied. 868 869 If `x` is not less than `y` element-wise, `message`, as well as the first 870 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 871 raised. 872 873 Args: 874 x: Numeric `Tensor`. 875 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 876 message: A string to prefix to the default message. 877 summarize: Print this many entries of each tensor. 878 name: A name for this operation (optional). Defaults to "assert_less". 879 880 Returns: 881 Op that raises `InvalidArgumentError` if `x < y` is False. 882 This can be used with `tf.control_dependencies` inside of `tf.function`s 883 to block followup computation until the check has executed. 884 @compatibility(eager) 885 returns None 886 @end_compatibility 887 888 Raises: 889 InvalidArgumentError: if the check can be performed immediately and 890 `x < y` is False. The check can be performed immediately during eager 891 execution or if `x` and `y` are statically known. 892 """ 893 return assert_less(x=x, y=y, summarize=summarize, message=message, name=name) 894 895 896@tf_export(v1=['debugging.assert_less', 'assert_less']) 897@dispatch.add_dispatch_support 898@_binary_assert_doc('<') 899def assert_less(x, y, data=None, summarize=None, message=None, name=None): 900 return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, 901 summarize, message, name) 902 903 904@tf_export('debugging.assert_less_equal', v1=[]) 905@dispatch.add_dispatch_support 906def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): 907 """Assert the condition `x <= y` holds element-wise. 908 909 This Op checks that `x[i] <= y[i]` holds for every pair of (possibly 910 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 911 trivially satisfied. 912 913 If `x` is not less or equal than `y` element-wise, `message`, as well as the 914 first `summarize` entries of `x` and `y` are printed, and 915 `InvalidArgumentError` is raised. 916 917 Args: 918 x: Numeric `Tensor`. 919 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 920 message: A string to prefix to the default message. 921 summarize: Print this many entries of each tensor. 922 name: A name for this operation (optional). Defaults to "assert_less_equal". 923 924 Returns: 925 Op that raises `InvalidArgumentError` if `x <= y` is False. This can be 926 used with `tf.control_dependencies` inside of `tf.function`s to block 927 followup computation until the check has executed. 928 @compatibility(eager) 929 returns None 930 @end_compatibility 931 932 Raises: 933 InvalidArgumentError: if the check can be performed immediately and 934 `x <= y` is False. The check can be performed immediately during eager 935 execution or if `x` and `y` are statically known. 936 """ 937 return assert_less_equal(x=x, y=y, 938 summarize=summarize, message=message, name=name) 939 940 941@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) 942@dispatch.add_dispatch_support 943@deprecation.deprecated_endpoints('assert_less_equal') 944@_binary_assert_doc('<=') 945def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 946 return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, 947 np.less_equal, x, y, data, summarize, message, name) 948 949 950@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) 951@dispatch.add_dispatch_support 952def assert_greater_v2(x, y, message=None, summarize=None, name=None): 953 """Assert the condition `x > y` holds element-wise. 954 955 This Op checks that `x[i] > y[i]` holds for every pair of (possibly 956 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 957 trivially satisfied. 958 959 If `x` is not greater than `y` element-wise, `message`, as well as the first 960 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 961 raised. 962 963 Args: 964 x: Numeric `Tensor`. 965 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 966 message: A string to prefix to the default message. 967 summarize: Print this many entries of each tensor. 968 name: A name for this operation (optional). Defaults to "assert_greater". 969 970 Returns: 971 Op that raises `InvalidArgumentError` if `x > y` is False. This can be 972 used with `tf.control_dependencies` inside of `tf.function`s to block 973 followup computation until the check has executed. 974 @compatibility(eager) 975 returns None 976 @end_compatibility 977 978 Raises: 979 InvalidArgumentError: if the check can be performed immediately and 980 `x > y` is False. The check can be performed immediately during eager 981 execution or if `x` and `y` are statically known. 982 """ 983 return assert_greater(x=x, y=y, summarize=summarize, message=message, 984 name=name) 985 986 987@tf_export(v1=['debugging.assert_greater', 'assert_greater']) 988@dispatch.add_dispatch_support 989@_binary_assert_doc('>') 990def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 991 return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, 992 y, data, summarize, message, name) 993 994 995@tf_export('debugging.assert_greater_equal', v1=[]) 996@dispatch.add_dispatch_support 997def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): 998 """Assert the condition `x >= y` holds element-wise. 999 1000 This Op checks that `x[i] >= y[i]` holds for every pair of (possibly 1001 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 1002 trivially satisfied. 1003 1004 If `x` is not greater or equal to `y` element-wise, `message`, as well as the 1005 first `summarize` entries of `x` and `y` are printed, and 1006 `InvalidArgumentError` is raised. 1007 1008 Args: 1009 x: Numeric `Tensor`. 1010 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 1011 message: A string to prefix to the default message. 1012 summarize: Print this many entries of each tensor. 1013 name: A name for this operation (optional). Defaults to 1014 "assert_greater_equal". 1015 1016 Returns: 1017 Op that raises `InvalidArgumentError` if `x >= y` is False. This can be 1018 used with `tf.control_dependencies` inside of `tf.function`s to block 1019 followup computation until the check has executed. 1020 @compatibility(eager) 1021 returns None 1022 @end_compatibility 1023 1024 Raises: 1025 InvalidArgumentError: if the check can be performed immediately and 1026 `x >= y` is False. The check can be performed immediately during eager 1027 execution or if `x` and `y` are statically known. 1028 """ 1029 return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, 1030 name=name) 1031 1032 1033@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) 1034@dispatch.add_dispatch_support 1035@deprecation.deprecated_endpoints('assert_greater_equal') 1036@_binary_assert_doc('>=') 1037def assert_greater_equal(x, y, data=None, summarize=None, message=None, 1038 name=None): 1039 return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, 1040 np.greater_equal, x, y, data, summarize, message, name) 1041 1042 1043def _assert_rank_condition( 1044 x, rank, static_condition, dynamic_condition, data, summarize): 1045 """Assert `x` has a rank that satisfies a given condition. 1046 1047 Args: 1048 x: Numeric `Tensor`. 1049 rank: Scalar `Tensor`. 1050 static_condition: A python function that takes `[actual_rank, given_rank]` 1051 and returns `True` if the condition is satisfied, `False` otherwise. 1052 dynamic_condition: An `op` that takes [actual_rank, given_rank] and return 1053 `True` if the condition is satisfied, `False` otherwise. 1054 data: The tensors to print out if the condition is false. Defaults to 1055 error message and first few entries of `x`. 1056 summarize: Print this many entries of each tensor. 1057 1058 Returns: 1059 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1060 1061 Raises: 1062 ValueError: If static checks determine `x` fails static_condition. 1063 """ 1064 assert_type(rank, dtypes.int32) 1065 1066 # Attempt to statically defined rank. 1067 rank_static = tensor_util.constant_value(rank) 1068 if rank_static is not None: 1069 if rank_static.ndim != 0: 1070 raise ValueError('Rank must be a scalar.') 1071 1072 x_rank_static = x.get_shape().ndims 1073 if x_rank_static is not None: 1074 if not static_condition(x_rank_static, rank_static): 1075 raise ValueError( 1076 'Static rank condition failed', x_rank_static, rank_static) 1077 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1078 1079 condition = dynamic_condition(array_ops.rank(x), rank) 1080 1081 # Add the condition that `rank` must have rank zero. Prevents the bug where 1082 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1083 if rank_static is None: 1084 this_data = ['Rank must be a scalar. Received rank: ', rank] 1085 rank_check = assert_rank(rank, 0, data=this_data) 1086 condition = control_flow_ops.with_dependencies([rank_check], condition) 1087 1088 return control_flow_ops.Assert(condition, data, summarize=summarize) 1089 1090 1091@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) 1092@dispatch.add_dispatch_support 1093def assert_rank_v2(x, rank, message=None, name=None): 1094 """Assert that `x` has rank equal to `rank`. 1095 1096 This Op checks that the rank of `x` is equal to `rank`. 1097 1098 If `x` has a different rank, `message`, as well as the shape of `x` are 1099 printed, and `InvalidArgumentError` is raised. 1100 1101 Args: 1102 x: `Tensor`. 1103 rank: Scalar integer `Tensor`. 1104 message: A string to prefix to the default message. 1105 name: A name for this operation (optional). Defaults to 1106 "assert_rank". 1107 1108 Returns: 1109 Op raising `InvalidArgumentError` unless `x` has specified rank. 1110 If static checks determine `x` has correct rank, a `no_op` is returned. 1111 This can be used with `tf.control_dependencies` inside of `tf.function`s 1112 to block followup computation until the check has executed. 1113 @compatibility(eager) 1114 returns None 1115 @end_compatibility 1116 1117 Raises: 1118 InvalidArgumentError: if the check can be performed immediately and 1119 `x` does not have rank `rank`. The check can be performed immediately 1120 during eager execution or if the shape of `x` is statically known. 1121 """ 1122 return assert_rank(x=x, rank=rank, message=message, name=name) 1123 1124 1125@tf_export(v1=['debugging.assert_rank', 'assert_rank']) 1126@dispatch.add_dispatch_support 1127def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 1128 """Assert `x` has rank equal to `rank`. 1129 1130 Example of adding a dependency to an operation: 1131 1132 ```python 1133 with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): 1134 output = tf.reduce_sum(x) 1135 ``` 1136 1137 Args: 1138 x: Numeric `Tensor`. 1139 rank: Scalar integer `Tensor`. 1140 data: The tensors to print out if the condition is False. Defaults to 1141 error message and the shape of `x`. 1142 summarize: Print this many entries of each tensor. 1143 message: A string to prefix to the default message. 1144 name: A name for this operation (optional). Defaults to "assert_rank". 1145 1146 Returns: 1147 Op raising `InvalidArgumentError` unless `x` has specified rank. 1148 If static checks determine `x` has correct rank, a `no_op` is returned. 1149 1150 Raises: 1151 ValueError: If static checks determine `x` has wrong rank. 1152 """ 1153 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 1154 if not isinstance(x, sparse_tensor.SparseTensor): 1155 x = ops.convert_to_tensor(x, name='x') 1156 rank = ops.convert_to_tensor(rank, name='rank') 1157 message = message or '' 1158 1159 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 1160 dynamic_condition = math_ops.equal 1161 1162 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1163 name = '' 1164 else: 1165 name = x.name 1166 1167 if data is None: 1168 data = [ 1169 message, 1170 'Tensor %s must have rank' % name, rank, 'Received shape: ', 1171 array_ops.shape(x) 1172 ] 1173 1174 try: 1175 assert_op = _assert_rank_condition(x, rank, static_condition, 1176 dynamic_condition, data, summarize) 1177 1178 except ValueError as e: 1179 if e.args[0] == 'Static rank condition failed': 1180 raise ValueError( 1181 '%s. Tensor %s must have rank %d. Received rank %d, shape %s' % 1182 (message, name, e.args[2], e.args[1], x.get_shape())) 1183 else: 1184 raise 1185 1186 return assert_op 1187 1188 1189@tf_export('debugging.assert_rank_at_least', v1=[]) 1190@dispatch.add_dispatch_support 1191def assert_rank_at_least_v2(x, rank, message=None, name=None): 1192 """Assert that `x` has rank of at least `rank`. 1193 1194 This Op checks that the rank of `x` is greater or equal to `rank`. 1195 1196 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` 1197 are printed, and `InvalidArgumentError` is raised. 1198 1199 Args: 1200 x: `Tensor`. 1201 rank: Scalar integer `Tensor`. 1202 message: A string to prefix to the default message. 1203 name: A name for this operation (optional). Defaults to 1204 "assert_rank_at_least". 1205 1206 Returns: 1207 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1208 If static checks determine `x` has correct rank, a `no_op` is returned. 1209 This can be used with `tf.control_dependencies` inside of `tf.function`s 1210 to block followup computation until the check has executed. 1211 @compatibility(eager) 1212 returns None 1213 @end_compatibility 1214 1215 Raises: 1216 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank 1217 cannot be statically determined. 1218 ValueError: If static checks determine `x` has mismatched rank. 1219 """ 1220 return assert_rank_at_least(x=x, rank=rank, message=message, name=name) 1221 1222 1223@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) 1224@dispatch.add_dispatch_support 1225@deprecation.deprecated_endpoints('assert_rank_at_least') 1226def assert_rank_at_least( 1227 x, rank, data=None, summarize=None, message=None, name=None): 1228 """Assert `x` has rank equal to `rank` or higher. 1229 1230 Example of adding a dependency to an operation: 1231 1232 ```python 1233 with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): 1234 output = tf.reduce_sum(x) 1235 ``` 1236 1237 Args: 1238 x: Numeric `Tensor`. 1239 rank: Scalar `Tensor`. 1240 data: The tensors to print out if the condition is False. Defaults to 1241 error message and first few entries of `x`. 1242 summarize: Print this many entries of each tensor. 1243 message: A string to prefix to the default message. 1244 name: A name for this operation (optional). 1245 Defaults to "assert_rank_at_least". 1246 1247 Returns: 1248 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1249 If static checks determine `x` has correct rank, a `no_op` is returned. 1250 1251 Raises: 1252 ValueError: If static checks determine `x` has wrong rank. 1253 """ 1254 with ops.name_scope( 1255 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 1256 x = ops.convert_to_tensor(x, name='x') 1257 rank = ops.convert_to_tensor(rank, name='rank') 1258 message = message or '' 1259 1260 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 1261 dynamic_condition = math_ops.greater_equal 1262 1263 if context.executing_eagerly(): 1264 name = '' 1265 else: 1266 name = x.name 1267 1268 if data is None: 1269 data = [ 1270 message, 1271 'Tensor %s must have rank at least' % name, rank, 1272 'Received shape: ', array_ops.shape(x) 1273 ] 1274 1275 try: 1276 assert_op = _assert_rank_condition(x, rank, static_condition, 1277 dynamic_condition, data, summarize) 1278 1279 except ValueError as e: 1280 if e.args[0] == 'Static rank condition failed': 1281 raise ValueError( 1282 '%s. Tensor %s must have rank at least %d. Received rank %d, ' 1283 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1284 else: 1285 raise 1286 1287 return assert_op 1288 1289 1290def _static_rank_in(actual_rank, given_ranks): 1291 return actual_rank in given_ranks 1292 1293 1294def _dynamic_rank_in(actual_rank, given_ranks): 1295 if len(given_ranks) < 1: 1296 return ops.convert_to_tensor(False) 1297 result = math_ops.equal(given_ranks[0], actual_rank) 1298 for given_rank in given_ranks[1:]: 1299 result = math_ops.logical_or( 1300 result, math_ops.equal(given_rank, actual_rank)) 1301 return result 1302 1303 1304def _assert_ranks_condition( 1305 x, ranks, static_condition, dynamic_condition, data, summarize): 1306 """Assert `x` has a rank that satisfies a given condition. 1307 1308 Args: 1309 x: Numeric `Tensor`. 1310 ranks: Scalar `Tensor`. 1311 static_condition: A python function that takes 1312 `[actual_rank, given_ranks]` and returns `True` if the condition is 1313 satisfied, `False` otherwise. 1314 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 1315 and return `True` if the condition is satisfied, `False` otherwise. 1316 data: The tensors to print out if the condition is false. Defaults to 1317 error message and first few entries of `x`. 1318 summarize: Print this many entries of each tensor. 1319 1320 Returns: 1321 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1322 1323 Raises: 1324 ValueError: If static checks determine `x` fails static_condition. 1325 """ 1326 for rank in ranks: 1327 assert_type(rank, dtypes.int32) 1328 1329 # Attempt to statically defined rank. 1330 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 1331 if not any(r is None for r in ranks_static): 1332 for rank_static in ranks_static: 1333 if rank_static.ndim != 0: 1334 raise ValueError('Rank must be a scalar.') 1335 1336 x_rank_static = x.get_shape().ndims 1337 if x_rank_static is not None: 1338 if not static_condition(x_rank_static, ranks_static): 1339 raise ValueError( 1340 'Static rank condition failed', x_rank_static, ranks_static) 1341 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1342 1343 condition = dynamic_condition(array_ops.rank(x), ranks) 1344 1345 # Add the condition that `rank` must have rank zero. Prevents the bug where 1346 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1347 for rank, rank_static in zip(ranks, ranks_static): 1348 if rank_static is None: 1349 this_data = ['Rank must be a scalar. Received rank: ', rank] 1350 rank_check = assert_rank(rank, 0, data=this_data) 1351 condition = control_flow_ops.with_dependencies([rank_check], condition) 1352 1353 return control_flow_ops.Assert(condition, data, summarize=summarize) 1354 1355 1356@tf_export('debugging.assert_rank_in', v1=[]) 1357@dispatch.add_dispatch_support 1358def assert_rank_in_v2(x, ranks, message=None, name=None): 1359 """Assert that `x` has a rank in `ranks`. 1360 1361 This Op checks that the rank of `x` is in `ranks`. 1362 1363 If `x` has a different rank, `message`, as well as the shape of `x` are 1364 printed, and `InvalidArgumentError` is raised. 1365 1366 Args: 1367 x: `Tensor`. 1368 ranks: `Iterable` of scalar `Tensor` objects. 1369 message: A string to prefix to the default message. 1370 name: A name for this operation (optional). Defaults to "assert_rank_in". 1371 1372 Returns: 1373 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1374 If static checks determine `x` has matching rank, a `no_op` is returned. 1375 This can be used with `tf.control_dependencies` inside of `tf.function`s 1376 to block followup computation until the check has executed. 1377 @compatibility(eager) 1378 returns None 1379 @end_compatibility 1380 1381 Raises: 1382 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot 1383 be statically determined. 1384 ValueError: If static checks determine `x` has mismatched rank. 1385 """ 1386 return assert_rank_in(x=x, ranks=ranks, message=message, name=name) 1387 1388 1389@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) 1390@dispatch.add_dispatch_support 1391@deprecation.deprecated_endpoints('assert_rank_in') 1392def assert_rank_in( 1393 x, ranks, data=None, summarize=None, message=None, name=None): 1394 """Assert `x` has rank in `ranks`. 1395 1396 Example of adding a dependency to an operation: 1397 1398 ```python 1399 with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): 1400 output = tf.reduce_sum(x) 1401 ``` 1402 1403 Args: 1404 x: Numeric `Tensor`. 1405 ranks: Iterable of scalar `Tensor` objects. 1406 data: The tensors to print out if the condition is False. Defaults to 1407 error message and first few entries of `x`. 1408 summarize: Print this many entries of each tensor. 1409 message: A string to prefix to the default message. 1410 name: A name for this operation (optional). 1411 Defaults to "assert_rank_in". 1412 1413 Returns: 1414 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1415 If static checks determine `x` has matching rank, a `no_op` is returned. 1416 1417 Raises: 1418 ValueError: If static checks determine `x` has mismatched rank. 1419 """ 1420 with ops.name_scope( 1421 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 1422 if not isinstance(x, sparse_tensor.SparseTensor): 1423 x = ops.convert_to_tensor(x, name='x') 1424 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 1425 message = message or '' 1426 1427 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1428 name = '' 1429 else: 1430 name = x.name 1431 1432 if data is None: 1433 data = [ 1434 message, 'Tensor %s must have rank in' % name 1435 ] + list(ranks) + [ 1436 'Received shape: ', array_ops.shape(x) 1437 ] 1438 1439 try: 1440 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 1441 _dynamic_rank_in, data, summarize) 1442 1443 except ValueError as e: 1444 if e.args[0] == 'Static rank condition failed': 1445 raise ValueError( 1446 '%s. Tensor %s must have rank in %s. Received rank %d, ' 1447 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1448 else: 1449 raise 1450 1451 return assert_op 1452 1453 1454@tf_export('debugging.assert_integer', v1=[]) 1455@dispatch.add_dispatch_support 1456def assert_integer_v2(x, message=None, name=None): 1457 """Assert that `x` is of integer dtype. 1458 1459 If `x` has a non-integer type, `message`, as well as the dtype of `x` are 1460 printed, and `InvalidArgumentError` is raised. 1461 1462 This can always be checked statically, so this method returns nothing. 1463 1464 Args: 1465 x: A `Tensor`. 1466 message: A string to prefix to the default message. 1467 name: A name for this operation (optional). Defaults to "assert_integer". 1468 1469 Raises: 1470 TypeError: If `x.dtype` is not a non-quantized integer type. 1471 """ 1472 assert_integer(x=x, message=message, name=name) 1473 1474 1475@tf_export(v1=['debugging.assert_integer', 'assert_integer']) 1476@dispatch.add_dispatch_support 1477@deprecation.deprecated_endpoints('assert_integer') 1478def assert_integer(x, message=None, name=None): 1479 """Assert that `x` is of integer dtype. 1480 1481 Example of adding a dependency to an operation: 1482 1483 ```python 1484 with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): 1485 output = tf.reduce_sum(x) 1486 ``` 1487 1488 Args: 1489 x: `Tensor` whose basetype is integer and is not quantized. 1490 message: A string to prefix to the default message. 1491 name: A name for this operation (optional). Defaults to "assert_integer". 1492 1493 Raises: 1494 TypeError: If `x.dtype` is anything other than non-quantized integer. 1495 1496 Returns: 1497 A `no_op` that does nothing. Type can be determined statically. 1498 """ 1499 message = message or '' 1500 with ops.name_scope(name, 'assert_integer', [x]): 1501 x = ops.convert_to_tensor(x, name='x') 1502 if not x.dtype.is_integer: 1503 if context.executing_eagerly(): 1504 name = 'tensor' 1505 else: 1506 name = x.name 1507 err_msg = ( 1508 '%s Expected "x" to be integer type. Found: %s of dtype %s' 1509 % (message, name, x.dtype)) 1510 raise TypeError(err_msg) 1511 1512 return control_flow_ops.no_op('statically_determined_was_integer') 1513 1514 1515@tf_export('debugging.assert_type', v1=[]) 1516@dispatch.add_dispatch_support 1517def assert_type_v2(tensor, tf_type, message=None, name=None): 1518 """Asserts that the given `Tensor` is of the specified type. 1519 1520 This can always be checked statically, so this method returns nothing. 1521 1522 Args: 1523 tensor: A `Tensor` or `SparseTensor`. 1524 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1525 etc). 1526 message: A string to prefix to the default message. 1527 name: A name for this operation. Defaults to "assert_type" 1528 1529 Raises: 1530 TypeError: If the tensor's data type doesn't match `tf_type`. 1531 """ 1532 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) 1533 1534 1535@tf_export(v1=['debugging.assert_type', 'assert_type']) 1536@dispatch.add_dispatch_support 1537@deprecation.deprecated_endpoints('assert_type') 1538def assert_type(tensor, tf_type, message=None, name=None): 1539 """Statically asserts that the given `Tensor` is of the specified type. 1540 1541 Args: 1542 tensor: A `Tensor` or `SparseTensor`. 1543 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1544 etc). 1545 message: A string to prefix to the default message. 1546 name: A name to give this `Op`. Defaults to "assert_type" 1547 1548 Raises: 1549 TypeError: If the tensors data type doesn't match `tf_type`. 1550 1551 Returns: 1552 A `no_op` that does nothing. Type can be determined statically. 1553 """ 1554 message = message or '' 1555 tf_type = dtypes.as_dtype(tf_type) 1556 with ops.name_scope(name, 'assert_type', [tensor]): 1557 if not isinstance(tensor, sparse_tensor.SparseTensor): 1558 tensor = ops.convert_to_tensor(tensor, name='tensor') 1559 if tensor.dtype != tf_type: 1560 if context.executing_eagerly(): 1561 raise TypeError('%s tensor must be of type %s' % (message, tf_type)) 1562 else: 1563 raise TypeError( 1564 '%s %s must be of type %s' % 1565 (message, tensor.name if hasattr(tensor, 'name') else '', tf_type)) 1566 1567 return control_flow_ops.no_op('statically_determined_correct_type') 1568 1569 1570def _dimension_sizes(x): 1571 """Gets the dimension sizes of a tensor `x`. 1572 1573 If a size can be determined statically it is returned as an integer, 1574 otherwise as a tensor. 1575 1576 If `x` is a scalar it is treated as rank 1 size 1. 1577 1578 Args: 1579 x: A `Tensor`. 1580 1581 Returns: 1582 Dimension sizes. 1583 """ 1584 dynamic_shape = array_ops.shape(x) 1585 rank = x.get_shape().rank 1586 rank_is_known = rank is not None 1587 if rank_is_known and rank == 0: 1588 return (1,) 1589 if rank_is_known and rank > 0: 1590 static_shape = x.get_shape().as_list() 1591 sizes = [ 1592 int(size) if size is not None else dynamic_shape[i] 1593 for i, size in enumerate(static_shape) 1594 ] 1595 return sizes 1596 has_rank_zero = math_ops.equal(array_ops.rank(x), 0) 1597 return control_flow_ops.cond( 1598 has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) 1599 1600 1601def _symbolic_dimension_sizes(symbolic_shape): 1602 # If len(symbolic_shape) == 0 construct a tuple 1603 if not symbolic_shape: 1604 return tuple([1]) 1605 1606 return symbolic_shape 1607 1608 1609def _has_known_value(dimension_size): 1610 not_none = dimension_size is not None 1611 try: 1612 int(dimension_size) 1613 can_be_parsed_as_int = True 1614 except (ValueError, TypeError): 1615 can_be_parsed_as_int = False 1616 return not_none and can_be_parsed_as_int 1617 1618 1619def _is_symbol_for_any_size(symbol): 1620 return symbol in [None, '.'] 1621 1622 1623_TensorDimSizes = collections.namedtuple( 1624 '_TensorDimSizes', 1625 ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) 1626 1627 1628@tf_export('debugging.assert_shapes', v1=[]) 1629@dispatch.add_dispatch_support 1630def assert_shapes_v2(shapes, data=None, summarize=None, message=None, 1631 name=None): 1632 """Assert tensor shapes and dimension size relationships between tensors. 1633 1634 This Op checks that a collection of tensors shape relationships 1635 satisfies given constraints. 1636 1637 Example: 1638 1639 >>> n = 10 1640 >>> q = 3 1641 >>> d = 7 1642 >>> x = tf.zeros([n,q]) 1643 >>> y = tf.ones([n,d]) 1644 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1645 >>> scalar = 1.0 1646 >>> tf.debugging.assert_shapes([ 1647 ... (x, ('N', 'Q')), 1648 ... (y, ('N', 'D')), 1649 ... (param, ('Q',)), 1650 ... (scalar, ()), 1651 ... ]) 1652 1653 >>> tf.debugging.assert_shapes([ 1654 ... (x, ('N', 'D')), 1655 ... (y, ('N', 'D')) 1656 ... ]) 1657 Traceback (most recent call last): 1658 ... 1659 ValueError: ... 1660 1661 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1662 all specified constraints, `message`, as well as the first `summarize` entries 1663 of the first encountered violating tensor are printed, and 1664 `InvalidArgumentError` is raised. 1665 1666 Size entries in the specified shapes are checked against other entries by 1667 their __hash__, except: 1668 - a size entry is interpreted as an explicit size if it can be parsed as an 1669 integer primitive. 1670 - a size entry is interpreted as *any* size if it is None or '.'. 1671 1672 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1673 a variable number of outer dimensions of unspecified size, i.e. the constraint 1674 applies to the inner-most dimensions only. 1675 1676 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1677 prefix) are both treated as having a single dimension of size one. 1678 1679 Args: 1680 shapes: dictionary with (`Tensor` to shape) items, or a list of 1681 (`Tensor`, shape) tuples. A shape must be an iterable. 1682 data: The tensors to print out if the condition is False. Defaults to error 1683 message and first few entries of the violating tensor. 1684 summarize: Print this many entries of the tensor. 1685 message: A string to prefix to the default message. 1686 name: A name for this operation (optional). Defaults to "assert_shapes". 1687 1688 Raises: 1689 ValueError: If static checks determine any shape constraint is violated. 1690 """ 1691 assert_shapes( 1692 shapes, data=data, summarize=summarize, message=message, name=name) 1693 1694 1695@tf_export(v1=['debugging.assert_shapes']) 1696@dispatch.add_dispatch_support 1697def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 1698 """Assert tensor shapes and dimension size relationships between tensors. 1699 1700 This Op checks that a collection of tensors shape relationships 1701 satisfies given constraints. 1702 1703 Example: 1704 1705 >>> n = 10 1706 >>> q = 3 1707 >>> d = 7 1708 >>> x = tf.zeros([n,q]) 1709 >>> y = tf.ones([n,d]) 1710 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1711 >>> scalar = 1.0 1712 >>> tf.debugging.assert_shapes([ 1713 ... (x, ('N', 'Q')), 1714 ... (y, ('N', 'D')), 1715 ... (param, ('Q',)), 1716 ... (scalar, ()), 1717 ... ]) 1718 1719 >>> tf.debugging.assert_shapes([ 1720 ... (x, ('N', 'D')), 1721 ... (y, ('N', 'D')) 1722 ... ]) 1723 Traceback (most recent call last): 1724 ... 1725 ValueError: ... 1726 1727 Example of adding a dependency to an operation: 1728 1729 ```python 1730 with tf.control_dependencies([tf.assert_shapes(shapes)]): 1731 output = tf.matmul(x, y, transpose_a=True) 1732 ``` 1733 1734 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1735 all specified constraints, `message`, as well as the first `summarize` entries 1736 of the first encountered violating tensor are printed, and 1737 `InvalidArgumentError` is raised. 1738 1739 Size entries in the specified shapes are checked against other entries by 1740 their __hash__, except: 1741 - a size entry is interpreted as an explicit size if it can be parsed as an 1742 integer primitive. 1743 - a size entry is interpreted as *any* size if it is None or '.'. 1744 1745 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1746 a variable number of outer dimensions of unspecified size, i.e. the constraint 1747 applies to the inner-most dimensions only. 1748 1749 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1750 prefix) are both treated as having a single dimension of size one. 1751 1752 Args: 1753 shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the 1754 expected shape of `Tensor`. See the example code above. The `shape` must 1755 be an iterable. Each element of the iterable can be either a concrete 1756 integer value or a string that abstractly represents the dimension. 1757 For example, 1758 - `('N', 'Q')` specifies a 2D shape wherein the first and second 1759 dimensions of shape may or may not be equal. 1760 - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second 1761 dimensions are equal. 1762 - `(1, 'N')` specifies a 2D shape wherein the first dimension is 1763 exactly 1 and the second dimension can be any value. 1764 Note that the abstract dimension letters take effect across different 1765 tuple elements of the list. For example, 1766 `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts 1767 that both `x` and `y` are rank-2 tensors and their first dimensions are 1768 equal (`N`). 1769 `shape` can also be a `tf.TensorShape`. 1770 data: The tensors to print out if the condition is False. Defaults to error 1771 message and first few entries of the violating tensor. 1772 summarize: Print this many entries of the tensor. 1773 message: A string to prefix to the default message. 1774 name: A name for this operation (optional). Defaults to "assert_shapes". 1775 1776 Returns: 1777 Op raising `InvalidArgumentError` unless all shape constraints are 1778 satisfied. 1779 If static checks determine all constraints are satisfied, a `no_op` is 1780 returned. 1781 1782 Raises: 1783 ValueError: If static checks determine any shape constraint is violated. 1784 """ 1785 # If the user manages to assemble a dict containing tensors (possible in 1786 # Graph mode only), make sure we still accept that. 1787 if isinstance(shapes, dict): 1788 shapes = shapes.items() 1789 1790 message = message or '' 1791 with ops.name_scope(name, 'assert_shapes', [shapes, data]): 1792 # Shape specified as None implies no constraint 1793 shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else 1794 ops.convert_to_tensor(x), s) 1795 for x, s in shapes if s is not None] 1796 1797 executing_eagerly = context.executing_eagerly() 1798 1799 def tensor_name(x): 1800 if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): 1801 return _shape_and_dtype_str(x) 1802 return x.name 1803 1804 tensor_dim_sizes = [] 1805 for tensor, symbolic_shape in shape_constraints: 1806 is_iterable = ( 1807 hasattr(symbolic_shape, '__iter__') or 1808 hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. 1809 ) 1810 if not is_iterable: 1811 raise ValueError( 1812 '%s. ' 1813 'Tensor %s. Specified shape must be an iterable. ' 1814 'An iterable has the attribute `__iter__` or `__getitem__`. ' 1815 'Received specified shape: %s' % 1816 (message, tensor_name(tensor), symbolic_shape)) 1817 1818 # We convert this into a tuple to handle strings, lists and numpy arrays 1819 symbolic_shape_tuple = tuple(symbolic_shape) 1820 1821 tensors_specified_innermost = False 1822 for i, symbol in enumerate(symbolic_shape_tuple): 1823 if symbol not in [Ellipsis, '*']: 1824 continue 1825 1826 if i != 0: 1827 raise ValueError( 1828 '%s. ' 1829 'Tensor %s specified shape index %d. ' 1830 'Symbol `...` or `*` for a variable number of ' 1831 'unspecified dimensions is only allowed as the first entry' % 1832 (message, tensor_name(tensor), i)) 1833 1834 tensors_specified_innermost = True 1835 1836 # Only include the size of the specified dimensions since the 0th symbol 1837 # is either ellipsis or * 1838 tensor_dim_sizes.append( 1839 _TensorDimSizes( 1840 tensor, tensors_specified_innermost, _dimension_sizes(tensor), 1841 _symbolic_dimension_sizes( 1842 symbolic_shape_tuple[1:] 1843 if tensors_specified_innermost else symbolic_shape_tuple))) 1844 1845 rank_assertions = [] 1846 for sizes in tensor_dim_sizes: 1847 rank = len(sizes.symbolic_sizes) 1848 rank_zero_or_one = rank in [0, 1] 1849 if sizes.unspecified_dim: 1850 if rank_zero_or_one: 1851 # No assertion of rank needed as `x` only need to have rank at least 1852 # 0. See elif rank_zero_or_one case comment. 1853 continue 1854 assertion = assert_rank_at_least( 1855 x=sizes.x, 1856 rank=rank, 1857 data=data, 1858 summarize=summarize, 1859 message=message, 1860 name=name) 1861 elif rank_zero_or_one: 1862 # Rank 0 is treated as rank 1 size 1, i.e. there is 1863 # no distinction between the two in terms of rank. 1864 # See _dimension_sizes. 1865 assertion = assert_rank_in( 1866 x=sizes.x, 1867 ranks=[0, 1], 1868 data=data, 1869 summarize=summarize, 1870 message=message, 1871 name=name) 1872 else: 1873 assertion = assert_rank( 1874 x=sizes.x, 1875 rank=rank, 1876 data=data, 1877 summarize=summarize, 1878 message=message, 1879 name=name) 1880 rank_assertions.append(assertion) 1881 1882 size_assertions = [] 1883 size_specifications = {} 1884 for sizes in tensor_dim_sizes: 1885 for i, size_symbol in enumerate(sizes.symbolic_sizes): 1886 1887 if _is_symbol_for_any_size(size_symbol): 1888 # Size specified as any implies no constraint 1889 continue 1890 1891 if sizes.unspecified_dim: 1892 tensor_dim = i - len(sizes.symbolic_sizes) 1893 else: 1894 tensor_dim = i 1895 1896 if size_symbol in size_specifications or _has_known_value(size_symbol): 1897 if _has_known_value(size_symbol): 1898 specified_size = int(size_symbol) 1899 size_check_message = 'Specified explicitly' 1900 else: 1901 specified_size, specified_by_y, specified_at_dim = \ 1902 size_specifications[size_symbol] 1903 size_check_message = ( 1904 'Specified by tensor %s dimension %d' % 1905 (tensor_name(specified_by_y), specified_at_dim)) 1906 1907 # This is extremely subtle. If actual_sizes is dynamic, we must 1908 # make sure a control dependency is inserted here so that this slice 1909 # can not execute until the rank is asserted to be enough for the 1910 # slice to not fail. 1911 with ops.control_dependencies(rank_assertions): 1912 actual_size = sizes.actual_sizes[tensor_dim] 1913 if _has_known_value(actual_size) and _has_known_value(specified_size): 1914 if int(actual_size) != int(specified_size): 1915 raise ValueError( 1916 '%s. %s. Tensor %s dimension %s must have size %d. ' 1917 'Received size %d, shape %s' % 1918 (message, size_check_message, tensor_name(sizes.x), 1919 tensor_dim, specified_size, actual_size, 1920 sizes.x.get_shape())) 1921 # No dynamic assertion needed 1922 continue 1923 1924 condition = math_ops.equal( 1925 ops.convert_to_tensor(actual_size), 1926 ops.convert_to_tensor(specified_size)) 1927 data_ = data 1928 if data is None: 1929 data_ = [ 1930 message, size_check_message, 1931 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 1932 'must have size', specified_size, 'Received shape: ', 1933 array_ops.shape(sizes.x) 1934 ] 1935 size_assertions.append( 1936 control_flow_ops.Assert(condition, data_, summarize=summarize)) 1937 else: 1938 # Not sure if actual_sizes is a constant, but for safety, guard 1939 # on rank. See explanation above about actual_sizes need for safety. 1940 with ops.control_dependencies(rank_assertions): 1941 size = sizes.actual_sizes[tensor_dim] 1942 size_specifications[size_symbol] = (size, sizes.x, tensor_dim) 1943 1944 # Ensure both assertions actually occur. 1945 with ops.control_dependencies(rank_assertions): 1946 shapes_assertion = control_flow_ops.group(size_assertions) 1947 1948 return shapes_assertion 1949 1950 1951# pylint: disable=line-too-long 1952def _get_diff_for_monotonic_comparison(x): 1953 """Gets the difference x[1:] - x[:-1].""" 1954 x = array_ops.reshape(x, [-1]) 1955 if not is_numeric_tensor(x): 1956 raise TypeError('Expected x to be numeric, instead found: %s' % x) 1957 1958 # If x has less than 2 elements, there is nothing to compare. So return []. 1959 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 1960 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 1961 1962 # With 2 or more elements, return x[1:] - x[:-1] 1963 s_len = array_ops.shape(x) - 1 1964 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 1965 return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 1966 1967 1968@tf_export( 1969 'debugging.is_numeric_tensor', 1970 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) 1971@deprecation.deprecated_endpoints('is_numeric_tensor') 1972def is_numeric_tensor(tensor): 1973 """Returns `True` if the elements of `tensor` are numbers. 1974 1975 Specifically, returns `True` if the dtype of `tensor` is one of the following: 1976 1977 * `tf.float32` 1978 * `tf.float64` 1979 * `tf.int8` 1980 * `tf.int16` 1981 * `tf.int32` 1982 * `tf.int64` 1983 * `tf.uint8` 1984 * `tf.qint8` 1985 * `tf.qint32` 1986 * `tf.quint8` 1987 * `tf.complex64` 1988 1989 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not 1990 a `tf.Tensor` object. 1991 """ 1992 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 1993 1994 1995@tf_export( 1996 'math.is_non_decreasing', 1997 v1=[ 1998 'math.is_non_decreasing', 'debugging.is_non_decreasing', 1999 'is_non_decreasing' 2000 ]) 2001@dispatch.add_dispatch_support 2002@deprecation.deprecated_endpoints('debugging.is_non_decreasing', 2003 'is_non_decreasing') 2004def is_non_decreasing(x, name=None): 2005 """Returns `True` if `x` is non-decreasing. 2006 2007 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2008 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 2009 If `x` has less than two elements, it is trivially non-decreasing. 2010 2011 See also: `is_strictly_increasing` 2012 2013 >>> x1 = tf.constant([1.0, 1.0, 3.0]) 2014 >>> tf.math.is_non_decreasing(x1) 2015 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2016 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2017 >>> tf.math.is_non_decreasing(x2) 2018 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2019 2020 Args: 2021 x: Numeric `Tensor`. 2022 name: A name for this operation (optional). Defaults to "is_non_decreasing" 2023 2024 Returns: 2025 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 2026 2027 Raises: 2028 TypeError: if `x` is not a numeric tensor. 2029 """ 2030 with ops.name_scope(name, 'is_non_decreasing', [x]): 2031 diff = _get_diff_for_monotonic_comparison(x) 2032 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 2033 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2034 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 2035 2036 2037@tf_export( 2038 'math.is_strictly_increasing', 2039 v1=[ 2040 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 2041 'is_strictly_increasing' 2042 ]) 2043@dispatch.add_dispatch_support 2044@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 2045 'is_strictly_increasing') 2046def is_strictly_increasing(x, name=None): 2047 """Returns `True` if `x` is strictly increasing. 2048 2049 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2050 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 2051 If `x` has less than two elements, it is trivially strictly increasing. 2052 2053 See also: `is_non_decreasing` 2054 2055 >>> x1 = tf.constant([1.0, 2.0, 3.0]) 2056 >>> tf.math.is_strictly_increasing(x1) 2057 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2058 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2059 >>> tf.math.is_strictly_increasing(x2) 2060 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2061 2062 Args: 2063 x: Numeric `Tensor`. 2064 name: A name for this operation (optional). 2065 Defaults to "is_strictly_increasing" 2066 2067 Returns: 2068 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 2069 2070 Raises: 2071 TypeError: if `x` is not a numeric tensor. 2072 """ 2073 with ops.name_scope(name, 'is_strictly_increasing', [x]): 2074 diff = _get_diff_for_monotonic_comparison(x) 2075 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 2076 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2077 return math_ops.reduce_all(math_ops.less(zero, diff)) 2078 2079 2080def _assert_same_base_type(items, expected_type=None): 2081 r"""Asserts all items are of the same base type. 2082 2083 Args: 2084 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 2085 `Operation`, or `IndexedSlices`). Can include `None` elements, which 2086 will be ignored. 2087 expected_type: Expected type. If not specified, assert all items are 2088 of the same base type. 2089 2090 Returns: 2091 Validated type, or none if neither expected_type nor items provided. 2092 2093 Raises: 2094 ValueError: If any types do not match. 2095 """ 2096 original_expected_type = expected_type 2097 mismatch = False 2098 for item in items: 2099 if item is not None: 2100 item_type = item.dtype.base_dtype 2101 if not expected_type: 2102 expected_type = item_type 2103 elif expected_type != item_type: 2104 mismatch = True 2105 break 2106 if mismatch: 2107 # Loop back through and build up an informative error message (this is very 2108 # slow, so we don't do it unless we found an error above). 2109 expected_type = original_expected_type 2110 original_item_str = None 2111 for item in items: 2112 if item is not None: 2113 item_type = item.dtype.base_dtype 2114 if not expected_type: 2115 expected_type = item_type 2116 original_item_str = item.name if hasattr(item, 'name') else str(item) 2117 elif expected_type != item_type: 2118 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 2119 item.name if hasattr(item, 'name') else str(item), 2120 item_type, expected_type, 2121 (' as %s' % original_item_str) if original_item_str else '')) 2122 return expected_type # Should be unreachable 2123 else: 2124 return expected_type 2125 2126 2127@tf_export( 2128 'debugging.assert_same_float_dtype', 2129 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) 2130@dispatch.add_dispatch_support 2131@deprecation.deprecated_endpoints('assert_same_float_dtype') 2132def assert_same_float_dtype(tensors=None, dtype=None): 2133 """Validate and return float type based on `tensors` and `dtype`. 2134 2135 For ops such as matrix multiplication, inputs and weights must be of the 2136 same float type. This function validates that all `tensors` are the same type, 2137 validates that type is `dtype` (if supplied), and returns the type. Type must 2138 be a floating point type. If neither `tensors` nor `dtype` is supplied, 2139 the function will return `dtypes.float32`. 2140 2141 Args: 2142 tensors: Tensors of input values. Can include `None` elements, which will be 2143 ignored. 2144 dtype: Expected type. 2145 2146 Returns: 2147 Validated type. 2148 2149 Raises: 2150 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 2151 float, or the common type of the inputs is not a floating point type. 2152 """ 2153 if tensors: 2154 dtype = _assert_same_base_type(tensors, dtype) 2155 if not dtype: 2156 dtype = dtypes.float32 2157 elif not dtype.is_floating: 2158 raise ValueError('Expected floating point type, got %s.' % dtype) 2159 return dtype 2160 2161 2162@tf_export('debugging.assert_scalar', v1=[]) 2163@dispatch.add_dispatch_support 2164def assert_scalar_v2(tensor, message=None, name=None): 2165 """Asserts that the given `tensor` is a scalar. 2166 2167 This function raises `ValueError` unless it can be certain that the given 2168 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2169 unknown. 2170 2171 This is always checked statically, so this method returns nothing. 2172 2173 Args: 2174 tensor: A `Tensor`. 2175 message: A string to prefix to the default message. 2176 name: A name for this operation. Defaults to "assert_scalar" 2177 2178 Raises: 2179 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2180 unknown. 2181 """ 2182 assert_scalar(tensor=tensor, message=message, name=name) 2183 2184 2185@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) 2186@dispatch.add_dispatch_support 2187@deprecation.deprecated_endpoints('assert_scalar') 2188def assert_scalar(tensor, name=None, message=None): 2189 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). 2190 2191 This function raises `ValueError` unless it can be certain that the given 2192 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2193 unknown. 2194 2195 Args: 2196 tensor: A `Tensor`. 2197 name: A name for this operation. Defaults to "assert_scalar" 2198 message: A string to prefix to the default message. 2199 2200 Returns: 2201 The input tensor (potentially converted to a `Tensor`). 2202 2203 Raises: 2204 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2205 unknown. 2206 """ 2207 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 2208 tensor = ops.convert_to_tensor(tensor, name=name_scope) 2209 shape = tensor.get_shape() 2210 if shape.ndims != 0: 2211 if context.executing_eagerly(): 2212 raise ValueError('%sExpected scalar shape, saw shape: %s.' 2213 % (message or '', shape,)) 2214 else: 2215 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' 2216 % (message or '', tensor.name, shape)) 2217 return tensor 2218 2219 2220@tf_export('ensure_shape') 2221@dispatch.add_dispatch_support 2222def ensure_shape(x, shape, name=None): 2223 """Updates the shape of a tensor and checks at runtime that the shape holds. 2224 2225 When executed, this operation asserts that the input tensor `x`'s shape 2226 is compatible with the `shape` argument. 2227 See `tf.TensorShape.is_compatible_with` for details. 2228 2229 >>> x = tf.constant([[1, 2, 3], 2230 ... [4, 5, 6]]) 2231 >>> x = tf.ensure_shape(x, [2, 3]) 2232 2233 Use `None` for unknown dimensions: 2234 2235 >>> x = tf.ensure_shape(x, [None, 3]) 2236 >>> x = tf.ensure_shape(x, [2, None]) 2237 2238 If the tensor's shape is not compatible with the `shape` argument, an error 2239 is raised: 2240 2241 >>> x = tf.ensure_shape(x, [5]) 2242 Traceback (most recent call last): 2243 ... 2244 tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not 2245 compatible with expected shape [5]. [Op:EnsureShape] 2246 2247 During graph construction (typically tracing a `tf.function`), 2248 `tf.ensure_shape` updates the static-shape of the **result** tensor by 2249 merging the two shapes. See `tf.TensorShape.merge_with` for details. 2250 2251 This is most useful when **you** know a shape that can't be determined 2252 statically by TensorFlow. 2253 2254 The following trivial `tf.function` prints the input tensor's 2255 static-shape before and after `ensure_shape` is applied. 2256 2257 >>> @tf.function 2258 ... def f(tensor): 2259 ... print("Static-shape before:", tensor.shape) 2260 ... tensor = tf.ensure_shape(tensor, [None, 3]) 2261 ... print("Static-shape after:", tensor.shape) 2262 ... return tensor 2263 2264 This lets you see the effect of `tf.ensure_shape` when the function is traced: 2265 >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) 2266 Static-shape before: (None, None) 2267 Static-shape after: (None, 3) 2268 2269 >>> cf(tf.zeros([3, 3])) # Passes 2270 >>> cf(tf.constant([1, 2, 3])) # fails 2271 Traceback (most recent call last): 2272 ... 2273 InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. 2274 2275 The above example raises `tf.errors.InvalidArgumentError`, because `x`'s 2276 shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` 2277 2278 Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and 2279 runtime shapes. This is stricter than `tf.Tensor.set_shape` which only 2280 checks the buildtime shape. 2281 2282 Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape 2283 of the resulting tensor and enforces it at runtime, raising an error if the 2284 tensor's runtime shape is incompatible with the specified shape. 2285 `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it 2286 at runtime, which may result in inconsistencies between the statically-known 2287 shape of tensors and the runtime value of tensors. 2288 2289 For example, of loading images of a known size: 2290 2291 >>> @tf.function 2292 ... def decode_image(png): 2293 ... image = tf.image.decode_png(png, channels=3) 2294 ... # the `print` executes during tracing. 2295 ... print("Initial shape: ", image.shape) 2296 ... image = tf.ensure_shape(image,[28, 28, 3]) 2297 ... print("Final shape: ", image.shape) 2298 ... return image 2299 2300 When tracing a function, no ops are being executed, shapes may be unknown. 2301 See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) 2302 for details. 2303 2304 >>> concrete_decode = decode_image.get_concrete_function( 2305 ... tf.TensorSpec([], dtype=tf.string)) 2306 Initial shape: (None, None, 3) 2307 Final shape: (28, 28, 3) 2308 2309 >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) 2310 >>> image = tf.cast(image,tf.uint8) 2311 >>> png = tf.image.encode_png(image) 2312 >>> image2 = concrete_decode(png) 2313 >>> print(image2.shape) 2314 (28, 28, 3) 2315 2316 >>> image = tf.concat([image,image], axis=0) 2317 >>> print(image.shape) 2318 (56, 28, 3) 2319 >>> png = tf.image.encode_png(image) 2320 >>> image2 = concrete_decode(png) 2321 Traceback (most recent call last): 2322 ... 2323 tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not 2324 compatible with expected shape [28,28,3]. 2325 2326 Caution: if you don't use the result of `tf.ensure_shape` the check may not 2327 run. 2328 2329 >>> @tf.function 2330 ... def bad_decode_image(png): 2331 ... image = tf.image.decode_png(png, channels=3) 2332 ... # the `print` executes during tracing. 2333 ... print("Initial shape: ", image.shape) 2334 ... # BAD: forgot to use the returned tensor. 2335 ... tf.ensure_shape(image,[28, 28, 3]) 2336 ... print("Final shape: ", image.shape) 2337 ... return image 2338 2339 >>> image = bad_decode_image(png) 2340 Initial shape: (None, None, 3) 2341 Final shape: (None, None, 3) 2342 >>> print(image.shape) 2343 (56, 28, 3) 2344 2345 Args: 2346 x: A `Tensor`. 2347 shape: A `TensorShape` representing the shape of this tensor, a 2348 `TensorShapeProto`, a list, a tuple, or None. 2349 name: A name for this operation (optional). Defaults to "EnsureShape". 2350 2351 Returns: 2352 A `Tensor`. Has the same type and contents as `x`. 2353 2354 Raises: 2355 tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape 2356 of `x`. 2357 """ 2358 if not isinstance(shape, tensor_shape.TensorShape): 2359 shape = tensor_shape.TensorShape(shape) 2360 2361 return array_ops.ensure_shape(x, shape, name=name) 2362 2363 2364@ops.RegisterGradient('EnsureShape') 2365def _ensure_shape_grad(op, grad): 2366 del op # Unused. 2367 return grad 2368