1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow.""" 18import collections 19from collections import OrderedDict 20import contextlib 21import functools 22import gc 23import itertools 24import math 25import os 26import random 27import re 28import tempfile 29import threading 30import time 31import unittest 32 33from absl.testing import parameterized 34import numpy as np 35 36from google.protobuf import descriptor_pool 37from google.protobuf import text_format 38 39from tensorflow.core.config import flags 40from tensorflow.core.framework import graph_pb2 41from tensorflow.core.protobuf import rewriter_config_pb2 42from tensorflow.python import pywrap_sanitizers 43from tensorflow.python import tf2 44from tensorflow.python.client import device_lib 45from tensorflow.python.client import pywrap_tf_session 46from tensorflow.python.client import session 47from tensorflow.python.compat.compat import forward_compatibility_horizon 48from tensorflow.python.eager import backprop 49from tensorflow.python.eager import context 50from tensorflow.python.eager import def_function 51from tensorflow.python.eager import tape 52from tensorflow.python.framework import _test_metrics_util 53from tensorflow.python.framework import config 54from tensorflow.python.framework import device as pydev 55from tensorflow.python.framework import dtypes 56from tensorflow.python.framework import errors 57from tensorflow.python.framework import errors_impl 58from tensorflow.python.framework import gpu_util 59from tensorflow.python.framework import importer 60from tensorflow.python.framework import indexed_slices 61from tensorflow.python.framework import ops 62from tensorflow.python.framework import random_seed 63from tensorflow.python.framework import sparse_tensor 64from tensorflow.python.framework import tensor_shape 65from tensorflow.python.framework import tensor_util 66from tensorflow.python.framework import tfrt_utils 67from tensorflow.python.framework import versions 68from tensorflow.python.ops import array_ops 69from tensorflow.python.ops import control_flow_util 70from tensorflow.python.ops import control_flow_util_v2 71from tensorflow.python.ops import gradients_impl 72from tensorflow.python.ops import math_ops 73from tensorflow.python.ops import script_ops 74from tensorflow.python.ops import summary_ops_v2 75from tensorflow.python.ops import variables 76from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import 77from tensorflow.python.ops.ragged import ragged_tensor 78from tensorflow.python.ops.ragged import ragged_tensor_value 79from tensorflow.python.platform import _pywrap_stacktrace_handler 80from tensorflow.python.platform import googletest 81from tensorflow.python.platform import tf_logging as logging 82from tensorflow.python.training import server_lib 83from tensorflow.python.util import _pywrap_util_port 84from tensorflow.python.util import compat 85from tensorflow.python.util import deprecation 86from tensorflow.python.util import nest 87from tensorflow.python.util import tf_decorator 88from tensorflow.python.util import tf_inspect 89from tensorflow.python.util import traceback_utils 90from tensorflow.python.util.compat import collections_abc 91from tensorflow.python.util.protobuf import compare 92from tensorflow.python.util.tf_export import tf_export 93 94 95# If the below import is made available through the BUILD rule, then this 96# function is overridden and will instead return True and cause Tensorflow 97# graphs to be compiled with XLA. 98def is_xla_enabled(): 99 return False 100 101 102try: 103 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import 104except Exception: # pylint: disable=broad-except 105 pass 106 107 108# Uses the same mechanism as above to selectively enable/disable MLIR 109# compilation. 110def is_mlir_bridge_enabled(): 111 return None 112 113 114try: 115 from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 116except ImportError: 117 try: 118 from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 119 except ImportError: 120 pass 121 122 123def is_asan_enabled(): 124 """Check if ASAN is enabled.""" 125 return pywrap_sanitizers.is_asan_enabled() 126 127 128def is_msan_enabled(): 129 """Check if MSAN is enabled.""" 130 return pywrap_sanitizers.is_msan_enabled() 131 132 133def is_tsan_enabled(): 134 """Check if TSAN is enabled.""" 135 return pywrap_sanitizers.is_tsan_enabled() 136 137 138def is_ubsan_enabled(): 139 """Check if UBSAN is enabled.""" 140 return pywrap_sanitizers.is_ubsan_enabled() 141 142 143def _get_object_count_by_type(exclude=()): 144 return ( 145 collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - 146 collections.Counter([type(obj).__name__ for obj in exclude])) 147 148 149@tf_export("test.gpu_device_name") 150def gpu_device_name(): 151 """Returns the name of a GPU device if available or a empty string. 152 153 This method should only be used in tests written with `tf.test.TestCase`. 154 155 >>> class MyTest(tf.test.TestCase): 156 ... 157 ... def test_add_on_gpu(self): 158 ... if not tf.test.is_built_with_gpu_support(): 159 ... self.skipTest("test is only applicable on GPU") 160 ... 161 ... with tf.device(tf.test.gpu_device_name()): 162 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 163 164 """ 165 for x in device_lib.list_local_devices(): 166 if x.device_type == "GPU": 167 return compat.as_str(x.name) 168 return "" 169 170 171def assert_ops_in_graph(expected_ops, graph): 172 """Assert all expected operations are found. 173 174 Args: 175 expected_ops: `dict<string, string>` of op name to op type. 176 graph: Graph to check. 177 178 Returns: 179 `dict<string, node>` of node name to node. 180 181 Raises: 182 ValueError: If the expected ops are not present in the graph. 183 """ 184 actual_ops = {} 185 gd = graph.as_graph_def() 186 for node in gd.node: 187 if node.name in expected_ops: 188 if expected_ops[node.name] != node.op: 189 raise ValueError("Expected op for node %s is different. %s vs %s" % 190 (node.name, expected_ops[node.name], node.op)) 191 actual_ops[node.name] = node 192 if set(expected_ops.keys()) != set(actual_ops.keys()): 193 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 194 (expected_ops.keys(), actual_ops.keys())) 195 return actual_ops 196 197 198@tf_export("test.assert_equal_graph_def", v1=[]) 199def assert_equal_graph_def_v2(expected, actual): 200 """Asserts that two `GraphDef`s are (mostly) the same. 201 202 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 203 nodes, attrs, and control inputs. Node names are used to match up nodes 204 between the graphs, so the naming of nodes must be consistent. This function 205 ignores randomized attribute values that may appear in V2 checkpoints. 206 207 Args: 208 expected: The `GraphDef` we expected. 209 actual: The `GraphDef` we have. 210 211 Raises: 212 AssertionError: If the `GraphDef`s do not match. 213 TypeError: If either argument is not a `GraphDef`. 214 """ 215 assert_equal_graph_def(actual, expected, checkpoint_v2=True, 216 hash_table_shared_name=True) 217 218 219@tf_export(v1=["test.assert_equal_graph_def"]) 220def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, 221 hash_table_shared_name=False): 222 """Asserts that two `GraphDef`s are (mostly) the same. 223 224 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 225 nodes, attrs, and control inputs. Node names are used to match up nodes 226 between the graphs, so the naming of nodes must be consistent. 227 228 Args: 229 actual: The `GraphDef` we have. 230 expected: The `GraphDef` we expected. 231 checkpoint_v2: boolean determining whether to ignore randomized attribute 232 values that appear in V2 checkpoints. 233 hash_table_shared_name: boolean determining whether to ignore randomized 234 shared_names that appear in HashTableV2 op defs. 235 236 Raises: 237 AssertionError: If the `GraphDef`s do not match. 238 TypeError: If either argument is not a `GraphDef`. 239 """ 240 assert_equal_graph_def(actual, expected, checkpoint_v2, 241 hash_table_shared_name) 242 243 244def assert_equal_graph_def(actual, expected, checkpoint_v2=False, 245 hash_table_shared_name=False): 246 if not isinstance(actual, graph_pb2.GraphDef): 247 raise TypeError("Expected tf.GraphDef for actual, got %s" % 248 type(actual).__name__) 249 if not isinstance(expected, graph_pb2.GraphDef): 250 raise TypeError("Expected tf.GraphDef for expected, got %s" % 251 type(expected).__name__) 252 253 if checkpoint_v2: 254 _strip_checkpoint_v2_randomized(actual) 255 _strip_checkpoint_v2_randomized(expected) 256 257 if hash_table_shared_name: 258 _strip_hash_table_shared_name(actual) 259 _strip_hash_table_shared_name(expected) 260 261 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(), 262 expected.SerializeToString()) 263 if diff: 264 raise AssertionError(compat.as_str(diff)) 265 266 267def assert_meta_graph_protos_equal(tester, a, b): 268 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 269 # Carefully check the collection_defs 270 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 271 collection_keys = a.collection_def.keys() 272 for k in collection_keys: 273 a_value = a.collection_def[k] 274 b_value = b.collection_def[k] 275 proto_type = ops.get_collection_proto_type(k) 276 if proto_type: 277 a_proto = proto_type() 278 b_proto = proto_type() 279 # Number of entries in the collections is the same 280 tester.assertEqual( 281 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 282 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 283 b_value.bytes_list.value): 284 a_proto.ParseFromString(a_value_item) 285 b_proto.ParseFromString(b_value_item) 286 tester.assertProtoEquals(a_proto, b_proto) 287 else: 288 tester.assertEquals(a_value, b_value) 289 # Compared the fields directly, remove their raw values from the 290 # proto comparison below. 291 a.ClearField("collection_def") 292 b.ClearField("collection_def") 293 294 # Check the graph_defs. 295 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 296 # Check graph_def versions (ignored by assert_equal_graph_def). 297 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 298 # Compared the fields directly, remove their raw values from the 299 # proto comparison below. 300 a.ClearField("graph_def") 301 b.ClearField("graph_def") 302 303 tester.assertProtoEquals(a, b) 304 305 306# Matches attributes named via _SHARDED_SUFFIX in 307# tensorflow/python/training/saver.py 308_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 309 310 311def _strip_checkpoint_v2_randomized(graph_def): 312 for node in graph_def.node: 313 delete_keys = [] 314 for attr_key in node.attr: 315 attr_tensor_value = node.attr[attr_key].tensor 316 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 317 attr_tensor_string_value = attr_tensor_value.string_val[0] 318 if (attr_tensor_string_value and 319 re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN), 320 attr_tensor_string_value)): 321 delete_keys.append(attr_key) 322 for attr_key in delete_keys: 323 del node.attr[attr_key] 324 325 326_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" 327 328 329def _strip_hash_table_shared_name(graph_def): 330 for node in graph_def.node: 331 delete_keys = [] 332 if node.op == "HashTableV2" and "shared_name" in node.attr: 333 if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), 334 node.attr["shared_name"].s): 335 delete_keys.append("shared_name") 336 for attr_key in delete_keys: 337 del node.attr[attr_key] 338 339 340def IsGoogleCudaEnabled(): 341 return _pywrap_util_port.IsGoogleCudaEnabled() 342 343 344def IsBuiltWithROCm(): 345 return _pywrap_util_port.IsBuiltWithROCm() 346 347 348def IsBuiltWithXLA(): 349 return _pywrap_util_port.IsBuiltWithXLA() 350 351 352def IsBuiltWithNvcc(): 353 return _pywrap_util_port.IsBuiltWithNvcc() 354 355 356def GpuSupportsHalfMatMulAndConv(): 357 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() 358 359 360def IsMklEnabled(): 361 return (_pywrap_util_port.IsMklEnabled() or 362 os.getenv("TF_ENABLE_ONEDNN_OPTS", "False").lower() in ["true", "1"]) 363 364 365def InstallStackTraceHandler(): 366 _pywrap_stacktrace_handler.InstallStacktraceHandler() 367 368 369def NHWCToNCHW(input_tensor): 370 """Converts the input from the NHWC format to NCHW. 371 372 Args: 373 input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape 374 375 Returns: 376 converted tensor or shape array 377 """ 378 # tensor dim -> new axis order 379 new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 380 if isinstance(input_tensor, ops.Tensor): 381 ndims = input_tensor.shape.ndims 382 return array_ops.transpose(input_tensor, new_axes[ndims]) 383 else: 384 ndims = len(input_tensor) 385 return [input_tensor[a] for a in new_axes[ndims]] 386 387 388def NHWCToNCHW_VECT_C(input_shape_or_tensor): 389 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 390 391 Note: Does not include quantization or type conversion steps, which should 392 be applied afterwards. 393 394 Args: 395 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 396 397 Returns: 398 tensor or shape array transformed into NCHW_VECT_C 399 400 Raises: 401 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 402 divisible by 4. 403 """ 404 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 405 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 406 temp_shape = ( 407 input_shape_or_tensor.shape.as_list() 408 if is_tensor else input_shape_or_tensor) 409 if temp_shape[-1] % 4 != 0: 410 raise ValueError( 411 "Last dimension of input must be evenly divisible by 4 to convert to " 412 "NCHW_VECT_C.") 413 temp_shape[-1] //= 4 414 temp_shape.append(4) 415 permutation = permutations[len(temp_shape)] 416 if is_tensor: 417 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 418 return array_ops.transpose(t, permutation) 419 else: 420 return [temp_shape[a] for a in permutation] 421 422 423def NCHW_VECT_CToNHWC(input_shape_or_tensor): 424 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 425 426 Note: Does not include de-quantization or type conversion steps, which should 427 be applied beforehand. 428 429 Args: 430 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 431 432 Returns: 433 tensor or shape array transformed into NHWC 434 435 Raises: 436 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 437 """ 438 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 439 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 440 input_shape = ( 441 input_shape_or_tensor.shape.as_list() 442 if is_tensor else input_shape_or_tensor) 443 if input_shape[-1] != 4: 444 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 445 permutation = permutations[len(input_shape)] 446 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 447 nhwc_shape[-1] *= input_shape[-1] 448 if is_tensor: 449 t = array_ops.transpose(input_shape_or_tensor, permutation) 450 return array_ops.reshape(t, nhwc_shape) 451 else: 452 return nhwc_shape 453 454 455def NCHWToNHWC(input_tensor): 456 """Converts the input from the NCHW format to NHWC. 457 458 Args: 459 input_tensor: a 4- or 5-D tensor, or an array representing shape 460 461 Returns: 462 converted tensor or shape array 463 """ 464 # tensor dim -> new axis order 465 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 466 if isinstance(input_tensor, ops.Tensor): 467 ndims = input_tensor.shape.ndims 468 return array_ops.transpose(input_tensor, new_axes[ndims]) 469 else: 470 ndims = len(input_tensor) 471 return [input_tensor[a] for a in new_axes[ndims]] 472 473 474def skip_if(condition): 475 """Skips the decorated function if condition is or evaluates to True. 476 477 Args: 478 condition: Either an expression that can be used in "if not condition" 479 statement, or a callable whose result should be a boolean. 480 481 Returns: 482 The wrapped function 483 """ 484 485 def real_skip_if(fn): 486 487 def wrapper(*args, **kwargs): 488 if callable(condition): 489 skip = condition() 490 else: 491 skip = condition 492 if not skip: 493 return fn(*args, **kwargs) 494 495 return wrapper 496 497 return real_skip_if 498 499 500@contextlib.contextmanager 501def skip_if_error(test_obj, error_type, messages=None): 502 """Context manager to skip cases not considered failures by the tests. 503 504 Note that this does not work if used in setUpClass/tearDownClass. 505 Usage in setUp/tearDown works fine just like regular test methods. 506 507 Args: 508 test_obj: A test object provided as `self` in the test methods; this object 509 is usually an instance of `unittest.TestCase`'s subclass and should have 510 `skipTest` method. 511 error_type: The error type to skip. Note that if `messages` are given, both 512 `error_type` and `messages` need to match for the test to be skipped. 513 messages: Optional, a string or list of strings. If `None`, the test will be 514 skipped if `error_type` matches what is raised; otherwise, the test is 515 skipped if any of the `messages` is contained in the message of the error 516 raised, and `error_type` matches the error raised. 517 518 Yields: 519 Nothing. 520 """ 521 if messages: 522 messages = nest.flatten(messages) 523 try: 524 yield 525 except error_type as e: 526 if not messages or any(message in str(e) for message in messages): 527 test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e))) 528 else: 529 raise 530 531 532def enable_c_shapes(fn): 533 """No-op. TODO(b/74620627): Remove this.""" 534 return fn 535 536 537def with_c_shapes(cls): 538 """No-op. TODO(b/74620627): Remove this.""" 539 return cls 540 541 542def enable_control_flow_v2(fn): 543 """Decorator for enabling CondV2 and WhileV2 on a test. 544 545 Note this enables using CondV2 and WhileV2 after running the test class's 546 setup/teardown methods. 547 548 In addition to this, callers must import the while_v2 module in order to set 549 the _while_v2 module in control_flow_ops. 550 551 Args: 552 fn: the function to be wrapped 553 554 Returns: 555 The wrapped function 556 """ 557 558 def wrapper(*args, **kwargs): 559 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 560 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 561 try: 562 return fn(*args, **kwargs) 563 finally: 564 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 565 566 return wrapper 567 568 569def with_control_flow_v2(cls): 570 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 571 572 Note this enables CondV2 and WhileV2 in new methods after running the test 573 class's setup method. 574 575 In addition to this, callers must import the while_v2 module in order to set 576 the _while_v2 module in control_flow_ops. 577 578 If a test function has _disable_control_flow_v2 attr set to True (using the 579 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 580 581 Example: 582 583 @test_util.with_control_flow_v2 584 class ControlFlowTest(test.TestCase): 585 586 def testEnabledForV2(self): 587 ... 588 589 @test_util.disable_control_flow_v2("b/xyzabc") 590 def testDisabledForV2(self): 591 ... 592 593 Generated class: 594 class ControlFlowTest(test.TestCase): 595 596 def testEnabledForV2(self): 597 ... 598 599 def testEnabledForV2WithControlFlowV2(self): 600 // Enable V2 flags. 601 testEnabledForV2(self) 602 // Restore V2 flags. 603 604 def testDisabledForV2(self): 605 ... 606 607 Args: 608 cls: class to decorate 609 610 Returns: 611 cls with new test methods added 612 """ 613 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 614 return cls 615 616 for name, value in cls.__dict__.copy().items(): 617 if (callable(value) and 618 name.startswith(unittest.TestLoader.testMethodPrefix) and 619 not getattr(value, "_disable_control_flow_v2", False)): 620 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 621 return cls 622 623 624def disable_control_flow_v2(unused_msg): 625 """Decorator for a function in a with_control_flow_v2 enabled test class. 626 627 Blocks the function from being run with v2 control flow ops. 628 629 Args: 630 unused_msg: Reason for disabling. 631 632 Returns: 633 The wrapped function with _disable_control_flow_v2 attr set to True. 634 """ 635 636 def wrapper(func): 637 func._disable_control_flow_v2 = True 638 return func 639 640 return wrapper 641 642 643def enable_output_all_intermediates(fn): 644 """Force-enable outputing all intermediates from functional control flow ops. 645 646 Args: 647 fn: the function to be wrapped 648 649 Returns: 650 The wrapped function 651 """ 652 653 def wrapper(*args, **kwargs): 654 output_all_intermediates_old = \ 655 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 656 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True 657 try: 658 return fn(*args, **kwargs) 659 finally: 660 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ 661 output_all_intermediates_old 662 663 return wrapper 664 665 666def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): 667 """Decorator for asserting that no new Python objects persist after a test. 668 669 Runs the test multiple times executing eagerly, first as a warmup and then to 670 let objects accumulate. The warmup helps ignore caches which do not grow as 671 the test is run repeatedly. 672 673 Useful for checking that there are no missing Py_DECREFs in the C exercised by 674 a bit of Python. 675 676 Args: 677 func: The function to test. 678 warmup_iters: The numer of warmup iterations, excluded from measuring. 679 680 Returns: 681 The wrapped function performing the test. 682 """ 683 684 def wrap_f(f): 685 def decorator(self, *args, **kwargs): 686 """Warms up, gets object counts, runs the test, checks for new objects.""" 687 with context.eager_mode(): 688 gc.disable() 689 # Run the test 2 times as warmup, in an attempt to fill up caches, which 690 # should not grow as the test is run repeatedly below. 691 # 692 # TODO(b/117156879): Running warmup twice is black magic; we have seen 693 # tests that fail with 1 warmup run, and pass with 2, on various 694 # versions of python2.7.x. 695 for _ in range(warmup_iters): 696 f(self, *args, **kwargs) 697 # Since we aren't in the normal test lifecycle, we need to manually run 698 # cleanups to clear out their object references. 699 self.doCleanups() 700 701 # Some objects are newly created by _get_object_count_by_type(). So 702 # create and save as a dummy variable to include it as a baseline. 703 obj_count_by_type = _get_object_count_by_type() 704 gc.collect() 705 706 # Make sure any registered functions are cleaned up in the C++ runtime. 707 registered_function_names = context.context().list_function_names() 708 709 # unittest.doCleanups adds to self._outcome with each unwound call. 710 # These objects are retained across gc collections so we exclude them 711 # from the object count calculation. 712 obj_count_by_type = _get_object_count_by_type( 713 exclude=gc.get_referents(self._outcome.errors, 714 self._outcome.skipped)) 715 716 if ops.has_default_graph(): 717 collection_sizes_before = { 718 collection: len(ops.get_collection(collection)) 719 for collection in ops.get_default_graph().collections 720 } 721 for _ in range(3): 722 f(self, *args, **kwargs) 723 # Since we aren't in the normal test lifecycle, we need to manually run 724 # cleanups to clear out their object references. 725 self.doCleanups() 726 # Note that gc.get_objects misses anything that isn't subject to garbage 727 # collection (C types). Collections are a common source of leaks, so we 728 # test for collection sizes explicitly. 729 if ops.has_default_graph(): 730 for collection_key in ops.get_default_graph().collections: 731 collection = ops.get_collection(collection_key) 732 size_before = collection_sizes_before.get(collection_key, 0) 733 if len(collection) > size_before: 734 raise AssertionError( 735 ("Collection %s increased in size from " 736 "%d to %d (current items %s).") % 737 (collection_key, size_before, len(collection), collection)) 738 # Make sure our collection checks don't show up as leaked memory by 739 # removing references to temporary variables. 740 del collection 741 del collection_key 742 del size_before 743 del collection_sizes_before 744 gc.collect() 745 746 # There should be no new Python objects hanging around. 747 obj_count_by_type = ( 748 _get_object_count_by_type( 749 exclude=gc.get_referents(self._outcome.errors, 750 self._outcome.skipped)) - 751 obj_count_by_type) 752 753 # There should be no newly registered functions hanging around. 754 leftover_functions = ( 755 context.context().list_function_names() - registered_function_names) 756 assert not leftover_functions, ( 757 "The following functions were newly created: %s" % 758 leftover_functions) 759 760 # In some cases (specifically on MacOS), new_count is somehow 761 # smaller than previous_count. 762 # Using plain assert because not all classes using this decorator 763 # have assertLessEqual 764 assert not obj_count_by_type, ( 765 "The following objects were newly created: %s" % 766 str(obj_count_by_type)) 767 gc.enable() 768 return decorator 769 770 if func is None: 771 return wrap_f 772 else: 773 return wrap_f(func) 774 775 776def assert_no_new_tensors(f): 777 """Decorator for asserting that no new Tensors persist after a test. 778 779 Mainly useful for checking that code using the Python C API has correctly 780 manipulated reference counts. 781 782 Clears the caches that it knows about, runs the garbage collector, then checks 783 that there are no Tensor or Tensor-like objects still around. This includes 784 Tensors to which something still has a reference (e.g. from missing 785 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 786 of the objects has __del__ defined). 787 788 Args: 789 f: The test case to run. 790 791 Returns: 792 The decorated test case. 793 """ 794 795 def decorator(self, **kwargs): 796 """Finds existing Tensors, runs the test, checks for new Tensors.""" 797 798 def _is_tensorflow_object(obj): 799 try: 800 return isinstance(obj, 801 (ops.Tensor, variables.Variable, 802 tensor_shape.Dimension, tensor_shape.TensorShape)) 803 except (ReferenceError, AttributeError): 804 # If the object no longer exists, we don't care about it. 805 return False 806 807 tensors_before = set( 808 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 809 outside_executed_eagerly = context.executing_eagerly() 810 # Run the test in a new graph so that collections get cleared when it's 811 # done, but inherit the graph key so optimizers behave. 812 outside_graph_key = ops.get_default_graph()._graph_key 813 with ops.Graph().as_default(): 814 ops.get_default_graph()._graph_key = outside_graph_key 815 if outside_executed_eagerly: 816 with context.eager_mode(): 817 result = f(self, **kwargs) 818 else: 819 result = f(self, **kwargs) 820 # Make an effort to clear caches, which would otherwise look like leaked 821 # Tensors. 822 context.context()._clear_caches() # pylint: disable=protected-access 823 gc.collect() 824 tensors_after = [ 825 obj for obj in gc.get_objects() 826 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 827 ] 828 if tensors_after: 829 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 830 len(tensors_after), 831 str(tensors_after), 832 ))) 833 return result 834 835 return decorator 836 837 838def _find_reference_cycle(objects, idx): 839 840 def get_ignore_reason(obj, denylist): 841 """Tests whether an object should be omitted from the dependency graph.""" 842 if len(denylist) > 100: 843 return "<depth limit>" 844 if tf_inspect.isframe(obj): 845 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 846 return "<test code>" 847 for b in denylist: 848 if b is obj: 849 return "<test code>" 850 if obj is denylist: 851 return "<test code>" 852 return None 853 854 # Note: this function is meant to help with diagnostics. Its output is purely 855 # a human-readable representation, so you may freely modify it to suit your 856 # needs. 857 def describe(obj, denylist, leaves_only=False): 858 """Returns a custom human-readable summary of obj. 859 860 Args: 861 obj: the value to describe. 862 denylist: same as denylist in get_ignore_reason. 863 leaves_only: boolean flag used when calling describe recursively. Useful 864 for summarizing collections. 865 """ 866 if get_ignore_reason(obj, denylist): 867 return "{}{}".format(get_ignore_reason(obj, denylist), type(obj)) 868 if tf_inspect.isframe(obj): 869 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 870 elif tf_inspect.ismodule(obj): 871 return "module: {}".format(obj.__name__) 872 else: 873 if leaves_only: 874 return "{}, {}".format(type(obj), id(obj)) 875 elif isinstance(obj, list): 876 return "list({}): {}".format( 877 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 878 elif isinstance(obj, tuple): 879 return "tuple({}): {}".format( 880 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 881 elif isinstance(obj, dict): 882 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 883 elif tf_inspect.isfunction(obj): 884 return "function({}) {}; globals ID: {}".format( 885 id(obj), obj.__name__, id(obj.__globals__)) 886 else: 887 return "{}, {}".format(type(obj), id(obj)) 888 889 def build_ref_graph(obj, graph, reprs, denylist): 890 """Builds a reference graph as <referrer> -> <list of referents>. 891 892 Args: 893 obj: The object to start from. The graph will be built by recursively 894 adding its referrers. 895 graph: Dict holding the graph to be built. To avoid creating extra 896 references, the graph holds object IDs rather than actual objects. 897 reprs: Auxiliary structure that maps object IDs to their human-readable 898 description. 899 denylist: List of objects to ignore. 900 """ 901 referrers = gc.get_referrers(obj) 902 denylist = denylist + (referrers,) 903 904 obj_id = id(obj) 905 for r in referrers: 906 if get_ignore_reason(r, denylist) is None: 907 r_id = id(r) 908 if r_id not in graph: 909 graph[r_id] = [] 910 if obj_id not in graph[r_id]: 911 graph[r_id].append(obj_id) 912 build_ref_graph(r, graph, reprs, denylist) 913 reprs[r_id] = describe(r, denylist) 914 915 def find_cycle(el, graph, reprs, path): 916 """Finds and prints a single cycle in the dependency graph.""" 917 if el not in graph: 918 return 919 for r in graph[el]: 920 if r in path: 921 logging.error("Reference cycle sample:") 922 for p in path + (r,): 923 logging.error(reprs.get(p, "unknown object " + str(p))) 924 return True 925 else: 926 if find_cycle(r, graph, reprs, path + (r,)): 927 return True 928 return False 929 930 obj = objects[idx] 931 graph = {} # referrer ID -> object ID 932 reprs = {} # object ID -> description 933 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 934 describe, build_ref_graph, find_cycle)) 935 for k in graph: 936 if find_cycle(k, graph, reprs, ()): 937 return True 938 return False 939 940 941def assert_no_garbage_created(f): 942 """Test method decorator to assert that no garbage has been created. 943 944 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 945 cannot be un-set (i.e. will disable garbage collection for any other unit 946 tests in the same file/shard). 947 948 Args: 949 f: The function to decorate. 950 951 Returns: 952 The decorated function. 953 """ 954 955 def decorator(self, **kwargs): 956 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 957 # Force-load `distribution_strategy_context` to prevent GC at 958 # test time when using eager. Remove once b/117329403 is resolved. 959 tape.distribution_strategy_context.get_strategy() 960 961 gc.disable() 962 previous_debug_flags = gc.get_debug() 963 gc.set_debug(gc.DEBUG_SAVEALL) 964 gc.collect() 965 previous_garbage = len(gc.garbage) 966 result = f(self, **kwargs) 967 gc.collect() 968 new_garbage = len(gc.garbage) 969 if new_garbage > previous_garbage: 970 971 for i, obj in enumerate(gc.garbage[previous_garbage:]): 972 # Known false positive for ast.fix_missing_locations. 973 if getattr(obj, "__module__", "") == "ast": 974 new_garbage -= 3 975 976 if new_garbage > previous_garbage: 977 logging.error( 978 "The decorated test created work for Python's garbage collector, " 979 "likely due to a reference cycle. New objects in cycle(s):") 980 for i, obj in enumerate(gc.garbage[previous_garbage:]): 981 try: 982 logging.error("Object %d of %d", i, 983 len(gc.garbage) - previous_garbage) 984 985 def _safe_object_str(obj): 986 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 987 988 logging.error(" Object type: %s", _safe_object_str(obj)) 989 logging.error( 990 " Referrer types: %s", ", ".join( 991 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 992 logging.error( 993 " Referent types: %s", ", ".join( 994 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 995 logging.error(" Object attribute names: %s", dir(obj)) 996 logging.error(" Object __str__:") 997 logging.error(obj) 998 logging.error(" Object __repr__:") 999 logging.error(repr(obj)) 1000 except Exception: # pylint: disable=broad-except 1001 logging.error("(Exception while printing object)") 1002 1003 # When garbage is created, this call can help identify reference cycles, 1004 # which are typically the cause of such garbage. 1005 if new_garbage > previous_garbage: 1006 for i in range(previous_garbage, new_garbage): 1007 if _find_reference_cycle(gc.garbage, i): 1008 break 1009 1010 # This will fail if any garbage has been created, typically because of a 1011 # reference cycle. 1012 self.assertEqual(previous_garbage, new_garbage) 1013 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 1014 # be nice to be able to decorate arbitrary tests in a large test suite and 1015 # not hold on to every object in other tests. 1016 gc.set_debug(previous_debug_flags) 1017 gc.enable() 1018 return result 1019 1020 return decorator 1021 1022 1023def _combine_named_parameters(**kwargs): 1024 """Generate combinations based on its keyword arguments. 1025 1026 Two sets of returned combinations can be concatenated using +. Their product 1027 can be computed using `times()`. 1028 1029 Args: 1030 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1031 `option=the_only_possibility`. 1032 1033 Returns: 1034 a list of dictionaries for each combination. Keys in the dictionaries are 1035 the keyword argument names. Each key has one value - one of the 1036 corresponding keyword argument values. 1037 """ 1038 sort_by_key = lambda k: k[0] 1039 combinations = [] 1040 for key, values in sorted(kwargs.items(), key=sort_by_key): 1041 if not isinstance(values, list): 1042 values = [values] 1043 combinations.append([(key, value) for value in values]) 1044 1045 return [OrderedDict(result) for result in itertools.product(*combinations)] 1046 1047 1048def generate_combinations_with_testcase_name(**kwargs): 1049 """Generate combinations based on its keyword arguments using combine(). 1050 1051 This function calls combine() and appends a testcase name to the list of 1052 dictionaries returned. The 'testcase_name' key is a required for named 1053 parameterized tests. 1054 1055 Args: 1056 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1057 `option=the_only_possibility`. 1058 1059 Returns: 1060 a list of dictionaries for each combination. Keys in the dictionaries are 1061 the keyword argument names. Each key has one value - one of the 1062 corresponding keyword argument values. 1063 """ 1064 combinations = _combine_named_parameters(**kwargs) 1065 named_combinations = [] 1066 for combination in combinations: 1067 assert isinstance(combination, OrderedDict) 1068 name = "".join([ 1069 "_{}_{}".format("".join(filter(str.isalnum, key)), 1070 "".join(filter(str.isalnum, str(value)))) 1071 for key, value in combination.items() 1072 ]) 1073 named_combinations.append( 1074 OrderedDict( 1075 list(combination.items()) + 1076 [("testcase_name", "_test{}".format(name))])) 1077 1078 return named_combinations 1079 1080 1081def run_all_in_graph_and_eager_modes(cls): 1082 """Execute all test methods in the given class with and without eager.""" 1083 base_decorator = run_in_graph_and_eager_modes 1084 for name in dir(cls): 1085 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1086 name.startswith("testSkipEager") or 1087 name.startswith("test_skip_eager") or 1088 name == "test_session"): 1089 continue 1090 value = getattr(cls, name, None) 1091 if callable(value): 1092 setattr(cls, name, base_decorator(value)) 1093 return cls 1094 1095 1096def enable_nested_function_shape_inference(fn): 1097 """Decorator for enabling nested_function_shape_inference on a test. 1098 1099 This function returns a decorator intended to be applied to test methods in 1100 a `tf.test.TestCase` class. Doing so will set nested_function_shape_inference, 1101 reset the context, execute the test, then reset the context to the state 1102 it was in prior to this test. 1103 1104 Example: 1105 1106 class MyTest(test.TestCase): 1107 1108 @enable_nested_function_shape_inference 1109 def testFoo(self): 1110 ... 1111 1112 Args: 1113 fn: the function to be wrapped. 1114 1115 Returns: 1116 The wrapped function. 1117 """ 1118 1119 def wrapper(*args, **kwargs): 1120 # If `nested_function_shape_inference` is already enabled do nothing. 1121 if flags.config().enable_nested_function_shape_inference.value(): 1122 return fn(*args, **kwargs) 1123 1124 flags.config().enable_nested_function_shape_inference.reset(True) 1125 try: 1126 return fn(*args, **kwargs) 1127 finally: 1128 flags.config().enable_nested_function_shape_inference.reset(False) 1129 1130 return wrapper 1131 1132 1133def enable_eager_op_as_function(fn): 1134 """Returns the same fn. This will be removed once all usages are removed. 1135 1136 Args: 1137 fn: the function to be wrapped. 1138 1139 Returns: 1140 The wrapped function. 1141 """ 1142 1143 def wrapper(*args, **kwargs): 1144 return fn(*args, **kwargs) 1145 1146 return wrapper 1147 1148 1149@tf_export("test.with_eager_op_as_function") 1150def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disable=unused-argument 1151 """Returns the same class. This will be removed once all usages are removed. 1152 1153 Args: 1154 cls: class to decorate. 1155 only_as_function: unused argument. 1156 1157 Returns: 1158 cls 1159 """ 1160 1161 def decorator(cls): 1162 return cls 1163 1164 if cls is not None: 1165 return decorator(cls) 1166 1167 return decorator 1168 1169 1170def enable_graph_building_optimization(fn): 1171 """Decorator for enabling graph_building_optimization on a test. 1172 1173 This function returns a decorator intended to be applied to test methods in 1174 a `tf.test.TestCase` class. Doing so will enable graph_building_optimization, 1175 execute the test, then reset the feature flag to its default value. 1176 1177 Example: 1178 1179 class MyTest(test.TestCase): 1180 1181 @enable_graph_building_optimization 1182 def testFoo(self): 1183 ... 1184 1185 Args: 1186 fn: the function to be wrapped. 1187 1188 Returns: 1189 The wrapped function. 1190 """ 1191 1192 def wrapper(*args, **kwargs): 1193 # If `graph_building_optimization` is already enabled do nothing. 1194 if flags.config().graph_building_optimization.value(): 1195 return fn(*args, **kwargs) 1196 1197 flags.config().graph_building_optimization.reset(True) 1198 try: 1199 return fn(*args, **kwargs) 1200 finally: 1201 flags.config().graph_building_optimization.reset(False) 1202 1203 return wrapper 1204 1205 1206def add_graph_building_optimization_tests(cls=None): 1207 """Adds methods with graph_building_optimization enabled to the test suite. 1208 1209 Example: 1210 1211 @test_util.add_graph_building_optimization_tests 1212 class FooTest(test.TestCase): 1213 1214 def testBar(self): 1215 ... 1216 1217 Generated class: 1218 class FooTest(test.TestCase): 1219 1220 def testBar(self): 1221 ... 1222 1223 def testBarWithGraphBuildingOptimization(self): 1224 // Enable graph_building_optimization 1225 testBar(self) 1226 // Disable graph_building_optimization 1227 1228 Args: 1229 cls: class to decorate. 1230 1231 Returns: 1232 cls with new test methods added. 1233 """ 1234 1235 def decorator(cls): 1236 if flags.config().graph_building_optimization.value(): 1237 return cls 1238 1239 for name, value in cls.__dict__.copy().items(): 1240 if (callable(value) and 1241 (name.startswith(unittest.TestLoader.testMethodPrefix) or 1242 name.startswith("benchmark"))): 1243 setattr(cls, name + "WithGraphBuildingOptimization", 1244 enable_graph_building_optimization(value)) 1245 return cls 1246 1247 if cls is not None: 1248 return decorator(cls) 1249 1250 return decorator 1251 1252 1253def disable_eager_op_as_function(unused_msg): 1254 """Decorator for a function in a with_eager_op_as_function enabled test class. 1255 1256 Blocks the function from being run with eager_op_as_function enabled. 1257 1258 Args: 1259 unused_msg: Reason for disabling. 1260 1261 Returns: 1262 The wrapped function with _disable_eager_op_as_function attr set to True. 1263 """ 1264 return _disable_test(execute_func=False) 1265 1266 1267def set_xla_env_flag(func=None, flag=""): 1268 """Decorator for setting XLA_FLAGS prior to running a test. 1269 1270 This function returns a decorator intended to be applied to test methods in 1271 a `tf.test.TestCase` class. Doing so will allow users to set any xla flags 1272 exposed via the XLA_FLAGS environment variable, execute the test, then reset 1273 the XLA_FLAGS to the state it was in prior to this test. 1274 1275 Example: 1276 1277 class MyTest(test.TestCase): 1278 1279 @set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false') 1280 def testFoo(self): 1281 ... 1282 1283 Args: 1284 func: The function to be wrapped. 1285 flag: The xla flag to be set in the XLA_FLAGS env variable. 1286 1287 Returns: 1288 The wrapped function. 1289 """ 1290 1291 def decorator(f): 1292 1293 @functools.wraps(f) 1294 def decorated(*args, **kwargs): 1295 original_xla_flags = os.environ.get("XLA_FLAGS") 1296 new_xla_flags = flag 1297 if original_xla_flags: 1298 new_xla_flags = new_xla_flags + " " + original_xla_flags 1299 os.environ["XLA_FLAGS"] = new_xla_flags 1300 try: 1301 return f(*args, **kwargs) 1302 finally: 1303 if original_xla_flags is None: 1304 del os.environ["XLA_FLAGS"] 1305 else: 1306 os.environ["XLA_FLAGS"] = original_xla_flags 1307 1308 return decorated 1309 1310 if func is not None: 1311 return decorator(func) 1312 1313 return decorator 1314 1315 1316def build_as_function_and_v1_graph(func=None): 1317 """Run a test case in v1 graph mode and inside tf.function in eager mode. 1318 1319 WARNING: This decorator can only be used in test cases that statically checks 1320 generated graph. Attempting to evaluate graph or function results via. 1321 session.run() or self.evaluate() will fail. 1322 1323 WARNING: This decorator can only be used for test cases that inherit from 1324 absl.testing.parameterized.TestCase. 1325 1326 Args: 1327 func: Test case function to be decorated. 1328 1329 Returns: 1330 Decorated test case function. 1331 """ 1332 1333 def decorator(f): 1334 if tf_inspect.isclass(f): 1335 raise ValueError( 1336 "`run_in_graph_mode_and_function` only supports test methods.") 1337 1338 @parameterized.named_parameters(("_v1_graph", "v1_graph"), 1339 ("_function", "function")) 1340 @functools.wraps(f) 1341 def decorated(self, run_mode, *args, **kwargs): 1342 if run_mode == "v1_graph": 1343 with ops.Graph().as_default(): 1344 f(self, *args, **kwargs) 1345 elif run_mode == "function": 1346 1347 @def_function.function 1348 def function_in_eager(): 1349 f(self, *args, **kwargs) 1350 1351 # Create a new graph for the eagerly executed version of this test for 1352 # better isolation. 1353 graph_for_eager_test = ops.Graph() 1354 with graph_for_eager_test.as_default(), context.eager_mode(): 1355 function_in_eager() 1356 ops.dismantle_graph(graph_for_eager_test) 1357 else: 1358 raise ValueError("Unknown run mode %s" % run_mode) 1359 1360 return decorated 1361 1362 if func is not None: 1363 return decorator(func) 1364 1365 return decorator 1366 1367 1368def run_in_async_and_sync_mode(f): 1369 """Execute the test in async mode and sync mode.""" 1370 1371 @parameterized.named_parameters([("Async", True), ("", False)]) 1372 @functools.wraps(f) 1373 def decorator(self, async_mode, *args, **kwargs): 1374 if async_mode: 1375 with context.execution_mode(context.ASYNC): 1376 f(self, *args, **kwargs) 1377 else: 1378 with context.execution_mode(context.SYNC): 1379 f(self, *args, **kwargs) 1380 return decorator 1381 1382 1383def run_in_graph_and_eager_modes(func=None, 1384 config=None, 1385 use_gpu=True, 1386 assert_no_eager_garbage=False): 1387 """Execute the decorated test with and without enabling eager execution. 1388 1389 This function returns a decorator intended to be applied to test methods in 1390 a `tf.test.TestCase` class. Doing so will cause the contents of the test 1391 method to be executed twice - once normally, and once with eager execution 1392 enabled. This allows unittests to confirm the equivalence between eager 1393 and graph execution (see `tf.compat.v1.enable_eager_execution`). 1394 1395 For example, consider the following unittest: 1396 1397 ```python 1398 class MyTests(tf.test.TestCase): 1399 1400 @run_in_graph_and_eager_modes 1401 def test_foo(self): 1402 x = tf.constant([1, 2]) 1403 y = tf.constant([3, 4]) 1404 z = tf.add(x, y) 1405 self.assertAllEqual([4, 6], self.evaluate(z)) 1406 1407 if __name__ == "__main__": 1408 tf.test.main() 1409 ``` 1410 1411 This test validates that `tf.add()` has the same behavior when computed with 1412 eager execution enabled as it does when constructing a TensorFlow graph and 1413 executing the `z` tensor in a session. 1414 1415 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1416 `run_in_graph_and_eager_modes` are available decorators for different 1417 v1/v2/eager/graph combinations. 1418 1419 1420 Args: 1421 func: function to be annotated. If `func` is None, this method returns a 1422 decorator the can be applied to a function. If `func` is not None this 1423 returns the decorator applied to `func`. 1424 config: An optional config_pb2.ConfigProto to use to configure the session 1425 when executing graphs. 1426 use_gpu: If True, attempt to run as many operations as possible on GPU. 1427 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 1428 collector and asserts that no extra garbage has been created when running 1429 the test with eager execution enabled. This will fail if there are 1430 reference cycles (e.g. a = []; a.append(a)). Off by default because some 1431 tests may create garbage for legitimate reasons (e.g. they define a class 1432 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 1433 Python interpreters (meaning that tests which rely on objects being 1434 collected elsewhere in the unit test file will not work). Additionally, 1435 checks that nothing still has a reference to Tensors that the test 1436 allocated. 1437 1438 Returns: 1439 Returns a decorator that will run the decorated test method twice: 1440 once by constructing and executing a graph in a session and once with 1441 eager execution enabled. 1442 """ 1443 1444 def decorator(f): 1445 if tf_inspect.isclass(f): 1446 raise ValueError( 1447 "`run_in_graph_and_eager_modes` only supports test methods. " 1448 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 1449 1450 def decorated(self, *args, **kwargs): 1451 logging.info("Running %s in GRAPH mode.", f.__name__) 1452 try: 1453 with context.graph_mode(): 1454 with self.test_session(use_gpu=use_gpu, config=config): 1455 f(self, *args, **kwargs) 1456 except unittest.case.SkipTest: 1457 pass 1458 1459 def run_eagerly(self, **kwargs): 1460 logging.info("Running %s in EAGER mode.", f.__name__) 1461 if not use_gpu: 1462 with ops.device("/device:CPU:0"): 1463 f(self, *args, **kwargs) 1464 else: 1465 f(self, *args, **kwargs) 1466 1467 if assert_no_eager_garbage: 1468 ops.reset_default_graph() 1469 run_eagerly = assert_no_new_tensors( 1470 assert_no_garbage_created(run_eagerly)) 1471 1472 # This decorator runs the wrapped test twice. 1473 # Reset the test environment between runs. 1474 self.tearDown() 1475 self._tempdir = None 1476 # Create a new graph for the eagerly executed version of this test for 1477 # better isolation. 1478 graph_for_eager_test = ops.Graph() 1479 with graph_for_eager_test.as_default(), context.eager_mode(): 1480 self.setUp() 1481 run_eagerly(self, **kwargs) 1482 ops.dismantle_graph(graph_for_eager_test) 1483 1484 return tf_decorator.make_decorator(f, decorated) 1485 1486 if func is not None: 1487 return decorator(func) 1488 1489 return decorator 1490 1491 1492def py_func_if_in_function(f): 1493 1494 def decorated(*args, **kwds): 1495 if not ops.inside_function(): 1496 return f(*args, **kwds) 1497 1498 tensor_args = [] 1499 tensor_indices = [] 1500 for i, arg in enumerate(args): 1501 if isinstance(arg, (ops.Tensor, variables.Variable)): 1502 tensor_args.append(arg) 1503 tensor_indices.append(i) 1504 1505 def inner_f(*inner_tensor_args): 1506 my_args = list(args) 1507 for i, n in zip(tensor_indices, inner_tensor_args): 1508 my_args[i] = n 1509 return f(*my_args, **kwds) 1510 1511 return script_ops.py_func(inner_f, tensor_args, []) 1512 1513 return tf_decorator.make_decorator(f, decorated) 1514 1515 1516def also_run_as_tf_function(f): 1517 """Runs the decorated test twice--once as is, once inside a tf.function. 1518 1519 This allows you to run a test both in eager execution and inside a 1520 tf.function, exercising the two execution modes supported in tf 2.0. The test 1521 assertions are automatically done inside tf.py_funcs, and tf.function ensures 1522 that they run in the proper order and with the proper side effects. 1523 1524 Currently variable creation is not supported in tests annotated with this 1525 decorator since it's tricky to ensure the variable doesn't get repeatedly 1526 created when retracing the tf.function. 1527 1528 Args: 1529 f: the test method to be decorated 1530 1531 Returns: 1532 The decorated test method, which will run both in eager and inside a 1533 tf.function. 1534 """ 1535 1536 def decorated(*args, **kwds): 1537 1538 def bound_f(): 1539 f(*args, **kwds) 1540 1541 with context.eager_mode(): 1542 # Running in eager mode 1543 bound_f() 1544 # Running as TF function 1545 # TODO(b/121143941): Remove the autograph override. 1546 def_function.function(bound_f, autograph=False)() 1547 1548 return decorated 1549 1550 1551def deprecated_graph_mode_only(func=None): 1552 """Execute the decorated test in graph mode. 1553 1554 This function returns a decorator intended to be applied to tests that are not 1555 compatible with eager mode. When this decorator is applied, the test body will 1556 be run in an environment where API calls construct graphs instead of executing 1557 eagerly. 1558 1559 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1560 `run_in_graph_and_eager_modes` are available decorators for different 1561 v1/v2/eager/graph combinations. 1562 1563 Args: 1564 func: function to be annotated. If `func` is None, this method returns a 1565 decorator the can be applied to a function. If `func` is not None this 1566 returns the decorator applied to `func`. 1567 1568 Returns: 1569 Returns a decorator that will run the decorated test method in graph mode. 1570 """ 1571 1572 def decorator(f): 1573 if tf_inspect.isclass(f): 1574 setup = f.__dict__.get("setUp") 1575 if setup is not None: 1576 setattr(f, "setUp", decorator(setup)) 1577 1578 for name, value in f.__dict__.copy().items(): 1579 if (callable(value) and 1580 name.startswith(unittest.TestLoader.testMethodPrefix)): 1581 setattr(f, name, decorator(value)) 1582 1583 return f 1584 1585 def decorated(self, *args, **kwargs): 1586 if context.executing_eagerly(): 1587 with context.graph_mode(): 1588 return f(self, *args, **kwargs) 1589 else: 1590 return f(self, *args, **kwargs) 1591 1592 return decorated 1593 1594 if func is not None: 1595 return decorator(func) 1596 1597 return decorator 1598 1599 1600run_deprecated_v1 = deprecated_graph_mode_only 1601 1602 1603def run_all_in_deprecated_graph_mode_only(cls): 1604 """Execute all tests in a class in graph mode.""" 1605 base_decorator = deprecated_graph_mode_only 1606 for name in dir(cls): 1607 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1608 name == "test_session"): 1609 continue 1610 value = getattr(cls, name, None) 1611 if callable(value): 1612 setattr(cls, name, base_decorator(value)) 1613 return cls 1614 1615 1616def run_v1_only(reason, func=None): 1617 """Execute the decorated test only if running in v1 mode. 1618 1619 This function is intended to be applied to tests that exercise v1 only 1620 functionality. If the test is run in v2 mode it will simply be skipped. 1621 1622 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1623 `run_in_graph_and_eager_modes` are available decorators for different 1624 v1/v2/eager/graph combinations. 1625 1626 Args: 1627 reason: string giving a reason for limiting the test to v1 only. 1628 func: function to be annotated. If `func` is None, this method returns a 1629 decorator the can be applied to a function. If `func` is not None this 1630 returns the decorator applied to `func`. 1631 1632 Returns: 1633 Returns a decorator that will conditionally skip the decorated test method. 1634 """ 1635 if not isinstance(reason, str): 1636 raise ValueError("'reason' should be string, got {}".format(type(reason))) 1637 1638 def decorator(f): 1639 if tf_inspect.isclass(f): 1640 # To skip an entire test suite class, we only decorate the setUp method 1641 # to skip all tests. There are cases when setUp is not defined (not 1642 # overridden in subclasses of TestCase, so not available in f.__dict__ 1643 # below). For those cases, we walk the method resolution order list and 1644 # pick the first setUp method we find (usually this should be the one in 1645 # the parent class since that's the TestCase class). 1646 for cls in type.mro(f): 1647 setup = cls.__dict__.get("setUp") 1648 if setup is not None: 1649 setattr(f, "setUp", decorator(setup)) 1650 break 1651 1652 return f 1653 else: 1654 # If f is just a function, just create a decorator for it and return it 1655 def decorated(self, *args, **kwargs): 1656 if tf2.enabled(): 1657 self.skipTest(reason) 1658 1659 return f(self, *args, **kwargs) 1660 1661 return decorated 1662 1663 if func is not None: 1664 return decorator(func) 1665 1666 return decorator 1667 1668 1669def run_v2_only(func=None): 1670 """Execute the decorated test only if running in v2 mode. 1671 1672 This function is intended to be applied to tests that exercise v2 only 1673 functionality. If the test is run in v1 mode it will simply be skipped. 1674 1675 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1676 `run_in_graph_and_eager_modes` are available decorators for different 1677 v1/v2/eager/graph combinations. 1678 1679 Args: 1680 func: function to be annotated. If `func` is None, this method returns a 1681 decorator the can be applied to a function. If `func` is not None this 1682 returns the decorator applied to `func`. 1683 1684 Returns: 1685 Returns a decorator that will conditionally skip the decorated test method. 1686 """ 1687 1688 def decorator(f): 1689 if tf_inspect.isclass(f): 1690 raise ValueError("`run_v2_only` only supports test methods.") 1691 1692 def decorated(self, *args, **kwargs): 1693 if not tf2.enabled(): 1694 self.skipTest("Test is only compatible with v2") 1695 1696 return f(self, *args, **kwargs) 1697 1698 return decorated 1699 1700 if func is not None: 1701 return decorator(func) 1702 1703 return decorator 1704 1705 1706def run_gpu_only(func=None): 1707 """Execute the decorated test only if a GPU is available. 1708 1709 This function is intended to be applied to tests that require the presence 1710 of a GPU. If a GPU is absent, it will simply be skipped. 1711 1712 Args: 1713 func: function to be annotated. If `func` is None, this method returns a 1714 decorator the can be applied to a function. If `func` is not None this 1715 returns the decorator applied to `func`. 1716 1717 Returns: 1718 Returns a decorator that will conditionally skip the decorated test method. 1719 """ 1720 1721 def decorator(f): 1722 if tf_inspect.isclass(f): 1723 raise ValueError("`run_gpu_only` only supports test methods.") 1724 1725 def decorated(self, *args, **kwargs): 1726 if not is_gpu_available(): 1727 self.skipTest("Test requires GPU") 1728 1729 return f(self, *args, **kwargs) 1730 1731 return decorated 1732 1733 if func is not None: 1734 return decorator(func) 1735 1736 return decorator 1737 1738 1739def run_cuda_only(func=None): 1740 """Execute the decorated test only if a GPU is available. 1741 1742 This function is intended to be applied to tests that require the presence 1743 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 1744 1745 Args: 1746 func: function to be annotated. If `func` is None, this method returns a 1747 decorator the can be applied to a function. If `func` is not None this 1748 returns the decorator applied to `func`. 1749 1750 Returns: 1751 Returns a decorator that will conditionally skip the decorated test method. 1752 """ 1753 1754 def decorator(f): 1755 if tf_inspect.isclass(f): 1756 raise ValueError("`run_cuda_only` only supports test methods.") 1757 1758 def decorated(self, *args, **kwargs): 1759 if not is_gpu_available(cuda_only=True): 1760 self.skipTest("Test requires CUDA GPU") 1761 1762 return f(self, *args, **kwargs) 1763 1764 return decorated 1765 1766 if func is not None: 1767 return decorator(func) 1768 1769 return decorator 1770 1771 1772def run_gpu_or_tpu(func=None): 1773 """Execute the decorated test only if a physical GPU or TPU is available. 1774 1775 This function is intended to be applied to tests that require the presence 1776 of a physical GPU or TPU. It complies with the following rules: 1777 - If a GPU is available, the test will run on the GPU. 1778 - If a GPU is absent and a TPU is available, the test will run on the TPU. 1779 - If both GPU and TPU are absent, the test will be skipped. 1780 1781 Args: 1782 func: function to be annotated. If `func` is None, this method returns a 1783 decorator the can be applied to a function. If `func` is not None this 1784 returns the decorator applied to `func`. 1785 1786 Returns: 1787 Returns a decorator that will conditionally skip the decorated test method. 1788 """ 1789 1790 def decorator(f): 1791 if tf_inspect.isclass(f): 1792 raise ValueError("`run_gpu_or_tpu` only supports test methods.") 1793 1794 def decorated(self, *args, **kwargs): 1795 if config.list_physical_devices("GPU"): 1796 return f(self, "GPU", *args, **kwargs) 1797 1798 if config.list_physical_devices("TPU"): 1799 return f(self, "TPU", *args, **kwargs) 1800 1801 self.skipTest("Test requires GPU or TPU") 1802 1803 return decorated 1804 1805 return decorator if func is None else decorator(func) 1806 1807 1808def with_forward_compatibility_horizons(*horizons): 1809 """Executes the decorated test with the specified forward-compat horizons. 1810 1811 Args: 1812 *horizons: A list of (year, month, day) tuples. If the list includes 1813 `None`, then the test will also be run with no forward-compatibility 1814 horizon set. 1815 1816 Returns: 1817 A decorator that will execute the test with the specified horizons. 1818 """ 1819 if not horizons: 1820 raise ValueError("Expected at least one horizon.") 1821 for horizon in horizons: 1822 if not ((horizon is None) or 1823 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): 1824 raise ValueError("Bad horizon value: %r" % horizon) 1825 1826 def decorator(f): 1827 if tf_inspect.isclass(f): 1828 raise ValueError("`with_forward_compatibility_horizons` only " 1829 "supports test methods.") 1830 def decorated(self, *args, **kwargs): 1831 for horizon in horizons: 1832 if horizon is None: 1833 f(self, *args, **kwargs) 1834 else: 1835 (year, month, day) = horizon 1836 with forward_compatibility_horizon(year, month, day): 1837 f(self, *args, **kwargs) 1838 return decorated 1839 1840 return decorator 1841 1842 1843@deprecation.deprecated(None, 1844 "Use `tf.config.list_physical_devices('GPU')` instead.") 1845@tf_export("test.is_gpu_available") 1846def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 1847 """Returns whether TensorFlow can access a GPU. 1848 1849 Warning: if a non-GPU version of the package is installed, the function would 1850 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow 1851 was build with CUDA support. 1852 1853 For example, 1854 >>> gpu_available = tf.test.is_gpu_available() 1855 >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True) 1856 >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0)) 1857 1858 Args: 1859 cuda_only: limit the search to CUDA GPUs. 1860 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 1861 CUDA compute capability required, or None if no requirement. 1862 1863 Note that the keyword arg name "cuda_only" is misleading (since routine will 1864 return true when a GPU device is available irrespective of whether TF was 1865 built with CUDA support or ROCm support. However no changes here because 1866 1867 ++ Changing the name "cuda_only" to something more generic would break 1868 backward compatibility 1869 1870 ++ Adding an equivalent "rocm_only" would require the implementation check 1871 the build type. This in turn would require doing the same for CUDA and thus 1872 potentially break backward compatibility 1873 1874 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, 1875 but would require most (if not all) callers to update the call to use 1876 "cuda_or_rocm_only" instead of "cuda_only" 1877 1878 Returns: 1879 True if a GPU device of the requested kind is available. 1880 """ 1881 1882 # This was needed earlier when we had support for SYCL in TensorFlow. 1883 del cuda_only 1884 1885 try: 1886 for local_device in device_lib.list_local_devices(): 1887 if local_device.device_type == "GPU": 1888 gpu_info = gpu_util.compute_capability_from_device_desc(local_device) 1889 cc = gpu_info.compute_capability or (0, 0) 1890 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability: 1891 return True 1892 return False 1893 except errors_impl.NotFoundError as e: 1894 if not all(x in str(e) for x in ["CUDA", "not find"]): 1895 raise e 1896 else: 1897 logging.error(str(e)) 1898 return False 1899 1900 1901@contextlib.contextmanager 1902def device(use_gpu): 1903 """Uses gpu when requested and available.""" 1904 if use_gpu and is_gpu_available(): 1905 dev = "/device:GPU:0" 1906 else: 1907 dev = "/device:CPU:0" 1908 with ops.device(dev): 1909 yield 1910 1911 1912@contextlib.contextmanager 1913def use_gpu(): 1914 """Uses gpu when requested and available.""" 1915 with device(use_gpu=True): 1916 yield 1917 1918 1919@contextlib.contextmanager 1920def force_gpu(): 1921 """Force the gpu to be used.""" 1922 with ops.device("/device:GPU:0"): 1923 yield 1924 1925 1926@contextlib.contextmanager 1927def force_cpu(): 1928 """Force the cpu to be used.""" 1929 with ops.device("/device:CPU:0"): 1930 yield 1931 1932 1933@contextlib.contextmanager 1934def deterministic_ops(): 1935 """Enables deterministic ops.""" 1936 try: 1937 config.enable_op_determinism() 1938 yield 1939 finally: 1940 config.disable_op_determinism() 1941 1942 1943class CapturedWrites: 1944 """A utility class to load the captured writes made to a stream.""" 1945 1946 def __init__(self, capture_location): 1947 self.capture_location = capture_location 1948 1949 def contents(self): 1950 """Get the captured writes as a single string.""" 1951 with open(self.capture_location) as tmp_file: 1952 output_data = "".join(tmp_file.readlines()) 1953 return output_data 1954 1955 1956class FakeEagerSession: 1957 """Fake session so tests that conditionally use placeholders can use eager. 1958 1959 There are a number of tests that conditionally use placeholders for shape 1960 inference. The pattern is demonstrated here: 1961 1962 ```python 1963 with self.cached_session() as sess: 1964 if static_shape: 1965 y = math_ops.matmul(x, ...) 1966 feed_dict = {} 1967 else: 1968 x_ph = array_ops.placeholder(...) 1969 y = math_ops.matmul(x_ph, ...) 1970 feed_dict = {x_ph: x} 1971 val = sess.run(y, feed_dict=feed_dict) 1972 ``` 1973 1974 Since the feed_dict is empty when not using placeholders we should be able to 1975 call self.evaluate(), however this requires rewriting the test case. 1976 This class should be considered a stop-gap solution to get tests running with 1977 eager with minimal changes to the actual test. 1978 """ 1979 1980 def __init__(self, test_case): 1981 self._test_case = test_case 1982 1983 def run(self, fetches, *args, **kwargs): 1984 """Evaluate `fetches`. 1985 1986 Fail if additional args are specified. 1987 1988 Args: 1989 fetches: A Tensor or a nested list/tuple of Tensors. 1990 *args: Positional arguments 1991 **kwargs: Keyword arguments 1992 1993 Raises: 1994 RuntimeError: If args or kwargs are specified. 1995 1996 Returns: 1997 Tensors as numpy values. 1998 """ 1999 feed_dict = kwargs.pop("feed_dict", {}) 2000 if feed_dict: 2001 raise RuntimeError( 2002 "feed_dict is not supported when eager execution is enabled " 2003 "(in this case, sess.run(t) is shorthand for t.numpy()") 2004 2005 if args or kwargs: 2006 raise RuntimeError( 2007 "Optional args are not supported when eager execution is enabled " 2008 "(in this case, sess.run(t) is shorthand for t.numpy()") 2009 2010 return self._test_case.evaluate(fetches) 2011 2012 2013class ErrorLoggingSession(session.Session): 2014 """Wrapper around a Session that logs errors in run().""" 2015 2016 def run(self, *args, **kwargs): 2017 try: 2018 return super().run(*args, **kwargs) 2019 except Exception as e: # pylint: disable=broad-except 2020 # Note: disable the logging for OutOfRangeError, which makes the output 2021 # of tf.data tests hard to read, because OutOfRangeError is used as the 2022 # signal completion 2023 if not isinstance(e, errors.OutOfRangeError): 2024 logging.error(str(e)) 2025 raise 2026 2027 2028def disable_cudnn_autotune(func): 2029 """Disable autotuning during the call to this function. 2030 2031 Some tests want to base assertions on a graph being isomorphic with a copy. 2032 To ensure this, this decorator disables autotuning. 2033 2034 Args: 2035 func: Function to run with CuDNN autotuning turned off. 2036 2037 Returns: 2038 Decorated function. 2039 """ 2040 2041 def decorator(f): 2042 2043 def decorated(self, *args, **kwargs): 2044 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") 2045 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" 2046 original_xla_flags = os.environ.get("XLA_FLAGS") 2047 new_xla_flags = "--xla_gpu_autotune_level=0" 2048 if original_xla_flags: 2049 new_xla_flags = original_xla_flags + " " + new_xla_flags 2050 os.environ["XLA_FLAGS"] = new_xla_flags 2051 2052 result = f(self, *args, **kwargs) 2053 2054 if (original_tf_cudnn_use_autotune is None): 2055 del os.environ["TF_CUDNN_USE_AUTOTUNE"] 2056 else: 2057 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune 2058 if (original_xla_flags is None): 2059 del os.environ["XLA_FLAGS"] 2060 else: 2061 os.environ["XLA_FLAGS"] = original_xla_flags 2062 2063 return result 2064 2065 return decorated 2066 2067 if func is not None: 2068 return decorator(func) 2069 2070 return decorator 2071 2072 2073# The description is just for documentation purposes. 2074def enable_tf_xla_constant_folding(description): 2075 2076 if not isinstance(description, str): 2077 raise ValueError("'description' should be string, got {}".format( 2078 type(description))) 2079 2080 def enable_tf_xla_constant_folding_impl(func): 2081 """Enable constant folding during the call to this function. 2082 2083 Some tests fail without constant folding. 2084 2085 Args: 2086 func: Function to run with constant folding turned on. 2087 2088 Returns: 2089 Decorated function. 2090 """ 2091 2092 def decorator(f): 2093 2094 def decorated(self, *args, **kwargs): 2095 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() 2096 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) 2097 result = f(self, *args, **kwargs) 2098 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) 2099 return result 2100 2101 return decorated 2102 2103 if func is not None: 2104 return decorator(func) 2105 2106 return decorator 2107 2108 return enable_tf_xla_constant_folding_impl 2109 2110 2111# Updates test function by selectively disabling it. 2112def _disable_test(execute_func): 2113 2114 def disable_test_impl(func): 2115 2116 def decorator(func): 2117 2118 def decorated(self, *args, **kwargs): 2119 if execute_func: 2120 return func(self, *args, **kwargs) 2121 2122 return tf_decorator.make_decorator(func, decorated) 2123 2124 if func is not None: 2125 return decorator(func) 2126 2127 return decorator 2128 2129 return disable_test_impl 2130 2131 2132# The description is just for documentation purposes. 2133def disable_xla(description): # pylint: disable=unused-argument 2134 """Execute the test method only if xla is not enabled.""" 2135 execute_func = not is_xla_enabled() 2136 return _disable_test(execute_func) 2137 2138 2139# The description is just for documentation purposes. 2140def disable_mlir_bridge(description): # pylint: disable=unused-argument 2141 """Execute the test method only if MLIR bridge is not enabled.""" 2142 execute_func = not is_mlir_bridge_enabled() 2143 return _disable_test(execute_func) 2144 2145 2146# The description is just for documentation purposes. 2147def disable_asan(description): # pylint: disable=unused-argument 2148 """Execute the test method only if ASAN is not enabled.""" 2149 execute_func = not is_asan_enabled() 2150 return _disable_test(execute_func) 2151 2152 2153# The description is just for documentation purposes. 2154def disable_msan(description): # pylint: disable=unused-argument 2155 """Execute the test method only if MSAN is not enabled.""" 2156 execute_func = not is_msan_enabled() 2157 return _disable_test(execute_func) 2158 2159 2160# The description is just for documentation purposes. 2161def disable_tsan(description): # pylint: disable=unused-argument 2162 """Execute the test method only if TSAN is not enabled.""" 2163 execute_func = not is_tsan_enabled() 2164 return _disable_test(execute_func) 2165 2166 2167# The description is just for documentation purposes. 2168def disable_ubsan(description): # pylint: disable=unused-argument 2169 """Execute the test method only if UBSAN is not enabled.""" 2170 execute_func = not is_ubsan_enabled() 2171 return _disable_test(execute_func) 2172 2173 2174# The description is just for documentation purposes. 2175def disable_tfrt(unused_description): 2176 2177 def disable_tfrt_impl(cls_or_func): 2178 """Execute the test only if tfrt is not enabled.""" 2179 2180 if tf_inspect.isclass(cls_or_func): 2181 if tfrt_utils.enabled(): 2182 return None 2183 else: 2184 return cls_or_func 2185 else: 2186 def decorator(func): 2187 2188 def decorated(self, *args, **kwargs): 2189 if tfrt_utils.enabled(): 2190 return 2191 else: 2192 return func(self, *args, **kwargs) 2193 2194 return decorated 2195 2196 if cls_or_func is not None: 2197 return decorator(cls_or_func) 2198 2199 return decorator 2200 2201 return disable_tfrt_impl 2202 2203 2204def for_all_test_methods(decorator, *args, **kwargs): 2205 """Generate class-level decorator from given method-level decorator. 2206 2207 It is expected for the given decorator to take some arguments and return 2208 a method that is then called on the test method to produce a decorated 2209 method. 2210 2211 Args: 2212 decorator: The decorator to apply. 2213 *args: Positional arguments 2214 **kwargs: Keyword arguments 2215 Returns: Function that will decorate a given classes test methods with the 2216 decorator. 2217 """ 2218 2219 def all_test_methods_impl(cls): 2220 """Apply decorator to all test methods in class.""" 2221 for name in dir(cls): 2222 value = getattr(cls, name) 2223 if callable(value) and name.startswith( 2224 "test") and (name != "test_session"): 2225 setattr(cls, name, decorator(*args, **kwargs)(value)) 2226 return cls 2227 2228 return all_test_methods_impl 2229 2230 2231# The description is just for documentation purposes. 2232def no_xla_auto_jit(description): # pylint: disable=unused-argument 2233 """This test is not intended to be run with XLA auto jit enabled.""" 2234 execute_func = not is_xla_enabled() 2235 return _disable_test(execute_func) 2236 2237 2238# The description is just for documentation purposes. 2239def xla_allow_fallback(description): # pylint: disable=unused-argument 2240 2241 def xla_allow_fallback_impl(func): 2242 """Allow fallback to TF even though testing xla.""" 2243 2244 def decorator(func): 2245 2246 def decorated(self, *args, **kwargs): 2247 if is_xla_enabled(): 2248 # Update the global XLABuildOpsPassFlags to enable lazy compilation, 2249 # which allows the compiler to fall back to TF classic. Remember the 2250 # old value so that we can reset it. 2251 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) 2252 result = func(self, *args, **kwargs) 2253 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) 2254 return result 2255 else: 2256 return func(self, *args, **kwargs) 2257 2258 return decorated 2259 2260 if func is not None: 2261 return decorator(func) 2262 2263 return decorator 2264 2265 return xla_allow_fallback_impl 2266 2267 2268# The description is just for documentation purposes. 2269def run_without_tensor_float_32(description): # pylint: disable=unused-argument 2270 """Execute test with TensorFloat-32 disabled. 2271 2272 While almost every real-world deep learning model runs fine with 2273 TensorFloat-32, many tests use assertAllClose or similar methods. 2274 TensorFloat-32 matmuls typically will cause such methods to fail with the 2275 default tolerances. 2276 2277 Args: 2278 description: A description used for documentation purposes, describing why 2279 the test requires TensorFloat-32 to be disabled. 2280 2281 Returns: 2282 Decorator which runs a test with TensorFloat-32 disabled. 2283 """ 2284 2285 def decorator(f): 2286 2287 @functools.wraps(f) 2288 def decorated(self, *args, **kwargs): 2289 allowed = config.tensor_float_32_execution_enabled() 2290 try: 2291 config.enable_tensor_float_32_execution(False) 2292 f(self, *args, **kwargs) 2293 finally: 2294 config.enable_tensor_float_32_execution(allowed) 2295 2296 return decorated 2297 2298 return decorator 2299 2300 2301# The description is just for documentation purposes. 2302def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 2303 """Execute all tests in a class with TensorFloat-32 disabled.""" 2304 return for_all_test_methods(run_without_tensor_float_32, description) 2305 2306 2307def matmul_without_tf32(a, b, *args, **kwargs): 2308 """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled. 2309 2310 This effectively runs matmul without TensorFloat-32. It should only be used in 2311 tests when verifying some other op or functions works correctly, e.g. to test 2312 `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In 2313 such cases, the matmul itself is not being tested so it's OK to run it with 2314 higher precision. 2315 2316 If a matmul itself is being tested, or some other op which uses matmul, use 2317 `run_without_tensor_float_32` instead. 2318 2319 This also casts complex64 inputs to complex128, since TensorFloat-32 can also 2320 be used with complex64 2321 2322 Args: 2323 a: First input to tf.linalg.matmul 2324 b: Second input to tf.linalg.matmul 2325 args: Other positional arguments to tf.linalg.matmul 2326 **kwargs: Other keyword arguments to tf.linalg.matmul 2327 2328 Returns: 2329 A tensor with the same type as `a`. 2330 """ 2331 if config.tensor_float_32_execution_enabled() and a.dtype == "float32": 2332 a = math_ops.cast(a, "float64") 2333 b = math_ops.cast(b, "float64") 2334 ret = math_ops.matmul(a, b, *args, **kwargs) 2335 return math_ops.cast(ret, a.dtype) 2336 elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": 2337 a = math_ops.cast(a, "complex128") 2338 b = math_ops.cast(b, "complex128") 2339 ret = math_ops.matmul(a, b, *args, **kwargs) 2340 return math_ops.cast(ret, a.dtype) 2341 else: 2342 return math_ops.matmul(a, b, *args, **kwargs) 2343 2344 2345class EagerSessionWarner: 2346 2347 def __getattr__(self, attr): 2348 raise AttributeError( 2349 "Trying to access properties or call methods on the result of " 2350 "self.session(), self.cached_session(), etc while eager execution " 2351 "is enabled. If you're porting this test case to TF 2.0, either " 2352 "adapt the test to work with eager execution or insert a call to " 2353 "tf.disable_eager_execution() in the main() function of this test " 2354 "file.") 2355 2356 2357@tf_export("test.TestCase") 2358class TensorFlowTestCase(googletest.TestCase): 2359 """Base class for tests that need to test TensorFlow.""" 2360 2361 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 2362 super().__init__(methodName) 2363 # Make sure we get unfiltered stack traces during the test 2364 traceback_utils.disable_traceback_filtering() 2365 if is_xla_enabled(): 2366 pywrap_tf_session.TF_SetXlaAutoJitMode("2") 2367 pywrap_tf_session.TF_SetXlaMinClusterSize(1) 2368 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False) 2369 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True) 2370 # Constant folding secretly runs code on TF:Classic CPU, so we also 2371 # disable it here. 2372 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) 2373 2374 # Check if the mlir bridge has been explicitly enabled or disabled. If 2375 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 2376 # or disable the bridge so do not update enable_mlir_bridge. 2377 if is_mlir_bridge_enabled(): 2378 context.context().enable_mlir_bridge = True 2379 elif is_mlir_bridge_enabled() is not None: 2380 context.context().enable_mlir_bridge = False 2381 2382 self._threads = [] 2383 self._tempdir = None 2384 self._cached_session = None 2385 self._test_start_time = None 2386 # This flag provides the ability to control whether the graph mode gets 2387 # initialized for TF1 or not. Initializing for TF1, which is what was 2388 # happening earlier, was preventing enablement of 'eager mode' in the test. 2389 self._set_default_seed = True 2390 2391 def setUp(self): 2392 super().setUp() 2393 self._ClearCachedSession() 2394 random.seed(random_seed.DEFAULT_GRAPH_SEED) 2395 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 2396 # Note: The following line is necessary because some test methods may error 2397 # out from within nested graph contexts (e.g., via assertRaises and 2398 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 2399 # under certain versions of Python. That would cause 2400 # ops.reset_default_graph() to throw an exception if the stack were not 2401 # cleared first. 2402 ops._default_graph_stack.reset() # pylint: disable=protected-access 2403 ops.reset_default_graph() 2404 if self._set_default_seed: 2405 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 2406 # Reset summary writer in case another test used set_as_default() with their 2407 # summary writer. 2408 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 2409 summary_state.writer = None 2410 2411 # Avoiding calling setUp() for the poorly named test_session method. 2412 if self.id().endswith(".test_session"): 2413 self.skipTest("Not a test.") 2414 2415 self._test_start_time = time.time() 2416 2417 def tearDown(self): 2418 # If a subclass overrides setUp and doesn't call the parent class's setUp, 2419 # then we may not have set the start time. 2420 if self._test_start_time is not None: 2421 logging.info("time(%s): %ss", self.id(), 2422 round(time.time() - self._test_start_time, 2)) 2423 2424 for thread in self._threads: 2425 thread.check_termination() 2426 2427 self._ClearCachedSession() 2428 super().tearDown() 2429 2430 def _ClearCachedSession(self): 2431 if self._cached_session is not None: 2432 self._cached_session.close() 2433 self._cached_session = None 2434 2435 def get_temp_dir(self): 2436 """Returns a unique temporary directory for the test to use. 2437 2438 If you call this method multiple times during in a test, it will return the 2439 same folder. However, across different runs the directories will be 2440 different. This will ensure that across different runs tests will not be 2441 able to pollute each others environment. 2442 If you need multiple unique directories within a single test, you should 2443 use tempfile.mkdtemp as follows: 2444 tempfile.mkdtemp(dir=self.get_temp_dir()): 2445 2446 Returns: 2447 string, the path to the unique temporary directory created for this test. 2448 """ 2449 if not self._tempdir: 2450 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 2451 return self._tempdir 2452 2453 @contextlib.contextmanager 2454 def captureWritesToStream(self, stream): 2455 """A context manager that captures the writes to a given stream. 2456 2457 This context manager captures all writes to a given stream inside of a 2458 `CapturedWrites` object. When this context manager is created, it yields 2459 the `CapturedWrites` object. The captured contents can be accessed by 2460 calling `.contents()` on the `CapturedWrites`. 2461 2462 For this function to work, the stream must have a file descriptor that 2463 can be modified using `os.dup` and `os.dup2`, and the stream must support 2464 a `.flush()` method. The default python sys.stdout and sys.stderr are 2465 examples of this. Note that this does not work in Colab or Jupyter 2466 notebooks, because those use alternate stdout streams. 2467 2468 Example: 2469 ```python 2470 class MyOperatorTest(test_util.TensorFlowTestCase): 2471 def testMyOperator(self): 2472 input = [1.0, 2.0, 3.0, 4.0, 5.0] 2473 with self.captureWritesToStream(sys.stdout) as captured: 2474 result = MyOperator(input).eval() 2475 self.assertStartsWith(captured.contents(), "This was printed.") 2476 ``` 2477 2478 Args: 2479 stream: The stream whose writes should be captured. This stream must have 2480 a file descriptor, support writing via using that file descriptor, and 2481 must have a `.flush()` method. 2482 2483 Yields: 2484 A `CapturedWrites` object that contains all writes to the specified stream 2485 made during this context. 2486 """ 2487 stream.flush() 2488 fd = stream.fileno() 2489 tmp_file, tmp_file_path = tempfile.mkstemp(dir=self.get_temp_dir()) 2490 orig_fd = os.dup(fd) 2491 os.dup2(tmp_file, fd) 2492 try: 2493 yield CapturedWrites(tmp_file_path) 2494 finally: 2495 os.close(tmp_file) 2496 os.dup2(orig_fd, fd) 2497 2498 def _AssertProtoEquals(self, a, b, msg=None): 2499 """Asserts that a and b are the same proto. 2500 2501 Uses ProtoEq() first, as it returns correct results 2502 for floating point attributes, and then use assertProtoEqual() 2503 in case of failure as it provides good error messages. 2504 2505 Args: 2506 a: a proto. 2507 b: another proto. 2508 msg: Optional message to report on failure. 2509 """ 2510 if not compare.ProtoEq(a, b): 2511 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 2512 2513 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 2514 """Asserts that message is same as parsed expected_message_ascii. 2515 2516 Creates another prototype of message, reads the ascii message into it and 2517 then compares them using self._AssertProtoEqual(). 2518 2519 Args: 2520 expected_message_maybe_ascii: proto message in original or ascii form. 2521 message: the message to validate. 2522 msg: Optional message to report on failure. 2523 """ 2524 if isinstance(expected_message_maybe_ascii, type(message)): 2525 expected_message = expected_message_maybe_ascii 2526 self._AssertProtoEquals(expected_message, message, msg=msg) 2527 elif isinstance(expected_message_maybe_ascii, (str, bytes)): 2528 expected_message = type(message)() 2529 text_format.Merge( 2530 expected_message_maybe_ascii, 2531 expected_message, 2532 descriptor_pool=descriptor_pool.Default()) 2533 self._AssertProtoEquals(expected_message, message, msg=msg) 2534 else: 2535 assert False, ("Can't compare protos of type %s and %s." % 2536 (type(expected_message_maybe_ascii), type(message))) 2537 2538 def assertProtoEqualsVersion( 2539 self, 2540 expected, 2541 actual, 2542 producer=versions.GRAPH_DEF_VERSION, 2543 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 2544 msg=None): 2545 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 2546 producer, min_consumer, expected) 2547 self.assertProtoEquals(expected, actual, msg=msg) 2548 2549 def assertStartsWith(self, actual, expected_start, msg=None): 2550 """Assert that actual.startswith(expected_start) is True. 2551 2552 Args: 2553 actual: str 2554 expected_start: str 2555 msg: Optional message to report on failure. 2556 """ 2557 if not actual.startswith(expected_start): 2558 fail_msg = "%r does not start with %r" % (actual, expected_start) 2559 fail_msg += " : %r" % (msg) if msg else "" 2560 self.fail(fail_msg) 2561 2562 def _eval_tensor(self, tensor): 2563 if tensor is None: 2564 return None 2565 elif callable(tensor): 2566 return self._eval_helper(tensor()) 2567 else: 2568 try: 2569 # for compatibility with TF1 test cases 2570 if sparse_tensor.is_sparse(tensor): 2571 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 2572 tensor.values.numpy(), 2573 tensor.dense_shape.numpy()) 2574 elif ragged_tensor.is_ragged(tensor): 2575 return ragged_tensor_value.RaggedTensorValue( 2576 self._eval_tensor(tensor.values), 2577 self._eval_tensor(tensor.row_splits)) 2578 elif isinstance(tensor, indexed_slices.IndexedSlices): 2579 return indexed_slices.IndexedSlicesValue( 2580 values=tensor.values.numpy(), 2581 indices=tensor.indices.numpy(), 2582 dense_shape=None 2583 if tensor.dense_shape is None else tensor.dense_shape.numpy()) 2584 else: 2585 if hasattr(tensor, "numpy") and callable(tensor.numpy): 2586 return tensor.numpy() 2587 else: 2588 # Try our best to convert CompositeTensor components to NumPy 2589 # arrays. Officially, we don't support NumPy arrays as 2590 # CompositeTensor components. So don't be surprised if this doesn't 2591 # work. 2592 return nest.map_structure(lambda t: t.numpy(), tensor, 2593 expand_composites=True) 2594 except AttributeError as e: 2595 raise ValueError(f"Unsupported type {type(tensor).__name__!r}.") from e 2596 2597 def _eval_helper(self, tensors): 2598 if tensors is None: 2599 return None 2600 return nest.map_structure(self._eval_tensor, tensors) 2601 2602 def evaluate(self, tensors): 2603 """Evaluates tensors and returns numpy values. 2604 2605 Args: 2606 tensors: A Tensor or a nested list/tuple of Tensors. 2607 2608 Returns: 2609 tensors numpy values. 2610 """ 2611 if context.executing_eagerly(): 2612 return self._eval_helper(tensors) 2613 else: 2614 sess = ops.get_default_session() 2615 if sess is None: 2616 with self.test_session() as sess: 2617 return sess.run(tensors) 2618 else: 2619 return sess.run(tensors) 2620 2621 # pylint: disable=g-doc-return-or-yield 2622 @contextlib.contextmanager 2623 def session(self, graph=None, config=None, use_gpu=True, force_gpu=False): 2624 """A context manager for a TensorFlow Session for use in executing tests. 2625 2626 Note that this will set this session and the graph as global defaults. 2627 2628 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2629 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2630 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2631 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2632 the CPU. 2633 2634 Example: 2635 2636 ``` python 2637 class MyOperatorTest(test_util.TensorFlowTestCase): 2638 def testMyOperator(self): 2639 with self.session(): 2640 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2641 result = MyOperator(valid_input).eval() 2642 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2643 invalid_input = [-1.0, 2.0, 7.0] 2644 with self.assertRaisesOpError("negative input not supported"): 2645 MyOperator(invalid_input).eval() 2646 ``` 2647 2648 Args: 2649 graph: Optional graph to use during the returned session. 2650 config: An optional config_pb2.ConfigProto to use to configure the 2651 session. 2652 use_gpu: If True, attempt to run as many ops as possible on GPU. 2653 force_gpu: If True, pin all ops to `/device:GPU:0`. 2654 2655 Yields: 2656 A Session object that should be used as a context manager to surround 2657 the graph building and execution code in a test case. 2658 """ 2659 if context.executing_eagerly(): 2660 yield EagerSessionWarner() 2661 else: 2662 with self._create_session(graph, config, force_gpu) as sess: 2663 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 2664 yield sess 2665 2666 @contextlib.contextmanager 2667 def cached_session(self, 2668 graph=None, 2669 config=None, 2670 use_gpu=True, 2671 force_gpu=False): 2672 """Returns a TensorFlow Session for use in executing tests. 2673 2674 This method behaves differently than self.session(): for performance reasons 2675 `cached_session` will by default reuse the same session within the same 2676 test. The session returned by this function will only be closed at the end 2677 of the test (in the TearDown function). 2678 2679 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2680 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2681 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2682 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2683 the CPU. 2684 2685 Example: 2686 ```python 2687 class MyOperatorTest(test_util.TensorFlowTestCase): 2688 def testMyOperator(self): 2689 with self.cached_session() as sess: 2690 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2691 result = MyOperator(valid_input).eval() 2692 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2693 invalid_input = [-1.0, 2.0, 7.0] 2694 with self.assertRaisesOpError("negative input not supported"): 2695 MyOperator(invalid_input).eval() 2696 ``` 2697 2698 Args: 2699 graph: Optional graph to use during the returned session. 2700 config: An optional config_pb2.ConfigProto to use to configure the 2701 session. 2702 use_gpu: If True, attempt to run as many ops as possible on GPU. 2703 force_gpu: If True, pin all ops to `/device:GPU:0`. 2704 2705 Yields: 2706 A Session object that should be used as a context manager to surround 2707 the graph building and execution code in a test case. 2708 """ 2709 if context.executing_eagerly(): 2710 yield FakeEagerSession(self) 2711 else: 2712 sess = self._get_cached_session( 2713 graph, config, force_gpu, crash_if_inconsistent_args=True) 2714 with self._constrain_devices_and_set_default(sess, use_gpu, 2715 force_gpu) as cached: 2716 yield cached 2717 2718 @contextlib.contextmanager 2719 @deprecation.deprecated(None, "Use `self.session()` or " 2720 "`self.cached_session()` instead.") 2721 def test_session(self, 2722 graph=None, 2723 config=None, 2724 use_gpu=True, 2725 force_gpu=False): 2726 """Use cached_session instead.""" 2727 if self.id().endswith(".test_session"): 2728 self.skipTest( 2729 "Tests that have the name \"test_session\" are automatically skipped " 2730 "by TensorFlow test fixture, as the name is reserved for creating " 2731 "sessions within tests. Please rename your test if you have a test " 2732 "with this name.") 2733 if context.executing_eagerly(): 2734 yield None 2735 else: 2736 if graph is None: 2737 sess = self._get_cached_session( 2738 graph, config, force_gpu, crash_if_inconsistent_args=False) 2739 with self._constrain_devices_and_set_default(sess, use_gpu, 2740 force_gpu) as cached: 2741 yield cached 2742 else: 2743 with self.session(graph, config, use_gpu, force_gpu) as sess: 2744 yield sess 2745 2746 # pylint: enable=g-doc-return-or-yield 2747 2748 class _CheckedThread(object): 2749 """A wrapper class for Thread that asserts successful completion. 2750 2751 This class should be created using the TensorFlowTestCase.checkedThread() 2752 method. 2753 """ 2754 2755 def __init__(self, testcase, target, args=None, kwargs=None): 2756 """Constructs a new instance of _CheckedThread. 2757 2758 Args: 2759 testcase: The TensorFlowTestCase for which this thread is being created. 2760 target: A callable object representing the code to be executed in the 2761 thread. 2762 args: A tuple of positional arguments that will be passed to target. 2763 kwargs: A dictionary of keyword arguments that will be passed to target. 2764 """ 2765 self._testcase = testcase 2766 self._target = target 2767 self._args = () if args is None else args 2768 self._kwargs = {} if kwargs is None else kwargs 2769 self._thread = threading.Thread(target=self._protected_run) 2770 self._exception = None 2771 2772 self._is_thread_joined = False 2773 2774 def _protected_run(self): 2775 """Target for the wrapper thread. Sets self._exception on failure.""" 2776 try: 2777 self._target(*self._args, **self._kwargs) 2778 except Exception as e: # pylint: disable=broad-except 2779 self._exception = e 2780 2781 def start(self): 2782 """Starts the thread's activity. 2783 2784 This must be called at most once per _CheckedThread object. It arranges 2785 for the object's target to be invoked in a separate thread of control. 2786 """ 2787 self._thread.start() 2788 2789 def join(self): 2790 """Blocks until the thread terminates. 2791 2792 Raises: 2793 self._testcase.failureException: If the thread terminates with due to 2794 an exception. 2795 """ 2796 self._is_thread_joined = True 2797 self._thread.join() 2798 if self._exception is not None: 2799 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 2800 2801 def is_alive(self): 2802 """Returns whether the thread is alive. 2803 2804 This method returns True just before the run() method starts 2805 until just after the run() method terminates. 2806 2807 Returns: 2808 True if the thread is alive, otherwise False. 2809 """ 2810 return self._thread.is_alive() 2811 2812 def check_termination(self): 2813 """Returns whether the checked thread was properly used and did terminate. 2814 2815 Every checked thread should be "join"ed after starting, and before the 2816 test tears down. If it is not joined, it is possible the thread will hang 2817 and cause flaky failures in tests. 2818 2819 Raises: 2820 self._testcase.failureException: If check_termination was called before 2821 thread was joined. 2822 2823 RuntimeError: If the thread is not terminated. This means thread was not 2824 joined with the main thread. 2825 """ 2826 if self._is_thread_joined: 2827 if self.is_alive(): 2828 raise RuntimeError( 2829 "Thread was not joined with main thread, and is still running " 2830 "when the test finished.") 2831 else: 2832 self._testcase.fail("A checked thread was not joined.") 2833 2834 def checkedThread(self, target, args=None, kwargs=None): 2835 """Returns a Thread wrapper that asserts 'target' completes successfully. 2836 2837 This method should be used to create all threads in test cases, as 2838 otherwise there is a risk that a thread will silently fail, and/or 2839 assertions made in the thread will not be respected. 2840 2841 Args: 2842 target: A callable object to be executed in the thread. 2843 args: The argument tuple for the target invocation. Defaults to (). 2844 kwargs: A dictionary of keyword arguments for the target invocation. 2845 Defaults to {}. 2846 2847 Returns: 2848 A wrapper for threading.Thread that supports start() and join() methods. 2849 """ 2850 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 2851 self._threads.append(ret) 2852 return ret 2853 2854 # pylint: enable=invalid-name 2855 @py_func_if_in_function 2856 def assertNear(self, f1, f2, err, msg=None): 2857 """Asserts that two floats are near each other. 2858 2859 Checks that |f1 - f2| < err and asserts a test failure 2860 if not. 2861 2862 Args: 2863 f1: A float value. 2864 f2: A float value. 2865 err: A float value. 2866 msg: An optional string message to append to the failure message. 2867 """ 2868 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 2869 self.assertTrue( 2870 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 2871 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 2872 2873 @py_func_if_in_function 2874 def assertArrayNear(self, farray1, farray2, err, msg=None): 2875 """Asserts that two float arrays are near each other. 2876 2877 Checks that for all elements of farray1 and farray2 2878 |f1 - f2| < err. Asserts a test failure if not. 2879 2880 Args: 2881 farray1: a list of float values. 2882 farray2: a list of float values. 2883 err: a float value. 2884 msg: Optional message to report on failure. 2885 """ 2886 self.assertEqual(len(farray1), len(farray2), msg=msg) 2887 for f1, f2 in zip(farray1, farray2): 2888 self.assertNear(float(f1), float(f2), err, msg=msg) 2889 2890 def _NDArrayNear(self, ndarray1, ndarray2, err): 2891 return np.linalg.norm(ndarray1 - ndarray2) < err 2892 2893 @py_func_if_in_function 2894 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 2895 """Asserts that two numpy arrays have near values. 2896 2897 Args: 2898 ndarray1: a numpy ndarray. 2899 ndarray2: a numpy ndarray. 2900 err: a float. The maximum absolute difference allowed. 2901 msg: Optional message to report on failure. 2902 """ 2903 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 2904 2905 def _GetNdArray(self, a): 2906 # If a is tensor-like then convert it to ndarray 2907 if tensor_util.is_tf_type(a): 2908 if isinstance(a, ops._EagerTensorBase): 2909 a = a.numpy() 2910 else: 2911 a = self.evaluate(a) 2912 if not isinstance(a, np.ndarray): 2913 return np.array(a) 2914 return a 2915 2916 def evaluate_if_both_tensors(self, a, b): 2917 if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and 2918 not isinstance(a, ops._EagerTensorBase) and 2919 not isinstance(b, ops._EagerTensorBase)): 2920 return self.evaluate((a, b)) 2921 else: 2922 return (a, b) 2923 2924 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2925 (a, b) = self.evaluate_if_both_tensors(a, b) 2926 a = self._GetNdArray(a) 2927 b = self._GetNdArray(b) 2928 # When the array rank is small, print its contents. Numpy array printing is 2929 # implemented using inefficient recursion so prints can cause tests to 2930 # time out. 2931 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 2932 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 2933 "%s.") % (a.shape, b.shape, b) 2934 else: 2935 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 2936 b.shape) 2937 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 2938 2939 msgs = [msg] 2940 # np.allclose does not always work for our custom bfloat16 extension type 2941 # when type promotions are involved, so we first cast any bfloat16 arrays 2942 # to float32. 2943 a_dtype = a.dtype 2944 a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a 2945 b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b 2946 if not np.allclose(a, b, rtol=rtol, atol=atol): 2947 # Adds more details to np.testing.assert_allclose. 2948 # 2949 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 2950 # checks whether two arrays are element-wise equal within a 2951 # tolerance. The relative difference (rtol * abs(b)) and the 2952 # absolute difference atol are added together to compare against 2953 # the absolute difference between a and b. Here, we want to 2954 # tell user which elements violate such conditions. 2955 cond = np.logical_or( 2956 np.abs(a - b) > atol + rtol * np.abs(b), 2957 np.isnan(a) != np.isnan(b)) 2958 if a.ndim: 2959 x = a[np.where(cond)] 2960 y = b[np.where(cond)] 2961 msgs.append("not close where = {}".format(np.where(cond))) 2962 else: 2963 # np.where is broken for scalars 2964 x, y = a, b 2965 msgs.append("not close lhs = {}".format(x)) 2966 msgs.append("not close rhs = {}".format(y)) 2967 msgs.append("not close dif = {}".format(np.abs(x - y))) 2968 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 2969 msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape)) 2970 # TODO(xpan): There seems to be a bug: 2971 # tensorflow/compiler/tests:binary_ops_test pass with float32 2972 # nan even though the equal_nan is False by default internally. 2973 np.testing.assert_allclose( 2974 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 2975 2976 def _assertAllCloseRecursive(self, 2977 a, 2978 b, 2979 rtol=1e-6, 2980 atol=1e-6, 2981 path=None, 2982 msg=None): 2983 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b): 2984 return self._assertRaggedClose(a, b, rtol, atol, msg) 2985 path = path or [] 2986 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "") 2987 msg = msg if msg else "" 2988 2989 # Check if a and/or b are namedtuples. 2990 if hasattr(a, "_asdict"): 2991 a = a._asdict() 2992 if hasattr(b, "_asdict"): 2993 b = b._asdict() 2994 a_is_dict = isinstance(a, collections_abc.Mapping) 2995 if a_is_dict != isinstance(b, collections_abc.Mapping): 2996 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 2997 (path_str, path_str, msg)) 2998 if a_is_dict: 2999 self.assertItemsEqual( 3000 a.keys(), 3001 b.keys(), 3002 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 3003 (path_str, a.keys(), path_str, b.keys(), msg)) 3004 for k in a: 3005 path.append(k) 3006 self._assertAllCloseRecursive( 3007 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 3008 del path[-1] 3009 elif isinstance(a, (list, tuple)): 3010 # Try to directly compare a, b as ndarrays; if not work, then traverse 3011 # through the sequence, which is more expensive. 3012 try: 3013 (a, b) = self.evaluate_if_both_tensors(a, b) 3014 a_as_ndarray = self._GetNdArray(a) 3015 b_as_ndarray = self._GetNdArray(b) 3016 self._assertArrayLikeAllClose( 3017 a_as_ndarray, 3018 b_as_ndarray, 3019 rtol=rtol, 3020 atol=atol, 3021 msg="Mismatched value: a%s is different from b%s. %s" % 3022 (path_str, path_str, msg)) 3023 except (ValueError, TypeError, NotImplementedError) as e: 3024 if len(a) != len(b): 3025 raise ValueError( 3026 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 3027 (path_str, len(a), path_str, len(b), msg)) 3028 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 3029 path.append(str(idx)) 3030 self._assertAllCloseRecursive( 3031 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 3032 del path[-1] 3033 # a and b are ndarray like objects 3034 else: 3035 try: 3036 self._assertArrayLikeAllClose( 3037 a, 3038 b, 3039 rtol=rtol, 3040 atol=atol, 3041 msg=("Mismatched value: a%s is different from b%s. %s" % 3042 (path_str, path_str, msg))) 3043 except TypeError as e: 3044 msg = ("Error: a%s has %s, but b%s has %s. %s" % 3045 (path_str, type(a), path_str, type(b), msg)) 3046 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 3047 raise 3048 3049 @py_func_if_in_function 3050 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 3051 """Asserts that two structures of numpy arrays or Tensors, have near values. 3052 3053 `a` and `b` can be arbitrarily nested structures. A layer of a nested 3054 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 3055 3056 Note: the implementation follows 3057 [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html) 3058 (and numpy.testing.assert_allclose). It checks whether two arrays are 3059 element-wise equal within a tolerance. The relative difference 3060 (`rtol * abs(b)`) and the absolute difference `atol` are added together 3061 to compare against the absolute difference between `a` and `b`. 3062 3063 Args: 3064 a: The expected numpy `ndarray`, or anything that can be converted into a 3065 numpy `ndarray` (including Tensor), or any arbitrarily nested of 3066 structure of these. 3067 b: The actual numpy `ndarray`, or anything that can be converted into a 3068 numpy `ndarray` (including Tensor), or any arbitrarily nested of 3069 structure of these. 3070 rtol: relative tolerance. 3071 atol: absolute tolerance. 3072 msg: Optional message to report on failure. 3073 3074 Raises: 3075 ValueError: if only one of `a[p]` and `b[p]` is a dict or 3076 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 3077 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 3078 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 3079 """ 3080 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 3081 3082 @py_func_if_in_function 3083 def assertAllCloseAccordingToType(self, 3084 a, 3085 b, 3086 rtol=1e-6, 3087 atol=1e-6, 3088 float_rtol=1e-6, 3089 float_atol=1e-6, 3090 half_rtol=1e-3, 3091 half_atol=1e-3, 3092 bfloat16_rtol=1e-2, 3093 bfloat16_atol=1e-2, 3094 msg=None): 3095 """Like assertAllClose, but also suitable for comparing fp16 arrays. 3096 3097 In particular, the tolerance is reduced to 1e-3 if at least 3098 one of the arguments is of type float16. 3099 3100 Args: 3101 a: the expected numpy ndarray or anything can be converted to one. 3102 b: the actual numpy ndarray or anything can be converted to one. 3103 rtol: relative tolerance. 3104 atol: absolute tolerance. 3105 float_rtol: relative tolerance for float32. 3106 float_atol: absolute tolerance for float32. 3107 half_rtol: relative tolerance for float16. 3108 half_atol: absolute tolerance for float16. 3109 bfloat16_rtol: relative tolerance for bfloat16. 3110 bfloat16_atol: absolute tolerance for bfloat16. 3111 msg: Optional message to report on failure. 3112 """ 3113 (a, b) = self.evaluate_if_both_tensors(a, b) 3114 a = self._GetNdArray(a) 3115 b = self._GetNdArray(b) 3116 # types with lower tol are put later to overwrite previous ones. 3117 if (a.dtype == np.float32 or b.dtype == np.float32 or 3118 a.dtype == np.complex64 or b.dtype == np.complex64): 3119 rtol = max(rtol, float_rtol) 3120 atol = max(atol, float_atol) 3121 if a.dtype == np.float16 or b.dtype == np.float16: 3122 rtol = max(rtol, half_rtol) 3123 atol = max(atol, half_atol) 3124 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 3125 b.dtype == dtypes.bfloat16.as_numpy_dtype): 3126 rtol = max(rtol, bfloat16_rtol) 3127 atol = max(atol, bfloat16_atol) 3128 3129 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 3130 3131 @py_func_if_in_function 3132 def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 3133 """Assert that two numpy arrays, or Tensors, do not have near values. 3134 3135 Args: 3136 a: The expected numpy `ndarray`, or anything that can be converted into a 3137 numpy `ndarray` (including Tensor), or any arbitrarily nested of 3138 structure of these. 3139 b: The actual numpy `ndarray`, or anything that can be converted into a 3140 numpy `ndarray` (including Tensor), or any arbitrarily nested of 3141 structure of these. 3142 rtol: relative tolerance. 3143 atol: absolute tolerance. 3144 msg: Optional message to report on failure. 3145 3146 Raises: 3147 AssertionError: If `a` and `b` are unexpectedly close at all elements. 3148 """ 3149 try: 3150 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 3151 except AssertionError: 3152 return 3153 msg = msg or "" 3154 raise AssertionError("The two values are close at all elements. %s" % msg) 3155 3156 @py_func_if_in_function 3157 def assertAllEqual(self, a, b, msg=None): 3158 """Asserts that two numpy arrays or Tensors have the same values. 3159 3160 Args: 3161 a: the expected numpy ndarray or anything can be converted to one. 3162 b: the actual numpy ndarray or anything can be converted to one. 3163 msg: Optional message to report on failure. 3164 """ 3165 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)): 3166 return self._assertRaggedEqual(a, b, msg) 3167 msg = msg if msg else "" 3168 (a, b) = self.evaluate_if_both_tensors(a, b) 3169 a = self._GetNdArray(a) 3170 b = self._GetNdArray(b) 3171 # Arbitrary bounds so that we don't print giant tensors. 3172 if (b.ndim <= 3 or b.size < 500): 3173 self.assertEqual( 3174 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 3175 " Contents: %r. \n%s." % (a.shape, b.shape, b, msg)) 3176 else: 3177 self.assertEqual( 3178 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 3179 " %s" % (a.shape, b.shape, msg)) 3180 3181 same = (a == b) 3182 3183 if (a.dtype in [ 3184 np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype 3185 ]): 3186 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 3187 msgs = [msg] 3188 if not np.all(same): 3189 # Adds more details to np.testing.assert_array_equal. 3190 diff = np.logical_not(same) 3191 if a.ndim: 3192 x = a[np.where(diff)] 3193 y = b[np.where(diff)] 3194 msgs.append("not equal where = {}".format(np.where(diff))) 3195 else: 3196 # np.where is broken for scalars 3197 x, y = a, b 3198 msgs.append("not equal lhs = %r" % x) 3199 msgs.append("not equal rhs = %r" % y) 3200 3201 if (a.dtype.kind != b.dtype.kind and 3202 {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})): 3203 a_list = [] 3204 b_list = [] 3205 # OK to flatten `a` and `b` because they are guaranteed to have the 3206 # same shape. 3207 for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]: 3208 for item in flat_arr: 3209 if isinstance(item, str): 3210 out_list.append(item.encode("utf-8")) 3211 else: 3212 out_list.append(item) 3213 a = np.array(a_list) 3214 b = np.array(b_list) 3215 3216 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 3217 3218 @py_func_if_in_function 3219 def assertNotAllEqual(self, a, b, msg=None): 3220 """Asserts that two numpy arrays or Tensors do not have the same values. 3221 3222 Args: 3223 a: the expected numpy ndarray or anything can be converted to one. 3224 b: the actual numpy ndarray or anything can be converted to one. 3225 msg: Optional message to report on failure. 3226 """ 3227 try: 3228 self.assertAllEqual(a, b) 3229 except AssertionError: 3230 return 3231 msg = msg or "" 3232 raise AssertionError("The two values are equal at all elements. %s" % msg) 3233 3234 @py_func_if_in_function 3235 def assertAllGreater(self, a, comparison_target): 3236 """Assert element values are all greater than a target value. 3237 3238 Args: 3239 a: The numpy `ndarray`, or anything that can be converted into a numpy 3240 `ndarray` (including Tensor). 3241 comparison_target: The target value of comparison. 3242 """ 3243 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3244 a = self._GetNdArray(a) 3245 self.assertGreater(np.min(a), comparison_target) 3246 3247 @py_func_if_in_function 3248 def assertAllLess(self, a, comparison_target): 3249 """Assert element values are all less than a target value. 3250 3251 Args: 3252 a: The numpy `ndarray`, or anything that can be converted into a numpy 3253 `ndarray` (including Tensor). 3254 comparison_target: The target value of comparison. 3255 """ 3256 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3257 a = self._GetNdArray(a) 3258 self.assertLess(np.max(a), comparison_target) 3259 3260 @py_func_if_in_function 3261 def assertAllGreaterEqual(self, a, comparison_target): 3262 """Assert element values are all greater than or equal to a target value. 3263 3264 Args: 3265 a: The numpy `ndarray`, or anything that can be converted into a numpy 3266 `ndarray` (including Tensor). 3267 comparison_target: The target value of comparison. 3268 """ 3269 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3270 a = self._GetNdArray(a) 3271 self.assertGreaterEqual(np.min(a), comparison_target) 3272 3273 @py_func_if_in_function 3274 def assertAllLessEqual(self, a, comparison_target): 3275 """Assert element values are all less than or equal to a target value. 3276 3277 Args: 3278 a: The numpy `ndarray`, or anything that can be converted into a numpy 3279 `ndarray` (including Tensor). 3280 comparison_target: The target value of comparison. 3281 """ 3282 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3283 a = self._GetNdArray(a) 3284 self.assertLessEqual(np.max(a), comparison_target) 3285 3286 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 3287 """Generate a summary of ndarray subscripts as a list of str. 3288 3289 If limit == N, this method will print up to the first N subscripts on 3290 separate 3291 lines. A line of ellipses (...) will be appended at the end if the number of 3292 subscripts exceeds N. 3293 3294 Args: 3295 subscripts: The tensor (np.ndarray) subscripts, of the same format as 3296 np.where()'s return value, i.e., a tuple of arrays with each array 3297 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 3298 value: (np.ndarray) value of the tensor. 3299 limit: (int) The maximum number of indices to print. 3300 indent: (int) Number of characters to indent at the beginning of each 3301 line. 3302 3303 Returns: 3304 (list of str) the multi-line representation of the subscripts and values, 3305 potentially with omission at the end. 3306 """ 3307 lines = [] 3308 subscripts = np.transpose(subscripts) 3309 prefix = " " * indent 3310 if np.ndim(value) == 0: 3311 return [prefix + "[0] : " + str(value)] 3312 for subscript in itertools.islice(subscripts, limit): 3313 lines.append(prefix + str(subscript) + " : " + 3314 str(value[tuple(subscript)])) 3315 if len(subscripts) > limit: 3316 lines.append(prefix + "...") 3317 return lines 3318 3319 @py_func_if_in_function 3320 def assertAllInRange(self, 3321 target, 3322 lower_bound, 3323 upper_bound, 3324 open_lower_bound=False, 3325 open_upper_bound=False): 3326 """Assert that elements in a Tensor are all in a given range. 3327 3328 Args: 3329 target: The numpy `ndarray`, or anything that can be converted into a 3330 numpy `ndarray` (including Tensor). 3331 lower_bound: lower bound of the range 3332 upper_bound: upper bound of the range 3333 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 3334 than the default >=) 3335 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 3336 than the default <=) 3337 3338 Raises: 3339 AssertionError: 3340 if the value tensor does not have an ordered numeric type (float* or 3341 int*), or 3342 if there are nan values, or 3343 if any of the elements do not fall in the specified range. 3344 """ 3345 target = self._GetNdArray(target) 3346 if not (np.issubdtype(target.dtype, np.floating) or 3347 np.issubdtype(target.dtype, np.integer)): 3348 raise AssertionError( 3349 "The value of %s does not have an ordered numeric type, instead it " 3350 "has type: %s" % (target, target.dtype)) 3351 3352 nan_subscripts = np.where(np.isnan(target)) 3353 if np.size(nan_subscripts): 3354 raise AssertionError( 3355 "%d of the %d element(s) are NaN. " 3356 "Subscripts(s) and value(s) of the NaN element(s):\n" % 3357 (len(nan_subscripts[0]), np.size(target)) + 3358 "\n".join(self._format_subscripts(nan_subscripts, target))) 3359 3360 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 3361 str(upper_bound) + (")" if open_upper_bound else "]")) 3362 3363 violations = ( 3364 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 3365 target, lower_bound)) 3366 violations = np.logical_or( 3367 violations, 3368 np.greater_equal(target, upper_bound) 3369 if open_upper_bound else np.greater(target, upper_bound)) 3370 violation_subscripts = np.where(violations) 3371 if np.size(violation_subscripts): 3372 raise AssertionError( 3373 "%d of the %d element(s) are outside the range %s. " % 3374 (len(violation_subscripts[0]), np.size(target), range_str) + 3375 "Subscript(s) and value(s) of the offending elements:\n" + 3376 "\n".join(self._format_subscripts(violation_subscripts, target))) 3377 3378 @py_func_if_in_function 3379 def assertAllInSet(self, target, expected_set): 3380 """Assert that elements of a Tensor are all in a given closed set. 3381 3382 Args: 3383 target: The numpy `ndarray`, or anything that can be converted into a 3384 numpy `ndarray` (including Tensor). 3385 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 3386 of the value of `target` are expected to fall into. 3387 3388 Raises: 3389 AssertionError: 3390 if any of the elements do not fall into `expected_set`. 3391 """ 3392 target = self._GetNdArray(target) 3393 3394 # Elements in target that are not in expected_set. 3395 diff = np.setdiff1d(target.flatten(), list(expected_set)) 3396 if np.size(diff): 3397 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 3398 (np.size(diff), expected_set, diff)) 3399 3400 @py_func_if_in_function 3401 def assertDTypeEqual(self, target, expected_dtype): 3402 """Assert ndarray data type is equal to expected. 3403 3404 Args: 3405 target: The numpy `ndarray`, or anything that can be converted into a 3406 numpy `ndarray` (including Tensor). 3407 expected_dtype: Expected data type. 3408 """ 3409 target = self._GetNdArray(target) 3410 if not isinstance(target, list): 3411 arrays = [target] 3412 for arr in arrays: 3413 self.assertEqual(arr.dtype, expected_dtype) 3414 3415 # pylint: disable=g-doc-return-or-yield 3416 @contextlib.contextmanager 3417 def assertRaisesWithPredicateMatch(self, exception_type, 3418 expected_err_re_or_predicate): 3419 """Returns a context manager to enclose code expected to raise an exception. 3420 3421 If the exception is an OpError, the op stack is also included in the message 3422 predicate search. 3423 3424 Args: 3425 exception_type: The expected type of exception that should be raised. 3426 expected_err_re_or_predicate: If this is callable, it should be a function 3427 of one argument that inspects the passed-in exception and returns True 3428 (success) or False (please fail the test). Otherwise, the error message 3429 is expected to match this regular expression partially. 3430 3431 Returns: 3432 A context manager to surround code that is expected to raise an 3433 exception. 3434 """ 3435 if callable(expected_err_re_or_predicate): 3436 predicate = expected_err_re_or_predicate 3437 else: 3438 3439 def predicate(e): 3440 err_str = e.message if isinstance(e, errors.OpError) else str(e) 3441 op = e.op if isinstance(e, errors.OpError) else None 3442 while op is not None: 3443 err_str += "\nCaused by: " + op.name 3444 op = op._original_op # pylint: disable=protected-access 3445 logging.info("Searching within error strings: '%s' within '%s'", 3446 expected_err_re_or_predicate, err_str) 3447 return re.search(expected_err_re_or_predicate, err_str) 3448 3449 try: 3450 yield 3451 self.fail(exception_type.__name__ + " not raised") 3452 except Exception as e: # pylint: disable=broad-except 3453 if not isinstance(e, exception_type) or not predicate(e): 3454 raise AssertionError("Exception of type %s: %s" % 3455 (str(type(e)), str(e))) 3456 3457 # pylint: enable=g-doc-return-or-yield 3458 3459 def assertRaisesOpError(self, expected_err_re_or_predicate): 3460 return self.assertRaisesWithPredicateMatch(errors.OpError, 3461 expected_err_re_or_predicate) 3462 3463 def assertRaisesIncompatibleShapesError( 3464 self, exception_type=errors.InvalidArgumentError): 3465 return self.assertRaisesWithPredicateMatch( 3466 exception_type, r"Incompatible shapes|Dimensions must be equal|" 3467 r"required broadcastable shapes") 3468 3469 def assertShapeEqual(self, input_a, input_b, msg=None): 3470 """Asserts that two Numpy or TensorFlow objects have the same shape. 3471 3472 For Tensors, this compares statically known shapes at compile time, not 3473 dynamic shapes at runtime. 3474 3475 Args: 3476 input_a: A Numpy ndarray, Numpy scalar, or a Tensor. 3477 input_b: A Numpy ndarray, Numpy scalar, or a Tensor. 3478 msg: Optional message to report on failure. 3479 3480 Raises: 3481 TypeError: If the arguments have the wrong type. 3482 """ 3483 if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)): 3484 raise TypeError( 3485 "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor." 3486 f"Instead received {type(input_a)}") 3487 if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)): 3488 raise TypeError( 3489 "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor." 3490 f"Instead received {type(input_b)}") 3491 shape_a = input_a.get_shape().as_list() if isinstance( 3492 input_a, ops.Tensor) else input_a.shape 3493 shape_b = input_b.get_shape().as_list() if isinstance( 3494 input_b, ops.Tensor) else input_b.shape 3495 self.assertAllEqual(shape_a, shape_b, msg=msg) 3496 3497 def assertDeviceEqual(self, device1, device2, msg=None): 3498 """Asserts that the two given devices are the same. 3499 3500 Args: 3501 device1: A string device name or TensorFlow `DeviceSpec` object. 3502 device2: A string device name or TensorFlow `DeviceSpec` object. 3503 msg: Optional message to report on failure. 3504 """ 3505 device1 = pydev.canonical_name(device1) 3506 device2 = pydev.canonical_name(device2) 3507 self.assertEqual( 3508 device1, device2, 3509 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 3510 3511 @py_func_if_in_function 3512 def assertDictEqual(self, a, b, msg=None): 3513 """Assert that two given dictionary of tensors are the same. 3514 3515 Args: 3516 a: Expected dictionary with numpy ndarray or anything else that can be 3517 converted to one as values. 3518 b: Actual dictionary with numpy ndarray or anything else that can be 3519 converted to one as values. 3520 msg: Optional message to report on failure. 3521 """ 3522 # To keep backwards compatibility, we first try the base class 3523 # assertDictEqual. If that fails we try the tensorflow one. 3524 try: 3525 super().assertDictEqual(a, b, msg) 3526 except Exception: # pylint: disable=broad-except 3527 self.assertSameElements(a.keys(), b.keys()) # pylint: disable=g-assert-in-except 3528 for k, v in a.items(): 3529 (a_k, b_k) = self.evaluate_if_both_tensors(v, b[k]) 3530 a_k = self._GetNdArray(a_k) 3531 b_k = self._GetNdArray(b_k) 3532 if np.issubdtype(a_k.dtype, np.floating): 3533 self.assertAllClose(v, b[k], msg=k) 3534 else: 3535 self.assertAllEqual(v, b[k], msg=k) 3536 3537 def _GetPyList(self, a): 3538 """Converts `a` to a nested python list.""" 3539 if isinstance(a, ragged_tensor.RaggedTensor): 3540 return self.evaluate(a).to_list() 3541 elif isinstance(a, ops.Tensor): 3542 a = self.evaluate(a) 3543 return a.tolist() if isinstance(a, np.ndarray) else a 3544 elif isinstance(a, np.ndarray): 3545 return a.tolist() 3546 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 3547 return a.to_list() 3548 else: 3549 return np.array(a).tolist() 3550 3551 def _assertRaggedEqual(self, a, b, msg): 3552 """Asserts that two ragged tensors are equal.""" 3553 a_list = self._GetPyList(a) 3554 b_list = self._GetPyList(b) 3555 self.assertEqual(a_list, b_list, msg) 3556 3557 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3558 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3559 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3560 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3561 3562 def _assertRaggedClose(self, a, b, rtol, atol, msg=None): 3563 a_list = self._GetPyList(a) 3564 b_list = self._GetPyList(b) 3565 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg) 3566 3567 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3568 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3569 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3570 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3571 3572 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): 3573 self.assertEqual(type(a), type(b)) 3574 if isinstance(a, (list, tuple)): 3575 self.assertLen(a, len(b), "Length differs for %s" % path) 3576 for i in range(len(a)): 3577 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg, 3578 "%s[%s]" % (path, i)) 3579 else: 3580 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) 3581 3582 # Fix Python 3+ compatibility issues 3583 # pylint: disable=invalid-name 3584 3585 # Silence a deprecation warning 3586 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 3587 3588 # assertItemsEqual is assertCountEqual as of 3.2. 3589 assertItemsEqual = googletest.TestCase.assertCountEqual 3590 3591 # pylint: enable=invalid-name 3592 3593 @contextlib.contextmanager 3594 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 3595 """Set the session and its graph to global default and constrain devices.""" 3596 if context.executing_eagerly(): 3597 yield None 3598 else: 3599 with sess.graph.as_default(), sess.as_default(): 3600 if force_gpu: 3601 # Use the name of an actual device if one is detected, or 3602 # '/device:GPU:0' otherwise 3603 gpu_name = gpu_device_name() 3604 if not gpu_name: 3605 gpu_name = "/device:GPU:0" 3606 with sess.graph.device(gpu_name): 3607 yield sess 3608 elif use_gpu: 3609 yield sess 3610 else: 3611 with sess.graph.device("/device:CPU:0"): 3612 yield sess 3613 3614 def _create_session(self, graph, config, force_gpu): 3615 """See session() for details.""" 3616 3617 def prepare_config(config): 3618 """Returns a config for sessions. 3619 3620 Args: 3621 config: An optional config_pb2.ConfigProto to use to configure the 3622 session. 3623 3624 Returns: 3625 A config_pb2.ConfigProto object. 3626 """ 3627 # TODO(b/114333779): Enforce allow_soft_placement=False when 3628 # use_gpu=False. Currently many tests rely on the fact that any device 3629 # will be used even when a specific device is supposed to be used. 3630 allow_soft_placement = not force_gpu 3631 if config is None: 3632 config = context.context().config 3633 config.allow_soft_placement = allow_soft_placement 3634 elif not allow_soft_placement and config.allow_soft_placement: 3635 config_copy = context.context().config 3636 config = config_copy 3637 config.allow_soft_placement = False 3638 # Don't perform optimizations for tests so we don't inadvertently run 3639 # gpu ops on cpu 3640 config.graph_options.optimizer_options.opt_level = -1 3641 # Disable Grappler constant folding since some tests & benchmarks 3642 # use constant input and become meaningless after constant folding. 3643 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 3644 # GRAPPLER TEAM. 3645 config.graph_options.rewrite_options.constant_folding = ( 3646 rewriter_config_pb2.RewriterConfig.OFF) 3647 config.graph_options.rewrite_options.pin_to_host_optimization = ( 3648 rewriter_config_pb2.RewriterConfig.OFF) 3649 return config 3650 3651 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 3652 3653 def _get_cached_session(self, 3654 graph=None, 3655 config=None, 3656 force_gpu=False, 3657 crash_if_inconsistent_args=True): 3658 """See cached_session() for documentation.""" 3659 if self._cached_session is None: 3660 sess = self._create_session( 3661 graph=graph, config=config, force_gpu=force_gpu) 3662 self._cached_session = sess 3663 self._cached_graph = graph 3664 self._cached_config = config 3665 self._cached_force_gpu = force_gpu 3666 return sess 3667 else: 3668 if crash_if_inconsistent_args and self._cached_graph is not graph: 3669 raise ValueError("The graph used to get the cached session is " 3670 "different than the one that was used to create the " 3671 "session. Maybe create a new session with " 3672 "self.session()") 3673 if crash_if_inconsistent_args and self._cached_config is not config: 3674 raise ValueError("The config used to get the cached session is " 3675 "different than the one that was used to create the " 3676 "session. Maybe create a new session with " 3677 "self.session()") 3678 if crash_if_inconsistent_args and (self._cached_force_gpu is 3679 not force_gpu): 3680 raise ValueError( 3681 "The force_gpu value used to get the cached session is " 3682 "different than the one that was used to create the " 3683 "session. Maybe create a new session with " 3684 "self.session()") 3685 return self._cached_session 3686 3687 3688ASSIGNED_PORTS = set() 3689lock = threading.Lock() 3690 3691 3692def pick_unused_port(): 3693 """Returns an unused and unassigned local port.""" 3694 import portpicker # pylint: disable=g-import-not-at-top 3695 3696 global ASSIGNED_PORTS 3697 with lock: 3698 while True: 3699 try: 3700 port = portpicker.pick_unused_port() 3701 except portpicker.NoFreePortFoundError as porterror: 3702 raise unittest.SkipTest("Flakes in portpicker library do not represent" 3703 " TensorFlow errors.") from porterror 3704 if port > 10000 and port not in ASSIGNED_PORTS: 3705 ASSIGNED_PORTS.add(port) 3706 logging.info("Using local port %r", port) 3707 return port 3708 3709 3710@tf_export("test.create_local_cluster") 3711def create_local_cluster(num_workers, 3712 num_ps, 3713 protocol="grpc", 3714 worker_config=None, 3715 ps_config=None): 3716 """Create and start local servers and return the associated `Server` objects. 3717 3718 "PS" stands for "parameter server": a task responsible for storing and 3719 updating the model's parameters. Other tasks send updates to these parameters 3720 as they work on optimizing the parameters. This particular division of labor 3721 between tasks is not required, but is common for distributed training. 3722 3723 Read more at https://www.tensorflow.org/guide/extend/architecture 3724 3725 ![components](https://www.tensorflow.org/images/diag1.svg "components") 3726 3727 3728 Figure illustrates the interaction of these components. 3729 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services. 3730 3731 3732 Example: 3733 ```python 3734 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 3735 3736 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers] 3737 3738 with tf.device("/job:ps/task:0"): 3739 ... 3740 with tf.device("/job:ps/task:1"): 3741 ... 3742 with tf.device("/job:worker/task:0"): 3743 ... 3744 with tf.device("/job:worker/task:1"): 3745 ... 3746 3747 worker_sessions[0].run(...) 3748 ``` 3749 3750 Args: 3751 num_workers: Number of worker servers to start. 3752 num_ps: Number of PS servers to start. 3753 protocol: Communication protocol. Allowed values are documented in the 3754 documentation of `tf.distribute.Server`. 3755 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be 3756 used to instantiate multiple devices etc. 3757 ps_config: (optional) `tf.ConfigProto` to initialize PS servers. 3758 3759 Returns: 3760 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 3761 of `num_workers` objects of type `tf.distribute.Server` (all running 3762 locally); 3763 and `ps_servers` is a list of `num_ps` objects of similar type. 3764 3765 Raises: 3766 ImportError: if portpicker module was not found at load time 3767 """ 3768 worker_ports = [pick_unused_port() for _ in range(num_workers)] 3769 ps_ports = [pick_unused_port() for _ in range(num_ps)] 3770 cluster_dict = { 3771 "worker": ["localhost:%s" % port for port in worker_ports], 3772 "ps": ["localhost:%s" % port for port in ps_ports] 3773 } 3774 cs = server_lib.ClusterSpec(cluster_dict) 3775 3776 workers = [ 3777 server_lib.Server( 3778 cs, 3779 job_name="worker", 3780 protocol=protocol, 3781 task_index=ix, 3782 config=worker_config, 3783 start=True) for ix in range(num_workers) 3784 ] 3785 ps_servers = [ 3786 server_lib.Server( 3787 cs, 3788 job_name="ps", 3789 protocol=protocol, 3790 task_index=ix, 3791 config=ps_config, 3792 start=True) for ix in range(num_ps) 3793 ] 3794 3795 return workers, ps_servers 3796 3797 3798def get_node_def_from_graph(node_name, graph_def): 3799 """Returns the `NodeDef` instance for given node name in the graph def. 3800 3801 This method explores only the NodeDefs in `graph_def.node`. 3802 3803 Args: 3804 node_name: Name of the NodeDef to search for. 3805 graph_def: An instance of `GraphDef` proto. 3806 3807 Returns: 3808 the `NodeDef` instance whose name field matches the given node_name or None. 3809 """ 3810 for node_def in graph_def.node: 3811 if node_def.name == node_name: 3812 return node_def 3813 return None 3814 3815 3816def set_producer_version(graph, producer_version): 3817 """Sets graph.graph_def_versions.producer to `producer_version`.""" 3818 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 3819 # it via import_graph_def though. 3820 graph_def = graph_pb2.GraphDef() 3821 graph_def.versions.producer = producer_version 3822 with graph.as_default(): 3823 importer.import_graph_def(graph_def) 3824 assert graph.graph_def_versions.producer, producer_version 3825 3826 3827@contextlib.contextmanager 3828def _fake_gradient_tape_context_manager(): 3829 """tf.gradients(...) implemented as tf.GradientTape context manager interface. 3830 3831 This is useful to test tf.gradients() in tests that uses tf.GradientTape(). 3832 3833 Yields: 3834 gradient tape instance that's implemented by tf.gradients() underneath. 3835 """ 3836 try: 3837 class FakeGradientTape: 3838 3839 def watch(self, x): 3840 pass 3841 3842 def gradient(self, y, x, grad_ys=None): 3843 result = gradients_impl.gradients(y, x, grad_ys) 3844 3845 # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single 3846 # element. So unpack if needed to match `tape.gradient()` behavior. 3847 if not isinstance(x, (list, tuple)): 3848 assert len(result) == 1 3849 return result[0] 3850 3851 return result 3852 3853 yield FakeGradientTape() 3854 finally: 3855 pass 3856 3857 3858class AbstractGradientTape: 3859 """Abstract GradientTape context manager that has multiple implementations. 3860 3861 This is useful to test both tf.GradientTape() and tf.gradients() without 3862 duplicating tests. 3863 """ 3864 3865 def __init__(self, use_tape, persistent=False): 3866 self._use_tape = use_tape 3867 self._persistent = persistent 3868 3869 def __enter__(self): 3870 if self._use_tape: 3871 self._tape_impl = backprop.GradientTape(persistent=self._persistent) 3872 else: 3873 self._tape_impl = _fake_gradient_tape_context_manager() 3874 return self._tape_impl.__enter__() 3875 3876 def __exit__(self, exc_type, exc_val, exc_tb): 3877 self._tape_impl.__exit__(exc_type, exc_val, exc_tb) 3878 3879 3880@contextlib.contextmanager 3881def run_functions_eagerly(run_eagerly): 3882 """Runs functions eagerly if `run_eagerly` is true. 3883 3884 WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode 3885 *WILL NOT* make the tf.function to run eagerly because eager is disabled by 3886 default in V1. Instead, tf.function will run as a traced graph function. 3887 3888 Ensures that the state (for running functions eagerly) is back to the initial 3889 `def_function.RUN_FUNCTIONS_EAGERLY` state. 3890 3891 Args: 3892 run_eagerly: Boolean determining whether to run the function eagerly or not. 3893 3894 Raises: 3895 ValueError if `run_eagerly` is not a boolean. 3896 3897 Yields: 3898 Nothing. 3899 """ 3900 if not isinstance(run_eagerly, bool): 3901 raise ValueError( 3902 "Expected bool for `run_eagerly` but got {}".format(run_eagerly)) 3903 3904 is_eager = context.executing_eagerly() 3905 if not is_eager and run_eagerly: 3906 logging.warning( 3907 "Running tf.function eagerly in V1 graph mode is not supported. " 3908 "tf.function will be run as a traced graph function.") 3909 3910 initial_state = def_function.functions_run_eagerly() 3911 def_function.run_functions_eagerly(run_eagerly) 3912 try: 3913 yield 3914 finally: 3915 def_function.run_functions_eagerly(initial_state) 3916 3917 3918class TestDelta: 3919 """A utility class to track increments to test counters.""" 3920 3921 def __init__(self, name, label): 3922 self.name = name 3923 self.label = label 3924 self.Reset() 3925 3926 def Reset(self): 3927 self.last_value = _test_metrics_util.test_counter_value( 3928 self.name, self.label) 3929 3930 def Get(self): 3931 value = _test_metrics_util.test_counter_value(self.name, self.label) 3932 return value - self.last_value 3933