1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-short-docstring-punctuation 16"""Asserts and Boolean Checks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.util import compat 35from tensorflow.python.util import deprecation 36from tensorflow.python.util.tf_export import tf_export 37 38NUMERIC_TYPES = frozenset( 39 [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, 40 dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, 41 dtypes.complex64]) 42 43__all__ = [ 44 'assert_negative', 45 'assert_positive', 46 'assert_proper_iterable', 47 'assert_non_negative', 48 'assert_non_positive', 49 'assert_equal', 50 'assert_none_equal', 51 'assert_near', 52 'assert_integer', 53 'assert_less', 54 'assert_less_equal', 55 'assert_greater', 56 'assert_greater_equal', 57 'assert_rank', 58 'assert_rank_at_least', 59 'assert_rank_in', 60 'assert_same_float_dtype', 61 'assert_scalar', 62 'assert_type', 63 'is_non_decreasing', 64 'is_numeric_tensor', 65 'is_strictly_increasing', 66] 67 68 69def _maybe_constant_value_string(t): 70 if not isinstance(t, ops.Tensor): 71 return str(t) 72 const_t = tensor_util.constant_value(t) 73 if const_t is not None: 74 return str(const_t) 75 return t 76 77 78def _assert_static(condition, data): 79 """Raises a InvalidArgumentError with as much information as possible.""" 80 if not condition: 81 data_static = [_maybe_constant_value_string(x) for x in data] 82 raise errors.InvalidArgumentError(node_def=None, op=None, 83 message='\n'.join(data_static)) 84 85 86def _shape_and_dtype_str(tensor): 87 """Returns a string containing tensor's shape and dtype.""" 88 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 89 90 91@tf_export( 92 'debugging.assert_proper_iterable', 93 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) 94@deprecation.deprecated_endpoints('assert_proper_iterable') 95def assert_proper_iterable(values): 96 """Static assert that values is a "proper" iterable. 97 98 `Ops` that expect iterables of `Tensor` can call this to validate input. 99 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 100 101 Args: 102 values: Object to be checked. 103 104 Raises: 105 TypeError: If `values` is not iterable or is one of 106 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 107 """ 108 unintentional_iterables = ( 109 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 110 + compat.bytes_or_text_types 111 ) 112 if isinstance(values, unintentional_iterables): 113 raise TypeError( 114 'Expected argument "values" to be a "proper" iterable. Found: %s' % 115 type(values)) 116 117 if not hasattr(values, '__iter__'): 118 raise TypeError( 119 'Expected argument "values" to be iterable. Found: %s' % type(values)) 120 121 122@tf_export('debugging.assert_negative', v1=[]) 123def assert_negative_v2(x, message=None, summarize=None, name=None): 124 """Assert the condition `x < 0` holds element-wise. 125 126 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is 127 empty, this is trivially satisfied. 128 129 If `x` is not negative everywhere, `message`, as well as the first `summarize` 130 entries of `x` are printed, and `InvalidArgumentError` is raised. 131 132 Args: 133 x: Numeric `Tensor`. 134 message: A string to prefix to the default message. 135 summarize: Print this many entries of each tensor. 136 name: A name for this operation (optional). Defaults to "assert_negative". 137 138 Raises: 139 InvalidArgumentError: if the check can be performed immediately and 140 `x[i] < 0` is False. The check can be performed immediately during eager 141 execution or if `x` is statically known. 142 """ 143 assert_negative(x=x, message=message, summarize=summarize, name=name) 144 145 146@tf_export(v1=['debugging.assert_negative', 'assert_negative']) 147@deprecation.deprecated_endpoints('assert_negative') 148def assert_negative(x, data=None, summarize=None, message=None, name=None): 149 """Assert the condition `x < 0` holds element-wise. 150 151 Example of adding a dependency to an operation: 152 153 ```python 154 with tf.control_dependencies([tf.assert_negative(x)]): 155 output = tf.reduce_sum(x) 156 ``` 157 158 Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`. 159 If `x` is empty this is trivially satisfied. 160 161 Args: 162 x: Numeric `Tensor`. 163 data: The tensors to print out if the condition is False. Defaults to 164 error message and first few entries of `x`. 165 summarize: Print this many entries of each tensor. 166 message: A string to prefix to the default message. 167 name: A name for this operation (optional). Defaults to "assert_negative". 168 169 Returns: 170 Op raising `InvalidArgumentError` unless `x` is all negative. 171 """ 172 message = message or '' 173 with ops.name_scope(name, 'assert_negative', [x, data]): 174 x = ops.convert_to_tensor(x, name='x') 175 if data is None: 176 if context.executing_eagerly(): 177 name = _shape_and_dtype_str(x) 178 else: 179 name = x.name 180 data = [ 181 message, 182 'Condition x < 0 did not hold element-wise:', 183 'x (%s) = ' % name, x] 184 zero = ops.convert_to_tensor(0, dtype=x.dtype) 185 return assert_less(x, zero, data=data, summarize=summarize) 186 187 188@tf_export('debugging.assert_positive', v1=[]) 189def assert_positive_v2(x, message=None, summarize=None, name=None): 190 """Assert the condition `x > 0` holds element-wise. 191 192 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is 193 empty, this is trivially satisfied. 194 195 If `x` is not positive everywhere, `message`, as well as the first `summarize` 196 entries of `x` are printed, and `InvalidArgumentError` is raised. 197 198 Args: 199 x: Numeric `Tensor`. 200 message: A string to prefix to the default message. 201 summarize: Print this many entries of each tensor. 202 name: A name for this operation (optional). Defaults to "assert_positive". 203 204 Raises: 205 InvalidArgumentError: if the check can be performed immediately and 206 `x[i] > 0` is False. The check can be performed immediately during eager 207 execution or if `x` is statically known. 208 """ 209 assert_positive(x=x, summarize=summarize, message=message, name=name) 210 211 212@tf_export(v1=['debugging.assert_positive', 'assert_positive']) 213@deprecation.deprecated_endpoints('assert_positive') 214def assert_positive(x, data=None, summarize=None, message=None, name=None): 215 """Assert the condition `x > 0` holds element-wise. 216 217 Example of adding a dependency to an operation: 218 219 ```python 220 with tf.control_dependencies([tf.assert_positive(x)]): 221 output = tf.reduce_sum(x) 222 ``` 223 224 Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`. 225 If `x` is empty this is trivially satisfied. 226 227 Args: 228 x: Numeric `Tensor`. 229 data: The tensors to print out if the condition is False. Defaults to 230 error message and first few entries of `x`. 231 summarize: Print this many entries of each tensor. 232 message: A string to prefix to the default message. 233 name: A name for this operation (optional). Defaults to "assert_positive". 234 235 Returns: 236 Op raising `InvalidArgumentError` unless `x` is all positive. 237 """ 238 message = message or '' 239 with ops.name_scope(name, 'assert_positive', [x, data]): 240 x = ops.convert_to_tensor(x, name='x') 241 if data is None: 242 if context.executing_eagerly(): 243 name = _shape_and_dtype_str(x) 244 else: 245 name = x.name 246 data = [ 247 message, 'Condition x > 0 did not hold element-wise:', 248 'x (%s) = ' % name, x] 249 zero = ops.convert_to_tensor(0, dtype=x.dtype) 250 return assert_less(zero, x, data=data, summarize=summarize) 251 252 253@tf_export('debugging.assert_non_negative', v1=[]) 254def assert_non_negative_v2(x, message=None, summarize=None, name=None): 255 """Assert the condition `x >= 0` holds element-wise. 256 257 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is 258 empty, this is trivially satisfied. 259 260 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize` 261 entries of `x` are printed, and `InvalidArgumentError` is raised. 262 263 Args: 264 x: Numeric `Tensor`. 265 message: A string to prefix to the default message. 266 summarize: Print this many entries of each tensor. 267 name: A name for this operation (optional). Defaults to 268 "assert_non_negative". 269 270 Raises: 271 InvalidArgumentError: if the check can be performed immediately and 272 `x[i] >= 0` is False. The check can be performed immediately during eager 273 execution or if `x` is statically known. 274 """ 275 assert_non_negative(x=x, summarize=summarize, message=message, name=name) 276 277 278@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) 279@deprecation.deprecated_endpoints('assert_non_negative') 280def assert_non_negative(x, data=None, summarize=None, message=None, name=None): 281 """Assert the condition `x >= 0` holds element-wise. 282 283 Example of adding a dependency to an operation: 284 285 ```python 286 with tf.control_dependencies([tf.assert_non_negative(x)]): 287 output = tf.reduce_sum(x) 288 ``` 289 290 Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`. 291 If `x` is empty this is trivially satisfied. 292 293 Args: 294 x: Numeric `Tensor`. 295 data: The tensors to print out if the condition is False. Defaults to 296 error message and first few entries of `x`. 297 summarize: Print this many entries of each tensor. 298 message: A string to prefix to the default message. 299 name: A name for this operation (optional). 300 Defaults to "assert_non_negative". 301 302 Returns: 303 Op raising `InvalidArgumentError` unless `x` is all non-negative. 304 """ 305 message = message or '' 306 with ops.name_scope(name, 'assert_non_negative', [x, data]): 307 x = ops.convert_to_tensor(x, name='x') 308 if data is None: 309 if context.executing_eagerly(): 310 name = _shape_and_dtype_str(x) 311 else: 312 name = x.name 313 data = [ 314 message, 315 'Condition x >= 0 did not hold element-wise:', 316 'x (%s) = ' % name, x] 317 zero = ops.convert_to_tensor(0, dtype=x.dtype) 318 return assert_less_equal(zero, x, data=data, summarize=summarize) 319 320 321@tf_export('debugging.assert_non_positive', v1=[]) 322def assert_non_positive_v2(x, message=None, summarize=None, name=None): 323 """Assert the condition `x <= 0` holds element-wise. 324 325 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is 326 empty, this is trivially satisfied. 327 328 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize` 329 entries of `x` are printed, and `InvalidArgumentError` is raised. 330 331 Args: 332 x: Numeric `Tensor`. 333 message: A string to prefix to the default message. 334 summarize: Print this many entries of each tensor. 335 name: A name for this operation (optional). Defaults to 336 "assert_non_positive". 337 338 Raises: 339 InvalidArgumentError: if the check can be performed immediately and 340 `x[i] <= 0` is False. The check can be performed immediately during eager 341 execution or if `x` is statically known. 342 """ 343 assert_non_positive(x=x, summarize=summarize, message=message, name=name) 344 345 346@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) 347@deprecation.deprecated_endpoints('assert_non_positive') 348def assert_non_positive(x, data=None, summarize=None, message=None, name=None): 349 """Assert the condition `x <= 0` holds element-wise. 350 351 Example of adding a dependency to an operation: 352 353 ```python 354 with tf.control_dependencies([tf.assert_non_positive(x)]): 355 output = tf.reduce_sum(x) 356 ``` 357 358 Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`. 359 If `x` is empty this is trivially satisfied. 360 361 Args: 362 x: Numeric `Tensor`. 363 data: The tensors to print out if the condition is False. Defaults to 364 error message and first few entries of `x`. 365 summarize: Print this many entries of each tensor. 366 message: A string to prefix to the default message. 367 name: A name for this operation (optional). 368 Defaults to "assert_non_positive". 369 370 Returns: 371 Op raising `InvalidArgumentError` unless `x` is all non-positive. 372 """ 373 message = message or '' 374 with ops.name_scope(name, 'assert_non_positive', [x, data]): 375 x = ops.convert_to_tensor(x, name='x') 376 if data is None: 377 if context.executing_eagerly(): 378 name = _shape_and_dtype_str(x) 379 else: 380 name = x.name 381 data = [ 382 message, 383 'Condition x <= 0 did not hold element-wise:' 384 'x (%s) = ' % name, x] 385 zero = ops.convert_to_tensor(0, dtype=x.dtype) 386 return assert_less_equal(x, zero, data=data, summarize=summarize) 387 388 389@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) 390def assert_equal_v2(x, y, message=None, summarize=None, name=None): 391 """Assert the condition `x == y` holds element-wise. 392 393 This Op checks that `x[i] == y[i]` holds for every pair of (possibly 394 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 395 trivially satisfied. 396 397 If `x` and `y` are not equal, `message`, as well as the first `summarize` 398 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. 399 400 Args: 401 x: Numeric `Tensor`. 402 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 403 message: A string to prefix to the default message. 404 summarize: Print this many entries of each tensor. 405 name: A name for this operation (optional). Defaults to "assert_equal". 406 407 Raises: 408 InvalidArgumentError: if the check can be performed immediately and 409 `x == y` is False. The check can be performed immediately during eager 410 execution or if `x` and `y` are statically known. 411 """ 412 assert_equal(x=x, y=y, summarize=summarize, message=message, name=name) 413 414 415@tf_export(v1=['debugging.assert_equal', 'assert_equal']) 416def assert_equal(x, y, data=None, summarize=None, message=None, name=None): 417 """Assert the condition `x == y` holds element-wise. 418 419 Example of adding a dependency to an operation: 420 421 ```python 422 with tf.control_dependencies([tf.assert_equal(x, y)]): 423 output = tf.reduce_sum(x) 424 ``` 425 426 This condition holds if for every pair of (possibly broadcast) elements 427 `x[i]`, `y[i]`, we have `x[i] == y[i]`. 428 If both `x` and `y` are empty, this is trivially satisfied. 429 430 Args: 431 x: Numeric `Tensor`. 432 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 433 data: The tensors to print out if the condition is False. Defaults to 434 error message and first few entries of `x`, `y`. 435 summarize: Print this many entries of each tensor. 436 message: A string to prefix to the default message. 437 name: A name for this operation (optional). Defaults to "assert_equal". 438 439 Returns: 440 Op that raises `InvalidArgumentError` if `x == y` is False. 441 @compatibility{eager} returns None 442 443 Raises: 444 InvalidArgumentError: if the check can be performed immediately and 445 `x == y` is False. The check can be performed immediately during eager 446 execution or if `x` and `y` are statically known. 447 """ 448 message = message or '' 449 with ops.name_scope(name, 'assert_equal', [x, y, data]): 450 x = ops.convert_to_tensor(x, name='x') 451 y = ops.convert_to_tensor(y, name='y') 452 453 if context.executing_eagerly(): 454 eq = math_ops.equal(x, y) 455 condition = math_ops.reduce_all(eq) 456 if not condition: 457 # Prepare a message with first elements of x and y. 458 summary_msg = '' 459 # Default to printing 3 elements like control_flow_ops.Assert (used 460 # by graph mode) does. 461 summarize = 3 if summarize is None else summarize 462 if summarize: 463 # reshape((-1,)) is the fastest way to get a flat array view. 464 x_np = x.numpy().reshape((-1,)) 465 y_np = y.numpy().reshape((-1,)) 466 x_sum = min(x_np.size, summarize) 467 y_sum = min(y_np.size, summarize) 468 summary_msg = ('First %d elements of x:\n%s\n' 469 'First %d elements of y:\n%s\n' % 470 (x_sum, x_np[:x_sum], 471 y_sum, y_np[:y_sum])) 472 473 index_and_values_str = '' 474 if x.shape == y.shape and x.shape.as_list(): 475 # If the shapes of x and y are the same (and not scalars), 476 # Get the values that actually differed and their indices. 477 # If shapes are different this information is more confusing 478 # than useful. 479 mask = math_ops.logical_not(eq) 480 indices = array_ops.where(mask) 481 indices_np = indices.numpy() 482 x_vals = array_ops.boolean_mask(x, mask) 483 y_vals = array_ops.boolean_mask(y, mask) 484 summarize = min(summarize, indices_np.shape[0]) 485 index_and_values_str = ( 486 'Indices of first %s different values:\n%s\n' 487 'Corresponding x values:\n%s\n' 488 'Corresponding y values:\n%s\n' % 489 (summarize, indices_np[:summarize], 490 x_vals.numpy().reshape((-1,))[:summarize], 491 y_vals.numpy().reshape((-1,))[:summarize])) 492 493 raise errors.InvalidArgumentError( 494 node_def=None, op=None, 495 message=('%s\nCondition x == y did not hold.\n%s%s' % 496 (message or '', index_and_values_str, summary_msg))) 497 return 498 499 if data is None: 500 data = [ 501 message, 502 'Condition x == y did not hold element-wise:', 503 'x (%s) = ' % x.name, x, 504 'y (%s) = ' % y.name, y 505 ] 506 condition = math_ops.reduce_all(math_ops.equal(x, y)) 507 x_static = tensor_util.constant_value(x) 508 y_static = tensor_util.constant_value(y) 509 if x_static is not None and y_static is not None: 510 condition_static = (x_static == y_static).all() 511 _assert_static(condition_static, data) 512 return control_flow_ops.Assert(condition, data, summarize=summarize) 513 514 515@tf_export('debugging.assert_none_equal', v1=[]) 516def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): 517 """Assert the condition `x != y` holds for all elements. 518 519 This Op checks that `x[i] != y[i]` holds for every pair of (possibly 520 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 521 trivially satisfied. 522 523 If any elements of `x` and `y` are equal, `message`, as well as the first 524 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 525 is raised. 526 527 Args: 528 x: Numeric `Tensor`. 529 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 530 summarize: Print this many entries of each tensor. 531 message: A string to prefix to the default message. 532 name: A name for this operation (optional). Defaults to 533 "assert_none_equal". 534 535 Raises: 536 InvalidArgumentError: if the check can be performed immediately and 537 `x != y` is False for any pair of elements in `x` and `y`. The check can 538 be performed immediately during eager execution or if `x` and `y` are 539 statically known. 540 """ 541 assert_none_equal(x=x, y=y, summarize=summarize, message=message, name=name) 542 543 544@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) 545@deprecation.deprecated_endpoints('assert_none_equal') 546def assert_none_equal( 547 x, y, data=None, summarize=None, message=None, name=None): 548 """Assert the condition `x != y` holds for all elements. 549 550 Example of adding a dependency to an operation: 551 552 ```python 553 with tf.control_dependencies([tf.assert_none_equal(x, y)]): 554 output = tf.reduce_sum(x) 555 ``` 556 557 This condition holds if for every pair of (possibly broadcast) elements 558 `x[i]`, `y[i]`, we have `x[i] != y[i]`. 559 If both `x` and `y` are empty, this is trivially satisfied. 560 561 Args: 562 x: Numeric `Tensor`. 563 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 564 data: The tensors to print out if the condition is False. Defaults to 565 error message and first few entries of `x`, `y`. 566 summarize: Print this many entries of each tensor. 567 message: A string to prefix to the default message. 568 name: A name for this operation (optional). 569 Defaults to "assert_none_equal". 570 571 Returns: 572 Op that raises `InvalidArgumentError` if `x != y` is ever False. 573 """ 574 message = message or '' 575 with ops.name_scope(name, 'assert_none_equal', [x, y, data]): 576 x = ops.convert_to_tensor(x, name='x') 577 y = ops.convert_to_tensor(y, name='y') 578 if context.executing_eagerly(): 579 x_name = _shape_and_dtype_str(x) 580 y_name = _shape_and_dtype_str(y) 581 else: 582 x_name = x.name 583 y_name = y.name 584 585 if data is None: 586 data = [ 587 message, 588 'Condition x != y did not hold for every single element:', 589 'x (%s) = ' % x_name, x, 590 'y (%s) = ' % y_name, y 591 ] 592 condition = math_ops.reduce_all(math_ops.not_equal(x, y)) 593 return control_flow_ops.Assert(condition, data, summarize=summarize) 594 595 596@tf_export('debugging.assert_near', v1=[]) 597def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, 598 name=None): 599 """Assert the condition `x` and `y` are close element-wise. 600 601 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every 602 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are 603 empty, this is trivially satisfied. 604 605 If any elements of `x` and `y` are not close, `message`, as well as the first 606 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 607 is raised. 608 609 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 610 representable positive number such that `1 + eps != 1`. This is about 611 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 612 See `numpy.finfo`. 613 614 Args: 615 x: Float or complex `Tensor`. 616 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. 617 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 618 The relative tolerance. Default is `10 * eps`. 619 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 620 The absolute tolerance. Default is `10 * eps`. 621 message: A string to prefix to the default message. 622 summarize: Print this many entries of each tensor. 623 name: A name for this operation (optional). Defaults to "assert_near". 624 625 Raises: 626 InvalidArgumentError: if the check can be performed immediately and 627 `x != y` is False for any pair of elements in `x` and `y`. The check can 628 be performed immediately during eager execution or if `x` and `y` are 629 statically known. 630 631 @compatibility(numpy) 632 Similar to `numpy.assert_allclose`, except tolerance depends on data type. 633 This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, 634 and even `16bit` data. 635 @end_compatibility 636 """ 637 assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, 638 message=message, name=name) 639 640 641@tf_export(v1=['debugging.assert_near', 'assert_near']) 642@deprecation.deprecated_endpoints('assert_near') 643def assert_near( 644 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 645 name=None): 646 """Assert the condition `x` and `y` are close element-wise. 647 648 Example of adding a dependency to an operation: 649 650 ```python 651 with tf.control_dependencies([tf.assert_near(x, y)]): 652 output = tf.reduce_sum(x) 653 ``` 654 655 This condition holds if for every pair of (possibly broadcast) elements 656 `x[i]`, `y[i]`, we have 657 658 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 659 660 If both `x` and `y` are empty, this is trivially satisfied. 661 662 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 663 representable positive number such that `1 + eps != 1`. This is about 664 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 665 See `numpy.finfo`. 666 667 Args: 668 x: Float or complex `Tensor`. 669 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 670 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 671 The relative tolerance. Default is `10 * eps`. 672 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 673 The absolute tolerance. Default is `10 * eps`. 674 data: The tensors to print out if the condition is False. Defaults to 675 error message and first few entries of `x`, `y`. 676 summarize: Print this many entries of each tensor. 677 message: A string to prefix to the default message. 678 name: A name for this operation (optional). Defaults to "assert_near". 679 680 Returns: 681 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 682 683 @compatibility(numpy) 684 Similar to `numpy.assert_allclose`, except tolerance depends on data type. 685 This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, 686 and even `16bit` data. 687 @end_compatibility 688 """ 689 message = message or '' 690 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 691 x = ops.convert_to_tensor(x, name='x') 692 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 693 694 eps = np.finfo(x.dtype.as_numpy_dtype).eps 695 rtol = 10 * eps if rtol is None else rtol 696 atol = 10 * eps if atol is None else atol 697 698 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) 699 atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) 700 701 if context.executing_eagerly(): 702 x_name = _shape_and_dtype_str(x) 703 y_name = _shape_and_dtype_str(y) 704 else: 705 x_name = x.name 706 y_name = y.name 707 708 if data is None: 709 data = [ 710 message, 711 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 712 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 713 ] 714 tol = atol + rtol * math_ops.abs(y) 715 diff = math_ops.abs(x - y) 716 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 717 return control_flow_ops.Assert(condition, data, summarize=summarize) 718 719 720@tf_export('debugging.assert_less', 'assert_less', v1=[]) 721def assert_less_v2(x, y, message=None, summarize=None, name=None): 722 """Assert the condition `x < y` holds element-wise. 723 724 This Op checks that `x[i] < y[i]` holds for every pair of (possibly 725 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 726 trivially satisfied. 727 728 If `x` is not less than `y` element-wise, `message`, as well as the first 729 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 730 raised. 731 732 Args: 733 x: Numeric `Tensor`. 734 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 735 message: A string to prefix to the default message. 736 summarize: Print this many entries of each tensor. 737 name: A name for this operation (optional). Defaults to "assert_less". 738 739 Raises: 740 InvalidArgumentError: if the check can be performed immediately and 741 `x < y` is False. The check can be performed immediately during eager 742 execution or if `x` and `y` are statically known. 743 """ 744 assert_less(x=x, y=y, summarize=summarize, message=message, name=name) 745 746 747@tf_export(v1=['debugging.assert_less', 'assert_less']) 748def assert_less(x, y, data=None, summarize=None, message=None, name=None): 749 """Assert the condition `x < y` holds element-wise. 750 751 Example of adding a dependency to an operation: 752 753 ```python 754 with tf.control_dependencies([tf.assert_less(x, y)]): 755 output = tf.reduce_sum(x) 756 ``` 757 758 This condition holds if for every pair of (possibly broadcast) elements 759 `x[i]`, `y[i]`, we have `x[i] < y[i]`. 760 If both `x` and `y` are empty, this is trivially satisfied. 761 762 Args: 763 x: Numeric `Tensor`. 764 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 765 data: The tensors to print out if the condition is False. Defaults to 766 error message and first few entries of `x`, `y`. 767 summarize: Print this many entries of each tensor. 768 message: A string to prefix to the default message. 769 name: A name for this operation (optional). Defaults to "assert_less". 770 771 Returns: 772 Op that raises `InvalidArgumentError` if `x < y` is False. 773 """ 774 message = message or '' 775 with ops.name_scope(name, 'assert_less', [x, y, data]): 776 x = ops.convert_to_tensor(x, name='x') 777 y = ops.convert_to_tensor(y, name='y') 778 if context.executing_eagerly(): 779 x_name = _shape_and_dtype_str(x) 780 y_name = _shape_and_dtype_str(y) 781 else: 782 x_name = x.name 783 y_name = y.name 784 785 if data is None: 786 data = [ 787 message, 788 'Condition x < y did not hold element-wise:', 789 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 790 ] 791 condition = math_ops.reduce_all(math_ops.less(x, y)) 792 return control_flow_ops.Assert(condition, data, summarize=summarize) 793 794 795@tf_export('debugging.assert_less_equal', v1=[]) 796def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): 797 """Assert the condition `x <= y` holds element-wise. 798 799 This Op checks that `x[i] <= y[i]` holds for every pair of (possibly 800 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 801 trivially satisfied. 802 803 If `x` is not less or equal than `y` element-wise, `message`, as well as the 804 first `summarize` entries of `x` and `y` are printed, and 805 `InvalidArgumentError` is raised. 806 807 Args: 808 x: Numeric `Tensor`. 809 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 810 message: A string to prefix to the default message. 811 summarize: Print this many entries of each tensor. 812 name: A name for this operation (optional). Defaults to "assert_less_equal". 813 814 Raises: 815 InvalidArgumentError: if the check can be performed immediately and 816 `x <= y` is False. The check can be performed immediately during eager 817 execution or if `x` and `y` are statically known. 818 """ 819 assert_less_equal(x=x, y=y, summarize=summarize, message=message, name=name) 820 821 822@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) 823@deprecation.deprecated_endpoints('assert_less_equal') 824def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 825 """Assert the condition `x <= y` holds element-wise. 826 827 Example of adding a dependency to an operation: 828 829 ```python 830 with tf.control_dependencies([tf.assert_less_equal(x, y)]): 831 output = tf.reduce_sum(x) 832 ``` 833 834 This condition holds if for every pair of (possibly broadcast) elements 835 `x[i]`, `y[i]`, we have `x[i] <= y[i]`. 836 If both `x` and `y` are empty, this is trivially satisfied. 837 838 Args: 839 x: Numeric `Tensor`. 840 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 841 data: The tensors to print out if the condition is False. Defaults to 842 error message and first few entries of `x`, `y`. 843 summarize: Print this many entries of each tensor. 844 message: A string to prefix to the default message. 845 name: A name for this operation (optional). Defaults to "assert_less_equal" 846 847 Returns: 848 Op that raises `InvalidArgumentError` if `x <= y` is False. 849 """ 850 message = message or '' 851 with ops.name_scope(name, 'assert_less_equal', [x, y, data]): 852 x = ops.convert_to_tensor(x, name='x') 853 y = ops.convert_to_tensor(y, name='y') 854 if context.executing_eagerly(): 855 x_name = _shape_and_dtype_str(x) 856 y_name = _shape_and_dtype_str(y) 857 else: 858 x_name = x.name 859 y_name = y.name 860 861 if data is None: 862 data = [ 863 message, 864 'Condition x <= y did not hold element-wise:' 865 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 866 ] 867 condition = math_ops.reduce_all(math_ops.less_equal(x, y)) 868 return control_flow_ops.Assert(condition, data, summarize=summarize) 869 870 871@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) 872def assert_greater_v2(x, y, message=None, summarize=None, name=None): 873 """Assert the condition `x > y` holds element-wise. 874 875 This Op checks that `x[i] > y[i]` holds for every pair of (possibly 876 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 877 trivially satisfied. 878 879 If `x` is not greater than `y` element-wise, `message`, as well as the first 880 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 881 raised. 882 883 Args: 884 x: Numeric `Tensor`. 885 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 886 message: A string to prefix to the default message. 887 summarize: Print this many entries of each tensor. 888 name: A name for this operation (optional). Defaults to "assert_greater". 889 890 Raises: 891 InvalidArgumentError: if the check can be performed immediately and 892 `x > y` is False. The check can be performed immediately during eager 893 execution or if `x` and `y` are statically known. 894 """ 895 assert_greater(x=x, y=y, summarize=summarize, message=message, name=name) 896 897 898@tf_export(v1=['debugging.assert_greater', 'assert_greater']) 899def assert_greater(x, y, data=None, summarize=None, message=None, name=None): 900 """Assert the condition `x > y` holds element-wise. 901 902 Example of adding a dependency to an operation: 903 904 ```python 905 with tf.control_dependencies([tf.assert_greater(x, y)]): 906 output = tf.reduce_sum(x) 907 ``` 908 909 This condition holds if for every pair of (possibly broadcast) elements 910 `x[i]`, `y[i]`, we have `x[i] > y[i]`. 911 If both `x` and `y` are empty, this is trivially satisfied. 912 913 Args: 914 x: Numeric `Tensor`. 915 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 916 data: The tensors to print out if the condition is False. Defaults to 917 error message and first few entries of `x`, `y`. 918 summarize: Print this many entries of each tensor. 919 message: A string to prefix to the default message. 920 name: A name for this operation (optional). Defaults to "assert_greater". 921 922 Returns: 923 Op that raises `InvalidArgumentError` if `x > y` is False. 924 """ 925 message = message or '' 926 with ops.name_scope(name, 'assert_greater', [x, y, data]): 927 x = ops.convert_to_tensor(x, name='x') 928 y = ops.convert_to_tensor(y, name='y') 929 if context.executing_eagerly(): 930 x_name = _shape_and_dtype_str(x) 931 y_name = _shape_and_dtype_str(y) 932 else: 933 x_name = x.name 934 y_name = y.name 935 936 if data is None: 937 data = [ 938 message, 939 'Condition x > y did not hold element-wise:' 940 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 941 ] 942 condition = math_ops.reduce_all(math_ops.greater(x, y)) 943 return control_flow_ops.Assert(condition, data, summarize=summarize) 944 945 946@tf_export('debugging.assert_greater_equal', v1=[]) 947def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): 948 """Assert the condition `x >= y` holds element-wise. 949 950 This Op checks that `x[i] >= y[i]` holds for every pair of (possibly 951 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 952 trivially satisfied. 953 954 If `x` is not greater or equal to `y` element-wise, `message`, as well as the 955 first `summarize` entries of `x` and `y` are printed, and 956 `InvalidArgumentError` is raised. 957 958 Args: 959 x: Numeric `Tensor`. 960 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 961 message: A string to prefix to the default message. 962 summarize: Print this many entries of each tensor. 963 name: A name for this operation (optional). Defaults to 964 "assert_greater_equal". 965 966 Raises: 967 InvalidArgumentError: if the check can be performed immediately and 968 `x >= y` is False. The check can be performed immediately during eager 969 execution or if `x` and `y` are statically known. 970 """ 971 assert_greater_equal(x=x, y=y, summarize=summarize, message=message, 972 name=name) 973 974 975@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) 976@deprecation.deprecated_endpoints('assert_greater_equal') 977def assert_greater_equal(x, y, data=None, summarize=None, message=None, 978 name=None): 979 """Assert the condition `x >= y` holds element-wise. 980 981 Example of adding a dependency to an operation: 982 983 ```python 984 with tf.control_dependencies([tf.assert_greater_equal(x, y)]): 985 output = tf.reduce_sum(x) 986 ``` 987 988 This condition holds if for every pair of (possibly broadcast) elements 989 `x[i]`, `y[i]`, we have `x[i] >= y[i]`. 990 If both `x` and `y` are empty, this is trivially satisfied. 991 992 Args: 993 x: Numeric `Tensor`. 994 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 995 data: The tensors to print out if the condition is False. Defaults to 996 error message and first few entries of `x`, `y`. 997 summarize: Print this many entries of each tensor. 998 message: A string to prefix to the default message. 999 name: A name for this operation (optional). Defaults to 1000 "assert_greater_equal" 1001 1002 Returns: 1003 Op that raises `InvalidArgumentError` if `x >= y` is False. 1004 """ 1005 message = message or '' 1006 with ops.name_scope(name, 'assert_greater_equal', [x, y, data]): 1007 x = ops.convert_to_tensor(x, name='x') 1008 y = ops.convert_to_tensor(y, name='y') 1009 if context.executing_eagerly(): 1010 x_name = _shape_and_dtype_str(x) 1011 y_name = _shape_and_dtype_str(y) 1012 else: 1013 x_name = x.name 1014 y_name = y.name 1015 1016 if data is None: 1017 data = [ 1018 message, 1019 'Condition x >= y did not hold element-wise:' 1020 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 1021 ] 1022 condition = math_ops.reduce_all(math_ops.greater_equal(x, y)) 1023 return control_flow_ops.Assert(condition, data, summarize=summarize) 1024 1025 1026def _assert_rank_condition( 1027 x, rank, static_condition, dynamic_condition, data, summarize): 1028 """Assert `x` has a rank that satisfies a given condition. 1029 1030 Args: 1031 x: Numeric `Tensor`. 1032 rank: Scalar `Tensor`. 1033 static_condition: A python function that takes `[actual_rank, given_rank]` 1034 and returns `True` if the condition is satisfied, `False` otherwise. 1035 dynamic_condition: An `op` that takes [actual_rank, given_rank] 1036 and return `True` if the condition is satisfied, `False` otherwise. 1037 data: The tensors to print out if the condition is false. Defaults to 1038 error message and first few entries of `x`. 1039 summarize: Print this many entries of each tensor. 1040 1041 Returns: 1042 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1043 1044 Raises: 1045 ValueError: If static checks determine `x` fails static_condition. 1046 """ 1047 assert_type(rank, dtypes.int32) 1048 1049 # Attempt to statically defined rank. 1050 rank_static = tensor_util.constant_value(rank) 1051 if rank_static is not None: 1052 if rank_static.ndim != 0: 1053 raise ValueError('Rank must be a scalar.') 1054 1055 x_rank_static = x.get_shape().ndims 1056 if x_rank_static is not None: 1057 if not static_condition(x_rank_static, rank_static): 1058 raise ValueError( 1059 'Static rank condition failed', x_rank_static, rank_static) 1060 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1061 1062 condition = dynamic_condition(array_ops.rank(x), rank) 1063 1064 # Add the condition that `rank` must have rank zero. Prevents the bug where 1065 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1066 if rank_static is None: 1067 this_data = ['Rank must be a scalar. Received rank: ', rank] 1068 rank_check = assert_rank(rank, 0, data=this_data) 1069 condition = control_flow_ops.with_dependencies([rank_check], condition) 1070 1071 return control_flow_ops.Assert(condition, data, summarize=summarize) 1072 1073 1074@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) 1075def assert_rank_v2(x, rank, message=None, name=None): 1076 """Assert that `x` has rank equal to `rank`. 1077 1078 This Op checks that the rank of `x` is equal to `rank`. 1079 1080 If `x` has a different rank, `message`, as well as the shape of `x` are 1081 printed, and `InvalidArgumentError` is raised. 1082 1083 Args: 1084 x: `Tensor`. 1085 rank: Scalar integer `Tensor`. 1086 message: A string to prefix to the default message. 1087 name: A name for this operation (optional). Defaults to 1088 "assert_rank". 1089 1090 Raises: 1091 InvalidArgumentError: if the check can be performed immediately and 1092 `x` does not have rank `rank`. The check can be performed immediately 1093 during eager execution or if the shape of `x` is statically known. 1094 """ 1095 assert_rank(x=x, rank=rank, message=message, name=name) 1096 1097 1098@tf_export(v1=['debugging.assert_rank', 'assert_rank']) 1099def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 1100 """Assert `x` has rank equal to `rank`. 1101 1102 Example of adding a dependency to an operation: 1103 1104 ```python 1105 with tf.control_dependencies([tf.assert_rank(x, 2)]): 1106 output = tf.reduce_sum(x) 1107 ``` 1108 1109 Args: 1110 x: Numeric `Tensor`. 1111 rank: Scalar integer `Tensor`. 1112 data: The tensors to print out if the condition is False. Defaults to 1113 error message and the shape of `x`. 1114 summarize: Print this many entries of each tensor. 1115 message: A string to prefix to the default message. 1116 name: A name for this operation (optional). Defaults to "assert_rank". 1117 1118 Returns: 1119 Op raising `InvalidArgumentError` unless `x` has specified rank. 1120 If static checks determine `x` has correct rank, a `no_op` is returned. 1121 1122 Raises: 1123 ValueError: If static checks determine `x` has wrong rank. 1124 """ 1125 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 1126 x = ops.convert_to_tensor(x, name='x') 1127 rank = ops.convert_to_tensor(rank, name='rank') 1128 message = message or '' 1129 1130 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 1131 dynamic_condition = math_ops.equal 1132 1133 if context.executing_eagerly(): 1134 name = '' 1135 else: 1136 name = x.name 1137 1138 if data is None: 1139 data = [ 1140 message, 1141 'Tensor %s must have rank' % name, rank, 'Received shape: ', 1142 array_ops.shape(x) 1143 ] 1144 1145 try: 1146 assert_op = _assert_rank_condition(x, rank, static_condition, 1147 dynamic_condition, data, summarize) 1148 1149 except ValueError as e: 1150 if e.args[0] == 'Static rank condition failed': 1151 raise ValueError( 1152 '%s. Tensor %s must have rank %d. Received rank %d, shape %s' % 1153 (message, name, e.args[2], e.args[1], x.get_shape())) 1154 else: 1155 raise 1156 1157 return assert_op 1158 1159 1160@tf_export('debugging.assert_rank_at_least', v1=[]) 1161def assert_rank_at_least_v2(x, rank, message=None, name=None): 1162 """Assert that `x` has rank of at least `rank`. 1163 1164 This Op checks that the rank of `x` is greater or equal to `rank`. 1165 1166 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` 1167 are printed, and `InvalidArgumentError` is raised. 1168 1169 Args: 1170 x: `Tensor`. 1171 rank: Scalar integer `Tensor`. 1172 message: A string to prefix to the default message. 1173 name: A name for this operation (optional). Defaults to 1174 "assert_rank_at_least". 1175 1176 Raises: 1177 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank 1178 cannot be statically determined. 1179 ValueError: If static checks determine `x` has mismatched rank. 1180 """ 1181 assert_rank_at_least(x=x, rank=rank, message=message, name=name) 1182 1183 1184@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) 1185@deprecation.deprecated_endpoints('assert_rank_at_least') 1186def assert_rank_at_least( 1187 x, rank, data=None, summarize=None, message=None, name=None): 1188 """Assert `x` has rank equal to `rank` or higher. 1189 1190 Example of adding a dependency to an operation: 1191 1192 ```python 1193 with tf.control_dependencies([tf.assert_rank_at_least(x, 2)]): 1194 output = tf.reduce_sum(x) 1195 ``` 1196 1197 Args: 1198 x: Numeric `Tensor`. 1199 rank: Scalar `Tensor`. 1200 data: The tensors to print out if the condition is False. Defaults to 1201 error message and first few entries of `x`. 1202 summarize: Print this many entries of each tensor. 1203 message: A string to prefix to the default message. 1204 name: A name for this operation (optional). 1205 Defaults to "assert_rank_at_least". 1206 1207 Returns: 1208 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1209 If static checks determine `x` has correct rank, a `no_op` is returned. 1210 1211 Raises: 1212 ValueError: If static checks determine `x` has wrong rank. 1213 """ 1214 with ops.name_scope( 1215 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 1216 x = ops.convert_to_tensor(x, name='x') 1217 rank = ops.convert_to_tensor(rank, name='rank') 1218 message = message or '' 1219 1220 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 1221 dynamic_condition = math_ops.greater_equal 1222 1223 if context.executing_eagerly(): 1224 name = '' 1225 else: 1226 name = x.name 1227 1228 if data is None: 1229 data = [ 1230 message, 1231 'Tensor %s must have rank at least' % name, rank, 1232 'Received shape: ', array_ops.shape(x) 1233 ] 1234 1235 try: 1236 assert_op = _assert_rank_condition(x, rank, static_condition, 1237 dynamic_condition, data, summarize) 1238 1239 except ValueError as e: 1240 if e.args[0] == 'Static rank condition failed': 1241 raise ValueError( 1242 '%s. Tensor %s must have rank at least %d. Received rank %d, ' 1243 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1244 else: 1245 raise 1246 1247 return assert_op 1248 1249 1250def _static_rank_in(actual_rank, given_ranks): 1251 return actual_rank in given_ranks 1252 1253 1254def _dynamic_rank_in(actual_rank, given_ranks): 1255 if len(given_ranks) < 1: 1256 return ops.convert_to_tensor(False) 1257 result = math_ops.equal(given_ranks[0], actual_rank) 1258 for given_rank in given_ranks[1:]: 1259 result = math_ops.logical_or( 1260 result, math_ops.equal(given_rank, actual_rank)) 1261 return result 1262 1263 1264def _assert_ranks_condition( 1265 x, ranks, static_condition, dynamic_condition, data, summarize): 1266 """Assert `x` has a rank that satisfies a given condition. 1267 1268 Args: 1269 x: Numeric `Tensor`. 1270 ranks: Scalar `Tensor`. 1271 static_condition: A python function that takes 1272 `[actual_rank, given_ranks]` and returns `True` if the condition is 1273 satisfied, `False` otherwise. 1274 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 1275 and return `True` if the condition is satisfied, `False` otherwise. 1276 data: The tensors to print out if the condition is false. Defaults to 1277 error message and first few entries of `x`. 1278 summarize: Print this many entries of each tensor. 1279 1280 Returns: 1281 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1282 1283 Raises: 1284 ValueError: If static checks determine `x` fails static_condition. 1285 """ 1286 for rank in ranks: 1287 assert_type(rank, dtypes.int32) 1288 1289 # Attempt to statically defined rank. 1290 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 1291 if not any(r is None for r in ranks_static): 1292 for rank_static in ranks_static: 1293 if rank_static.ndim != 0: 1294 raise ValueError('Rank must be a scalar.') 1295 1296 x_rank_static = x.get_shape().ndims 1297 if x_rank_static is not None: 1298 if not static_condition(x_rank_static, ranks_static): 1299 raise ValueError( 1300 'Static rank condition failed', x_rank_static, ranks_static) 1301 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1302 1303 condition = dynamic_condition(array_ops.rank(x), ranks) 1304 1305 # Add the condition that `rank` must have rank zero. Prevents the bug where 1306 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1307 for rank, rank_static in zip(ranks, ranks_static): 1308 if rank_static is None: 1309 this_data = ['Rank must be a scalar. Received rank: ', rank] 1310 rank_check = assert_rank(rank, 0, data=this_data) 1311 condition = control_flow_ops.with_dependencies([rank_check], condition) 1312 1313 return control_flow_ops.Assert(condition, data, summarize=summarize) 1314 1315 1316@tf_export('debugging.assert_rank_in', v1=[]) 1317def assert_rank_in_v2(x, ranks, message=None, name=None): 1318 """Assert that `x` has a rank in `ranks`. 1319 1320 This Op checks that the rank of `x` is in `ranks`. 1321 1322 If `x` has a different rank, `message`, as well as the shape of `x` are 1323 printed, and `InvalidArgumentError` is raised. 1324 1325 Args: 1326 x: `Tensor`. 1327 ranks: `Iterable` of scalar `Tensor` objects. 1328 message: A string to prefix to the default message. 1329 name: A name for this operation (optional). Defaults to "assert_rank_in". 1330 1331 Raises: 1332 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot 1333 be statically determined. 1334 ValueError: If static checks determine `x` has mismatched rank. 1335 """ 1336 assert_rank_in(x=x, ranks=ranks, message=message, name=name) 1337 1338 1339@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) 1340@deprecation.deprecated_endpoints('assert_rank_in') 1341def assert_rank_in( 1342 x, ranks, data=None, summarize=None, message=None, name=None): 1343 """Assert `x` has rank in `ranks`. 1344 1345 Example of adding a dependency to an operation: 1346 1347 ```python 1348 with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]): 1349 output = tf.reduce_sum(x) 1350 ``` 1351 1352 Args: 1353 x: Numeric `Tensor`. 1354 ranks: Iterable of scalar `Tensor` objects. 1355 data: The tensors to print out if the condition is False. Defaults to 1356 error message and first few entries of `x`. 1357 summarize: Print this many entries of each tensor. 1358 message: A string to prefix to the default message. 1359 name: A name for this operation (optional). 1360 Defaults to "assert_rank_in". 1361 1362 Returns: 1363 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1364 If static checks determine `x` has matching rank, a `no_op` is returned. 1365 1366 Raises: 1367 ValueError: If static checks determine `x` has mismatched rank. 1368 """ 1369 with ops.name_scope( 1370 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 1371 x = ops.convert_to_tensor(x, name='x') 1372 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 1373 message = message or '' 1374 1375 if context.executing_eagerly(): 1376 name = '' 1377 else: 1378 name = x.name 1379 1380 if data is None: 1381 data = [ 1382 message, 'Tensor %s must have rank in' % name 1383 ] + list(ranks) + [ 1384 'Received shape: ', array_ops.shape(x) 1385 ] 1386 1387 try: 1388 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 1389 _dynamic_rank_in, data, summarize) 1390 1391 except ValueError as e: 1392 if e.args[0] == 'Static rank condition failed': 1393 raise ValueError( 1394 '%s. Tensor %s must have rank in %s. Received rank %d, ' 1395 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1396 else: 1397 raise 1398 1399 return assert_op 1400 1401 1402@tf_export('debugging.assert_integer', v1=[]) 1403def assert_integer_v2(x, message=None, name=None): 1404 """Assert that `x` is of integer dtype. 1405 1406 If `x` has a non-integer type, `message`, as well as the dtype of `x` are 1407 printed, and `InvalidArgumentError` is raised. 1408 1409 Args: 1410 x: A `Tensor`. 1411 message: A string to prefix to the default message. 1412 name: A name for this operation (optional). Defaults to "assert_integer". 1413 1414 Raises: 1415 TypeError: If `x.dtype` is not a non-quantized integer type. 1416 """ 1417 assert_integer(x=x, message=message, name=name) 1418 1419 1420@tf_export(v1=['debugging.assert_integer', 'assert_integer']) 1421@deprecation.deprecated_endpoints('assert_integer') 1422def assert_integer(x, message=None, name=None): 1423 """Assert that `x` is of integer dtype. 1424 1425 Example of adding a dependency to an operation: 1426 1427 ```python 1428 with tf.control_dependencies([tf.assert_integer(x)]): 1429 output = tf.reduce_sum(x) 1430 ``` 1431 1432 Args: 1433 x: `Tensor` whose basetype is integer and is not quantized. 1434 message: A string to prefix to the default message. 1435 name: A name for this operation (optional). Defaults to "assert_integer". 1436 1437 Raises: 1438 TypeError: If `x.dtype` is anything other than non-quantized integer. 1439 1440 Returns: 1441 A `no_op` that does nothing. Type can be determined statically. 1442 """ 1443 message = message or '' 1444 with ops.name_scope(name, 'assert_integer', [x]): 1445 x = ops.convert_to_tensor(x, name='x') 1446 if not x.dtype.is_integer: 1447 if context.executing_eagerly(): 1448 name = 'tensor' 1449 else: 1450 name = x.name 1451 err_msg = ( 1452 '%s Expected "x" to be integer type. Found: %s of dtype %s' 1453 % (message, name, x.dtype)) 1454 raise TypeError(err_msg) 1455 1456 return control_flow_ops.no_op('statically_determined_was_integer') 1457 1458 1459@tf_export('debugging.assert_type', v1=[]) 1460def assert_type_v2(tensor, tf_type, message=None, name=None): 1461 """Asserts that the given `Tensor` is of the specified type. 1462 1463 Args: 1464 tensor: A `Tensor`. 1465 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1466 etc). 1467 message: A string to prefix to the default message. 1468 name: A name for this operation. Defaults to "assert_type" 1469 1470 Raises: 1471 TypeError: If the tensor's data type doesn't match `tf_type`. 1472 """ 1473 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) 1474 1475 1476@tf_export(v1=['debugging.assert_type', 'assert_type']) 1477@deprecation.deprecated_endpoints('assert_type') 1478def assert_type(tensor, tf_type, message=None, name=None): 1479 """Statically asserts that the given `Tensor` is of the specified type. 1480 1481 Args: 1482 tensor: A `Tensor`. 1483 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1484 etc). 1485 message: A string to prefix to the default message. 1486 name: A name to give this `Op`. Defaults to "assert_type" 1487 1488 Raises: 1489 TypeError: If the tensors data type doesn't match `tf_type`. 1490 1491 Returns: 1492 A `no_op` that does nothing. Type can be determined statically. 1493 """ 1494 message = message or '' 1495 with ops.name_scope(name, 'assert_type', [tensor]): 1496 tensor = ops.convert_to_tensor(tensor, name='tensor') 1497 if tensor.dtype != tf_type: 1498 if context.executing_eagerly(): 1499 raise TypeError('%s tensor must be of type %s' % (message, tf_type)) 1500 else: 1501 raise TypeError('%s %s must be of type %s' % (message, tensor.name, 1502 tf_type)) 1503 1504 return control_flow_ops.no_op('statically_determined_correct_type') 1505 1506 1507# pylint: disable=line-too-long 1508def _get_diff_for_monotonic_comparison(x): 1509 """Gets the difference x[1:] - x[:-1].""" 1510 x = array_ops.reshape(x, [-1]) 1511 if not is_numeric_tensor(x): 1512 raise TypeError('Expected x to be numeric, instead found: %s' % x) 1513 1514 # If x has less than 2 elements, there is nothing to compare. So return []. 1515 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 1516 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 1517 1518 # With 2 or more elements, return x[1:] - x[:-1] 1519 s_len = array_ops.shape(x) - 1 1520 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 1521 return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 1522 1523 1524@tf_export( 1525 'debugging.is_numeric_tensor', 1526 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) 1527@deprecation.deprecated_endpoints('is_numeric_tensor') 1528def is_numeric_tensor(tensor): 1529 """Returns `True` if the elements of `tensor` are numbers. 1530 1531 Specifically, returns `True` if the dtype of `tensor` is one of the following: 1532 1533 * `tf.float32` 1534 * `tf.float64` 1535 * `tf.int8` 1536 * `tf.int16` 1537 * `tf.int32` 1538 * `tf.int64` 1539 * `tf.uint8` 1540 * `tf.qint8` 1541 * `tf.qint32` 1542 * `tf.quint8` 1543 * `tf.complex64` 1544 1545 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not 1546 a `tf.Tensor` object. 1547 """ 1548 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 1549 1550 1551@tf_export( 1552 'math.is_non_decreasing', 1553 v1=[ 1554 'math.is_non_decreasing', 'debugging.is_non_decreasing', 1555 'is_non_decreasing' 1556 ]) 1557@deprecation.deprecated_endpoints('debugging.is_non_decreasing', 1558 'is_non_decreasing') 1559def is_non_decreasing(x, name=None): 1560 """Returns `True` if `x` is non-decreasing. 1561 1562 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 1563 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 1564 If `x` has less than two elements, it is trivially non-decreasing. 1565 1566 See also: `is_strictly_increasing` 1567 1568 Args: 1569 x: Numeric `Tensor`. 1570 name: A name for this operation (optional). Defaults to "is_non_decreasing" 1571 1572 Returns: 1573 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 1574 1575 Raises: 1576 TypeError: if `x` is not a numeric tensor. 1577 """ 1578 with ops.name_scope(name, 'is_non_decreasing', [x]): 1579 diff = _get_diff_for_monotonic_comparison(x) 1580 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 1581 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 1582 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 1583 1584 1585@tf_export( 1586 'math.is_strictly_increasing', 1587 v1=[ 1588 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 1589 'is_strictly_increasing' 1590 ]) 1591@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 1592 'is_strictly_increasing') 1593def is_strictly_increasing(x, name=None): 1594 """Returns `True` if `x` is strictly increasing. 1595 1596 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 1597 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 1598 If `x` has less than two elements, it is trivially strictly increasing. 1599 1600 See also: `is_non_decreasing` 1601 1602 Args: 1603 x: Numeric `Tensor`. 1604 name: A name for this operation (optional). 1605 Defaults to "is_strictly_increasing" 1606 1607 Returns: 1608 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 1609 1610 Raises: 1611 TypeError: if `x` is not a numeric tensor. 1612 """ 1613 with ops.name_scope(name, 'is_strictly_increasing', [x]): 1614 diff = _get_diff_for_monotonic_comparison(x) 1615 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 1616 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 1617 return math_ops.reduce_all(math_ops.less(zero, diff)) 1618 1619 1620def _assert_same_base_type(items, expected_type=None): 1621 r"""Asserts all items are of the same base type. 1622 1623 Args: 1624 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 1625 `Operation`, or `IndexedSlices`). Can include `None` elements, which 1626 will be ignored. 1627 expected_type: Expected type. If not specified, assert all items are 1628 of the same base type. 1629 1630 Returns: 1631 Validated type, or none if neither expected_type nor items provided. 1632 1633 Raises: 1634 ValueError: If any types do not match. 1635 """ 1636 original_expected_type = expected_type 1637 mismatch = False 1638 for item in items: 1639 if item is not None: 1640 item_type = item.dtype.base_dtype 1641 if not expected_type: 1642 expected_type = item_type 1643 elif expected_type != item_type: 1644 mismatch = True 1645 break 1646 if mismatch: 1647 # Loop back through and build up an informative error message (this is very 1648 # slow, so we don't do it unless we found an error above). 1649 expected_type = original_expected_type 1650 original_item_str = None 1651 for item in items: 1652 if item is not None: 1653 item_type = item.dtype.base_dtype 1654 if not expected_type: 1655 expected_type = item_type 1656 original_item_str = item.name if hasattr(item, 'name') else str(item) 1657 elif expected_type != item_type: 1658 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 1659 item.name if hasattr(item, 'name') else str(item), 1660 item_type, expected_type, 1661 (' as %s' % original_item_str) if original_item_str else '')) 1662 return expected_type # Should be unreachable 1663 else: 1664 return expected_type 1665 1666 1667@tf_export( 1668 'debugging.assert_same_float_dtype', 1669 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) 1670@deprecation.deprecated_endpoints('assert_same_float_dtype') 1671def assert_same_float_dtype(tensors=None, dtype=None): 1672 """Validate and return float type based on `tensors` and `dtype`. 1673 1674 For ops such as matrix multiplication, inputs and weights must be of the 1675 same float type. This function validates that all `tensors` are the same type, 1676 validates that type is `dtype` (if supplied), and returns the type. Type must 1677 be a floating point type. If neither `tensors` nor `dtype` is supplied, 1678 the function will return `dtypes.float32`. 1679 1680 Args: 1681 tensors: Tensors of input values. Can include `None` elements, which will be 1682 ignored. 1683 dtype: Expected type. 1684 1685 Returns: 1686 Validated type. 1687 1688 Raises: 1689 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 1690 float, or the common type of the inputs is not a floating point type. 1691 """ 1692 if tensors: 1693 dtype = _assert_same_base_type(tensors, dtype) 1694 if not dtype: 1695 dtype = dtypes.float32 1696 elif not dtype.is_floating: 1697 raise ValueError('Expected floating point type, got %s.' % dtype) 1698 return dtype 1699 1700 1701@tf_export('debugging.assert_scalar', v1=[]) 1702def assert_scalar_v2(tensor, message=None, name=None): 1703 """Asserts that the given `tensor` is a scalar. 1704 1705 This function raises `ValueError` unless it can be certain that the given 1706 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 1707 unknown. 1708 1709 Args: 1710 tensor: A `Tensor`. 1711 message: A string to prefix to the default message. 1712 name: A name for this operation. Defaults to "assert_scalar" 1713 1714 Raises: 1715 ValueError: If the tensor is not scalar (rank 0), or if its shape is 1716 unknown. 1717 """ 1718 assert_scalar(tensor=tensor, message=message, name=name) 1719 1720 1721@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) 1722@deprecation.deprecated_endpoints('assert_scalar') 1723def assert_scalar(tensor, name=None, message=None): 1724 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). 1725 1726 This function raises `ValueError` unless it can be certain that the given 1727 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 1728 unknown. 1729 1730 Args: 1731 tensor: A `Tensor`. 1732 name: A name for this operation. Defaults to "assert_scalar" 1733 message: A string to prefix to the default message. 1734 1735 Returns: 1736 The input tensor (potentially converted to a `Tensor`). 1737 1738 Raises: 1739 ValueError: If the tensor is not scalar (rank 0), or if its shape is 1740 unknown. 1741 """ 1742 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 1743 tensor = ops.convert_to_tensor(tensor, name=name_scope) 1744 shape = tensor.get_shape() 1745 if shape.ndims != 0: 1746 if context.executing_eagerly(): 1747 raise ValueError('%sExpected scalar shape, saw shape: %s.' 1748 % (message or '', shape,)) 1749 else: 1750 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' 1751 % (message or '', tensor.name, shape)) 1752 return tensor 1753 1754 1755@tf_export('ensure_shape') 1756def ensure_shape(x, shape, name=None): 1757 """Updates the shape of a tensor and checks at runtime that the shape holds. 1758 1759 For example: 1760 ```python 1761 x = tf.placeholder(tf.int32) 1762 print(x.shape) 1763 ==> TensorShape(None) 1764 y = x * 2 1765 print(y.shape) 1766 ==> TensorShape(None) 1767 1768 y = tf.ensure_shape(y, (None, 3, 3)) 1769 print(y.shape) 1770 ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)]) 1771 1772 with tf.Session() as sess: 1773 # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not 1774 # compatible with the shape (None, 3, 3) 1775 sess.run(y, feed_dict={x: [1, 2, 3]}) 1776 1777 ``` 1778 1779 NOTE: This differs from `Tensor.set_shape` in that it sets the static shape 1780 of the resulting tensor and enforces it at runtime, raising an error if the 1781 tensor's runtime shape is incompatible with the specified shape. 1782 `Tensor.set_shape` sets the static shape of the tensor without enforcing it 1783 at runtime, which may result in inconsistencies between the statically-known 1784 shape of tensors and the runtime value of tensors. 1785 1786 Args: 1787 x: A `Tensor`. 1788 shape: A `TensorShape` representing the shape of this tensor, a 1789 `TensorShapeProto`, a list, a tuple, or None. 1790 name: A name for this operation (optional). Defaults to "EnsureShape". 1791 1792 Returns: 1793 A `Tensor`. Has the same type and contents as `x`. At runtime, raises a 1794 `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape 1795 of `x`. 1796 """ 1797 if not isinstance(shape, tensor_shape.TensorShape): 1798 shape = tensor_shape.TensorShape(shape) 1799 1800 return array_ops.ensure_shape(x, shape, name=name) 1801 1802 1803@ops.RegisterGradient('EnsureShape') 1804def _ensure_shape_grad(op, grad): 1805 del op # Unused. 1806 return grad 1807