1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23from collections import OrderedDict 24import contextlib 25import functools 26import gc 27import itertools 28import math 29import os 30import random 31import re 32import tempfile 33import threading 34import time 35import unittest 36 37from absl.testing import parameterized 38import numpy as np 39import six 40 41from google.protobuf import descriptor_pool 42from google.protobuf import text_format 43 44from tensorflow.core.framework import graph_pb2 45from tensorflow.core.protobuf import rewriter_config_pb2 46from tensorflow.python import pywrap_sanitizers 47from tensorflow.python import tf2 48from tensorflow.python.client import device_lib 49from tensorflow.python.client import pywrap_tf_session 50from tensorflow.python.client import session 51from tensorflow.python.compat.compat import forward_compatibility_horizon 52from tensorflow.python.eager import backprop 53from tensorflow.python.eager import context 54from tensorflow.python.eager import def_function 55from tensorflow.python.eager import tape 56from tensorflow.python.framework import config 57from tensorflow.python.framework import device as pydev 58from tensorflow.python.framework import dtypes 59from tensorflow.python.framework import errors 60from tensorflow.python.framework import errors_impl 61from tensorflow.python.framework import gpu_util 62from tensorflow.python.framework import importer 63from tensorflow.python.framework import ops 64from tensorflow.python.framework import random_seed 65from tensorflow.python.framework import sparse_tensor 66from tensorflow.python.framework import tensor_shape 67from tensorflow.python.framework import tensor_util 68from tensorflow.python.framework import tfrt_utils 69from tensorflow.python.framework import versions 70from tensorflow.python.ops import array_ops 71from tensorflow.python.ops import control_flow_util 72from tensorflow.python.ops import control_flow_util_v2 73from tensorflow.python.ops import gradients_impl 74from tensorflow.python.ops import math_ops 75from tensorflow.python.ops import script_ops 76from tensorflow.python.ops import summary_ops_v2 77from tensorflow.python.ops import variables 78from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import 79from tensorflow.python.ops.ragged import ragged_tensor 80from tensorflow.python.ops.ragged import ragged_tensor_value 81from tensorflow.python.platform import _pywrap_stacktrace_handler 82from tensorflow.python.platform import googletest 83from tensorflow.python.platform import tf_logging as logging 84from tensorflow.python.training import server_lib 85from tensorflow.python.util import _pywrap_util_port 86from tensorflow.python.util import compat 87from tensorflow.python.util import deprecation 88from tensorflow.python.util import nest 89from tensorflow.python.util import tf_decorator 90from tensorflow.python.util import tf_inspect 91from tensorflow.python.util import traceback_utils 92from tensorflow.python.util.compat import collections_abc 93from tensorflow.python.util.protobuf import compare 94from tensorflow.python.util.tf_export import tf_export 95 96 97# If the below import is made available through the BUILD rule, then this 98# function is overridden and will instead return True and cause Tensorflow 99# graphs to be compiled with XLA. 100def is_xla_enabled(): 101 return False 102 103 104try: 105 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import 106except Exception: # pylint: disable=broad-except 107 pass 108 109 110# Uses the same mechanism as above to selectively enable/disable MLIR 111# compilation. 112def is_mlir_bridge_enabled(): 113 return None 114 115 116try: 117 from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 118except ImportError: 119 try: 120 from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 121 except ImportError: 122 pass 123 124 125def is_asan_enabled(): 126 """Check if ASAN is enabled.""" 127 return pywrap_sanitizers.is_asan_enabled() 128 129 130def is_msan_enabled(): 131 """Check if MSAN is enabled.""" 132 return pywrap_sanitizers.is_msan_enabled() 133 134 135def is_tsan_enabled(): 136 """Check if TSAN is enabled.""" 137 return pywrap_sanitizers.is_tsan_enabled() 138 139 140def is_ubsan_enabled(): 141 """Check if UBSAN is enabled.""" 142 return pywrap_sanitizers.is_ubsan_enabled() 143 144 145def _get_object_count_by_type(exclude=()): 146 return ( 147 collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - 148 collections.Counter([type(obj).__name__ for obj in exclude])) 149 150 151@tf_export("test.gpu_device_name") 152def gpu_device_name(): 153 """Returns the name of a GPU device if available or a empty string. 154 155 This method should only be used in tests written with `tf.test.TestCase`. 156 157 >>> class MyTest(tf.test.TestCase): 158 ... 159 ... def test_add_on_gpu(self): 160 ... if not tf.test.is_built_with_gpu_support(): 161 ... self.skipTest("test is only applicable on GPU") 162 ... 163 ... with tf.device(tf.test.gpu_device_name()): 164 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 165 166 """ 167 for x in device_lib.list_local_devices(): 168 if x.device_type == "GPU": 169 return compat.as_str(x.name) 170 return "" 171 172 173def assert_ops_in_graph(expected_ops, graph): 174 """Assert all expected operations are found. 175 176 Args: 177 expected_ops: `dict<string, string>` of op name to op type. 178 graph: Graph to check. 179 180 Returns: 181 `dict<string, node>` of node name to node. 182 183 Raises: 184 ValueError: If the expected ops are not present in the graph. 185 """ 186 actual_ops = {} 187 gd = graph.as_graph_def() 188 for node in gd.node: 189 if node.name in expected_ops: 190 if expected_ops[node.name] != node.op: 191 raise ValueError("Expected op for node %s is different. %s vs %s" % 192 (node.name, expected_ops[node.name], node.op)) 193 actual_ops[node.name] = node 194 if set(expected_ops.keys()) != set(actual_ops.keys()): 195 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 196 (expected_ops.keys(), actual_ops.keys())) 197 return actual_ops 198 199 200@tf_export("test.assert_equal_graph_def", v1=[]) 201def assert_equal_graph_def_v2(expected, actual): 202 """Asserts that two `GraphDef`s are (mostly) the same. 203 204 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 205 nodes, attrs, and control inputs. Node names are used to match up nodes 206 between the graphs, so the naming of nodes must be consistent. This function 207 ignores randomized attribute values that may appear in V2 checkpoints. 208 209 Args: 210 expected: The `GraphDef` we expected. 211 actual: The `GraphDef` we have. 212 213 Raises: 214 AssertionError: If the `GraphDef`s do not match. 215 TypeError: If either argument is not a `GraphDef`. 216 """ 217 assert_equal_graph_def(actual, expected, checkpoint_v2=True, 218 hash_table_shared_name=True) 219 220 221@tf_export(v1=["test.assert_equal_graph_def"]) 222def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, 223 hash_table_shared_name=False): 224 """Asserts that two `GraphDef`s are (mostly) the same. 225 226 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 227 nodes, attrs, and control inputs. Node names are used to match up nodes 228 between the graphs, so the naming of nodes must be consistent. 229 230 Args: 231 actual: The `GraphDef` we have. 232 expected: The `GraphDef` we expected. 233 checkpoint_v2: boolean determining whether to ignore randomized attribute 234 values that appear in V2 checkpoints. 235 hash_table_shared_name: boolean determining whether to ignore randomized 236 shared_names that appear in HashTableV2 op defs. 237 238 Raises: 239 AssertionError: If the `GraphDef`s do not match. 240 TypeError: If either argument is not a `GraphDef`. 241 """ 242 assert_equal_graph_def(actual, expected, checkpoint_v2, 243 hash_table_shared_name) 244 245 246def assert_equal_graph_def(actual, expected, checkpoint_v2=False, 247 hash_table_shared_name=False): 248 if not isinstance(actual, graph_pb2.GraphDef): 249 raise TypeError("Expected tf.GraphDef for actual, got %s" % 250 type(actual).__name__) 251 if not isinstance(expected, graph_pb2.GraphDef): 252 raise TypeError("Expected tf.GraphDef for expected, got %s" % 253 type(expected).__name__) 254 255 if checkpoint_v2: 256 _strip_checkpoint_v2_randomized(actual) 257 _strip_checkpoint_v2_randomized(expected) 258 259 if hash_table_shared_name: 260 _strip_hash_table_shared_name(actual) 261 _strip_hash_table_shared_name(expected) 262 263 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(), 264 expected.SerializeToString()) 265 if diff: 266 raise AssertionError(compat.as_str(diff)) 267 268 269def assert_meta_graph_protos_equal(tester, a, b): 270 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 271 # Carefully check the collection_defs 272 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 273 collection_keys = a.collection_def.keys() 274 for k in collection_keys: 275 a_value = a.collection_def[k] 276 b_value = b.collection_def[k] 277 proto_type = ops.get_collection_proto_type(k) 278 if proto_type: 279 a_proto = proto_type() 280 b_proto = proto_type() 281 # Number of entries in the collections is the same 282 tester.assertEqual( 283 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 284 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 285 b_value.bytes_list.value): 286 a_proto.ParseFromString(a_value_item) 287 b_proto.ParseFromString(b_value_item) 288 tester.assertProtoEquals(a_proto, b_proto) 289 else: 290 tester.assertEquals(a_value, b_value) 291 # Compared the fields directly, remove their raw values from the 292 # proto comparison below. 293 a.ClearField("collection_def") 294 b.ClearField("collection_def") 295 296 # Check the graph_defs. 297 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 298 # Check graph_def versions (ignored by assert_equal_graph_def). 299 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 300 # Compared the fields directly, remove their raw values from the 301 # proto comparison below. 302 a.ClearField("graph_def") 303 b.ClearField("graph_def") 304 305 tester.assertProtoEquals(a, b) 306 307 308# Matches attributes named via _SHARDED_SUFFIX in 309# tensorflow/python/training/saver.py 310_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 311 312 313def _strip_checkpoint_v2_randomized(graph_def): 314 for node in graph_def.node: 315 delete_keys = [] 316 for attr_key in node.attr: 317 attr_tensor_value = node.attr[attr_key].tensor 318 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 319 attr_tensor_string_value = attr_tensor_value.string_val[0] 320 if (attr_tensor_string_value and 321 re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN), 322 attr_tensor_string_value)): 323 delete_keys.append(attr_key) 324 for attr_key in delete_keys: 325 del node.attr[attr_key] 326 327 328_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" 329 330 331def _strip_hash_table_shared_name(graph_def): 332 for node in graph_def.node: 333 delete_keys = [] 334 if node.op == "HashTableV2" and "shared_name" in node.attr: 335 if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), 336 node.attr["shared_name"].s): 337 delete_keys.append("shared_name") 338 for attr_key in delete_keys: 339 del node.attr[attr_key] 340 341 342def IsGoogleCudaEnabled(): 343 return _pywrap_util_port.IsGoogleCudaEnabled() 344 345 346def IsBuiltWithROCm(): 347 return _pywrap_util_port.IsBuiltWithROCm() 348 349 350def IsBuiltWithXLA(): 351 return _pywrap_util_port.IsBuiltWithXLA() 352 353 354def IsBuiltWithNvcc(): 355 return _pywrap_util_port.IsBuiltWithNvcc() 356 357 358def GpuSupportsHalfMatMulAndConv(): 359 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() 360 361 362def IsMklEnabled(): 363 return (_pywrap_util_port.IsMklEnabled() or 364 os.getenv("TF_ENABLE_ONEDNN_OPTS", "False").lower() in ["true", "1"]) 365 366 367def InstallStackTraceHandler(): 368 _pywrap_stacktrace_handler.InstallStacktraceHandler() 369 370 371def NHWCToNCHW(input_tensor): 372 """Converts the input from the NHWC format to NCHW. 373 374 Args: 375 input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape 376 377 Returns: 378 converted tensor or shape array 379 """ 380 # tensor dim -> new axis order 381 new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 382 if isinstance(input_tensor, ops.Tensor): 383 ndims = input_tensor.shape.ndims 384 return array_ops.transpose(input_tensor, new_axes[ndims]) 385 else: 386 ndims = len(input_tensor) 387 return [input_tensor[a] for a in new_axes[ndims]] 388 389 390def NHWCToNCHW_VECT_C(input_shape_or_tensor): 391 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 392 393 Note: Does not include quantization or type conversion steps, which should 394 be applied afterwards. 395 396 Args: 397 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 398 399 Returns: 400 tensor or shape array transformed into NCHW_VECT_C 401 402 Raises: 403 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 404 divisible by 4. 405 """ 406 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 407 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 408 temp_shape = ( 409 input_shape_or_tensor.shape.as_list() 410 if is_tensor else input_shape_or_tensor) 411 if temp_shape[-1] % 4 != 0: 412 raise ValueError( 413 "Last dimension of input must be evenly divisible by 4 to convert to " 414 "NCHW_VECT_C.") 415 temp_shape[-1] //= 4 416 temp_shape.append(4) 417 permutation = permutations[len(temp_shape)] 418 if is_tensor: 419 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 420 return array_ops.transpose(t, permutation) 421 else: 422 return [temp_shape[a] for a in permutation] 423 424 425def NCHW_VECT_CToNHWC(input_shape_or_tensor): 426 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 427 428 Note: Does not include de-quantization or type conversion steps, which should 429 be applied beforehand. 430 431 Args: 432 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 433 434 Returns: 435 tensor or shape array transformed into NHWC 436 437 Raises: 438 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 439 """ 440 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 441 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 442 input_shape = ( 443 input_shape_or_tensor.shape.as_list() 444 if is_tensor else input_shape_or_tensor) 445 if input_shape[-1] != 4: 446 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 447 permutation = permutations[len(input_shape)] 448 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 449 nhwc_shape[-1] *= input_shape[-1] 450 if is_tensor: 451 t = array_ops.transpose(input_shape_or_tensor, permutation) 452 return array_ops.reshape(t, nhwc_shape) 453 else: 454 return nhwc_shape 455 456 457def NCHWToNHWC(input_tensor): 458 """Converts the input from the NCHW format to NHWC. 459 460 Args: 461 input_tensor: a 4- or 5-D tensor, or an array representing shape 462 463 Returns: 464 converted tensor or shape array 465 """ 466 # tensor dim -> new axis order 467 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 468 if isinstance(input_tensor, ops.Tensor): 469 ndims = input_tensor.shape.ndims 470 return array_ops.transpose(input_tensor, new_axes[ndims]) 471 else: 472 ndims = len(input_tensor) 473 return [input_tensor[a] for a in new_axes[ndims]] 474 475 476def skip_if(condition): 477 """Skips the decorated function if condition is or evaluates to True. 478 479 Args: 480 condition: Either an expression that can be used in "if not condition" 481 statement, or a callable whose result should be a boolean. 482 483 Returns: 484 The wrapped function 485 """ 486 487 def real_skip_if(fn): 488 489 def wrapper(*args, **kwargs): 490 if callable(condition): 491 skip = condition() 492 else: 493 skip = condition 494 if not skip: 495 return fn(*args, **kwargs) 496 497 return wrapper 498 499 return real_skip_if 500 501 502@contextlib.contextmanager 503def skip_if_error(test_obj, error_type, messages=None): 504 """Context manager to skip cases not considered failures by the tests. 505 506 Note that this does not work if used in setUpClass/tearDownClass. 507 Usage in setUp/tearDown works fine just like regular test methods. 508 509 Args: 510 test_obj: A test object provided as `self` in the test methods; this object 511 is usually an instance of `unittest.TestCase`'s subclass and should have 512 `skipTest` method. 513 error_type: The error type to skip. Note that if `messages` are given, both 514 `error_type` and `messages` need to match for the test to be skipped. 515 messages: Optional, a string or list of strings. If `None`, the test will be 516 skipped if `error_type` matches what is raised; otherwise, the test is 517 skipped if any of the `messages` is contained in the message of the error 518 raised, and `error_type` matches the error raised. 519 520 Yields: 521 Nothing. 522 """ 523 if messages: 524 messages = nest.flatten(messages) 525 try: 526 yield 527 except error_type as e: 528 if not messages or any(message in str(e) for message in messages): 529 test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e))) 530 else: 531 raise 532 533 534def enable_c_shapes(fn): 535 """No-op. TODO(b/74620627): Remove this.""" 536 return fn 537 538 539def with_c_shapes(cls): 540 """No-op. TODO(b/74620627): Remove this.""" 541 return cls 542 543 544def enable_control_flow_v2(fn): 545 """Decorator for enabling CondV2 and WhileV2 on a test. 546 547 Note this enables using CondV2 and WhileV2 after running the test class's 548 setup/teardown methods. 549 550 In addition to this, callers must import the while_v2 module in order to set 551 the _while_v2 module in control_flow_ops. 552 553 Args: 554 fn: the function to be wrapped 555 556 Returns: 557 The wrapped function 558 """ 559 560 def wrapper(*args, **kwargs): 561 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 562 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 563 try: 564 return fn(*args, **kwargs) 565 finally: 566 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 567 568 return wrapper 569 570 571def with_control_flow_v2(cls): 572 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 573 574 Note this enables CondV2 and WhileV2 in new methods after running the test 575 class's setup method. 576 577 In addition to this, callers must import the while_v2 module in order to set 578 the _while_v2 module in control_flow_ops. 579 580 If a test function has _disable_control_flow_v2 attr set to True (using the 581 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 582 583 Example: 584 585 @test_util.with_control_flow_v2 586 class ControlFlowTest(test.TestCase): 587 588 def testEnabledForV2(self): 589 ... 590 591 @test_util.disable_control_flow_v2("b/xyzabc") 592 def testDisabledForV2(self): 593 ... 594 595 Generated class: 596 class ControlFlowTest(test.TestCase): 597 598 def testEnabledForV2(self): 599 ... 600 601 def testEnabledForV2WithControlFlowV2(self): 602 // Enable V2 flags. 603 testEnabledForV2(self) 604 // Restore V2 flags. 605 606 def testDisabledForV2(self): 607 ... 608 609 Args: 610 cls: class to decorate 611 612 Returns: 613 cls with new test methods added 614 """ 615 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 616 return cls 617 618 for name, value in cls.__dict__.copy().items(): 619 if (callable(value) and 620 name.startswith(unittest.TestLoader.testMethodPrefix) and 621 not getattr(value, "_disable_control_flow_v2", False)): 622 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 623 return cls 624 625 626def disable_control_flow_v2(unused_msg): 627 """Decorator for a function in a with_control_flow_v2 enabled test class. 628 629 Blocks the function from being run with v2 control flow ops. 630 631 Args: 632 unused_msg: Reason for disabling. 633 634 Returns: 635 The wrapped function with _disable_control_flow_v2 attr set to True. 636 """ 637 638 def wrapper(func): 639 func._disable_control_flow_v2 = True 640 return func 641 642 return wrapper 643 644 645def enable_output_all_intermediates(fn): 646 """Force-enable outputing all intermediates from functional control flow ops. 647 648 Args: 649 fn: the function to be wrapped 650 651 Returns: 652 The wrapped function 653 """ 654 655 def wrapper(*args, **kwargs): 656 output_all_intermediates_old = \ 657 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 658 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True 659 try: 660 return fn(*args, **kwargs) 661 finally: 662 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ 663 output_all_intermediates_old 664 665 return wrapper 666 667 668def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): 669 """Decorator for asserting that no new Python objects persist after a test. 670 671 Runs the test multiple times executing eagerly, first as a warmup and then to 672 let objects accumulate. The warmup helps ignore caches which do not grow as 673 the test is run repeatedly. 674 675 Useful for checking that there are no missing Py_DECREFs in the C exercised by 676 a bit of Python. 677 678 Args: 679 func: The function to test. 680 warmup_iters: The numer of warmup iterations, excluded from measuring. 681 682 Returns: 683 The wrapped function performing the test. 684 """ 685 686 def wrap_f(f): 687 def decorator(self, *args, **kwargs): 688 """Warms up, gets object counts, runs the test, checks for new objects.""" 689 with context.eager_mode(): 690 gc.disable() 691 # Run the test 2 times as warmup, in an attempt to fill up caches, which 692 # should not grow as the test is run repeatedly below. 693 # 694 # TODO(b/117156879): Running warmup twice is black magic; we have seen 695 # tests that fail with 1 warmup run, and pass with 2, on various 696 # versions of python2.7.x. 697 for _ in range(warmup_iters): 698 f(self, *args, **kwargs) 699 # Since we aren't in the normal test lifecycle, we need to manually run 700 # cleanups to clear out their object references. 701 self.doCleanups() 702 703 # Some objects are newly created by _get_object_count_by_type(). So 704 # create and save as a dummy variable to include it as a baseline. 705 obj_count_by_type = _get_object_count_by_type() 706 gc.collect() 707 708 # Make sure any registered functions are cleaned up in the C++ runtime. 709 registered_function_names = context.context().list_function_names() 710 711 # unittest.doCleanups adds to self._outcome with each unwound call. 712 # These objects are retained across gc collections so we exclude them 713 # from the object count calculation. 714 obj_count_by_type = _get_object_count_by_type( 715 exclude=gc.get_referents(self._outcome.errors, 716 self._outcome.skipped)) 717 718 if ops.has_default_graph(): 719 collection_sizes_before = { 720 collection: len(ops.get_collection(collection)) 721 for collection in ops.get_default_graph().collections 722 } 723 for _ in range(3): 724 f(self, *args, **kwargs) 725 # Since we aren't in the normal test lifecycle, we need to manually run 726 # cleanups to clear out their object references. 727 self.doCleanups() 728 # Note that gc.get_objects misses anything that isn't subject to garbage 729 # collection (C types). Collections are a common source of leaks, so we 730 # test for collection sizes explicitly. 731 if ops.has_default_graph(): 732 for collection_key in ops.get_default_graph().collections: 733 collection = ops.get_collection(collection_key) 734 size_before = collection_sizes_before.get(collection_key, 0) 735 if len(collection) > size_before: 736 raise AssertionError( 737 ("Collection %s increased in size from " 738 "%d to %d (current items %s).") % 739 (collection_key, size_before, len(collection), collection)) 740 # Make sure our collection checks don't show up as leaked memory by 741 # removing references to temporary variables. 742 del collection 743 del collection_key 744 del size_before 745 del collection_sizes_before 746 gc.collect() 747 748 # There should be no new Python objects hanging around. 749 obj_count_by_type = ( 750 _get_object_count_by_type( 751 exclude=gc.get_referents(self._outcome.errors, 752 self._outcome.skipped)) - 753 obj_count_by_type) 754 755 # There should be no newly registered functions hanging around. 756 leftover_functions = ( 757 context.context().list_function_names() - registered_function_names) 758 assert not leftover_functions, ( 759 "The following functions were newly created: %s" % 760 leftover_functions) 761 762 # In some cases (specifically on MacOS), new_count is somehow 763 # smaller than previous_count. 764 # Using plain assert because not all classes using this decorator 765 # have assertLessEqual 766 assert not obj_count_by_type, ( 767 "The following objects were newly created: %s" % 768 str(obj_count_by_type)) 769 gc.enable() 770 return decorator 771 772 if func is None: 773 return wrap_f 774 else: 775 return wrap_f(func) 776 777 778def assert_no_new_tensors(f): 779 """Decorator for asserting that no new Tensors persist after a test. 780 781 Mainly useful for checking that code using the Python C API has correctly 782 manipulated reference counts. 783 784 Clears the caches that it knows about, runs the garbage collector, then checks 785 that there are no Tensor or Tensor-like objects still around. This includes 786 Tensors to which something still has a reference (e.g. from missing 787 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 788 of the objects has __del__ defined). 789 790 Args: 791 f: The test case to run. 792 793 Returns: 794 The decorated test case. 795 """ 796 797 def decorator(self, **kwargs): 798 """Finds existing Tensors, runs the test, checks for new Tensors.""" 799 800 def _is_tensorflow_object(obj): 801 try: 802 return isinstance(obj, 803 (ops.Tensor, variables.Variable, 804 tensor_shape.Dimension, tensor_shape.TensorShape)) 805 except (ReferenceError, AttributeError): 806 # If the object no longer exists, we don't care about it. 807 return False 808 809 tensors_before = set( 810 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 811 outside_executed_eagerly = context.executing_eagerly() 812 # Run the test in a new graph so that collections get cleared when it's 813 # done, but inherit the graph key so optimizers behave. 814 outside_graph_key = ops.get_default_graph()._graph_key 815 with ops.Graph().as_default(): 816 ops.get_default_graph()._graph_key = outside_graph_key 817 if outside_executed_eagerly: 818 with context.eager_mode(): 819 result = f(self, **kwargs) 820 else: 821 result = f(self, **kwargs) 822 # Make an effort to clear caches, which would otherwise look like leaked 823 # Tensors. 824 context.context()._clear_caches() # pylint: disable=protected-access 825 gc.collect() 826 tensors_after = [ 827 obj for obj in gc.get_objects() 828 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 829 ] 830 if tensors_after: 831 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 832 len(tensors_after), 833 str(tensors_after), 834 ))) 835 return result 836 837 return decorator 838 839 840def _find_reference_cycle(objects, idx): 841 842 def get_ignore_reason(obj, denylist): 843 """Tests whether an object should be omitted from the dependency graph.""" 844 if len(denylist) > 100: 845 return "<depth limit>" 846 if tf_inspect.isframe(obj): 847 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 848 return "<test code>" 849 for b in denylist: 850 if b is obj: 851 return "<test code>" 852 if obj is denylist: 853 return "<test code>" 854 return None 855 856 # Note: this function is meant to help with diagnostics. Its output is purely 857 # a human-readable representation, so you may freely modify it to suit your 858 # needs. 859 def describe(obj, denylist, leaves_only=False): 860 """Returns a custom human-readable summary of obj. 861 862 Args: 863 obj: the value to describe. 864 denylist: same as denylist in get_ignore_reason. 865 leaves_only: boolean flag used when calling describe recursively. Useful 866 for summarizing collections. 867 """ 868 if get_ignore_reason(obj, denylist): 869 return "{}{}".format(get_ignore_reason(obj, denylist), type(obj)) 870 if tf_inspect.isframe(obj): 871 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 872 elif tf_inspect.ismodule(obj): 873 return "module: {}".format(obj.__name__) 874 else: 875 if leaves_only: 876 return "{}, {}".format(type(obj), id(obj)) 877 elif isinstance(obj, list): 878 return "list({}): {}".format( 879 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 880 elif isinstance(obj, tuple): 881 return "tuple({}): {}".format( 882 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 883 elif isinstance(obj, dict): 884 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 885 elif tf_inspect.isfunction(obj): 886 return "function({}) {}; globals ID: {}".format( 887 id(obj), obj.__name__, id(obj.__globals__)) 888 else: 889 return "{}, {}".format(type(obj), id(obj)) 890 891 def build_ref_graph(obj, graph, reprs, denylist): 892 """Builds a reference graph as <referrer> -> <list of referents>. 893 894 Args: 895 obj: The object to start from. The graph will be built by recursively 896 adding its referrers. 897 graph: Dict holding the graph to be built. To avoid creating extra 898 references, the graph holds object IDs rather than actual objects. 899 reprs: Auxiliary structure that maps object IDs to their human-readable 900 description. 901 denylist: List of objects to ignore. 902 """ 903 referrers = gc.get_referrers(obj) 904 denylist = denylist + (referrers,) 905 906 obj_id = id(obj) 907 for r in referrers: 908 if get_ignore_reason(r, denylist) is None: 909 r_id = id(r) 910 if r_id not in graph: 911 graph[r_id] = [] 912 if obj_id not in graph[r_id]: 913 graph[r_id].append(obj_id) 914 build_ref_graph(r, graph, reprs, denylist) 915 reprs[r_id] = describe(r, denylist) 916 917 def find_cycle(el, graph, reprs, path): 918 """Finds and prints a single cycle in the dependency graph.""" 919 if el not in graph: 920 return 921 for r in graph[el]: 922 if r in path: 923 logging.error("Reference cycle sample:") 924 for p in path + (r,): 925 logging.error(reprs.get(p, "unknown object " + str(p))) 926 return True 927 else: 928 if find_cycle(r, graph, reprs, path + (r,)): 929 return True 930 return False 931 932 obj = objects[idx] 933 graph = {} # referrer ID -> object ID 934 reprs = {} # object ID -> description 935 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 936 describe, build_ref_graph, find_cycle)) 937 for k in graph: 938 if find_cycle(k, graph, reprs, ()): 939 return True 940 return False 941 942 943def assert_no_garbage_created(f): 944 """Test method decorator to assert that no garbage has been created. 945 946 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 947 cannot be un-set (i.e. will disable garbage collection for any other unit 948 tests in the same file/shard). 949 950 Args: 951 f: The function to decorate. 952 953 Returns: 954 The decorated function. 955 """ 956 957 def decorator(self, **kwargs): 958 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 959 # Force-load `distribution_strategy_context` to prevent GC at 960 # test time when using eager. Remove once b/117329403 is resolved. 961 tape.distribution_strategy_context.get_strategy() 962 963 gc.disable() 964 previous_debug_flags = gc.get_debug() 965 gc.set_debug(gc.DEBUG_SAVEALL) 966 gc.collect() 967 previous_garbage = len(gc.garbage) 968 result = f(self, **kwargs) 969 gc.collect() 970 new_garbage = len(gc.garbage) 971 if new_garbage > previous_garbage: 972 973 for i, obj in enumerate(gc.garbage[previous_garbage:]): 974 # Known false positive for ast.fix_missing_locations. 975 if getattr(obj, "__module__", "") == "ast": 976 new_garbage -= 3 977 978 if new_garbage > previous_garbage: 979 logging.error( 980 "The decorated test created work for Python's garbage collector, " 981 "likely due to a reference cycle. New objects in cycle(s):") 982 for i, obj in enumerate(gc.garbage[previous_garbage:]): 983 try: 984 logging.error("Object %d of %d", i, 985 len(gc.garbage) - previous_garbage) 986 987 def _safe_object_str(obj): 988 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 989 990 logging.error(" Object type: %s", _safe_object_str(obj)) 991 logging.error( 992 " Referrer types: %s", ", ".join( 993 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 994 logging.error( 995 " Referent types: %s", ", ".join( 996 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 997 logging.error(" Object attribute names: %s", dir(obj)) 998 logging.error(" Object __str__:") 999 logging.error(obj) 1000 logging.error(" Object __repr__:") 1001 logging.error(repr(obj)) 1002 except Exception: # pylint: disable=broad-except 1003 logging.error("(Exception while printing object)") 1004 1005 # When garbage is created, this call can help identify reference cycles, 1006 # which are typically the cause of such garbage. 1007 if new_garbage > previous_garbage: 1008 for i in range(previous_garbage, new_garbage): 1009 if _find_reference_cycle(gc.garbage, i): 1010 break 1011 1012 # This will fail if any garbage has been created, typically because of a 1013 # reference cycle. 1014 self.assertEqual(previous_garbage, new_garbage) 1015 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 1016 # be nice to be able to decorate arbitrary tests in a large test suite and 1017 # not hold on to every object in other tests. 1018 gc.set_debug(previous_debug_flags) 1019 gc.enable() 1020 return result 1021 1022 return decorator 1023 1024 1025def _combine_named_parameters(**kwargs): 1026 """Generate combinations based on its keyword arguments. 1027 1028 Two sets of returned combinations can be concatenated using +. Their product 1029 can be computed using `times()`. 1030 1031 Args: 1032 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1033 `option=the_only_possibility`. 1034 1035 Returns: 1036 a list of dictionaries for each combination. Keys in the dictionaries are 1037 the keyword argument names. Each key has one value - one of the 1038 corresponding keyword argument values. 1039 """ 1040 sort_by_key = lambda k: k[0] 1041 combinations = [] 1042 for key, values in sorted(kwargs.items(), key=sort_by_key): 1043 if not isinstance(values, list): 1044 values = [values] 1045 combinations.append([(key, value) for value in values]) 1046 1047 return [OrderedDict(result) for result in itertools.product(*combinations)] 1048 1049 1050def generate_combinations_with_testcase_name(**kwargs): 1051 """Generate combinations based on its keyword arguments using combine(). 1052 1053 This function calls combine() and appends a testcase name to the list of 1054 dictionaries returned. The 'testcase_name' key is a required for named 1055 parameterized tests. 1056 1057 Args: 1058 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1059 `option=the_only_possibility`. 1060 1061 Returns: 1062 a list of dictionaries for each combination. Keys in the dictionaries are 1063 the keyword argument names. Each key has one value - one of the 1064 corresponding keyword argument values. 1065 """ 1066 combinations = _combine_named_parameters(**kwargs) 1067 named_combinations = [] 1068 for combination in combinations: 1069 assert isinstance(combination, OrderedDict) 1070 name = "".join([ 1071 "_{}_{}".format("".join(filter(str.isalnum, key)), 1072 "".join(filter(str.isalnum, str(value)))) 1073 for key, value in combination.items() 1074 ]) 1075 named_combinations.append( 1076 OrderedDict( 1077 list(combination.items()) + 1078 [("testcase_name", "_test{}".format(name))])) 1079 1080 return named_combinations 1081 1082 1083def run_all_in_graph_and_eager_modes(cls): 1084 """Execute all test methods in the given class with and without eager.""" 1085 base_decorator = run_in_graph_and_eager_modes 1086 for name in dir(cls): 1087 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1088 name.startswith("testSkipEager") or 1089 name.startswith("test_skip_eager") or 1090 name == "test_session"): 1091 continue 1092 value = getattr(cls, name, None) 1093 if callable(value): 1094 setattr(cls, name, base_decorator(value)) 1095 return cls 1096 1097 1098def build_as_function_and_v1_graph(func=None): 1099 """Run a test case in v1 graph mode and inside tf.function in eager mode. 1100 1101 WARNING: This decorator can only be used in test cases that statically checks 1102 generated graph. Attempting to evaluate graph or function results via. 1103 session.run() or self.evaluate() will fail. 1104 1105 WARNING: This decorator can only be used for test cases that inherit from 1106 absl.testing.parameterized.TestCase. 1107 1108 Args: 1109 func: Test case function to be decorated. 1110 1111 Returns: 1112 Decorated test case function. 1113 """ 1114 1115 def decorator(f): 1116 if tf_inspect.isclass(f): 1117 raise ValueError( 1118 "`run_in_graph_mode_and_function` only supports test methods.") 1119 1120 @parameterized.named_parameters(("_v1_graph", "v1_graph"), 1121 ("_function", "function")) 1122 @functools.wraps(f) 1123 def decorated(self, run_mode, *args, **kwargs): 1124 if run_mode == "v1_graph": 1125 with ops.Graph().as_default(): 1126 f(self, *args, **kwargs) 1127 elif run_mode == "function": 1128 1129 @def_function.function 1130 def function_in_eager(): 1131 f(self, *args, **kwargs) 1132 1133 # Create a new graph for the eagerly executed version of this test for 1134 # better isolation. 1135 graph_for_eager_test = ops.Graph() 1136 with graph_for_eager_test.as_default(), context.eager_mode(): 1137 function_in_eager() 1138 ops.dismantle_graph(graph_for_eager_test) 1139 else: 1140 return ValueError("Unknown run mode %s" % run_mode) 1141 1142 return decorated 1143 1144 if func is not None: 1145 return decorator(func) 1146 1147 return decorator 1148 1149 1150def run_in_async_and_sync_mode(f): 1151 """Execute the test in async mode and sync mode.""" 1152 1153 @parameterized.named_parameters([("Async", True), ("", False)]) 1154 @functools.wraps(f) 1155 def decorator(self, async_mode, *args, **kwargs): 1156 if async_mode: 1157 with context.execution_mode(context.ASYNC): 1158 f(self, *args, **kwargs) 1159 else: 1160 with context.execution_mode(context.SYNC): 1161 f(self, *args, **kwargs) 1162 return decorator 1163 1164 1165def run_in_graph_and_eager_modes(func=None, 1166 config=None, 1167 use_gpu=True, 1168 assert_no_eager_garbage=False): 1169 """Execute the decorated test with and without enabling eager execution. 1170 1171 This function returns a decorator intended to be applied to test methods in 1172 a `tf.test.TestCase` class. Doing so will cause the contents of the test 1173 method to be executed twice - once normally, and once with eager execution 1174 enabled. This allows unittests to confirm the equivalence between eager 1175 and graph execution (see `tf.compat.v1.enable_eager_execution`). 1176 1177 For example, consider the following unittest: 1178 1179 ```python 1180 class MyTests(tf.test.TestCase): 1181 1182 @run_in_graph_and_eager_modes 1183 def test_foo(self): 1184 x = tf.constant([1, 2]) 1185 y = tf.constant([3, 4]) 1186 z = tf.add(x, y) 1187 self.assertAllEqual([4, 6], self.evaluate(z)) 1188 1189 if __name__ == "__main__": 1190 tf.test.main() 1191 ``` 1192 1193 This test validates that `tf.add()` has the same behavior when computed with 1194 eager execution enabled as it does when constructing a TensorFlow graph and 1195 executing the `z` tensor in a session. 1196 1197 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1198 `run_in_graph_and_eager_modes` are available decorators for different 1199 v1/v2/eager/graph combinations. 1200 1201 1202 Args: 1203 func: function to be annotated. If `func` is None, this method returns a 1204 decorator the can be applied to a function. If `func` is not None this 1205 returns the decorator applied to `func`. 1206 config: An optional config_pb2.ConfigProto to use to configure the session 1207 when executing graphs. 1208 use_gpu: If True, attempt to run as many operations as possible on GPU. 1209 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 1210 collector and asserts that no extra garbage has been created when running 1211 the test with eager execution enabled. This will fail if there are 1212 reference cycles (e.g. a = []; a.append(a)). Off by default because some 1213 tests may create garbage for legitimate reasons (e.g. they define a class 1214 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 1215 Python interpreters (meaning that tests which rely on objects being 1216 collected elsewhere in the unit test file will not work). Additionally, 1217 checks that nothing still has a reference to Tensors that the test 1218 allocated. 1219 1220 Returns: 1221 Returns a decorator that will run the decorated test method twice: 1222 once by constructing and executing a graph in a session and once with 1223 eager execution enabled. 1224 """ 1225 1226 def decorator(f): 1227 if tf_inspect.isclass(f): 1228 raise ValueError( 1229 "`run_in_graph_and_eager_modes` only supports test methods. " 1230 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 1231 1232 def decorated(self, *args, **kwargs): 1233 try: 1234 with context.graph_mode(): 1235 with self.test_session(use_gpu=use_gpu, config=config): 1236 f(self, *args, **kwargs) 1237 except unittest.case.SkipTest: 1238 pass 1239 1240 def run_eagerly(self, **kwargs): 1241 if not use_gpu: 1242 with ops.device("/device:CPU:0"): 1243 f(self, *args, **kwargs) 1244 else: 1245 f(self, *args, **kwargs) 1246 1247 if assert_no_eager_garbage: 1248 ops.reset_default_graph() 1249 run_eagerly = assert_no_new_tensors( 1250 assert_no_garbage_created(run_eagerly)) 1251 1252 # This decorator runs the wrapped test twice. 1253 # Reset the test environment between runs. 1254 self.tearDown() 1255 self._tempdir = None 1256 # Create a new graph for the eagerly executed version of this test for 1257 # better isolation. 1258 graph_for_eager_test = ops.Graph() 1259 with graph_for_eager_test.as_default(), context.eager_mode(): 1260 self.setUp() 1261 run_eagerly(self, **kwargs) 1262 ops.dismantle_graph(graph_for_eager_test) 1263 1264 return tf_decorator.make_decorator(f, decorated) 1265 1266 if func is not None: 1267 return decorator(func) 1268 1269 return decorator 1270 1271 1272def py_func_if_in_function(f): 1273 1274 def decorated(*args, **kwds): 1275 if not ops.inside_function(): 1276 return f(*args, **kwds) 1277 1278 tensor_args = [] 1279 tensor_indices = [] 1280 for i, arg in enumerate(args): 1281 if isinstance(arg, (ops.Tensor, variables.Variable)): 1282 tensor_args.append(arg) 1283 tensor_indices.append(i) 1284 1285 def inner_f(*inner_tensor_args): 1286 my_args = list(args) 1287 for i, n in zip(tensor_indices, inner_tensor_args): 1288 my_args[i] = n 1289 return f(*my_args, **kwds) 1290 1291 return script_ops.py_func(inner_f, tensor_args, []) 1292 1293 return tf_decorator.make_decorator(f, decorated) 1294 1295 1296def also_run_as_tf_function(f): 1297 """Runs the decorated test twice--once as is, once inside a tf.function. 1298 1299 This allows you to run a test both in eager execution and inside a 1300 tf.function, exercising the two execution modes supported in tf 2.0. The test 1301 assertions are automatically done inside tf.py_funcs, and tf.function ensures 1302 that they run in the proper order and with the proper side effects. 1303 1304 Currently variable creation is not supported in tests annotated with this 1305 decorator since it's tricky to ensure the variable doesn't get repeatedly 1306 created when retracing the tf.function. 1307 1308 Args: 1309 f: the test method to be decorated 1310 1311 Returns: 1312 The decorated test method, which will run both in eager and inside a 1313 tf.function. 1314 """ 1315 1316 def decorated(*args, **kwds): 1317 1318 def bound_f(): 1319 f(*args, **kwds) 1320 1321 with context.eager_mode(): 1322 # Running in eager mode 1323 bound_f() 1324 # Running as TF function 1325 # TODO(b/121143941): Remove the autograph override. 1326 def_function.function(bound_f, autograph=False)() 1327 1328 return decorated 1329 1330 1331def deprecated_graph_mode_only(func=None): 1332 """Execute the decorated test in graph mode. 1333 1334 This function returns a decorator intended to be applied to tests that are not 1335 compatible with eager mode. When this decorator is applied, the test body will 1336 be run in an environment where API calls construct graphs instead of executing 1337 eagerly. 1338 1339 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1340 `run_in_graph_and_eager_modes` are available decorators for different 1341 v1/v2/eager/graph combinations. 1342 1343 Args: 1344 func: function to be annotated. If `func` is None, this method returns a 1345 decorator the can be applied to a function. If `func` is not None this 1346 returns the decorator applied to `func`. 1347 1348 Returns: 1349 Returns a decorator that will run the decorated test method in graph mode. 1350 """ 1351 1352 def decorator(f): 1353 if tf_inspect.isclass(f): 1354 setup = f.__dict__.get("setUp") 1355 if setup is not None: 1356 setattr(f, "setUp", decorator(setup)) 1357 1358 for name, value in f.__dict__.copy().items(): 1359 if (callable(value) and 1360 name.startswith(unittest.TestLoader.testMethodPrefix)): 1361 setattr(f, name, decorator(value)) 1362 1363 return f 1364 1365 def decorated(self, *args, **kwargs): 1366 if context.executing_eagerly(): 1367 with context.graph_mode(): 1368 return f(self, *args, **kwargs) 1369 else: 1370 return f(self, *args, **kwargs) 1371 1372 return decorated 1373 1374 if func is not None: 1375 return decorator(func) 1376 1377 return decorator 1378 1379 1380run_deprecated_v1 = deprecated_graph_mode_only 1381 1382 1383def run_all_in_deprecated_graph_mode_only(cls): 1384 """Execute all tests in a class in graph mode.""" 1385 base_decorator = deprecated_graph_mode_only 1386 for name in dir(cls): 1387 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1388 name == "test_session"): 1389 continue 1390 value = getattr(cls, name, None) 1391 if callable(value): 1392 setattr(cls, name, base_decorator(value)) 1393 return cls 1394 1395 1396def run_v1_only(reason, func=None): 1397 """Execute the decorated test only if running in v1 mode. 1398 1399 This function is intended to be applied to tests that exercise v1 only 1400 functionality. If the test is run in v2 mode it will simply be skipped. 1401 1402 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1403 `run_in_graph_and_eager_modes` are available decorators for different 1404 v1/v2/eager/graph combinations. 1405 1406 Args: 1407 reason: string giving a reason for limiting the test to v1 only. 1408 func: function to be annotated. If `func` is None, this method returns a 1409 decorator the can be applied to a function. If `func` is not None this 1410 returns the decorator applied to `func`. 1411 1412 Returns: 1413 Returns a decorator that will conditionally skip the decorated test method. 1414 """ 1415 if not isinstance(reason, str): 1416 raise ValueError("'reason' should be string, got {}".format(type(reason))) 1417 1418 def decorator(f): 1419 if tf_inspect.isclass(f): 1420 # To skip an entire test suite class, we only decorate the setUp method 1421 # to skip all tests. There are cases when setUp is not defined (not 1422 # overridden in subclasses of TestCase, so not available in f.__dict__ 1423 # below). For those cases, we walk the method resolution order list and 1424 # pick the first setUp method we find (usually this should be the one in 1425 # the parent class since that's the TestCase class). 1426 for cls in type.mro(f): 1427 setup = cls.__dict__.get("setUp") 1428 if setup is not None: 1429 setattr(f, "setUp", decorator(setup)) 1430 break 1431 1432 return f 1433 else: 1434 # If f is just a function, just create a decorator for it and return it 1435 def decorated(self, *args, **kwargs): 1436 if tf2.enabled(): 1437 self.skipTest(reason) 1438 1439 return f(self, *args, **kwargs) 1440 1441 return decorated 1442 1443 if func is not None: 1444 return decorator(func) 1445 1446 return decorator 1447 1448 1449def run_v2_only(func=None): 1450 """Execute the decorated test only if running in v2 mode. 1451 1452 This function is intended to be applied to tests that exercise v2 only 1453 functionality. If the test is run in v1 mode it will simply be skipped. 1454 1455 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1456 `run_in_graph_and_eager_modes` are available decorators for different 1457 v1/v2/eager/graph combinations. 1458 1459 Args: 1460 func: function to be annotated. If `func` is None, this method returns a 1461 decorator the can be applied to a function. If `func` is not None this 1462 returns the decorator applied to `func`. 1463 1464 Returns: 1465 Returns a decorator that will conditionally skip the decorated test method. 1466 """ 1467 1468 def decorator(f): 1469 if tf_inspect.isclass(f): 1470 raise ValueError("`run_v2_only` only supports test methods.") 1471 1472 def decorated(self, *args, **kwargs): 1473 if not tf2.enabled(): 1474 self.skipTest("Test is only compatible with v2") 1475 1476 return f(self, *args, **kwargs) 1477 1478 return decorated 1479 1480 if func is not None: 1481 return decorator(func) 1482 1483 return decorator 1484 1485 1486def run_gpu_only(func=None): 1487 """Execute the decorated test only if a GPU is available. 1488 1489 This function is intended to be applied to tests that require the presence 1490 of a GPU. If a GPU is absent, it will simply be skipped. 1491 1492 Args: 1493 func: function to be annotated. If `func` is None, this method returns a 1494 decorator the can be applied to a function. If `func` is not None this 1495 returns the decorator applied to `func`. 1496 1497 Returns: 1498 Returns a decorator that will conditionally skip the decorated test method. 1499 """ 1500 1501 def decorator(f): 1502 if tf_inspect.isclass(f): 1503 raise ValueError("`run_gpu_only` only supports test methods.") 1504 1505 def decorated(self, *args, **kwargs): 1506 if not is_gpu_available(): 1507 self.skipTest("Test requires GPU") 1508 1509 return f(self, *args, **kwargs) 1510 1511 return decorated 1512 1513 if func is not None: 1514 return decorator(func) 1515 1516 return decorator 1517 1518 1519def run_cuda_only(func=None): 1520 """Execute the decorated test only if a GPU is available. 1521 1522 This function is intended to be applied to tests that require the presence 1523 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 1524 1525 Args: 1526 func: function to be annotated. If `func` is None, this method returns a 1527 decorator the can be applied to a function. If `func` is not None this 1528 returns the decorator applied to `func`. 1529 1530 Returns: 1531 Returns a decorator that will conditionally skip the decorated test method. 1532 """ 1533 1534 def decorator(f): 1535 if tf_inspect.isclass(f): 1536 raise ValueError("`run_cuda_only` only supports test methods.") 1537 1538 def decorated(self, *args, **kwargs): 1539 if not is_gpu_available(cuda_only=True): 1540 self.skipTest("Test requires CUDA GPU") 1541 1542 return f(self, *args, **kwargs) 1543 1544 return decorated 1545 1546 if func is not None: 1547 return decorator(func) 1548 1549 return decorator 1550 1551 1552def run_gpu_or_tpu(func=None): 1553 """Execute the decorated test only if a physical GPU or TPU is available. 1554 1555 This function is intended to be applied to tests that require the presence 1556 of a physical GPU or TPU. It complies with the following rules: 1557 - If a GPU is available, the test will run on the GPU. 1558 - If a GPU is absent and a TPU is available, the test will run on the TPU. 1559 - If both GPU and TPU are absent, the test will be skipped. 1560 1561 Args: 1562 func: function to be annotated. If `func` is None, this method returns a 1563 decorator the can be applied to a function. If `func` is not None this 1564 returns the decorator applied to `func`. 1565 1566 Returns: 1567 Returns a decorator that will conditionally skip the decorated test method. 1568 """ 1569 1570 def decorator(f): 1571 if tf_inspect.isclass(f): 1572 raise ValueError("`run_gpu_or_tpu` only supports test methods.") 1573 1574 def decorated(self, *args, **kwargs): 1575 if config.list_physical_devices("GPU"): 1576 return f(self, "GPU", *args, **kwargs) 1577 1578 if config.list_physical_devices("TPU"): 1579 return f(self, "TPU", *args, **kwargs) 1580 1581 self.skipTest("Test requires GPU or TPU") 1582 1583 return decorated 1584 1585 return decorator if func is None else decorator(func) 1586 1587 1588def with_forward_compatibility_horizons(*horizons): 1589 """Executes the decorated test with the specified forward-compat horizons. 1590 1591 Args: 1592 *horizons: A list of (year, month, day) tuples. If the list includes 1593 `None`, then the test will also be run with no forward-compatibility 1594 horizon set. 1595 1596 Returns: 1597 A decorator that will execute the test with the specified horizons. 1598 """ 1599 if not horizons: 1600 raise ValueError("Expected at least one horizon.") 1601 for horizon in horizons: 1602 if not ((horizon is None) or 1603 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): 1604 raise ValueError("Bad horizon value: %r" % horizon) 1605 1606 def decorator(f): 1607 if tf_inspect.isclass(f): 1608 raise ValueError("`with_forward_compatibility_horizons` only " 1609 "supports test methods.") 1610 def decorated(self, *args, **kwargs): 1611 for horizon in horizons: 1612 if horizon is None: 1613 f(self, *args, **kwargs) 1614 else: 1615 (year, month, day) = horizon 1616 with forward_compatibility_horizon(year, month, day): 1617 f(self, *args, **kwargs) 1618 return decorated 1619 1620 return decorator 1621 1622 1623@deprecation.deprecated(None, 1624 "Use `tf.config.list_physical_devices('GPU')` instead.") 1625@tf_export("test.is_gpu_available") 1626def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 1627 """Returns whether TensorFlow can access a GPU. 1628 1629 Warning: if a non-GPU version of the package is installed, the function would 1630 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow 1631 was build with CUDA support. 1632 1633 For example, 1634 >>> gpu_available = tf.test.is_gpu_available() 1635 >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True) 1636 >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0)) 1637 1638 Args: 1639 cuda_only: limit the search to CUDA GPUs. 1640 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 1641 CUDA compute capability required, or None if no requirement. 1642 1643 Note that the keyword arg name "cuda_only" is misleading (since routine will 1644 return true when a GPU device is available irrespective of whether TF was 1645 built with CUDA support or ROCm support. However no changes here because 1646 1647 ++ Changing the name "cuda_only" to something more generic would break 1648 backward compatibility 1649 1650 ++ Adding an equivalent "rocm_only" would require the implementation check 1651 the build type. This in turn would require doing the same for CUDA and thus 1652 potentially break backward compatibility 1653 1654 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, 1655 but would require most (if not all) callers to update the call to use 1656 "cuda_or_rocm_only" instead of "cuda_only" 1657 1658 Returns: 1659 True if a GPU device of the requested kind is available. 1660 """ 1661 1662 # This was needed earlier when we had support for SYCL in TensorFlow. 1663 del cuda_only 1664 1665 try: 1666 for local_device in device_lib.list_local_devices(): 1667 if local_device.device_type == "GPU": 1668 gpu_info = gpu_util.compute_capability_from_device_desc(local_device) 1669 cc = gpu_info.compute_capability or (0, 0) 1670 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability: 1671 return True 1672 return False 1673 except errors_impl.NotFoundError as e: 1674 if not all(x in str(e) for x in ["CUDA", "not find"]): 1675 raise e 1676 else: 1677 logging.error(str(e)) 1678 return False 1679 1680 1681@contextlib.contextmanager 1682def device(use_gpu): 1683 """Uses gpu when requested and available.""" 1684 if use_gpu and is_gpu_available(): 1685 dev = "/device:GPU:0" 1686 else: 1687 dev = "/device:CPU:0" 1688 with ops.device(dev): 1689 yield 1690 1691 1692@contextlib.contextmanager 1693def use_gpu(): 1694 """Uses gpu when requested and available.""" 1695 with device(use_gpu=True): 1696 yield 1697 1698 1699@contextlib.contextmanager 1700def force_gpu(): 1701 """Force the gpu to be used.""" 1702 with ops.device("/device:GPU:0"): 1703 yield 1704 1705 1706@contextlib.contextmanager 1707def force_cpu(): 1708 """Force the cpu to be used.""" 1709 with ops.device("/device:CPU:0"): 1710 yield 1711 1712 1713class CapturedWrites(object): 1714 """A utility class to load the captured writes made to a stream.""" 1715 1716 def __init__(self, capture_location): 1717 self.capture_location = capture_location 1718 1719 def contents(self): 1720 """Get the captured writes as a single string.""" 1721 with open(self.capture_location) as tmp_file: 1722 output_data = "".join(tmp_file.readlines()) 1723 return output_data 1724 1725 1726class FakeEagerSession(object): 1727 """Fake session so tests that conditionally use placeholders can use eager. 1728 1729 There are a number of tests that conditionally use placeholders for shape 1730 inference. The pattern is demonstrated here: 1731 1732 ```python 1733 with self.cached_session() as sess: 1734 if static_shape: 1735 y = math_ops.matmul(x, ...) 1736 feed_dict = {} 1737 else: 1738 x_ph = array_ops.placeholder(...) 1739 y = math_ops.matmul(x_ph, ...) 1740 feed_dict = {x_ph: x} 1741 val = sess.run(y, feed_dict=feed_dict) 1742 ``` 1743 1744 Since the feed_dict is empty when not using placeholders we should be able to 1745 call self.evaluate(), however this requires rewriting the test case. 1746 This class should be considered a stop-gap solution to get tests running with 1747 eager with minimal changes to the actual test. 1748 """ 1749 1750 def __init__(self, test_case): 1751 self._test_case = test_case 1752 1753 def run(self, fetches, *args, **kwargs): 1754 """Evaluate `fetches`. 1755 1756 Fail if additional args are specified. 1757 1758 Args: 1759 fetches: A Tensor or a nested list/tuple of Tensors. 1760 *args: Positional arguments 1761 **kwargs: Keyword arguments 1762 1763 Raises: 1764 RuntimeError: If args or kwargs are specified. 1765 1766 Returns: 1767 Tensors as numpy values. 1768 """ 1769 feed_dict = kwargs.pop("feed_dict", {}) 1770 if feed_dict: 1771 raise RuntimeError( 1772 "feed_dict is not supported when eager execution is enabled " 1773 "(in this case, sess.run(t) is shorthand for t.numpy()") 1774 1775 if args or kwargs: 1776 raise RuntimeError( 1777 "Optional args are not supported when eager execution is enabled " 1778 "(in this case, sess.run(t) is shorthand for t.numpy()") 1779 1780 return self._test_case.evaluate(fetches) 1781 1782 1783class ErrorLoggingSession(session.Session): 1784 """Wrapper around a Session that logs errors in run().""" 1785 1786 def run(self, *args, **kwargs): 1787 try: 1788 return super(ErrorLoggingSession, self).run(*args, **kwargs) 1789 except Exception as e: # pylint: disable=broad-except 1790 # Note: disable the logging for OutOfRangeError, which makes the output 1791 # of tf.data tests hard to read, because OutOfRangeError is used as the 1792 # signal completion 1793 if not isinstance(e, errors.OutOfRangeError): 1794 logging.error(str(e)) 1795 raise 1796 1797 1798def disable_cudnn_autotune(func): 1799 """Disable autotuning during the call to this function. 1800 1801 Some tests want to base assertions on a graph being isomorphic with a copy. 1802 To ensure this, this decorator disables autotuning. 1803 1804 Args: 1805 func: Function to run with CuDNN autotuning turned off. 1806 1807 Returns: 1808 Decorated function. 1809 """ 1810 1811 def decorator(f): 1812 1813 def decorated(self, *args, **kwargs): 1814 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") 1815 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" 1816 original_xla_flags = os.environ.get("XLA_FLAGS") 1817 new_xla_flags = "--xla_gpu_autotune_level=0" 1818 if original_xla_flags: 1819 new_xla_flags = original_xla_flags + " " + new_xla_flags 1820 os.environ["XLA_FLAGS"] = new_xla_flags 1821 1822 result = f(self, *args, **kwargs) 1823 1824 if (original_tf_cudnn_use_autotune is None): 1825 del os.environ["TF_CUDNN_USE_AUTOTUNE"] 1826 else: 1827 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune 1828 if (original_xla_flags is None): 1829 del os.environ["XLA_FLAGS"] 1830 else: 1831 os.environ["XLA_FLAGS"] = original_xla_flags 1832 1833 return result 1834 1835 return decorated 1836 1837 if func is not None: 1838 return decorator(func) 1839 1840 return decorator 1841 1842 1843# The description is just for documentation purposes. 1844def enable_tf_xla_constant_folding(description): 1845 1846 if not isinstance(description, str): 1847 raise ValueError("'description' should be string, got {}".format( 1848 type(description))) 1849 1850 def enable_tf_xla_constant_folding_impl(func): 1851 """Enable constant folding during the call to this function. 1852 1853 Some tests fail without constant folding. 1854 1855 Args: 1856 func: Function to run with constant folding turned on. 1857 1858 Returns: 1859 Decorated function. 1860 """ 1861 1862 def decorator(f): 1863 1864 def decorated(self, *args, **kwargs): 1865 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() 1866 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) 1867 result = f(self, *args, **kwargs) 1868 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) 1869 return result 1870 1871 return decorated 1872 1873 if func is not None: 1874 return decorator(func) 1875 1876 return decorator 1877 1878 return enable_tf_xla_constant_folding_impl 1879 1880 1881# Updates test function by selectively disabling it. 1882def _disable_test(execute_func): 1883 1884 def disable_test_impl(func): 1885 1886 def decorator(func): 1887 1888 def decorated(self, *args, **kwargs): 1889 if execute_func: 1890 return func(self, *args, **kwargs) 1891 1892 return tf_decorator.make_decorator(func, decorated) 1893 1894 if func is not None: 1895 return decorator(func) 1896 1897 return decorator 1898 1899 return disable_test_impl 1900 1901 1902# The description is just for documentation purposes. 1903def disable_xla(description): # pylint: disable=unused-argument 1904 """Execute the test method only if xla is not enabled.""" 1905 execute_func = not is_xla_enabled() 1906 return _disable_test(execute_func) 1907 1908 1909# The description is just for documentation purposes. 1910def disable_mlir_bridge(description): # pylint: disable=unused-argument 1911 """Execute the test method only if MLIR bridge is not enabled.""" 1912 execute_func = not is_mlir_bridge_enabled() 1913 return _disable_test(execute_func) 1914 1915 1916# The description is just for documentation purposes. 1917def disable_asan(description): # pylint: disable=unused-argument 1918 """Execute the test method only if ASAN is not enabled.""" 1919 execute_func = not is_asan_enabled() 1920 return _disable_test(execute_func) 1921 1922 1923# The description is just for documentation purposes. 1924def disable_msan(description): # pylint: disable=unused-argument 1925 """Execute the test method only if MSAN is not enabled.""" 1926 execute_func = not is_msan_enabled() 1927 return _disable_test(execute_func) 1928 1929 1930# The description is just for documentation purposes. 1931def disable_tsan(description): # pylint: disable=unused-argument 1932 """Execute the test method only if TSAN is not enabled.""" 1933 execute_func = not is_tsan_enabled() 1934 return _disable_test(execute_func) 1935 1936 1937# The description is just for documentation purposes. 1938def disable_ubsan(description): # pylint: disable=unused-argument 1939 """Execute the test method only if UBSAN is not enabled.""" 1940 execute_func = not is_ubsan_enabled() 1941 return _disable_test(execute_func) 1942 1943 1944# The description is just for documentation purposes. 1945def disable_tfrt(unused_description): 1946 1947 def disable_tfrt_impl(cls_or_func): 1948 """Execute the test only if tfrt is not enabled.""" 1949 1950 if tf_inspect.isclass(cls_or_func): 1951 if tfrt_utils.enabled(): 1952 return None 1953 else: 1954 return cls_or_func 1955 else: 1956 def decorator(func): 1957 1958 def decorated(self, *args, **kwargs): 1959 if tfrt_utils.enabled(): 1960 return 1961 else: 1962 return func(self, *args, **kwargs) 1963 1964 return decorated 1965 1966 if cls_or_func is not None: 1967 return decorator(cls_or_func) 1968 1969 return decorator 1970 1971 return disable_tfrt_impl 1972 1973 1974def for_all_test_methods(decorator, *args, **kwargs): 1975 """Generate class-level decorator from given method-level decorator. 1976 1977 It is expected for the given decorator to take some arguments and return 1978 a method that is then called on the test method to produce a decorated 1979 method. 1980 1981 Args: 1982 decorator: The decorator to apply. 1983 *args: Positional arguments 1984 **kwargs: Keyword arguments 1985 Returns: Function that will decorate a given classes test methods with the 1986 decorator. 1987 """ 1988 1989 def all_test_methods_impl(cls): 1990 """Apply decorator to all test methods in class.""" 1991 for name in dir(cls): 1992 value = getattr(cls, name) 1993 if callable(value) and name.startswith( 1994 "test") and (name != "test_session"): 1995 setattr(cls, name, decorator(*args, **kwargs)(value)) 1996 return cls 1997 1998 return all_test_methods_impl 1999 2000 2001# The description is just for documentation purposes. 2002def no_xla_auto_jit(description): # pylint: disable=unused-argument 2003 """This test is not intended to be run with XLA auto jit enabled.""" 2004 execute_func = not is_xla_enabled() 2005 return _disable_test(execute_func) 2006 2007 2008# The description is just for documentation purposes. 2009def xla_allow_fallback(description): # pylint: disable=unused-argument 2010 2011 def xla_allow_fallback_impl(func): 2012 """Allow fallback to TF even though testing xla.""" 2013 2014 def decorator(func): 2015 2016 def decorated(self, *args, **kwargs): 2017 if is_xla_enabled(): 2018 # Update the global XLABuildOpsPassFlags to enable lazy compilation, 2019 # which allows the compiler to fall back to TF classic. Remember the 2020 # old value so that we can reset it. 2021 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) 2022 result = func(self, *args, **kwargs) 2023 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) 2024 return result 2025 else: 2026 return func(self, *args, **kwargs) 2027 2028 return decorated 2029 2030 if func is not None: 2031 return decorator(func) 2032 2033 return decorator 2034 2035 return xla_allow_fallback_impl 2036 2037 2038# The description is just for documentation purposes. 2039def run_without_tensor_float_32(description): # pylint: disable=unused-argument 2040 """Execute test with TensorFloat-32 disabled. 2041 2042 While almost every real-world deep learning model runs fine with 2043 TensorFloat-32, many tests use assertAllClose or similar methods. 2044 TensorFloat-32 matmuls typically will cause such methods to fail with the 2045 default tolerances. 2046 2047 Args: 2048 description: A description used for documentation purposes, describing why 2049 the test requires TensorFloat-32 to be disabled. 2050 2051 Returns: 2052 Decorator which runs a test with TensorFloat-32 disabled. 2053 """ 2054 2055 def decorator(f): 2056 2057 @functools.wraps(f) 2058 def decorated(self, *args, **kwargs): 2059 allowed = config.tensor_float_32_execution_enabled() 2060 try: 2061 config.enable_tensor_float_32_execution(False) 2062 f(self, *args, **kwargs) 2063 finally: 2064 config.enable_tensor_float_32_execution(allowed) 2065 2066 return decorated 2067 2068 return decorator 2069 2070 2071# The description is just for documentation purposes. 2072def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 2073 """Execute all tests in a class with TensorFloat-32 disabled.""" 2074 return for_all_test_methods(run_without_tensor_float_32, description) 2075 2076 2077def matmul_without_tf32(a, b, *args, **kwargs): 2078 """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled. 2079 2080 This effectively runs matmul without TensorFloat-32. It should only be used in 2081 tests when verifying some other op or functions works correctly, e.g. to test 2082 `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In 2083 such cases, the matmul itself is not being tested so it's OK to run it with 2084 higher precision. 2085 2086 If a matmul itself is being tested, or some other op which uses matmul, use 2087 `run_without_tensor_float_32` instead. 2088 2089 This also casts complex64 inputs to complex128, since TensorFloat-32 can also 2090 be used with complex64 2091 2092 Args: 2093 a: First input to tf.linalg.matmul 2094 b: Second input to tf.linalg.matmul 2095 args: Other positional arguments to tf.linalg.matmul 2096 **kwargs: Other keyword arguments to tf.linalg.matmul 2097 2098 Returns: 2099 A tensor with the same type as `a`. 2100 """ 2101 if config.tensor_float_32_execution_enabled() and a.dtype == "float32": 2102 a = math_ops.cast(a, "float64") 2103 b = math_ops.cast(b, "float64") 2104 ret = math_ops.matmul(a, b, *args, **kwargs) 2105 return math_ops.cast(ret, a.dtype) 2106 elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": 2107 a = math_ops.cast(a, "complex128") 2108 b = math_ops.cast(b, "complex128") 2109 ret = math_ops.matmul(a, b, *args, **kwargs) 2110 return math_ops.cast(ret, a.dtype) 2111 else: 2112 return math_ops.matmul(a, b, *args, **kwargs) 2113 2114 2115class EagerSessionWarner(object): 2116 2117 def __getattr__(self, attr): 2118 raise AttributeError( 2119 "Trying to access properties or call methods on the result of " 2120 "self.session(), self.cached_session(), etc while eager execution " 2121 "is enabled. If you're porting this test case to TF 2.0, either " 2122 "adapt the test to work with eager execution or insert a call to " 2123 "tf.disable_eager_execution() in the main() function of this test " 2124 "file.") 2125 2126 2127@tf_export("test.TestCase") 2128class TensorFlowTestCase(googletest.TestCase): 2129 """Base class for tests that need to test TensorFlow.""" 2130 2131 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 2132 super(TensorFlowTestCase, self).__init__(methodName) 2133 # Make sure we get unfiltered stack traces during the test 2134 traceback_utils.disable_traceback_filtering() 2135 if is_xla_enabled(): 2136 pywrap_tf_session.TF_SetXlaAutoJitMode("2") 2137 pywrap_tf_session.TF_SetXlaMinClusterSize(1) 2138 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False) 2139 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True) 2140 # Constant folding secretly runs code on TF:Classic CPU, so we also 2141 # disable it here. 2142 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) 2143 2144 # Check if the mlir bridge has been explicitly enabled or disabled. If 2145 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 2146 # or disable the bridge so do not update enable_mlir_bridge. 2147 if is_mlir_bridge_enabled(): 2148 context.context().enable_mlir_bridge = True 2149 elif is_mlir_bridge_enabled() is not None: 2150 context.context().enable_mlir_bridge = False 2151 2152 self._threads = [] 2153 self._tempdir = None 2154 self._cached_session = None 2155 self._test_start_time = None 2156 # This flag provides the ability to control whether the graph mode gets 2157 # initialized for TF1 or not. Initializing for TF1, which is what was 2158 # happening earlier, was preventing enablement of 'eager mode' in the test. 2159 self._set_default_seed = True 2160 2161 def setUp(self): 2162 super(TensorFlowTestCase, self).setUp() 2163 self._ClearCachedSession() 2164 random.seed(random_seed.DEFAULT_GRAPH_SEED) 2165 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 2166 # Note: The following line is necessary because some test methods may error 2167 # out from within nested graph contexts (e.g., via assertRaises and 2168 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 2169 # under certain versions of Python. That would cause 2170 # ops.reset_default_graph() to throw an exception if the stack were not 2171 # cleared first. 2172 ops._default_graph_stack.reset() # pylint: disable=protected-access 2173 ops.reset_default_graph() 2174 if self._set_default_seed: 2175 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 2176 # Reset summary writer in case another test used set_as_default() with their 2177 # summary writer. 2178 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 2179 summary_state.writer = None 2180 2181 # Avoiding calling setUp() for the poorly named test_session method. 2182 if self.id().endswith(".test_session"): 2183 self.skipTest("Not a test.") 2184 2185 self._test_start_time = time.time() 2186 2187 def tearDown(self): 2188 # If a subclass overrides setUp and doesn't call the parent class's setUp, 2189 # then we may not have set the start time. 2190 if self._test_start_time is not None: 2191 logging.info("time(%s): %ss", self.id(), 2192 round(time.time() - self._test_start_time, 2)) 2193 2194 for thread in self._threads: 2195 thread.check_termination() 2196 2197 self._ClearCachedSession() 2198 super(TensorFlowTestCase, self).tearDown() 2199 2200 def _ClearCachedSession(self): 2201 if self._cached_session is not None: 2202 self._cached_session.close() 2203 self._cached_session = None 2204 2205 def get_temp_dir(self): 2206 """Returns a unique temporary directory for the test to use. 2207 2208 If you call this method multiple times during in a test, it will return the 2209 same folder. However, across different runs the directories will be 2210 different. This will ensure that across different runs tests will not be 2211 able to pollute each others environment. 2212 If you need multiple unique directories within a single test, you should 2213 use tempfile.mkdtemp as follows: 2214 tempfile.mkdtemp(dir=self.get_temp_dir()): 2215 2216 Returns: 2217 string, the path to the unique temporary directory created for this test. 2218 """ 2219 if not self._tempdir: 2220 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 2221 return self._tempdir 2222 2223 @contextlib.contextmanager 2224 def captureWritesToStream(self, stream): 2225 """A context manager that captures the writes to a given stream. 2226 2227 This context manager captures all writes to a given stream inside of a 2228 `CapturedWrites` object. When this context manager is created, it yields 2229 the `CapturedWrites` object. The captured contents can be accessed by 2230 calling `.contents()` on the `CapturedWrites`. 2231 2232 For this function to work, the stream must have a file descriptor that 2233 can be modified using `os.dup` and `os.dup2`, and the stream must support 2234 a `.flush()` method. The default python sys.stdout and sys.stderr are 2235 examples of this. Note that this does not work in Colab or Jupyter 2236 notebooks, because those use alternate stdout streams. 2237 2238 Example: 2239 ```python 2240 class MyOperatorTest(test_util.TensorFlowTestCase): 2241 def testMyOperator(self): 2242 input = [1.0, 2.0, 3.0, 4.0, 5.0] 2243 with self.captureWritesToStream(sys.stdout) as captured: 2244 result = MyOperator(input).eval() 2245 self.assertStartsWith(captured.contents(), "This was printed.") 2246 ``` 2247 2248 Args: 2249 stream: The stream whose writes should be captured. This stream must have 2250 a file descriptor, support writing via using that file descriptor, and 2251 must have a `.flush()` method. 2252 2253 Yields: 2254 A `CapturedWrites` object that contains all writes to the specified stream 2255 made during this context. 2256 """ 2257 stream.flush() 2258 fd = stream.fileno() 2259 tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir()) 2260 tmp_file = open(tmp_file_path, "w") 2261 orig_fd = os.dup(fd) 2262 os.dup2(tmp_file.fileno(), fd) 2263 try: 2264 yield CapturedWrites(tmp_file_path) 2265 finally: 2266 tmp_file.close() 2267 os.dup2(orig_fd, fd) 2268 2269 def _AssertProtoEquals(self, a, b, msg=None): 2270 """Asserts that a and b are the same proto. 2271 2272 Uses ProtoEq() first, as it returns correct results 2273 for floating point attributes, and then use assertProtoEqual() 2274 in case of failure as it provides good error messages. 2275 2276 Args: 2277 a: a proto. 2278 b: another proto. 2279 msg: Optional message to report on failure. 2280 """ 2281 if not compare.ProtoEq(a, b): 2282 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 2283 2284 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 2285 """Asserts that message is same as parsed expected_message_ascii. 2286 2287 Creates another prototype of message, reads the ascii message into it and 2288 then compares them using self._AssertProtoEqual(). 2289 2290 Args: 2291 expected_message_maybe_ascii: proto message in original or ascii form. 2292 message: the message to validate. 2293 msg: Optional message to report on failure. 2294 """ 2295 if isinstance(expected_message_maybe_ascii, type(message)): 2296 expected_message = expected_message_maybe_ascii 2297 self._AssertProtoEquals(expected_message, message, msg=msg) 2298 elif isinstance(expected_message_maybe_ascii, (str, bytes)): 2299 expected_message = type(message)() 2300 text_format.Merge( 2301 expected_message_maybe_ascii, 2302 expected_message, 2303 descriptor_pool=descriptor_pool.Default()) 2304 self._AssertProtoEquals(expected_message, message, msg=msg) 2305 else: 2306 assert False, ("Can't compare protos of type %s and %s." % 2307 (type(expected_message_maybe_ascii), type(message))) 2308 2309 def assertProtoEqualsVersion( 2310 self, 2311 expected, 2312 actual, 2313 producer=versions.GRAPH_DEF_VERSION, 2314 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 2315 msg=None): 2316 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 2317 producer, min_consumer, expected) 2318 self.assertProtoEquals(expected, actual, msg=msg) 2319 2320 def assertStartsWith(self, actual, expected_start, msg=None): 2321 """Assert that actual.startswith(expected_start) is True. 2322 2323 Args: 2324 actual: str 2325 expected_start: str 2326 msg: Optional message to report on failure. 2327 """ 2328 if not actual.startswith(expected_start): 2329 fail_msg = "%r does not start with %r" % (actual, expected_start) 2330 fail_msg += " : %r" % (msg) if msg else "" 2331 self.fail(fail_msg) 2332 2333 def _eval_tensor(self, tensor): 2334 if tensor is None: 2335 return None 2336 elif callable(tensor): 2337 return self._eval_helper(tensor()) 2338 else: 2339 try: 2340 if sparse_tensor.is_sparse(tensor): 2341 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 2342 tensor.values.numpy(), 2343 tensor.dense_shape.numpy()) 2344 elif ragged_tensor.is_ragged(tensor): 2345 return ragged_tensor_value.RaggedTensorValue( 2346 self._eval_tensor(tensor.values), 2347 self._eval_tensor(tensor.row_splits)) 2348 elif isinstance(tensor, ops.IndexedSlices): 2349 return ops.IndexedSlicesValue( 2350 values=tensor.values.numpy(), 2351 indices=tensor.indices.numpy(), 2352 dense_shape=tensor.dense_shape.numpy()) 2353 # Convert tensors and composite tensors to numpy arrays. 2354 return nest.map_structure(lambda t: t.numpy(), tensor, 2355 expand_composites=True) 2356 except AttributeError as e: 2357 six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) 2358 2359 def _eval_helper(self, tensors): 2360 if tensors is None: 2361 return None 2362 return nest.map_structure(self._eval_tensor, tensors) 2363 2364 def evaluate(self, tensors): 2365 """Evaluates tensors and returns numpy values. 2366 2367 Args: 2368 tensors: A Tensor or a nested list/tuple of Tensors. 2369 2370 Returns: 2371 tensors numpy values. 2372 """ 2373 if context.executing_eagerly(): 2374 return self._eval_helper(tensors) 2375 else: 2376 sess = ops.get_default_session() 2377 if sess is None: 2378 with self.test_session() as sess: 2379 return sess.run(tensors) 2380 else: 2381 return sess.run(tensors) 2382 2383 # pylint: disable=g-doc-return-or-yield 2384 @contextlib.contextmanager 2385 def session(self, graph=None, config=None, use_gpu=True, force_gpu=False): 2386 """A context manager for a TensorFlow Session for use in executing tests. 2387 2388 Note that this will set this session and the graph as global defaults. 2389 2390 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2391 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2392 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2393 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2394 the CPU. 2395 2396 Example: 2397 2398 ``` python 2399 class MyOperatorTest(test_util.TensorFlowTestCase): 2400 def testMyOperator(self): 2401 with self.session(): 2402 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2403 result = MyOperator(valid_input).eval() 2404 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2405 invalid_input = [-1.0, 2.0, 7.0] 2406 with self.assertRaisesOpError("negative input not supported"): 2407 MyOperator(invalid_input).eval() 2408 ``` 2409 2410 Args: 2411 graph: Optional graph to use during the returned session. 2412 config: An optional config_pb2.ConfigProto to use to configure the 2413 session. 2414 use_gpu: If True, attempt to run as many ops as possible on GPU. 2415 force_gpu: If True, pin all ops to `/device:GPU:0`. 2416 2417 Yields: 2418 A Session object that should be used as a context manager to surround 2419 the graph building and execution code in a test case. 2420 """ 2421 if context.executing_eagerly(): 2422 yield EagerSessionWarner() 2423 else: 2424 with self._create_session(graph, config, force_gpu) as sess: 2425 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 2426 yield sess 2427 2428 @contextlib.contextmanager 2429 def cached_session(self, 2430 graph=None, 2431 config=None, 2432 use_gpu=True, 2433 force_gpu=False): 2434 """Returns a TensorFlow Session for use in executing tests. 2435 2436 This method behaves differently than self.session(): for performance reasons 2437 `cached_session` will by default reuse the same session within the same 2438 test. The session returned by this function will only be closed at the end 2439 of the test (in the TearDown function). 2440 2441 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2442 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2443 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2444 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2445 the CPU. 2446 2447 Example: 2448 ```python 2449 class MyOperatorTest(test_util.TensorFlowTestCase): 2450 def testMyOperator(self): 2451 with self.cached_session() as sess: 2452 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2453 result = MyOperator(valid_input).eval() 2454 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2455 invalid_input = [-1.0, 2.0, 7.0] 2456 with self.assertRaisesOpError("negative input not supported"): 2457 MyOperator(invalid_input).eval() 2458 ``` 2459 2460 Args: 2461 graph: Optional graph to use during the returned session. 2462 config: An optional config_pb2.ConfigProto to use to configure the 2463 session. 2464 use_gpu: If True, attempt to run as many ops as possible on GPU. 2465 force_gpu: If True, pin all ops to `/device:GPU:0`. 2466 2467 Yields: 2468 A Session object that should be used as a context manager to surround 2469 the graph building and execution code in a test case. 2470 """ 2471 if context.executing_eagerly(): 2472 yield FakeEagerSession(self) 2473 else: 2474 sess = self._get_cached_session( 2475 graph, config, force_gpu, crash_if_inconsistent_args=True) 2476 with self._constrain_devices_and_set_default(sess, use_gpu, 2477 force_gpu) as cached: 2478 yield cached 2479 2480 @contextlib.contextmanager 2481 @deprecation.deprecated(None, "Use `self.session()` or " 2482 "`self.cached_session()` instead.") 2483 def test_session(self, 2484 graph=None, 2485 config=None, 2486 use_gpu=True, 2487 force_gpu=False): 2488 """Use cached_session instead.""" 2489 if self.id().endswith(".test_session"): 2490 self.skipTest( 2491 "Tests that have the name \"test_session\" are automatically skipped " 2492 "by TensorFlow test fixture, as the name is reserved for creating " 2493 "sessions within tests. Please rename your test if you have a test " 2494 "with this name.") 2495 if context.executing_eagerly(): 2496 yield None 2497 else: 2498 if graph is None: 2499 sess = self._get_cached_session( 2500 graph, config, force_gpu, crash_if_inconsistent_args=False) 2501 with self._constrain_devices_and_set_default(sess, use_gpu, 2502 force_gpu) as cached: 2503 yield cached 2504 else: 2505 with self.session(graph, config, use_gpu, force_gpu) as sess: 2506 yield sess 2507 2508 # pylint: enable=g-doc-return-or-yield 2509 2510 class _CheckedThread(object): 2511 """A wrapper class for Thread that asserts successful completion. 2512 2513 This class should be created using the TensorFlowTestCase.checkedThread() 2514 method. 2515 """ 2516 2517 def __init__(self, testcase, target, args=None, kwargs=None): 2518 """Constructs a new instance of _CheckedThread. 2519 2520 Args: 2521 testcase: The TensorFlowTestCase for which this thread is being created. 2522 target: A callable object representing the code to be executed in the 2523 thread. 2524 args: A tuple of positional arguments that will be passed to target. 2525 kwargs: A dictionary of keyword arguments that will be passed to target. 2526 """ 2527 self._testcase = testcase 2528 self._target = target 2529 self._args = () if args is None else args 2530 self._kwargs = {} if kwargs is None else kwargs 2531 self._thread = threading.Thread(target=self._protected_run) 2532 self._exception = None 2533 2534 self._is_thread_joined = False 2535 2536 def _protected_run(self): 2537 """Target for the wrapper thread. Sets self._exception on failure.""" 2538 try: 2539 self._target(*self._args, **self._kwargs) 2540 except Exception as e: # pylint: disable=broad-except 2541 self._exception = e 2542 2543 def start(self): 2544 """Starts the thread's activity. 2545 2546 This must be called at most once per _CheckedThread object. It arranges 2547 for the object's target to be invoked in a separate thread of control. 2548 """ 2549 self._thread.start() 2550 2551 def join(self): 2552 """Blocks until the thread terminates. 2553 2554 Raises: 2555 self._testcase.failureException: If the thread terminates with due to 2556 an exception. 2557 """ 2558 self._is_thread_joined = True 2559 self._thread.join() 2560 if self._exception is not None: 2561 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 2562 2563 def is_alive(self): 2564 """Returns whether the thread is alive. 2565 2566 This method returns True just before the run() method starts 2567 until just after the run() method terminates. 2568 2569 Returns: 2570 True if the thread is alive, otherwise False. 2571 """ 2572 return self._thread.is_alive() 2573 2574 def check_termination(self): 2575 """Returns whether the checked thread was properly used and did terminate. 2576 2577 Every checked thread should be "join"ed after starting, and before the 2578 test tears down. If it is not joined, it is possible the thread will hang 2579 and cause flaky failures in tests. 2580 2581 Raises: 2582 self._testcase.failureException: If check_termination was called before 2583 thread was joined. 2584 2585 RuntimeError: If the thread is not terminated. This means thread was not 2586 joined with the main thread. 2587 """ 2588 if self._is_thread_joined: 2589 if self.is_alive(): 2590 raise RuntimeError( 2591 "Thread was not joined with main thread, and is still running " 2592 "when the test finished.") 2593 else: 2594 self._testcase.fail("A checked thread was not joined.") 2595 2596 def checkedThread(self, target, args=None, kwargs=None): 2597 """Returns a Thread wrapper that asserts 'target' completes successfully. 2598 2599 This method should be used to create all threads in test cases, as 2600 otherwise there is a risk that a thread will silently fail, and/or 2601 assertions made in the thread will not be respected. 2602 2603 Args: 2604 target: A callable object to be executed in the thread. 2605 args: The argument tuple for the target invocation. Defaults to (). 2606 kwargs: A dictionary of keyword arguments for the target invocation. 2607 Defaults to {}. 2608 2609 Returns: 2610 A wrapper for threading.Thread that supports start() and join() methods. 2611 """ 2612 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 2613 self._threads.append(ret) 2614 return ret 2615 2616 # pylint: enable=invalid-name 2617 @py_func_if_in_function 2618 def assertNear(self, f1, f2, err, msg=None): 2619 """Asserts that two floats are near each other. 2620 2621 Checks that |f1 - f2| < err and asserts a test failure 2622 if not. 2623 2624 Args: 2625 f1: A float value. 2626 f2: A float value. 2627 err: A float value. 2628 msg: An optional string message to append to the failure message. 2629 """ 2630 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 2631 self.assertTrue( 2632 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 2633 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 2634 2635 @py_func_if_in_function 2636 def assertArrayNear(self, farray1, farray2, err, msg=None): 2637 """Asserts that two float arrays are near each other. 2638 2639 Checks that for all elements of farray1 and farray2 2640 |f1 - f2| < err. Asserts a test failure if not. 2641 2642 Args: 2643 farray1: a list of float values. 2644 farray2: a list of float values. 2645 err: a float value. 2646 msg: Optional message to report on failure. 2647 """ 2648 self.assertEqual(len(farray1), len(farray2), msg=msg) 2649 for f1, f2 in zip(farray1, farray2): 2650 self.assertNear(float(f1), float(f2), err, msg=msg) 2651 2652 def _NDArrayNear(self, ndarray1, ndarray2, err): 2653 return np.linalg.norm(ndarray1 - ndarray2) < err 2654 2655 @py_func_if_in_function 2656 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 2657 """Asserts that two numpy arrays have near values. 2658 2659 Args: 2660 ndarray1: a numpy ndarray. 2661 ndarray2: a numpy ndarray. 2662 err: a float. The maximum absolute difference allowed. 2663 msg: Optional message to report on failure. 2664 """ 2665 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 2666 2667 def _GetNdArray(self, a): 2668 # If a is tensor-like then convert it to ndarray 2669 if tensor_util.is_tf_type(a): 2670 if isinstance(a, ops._EagerTensorBase): 2671 a = a.numpy() 2672 else: 2673 a = self.evaluate(a) 2674 if not isinstance(a, np.ndarray): 2675 return np.array(a) 2676 return a 2677 2678 def evaluate_if_both_tensors(self, a, b): 2679 if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and 2680 not isinstance(a, ops._EagerTensorBase) and 2681 not isinstance(b, ops._EagerTensorBase)): 2682 return self.evaluate((a, b)) 2683 else: 2684 return (a, b) 2685 2686 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2687 (a, b) = self.evaluate_if_both_tensors(a, b) 2688 a = self._GetNdArray(a) 2689 b = self._GetNdArray(b) 2690 # When the array rank is small, print its contents. Numpy array printing is 2691 # implemented using inefficient recursion so prints can cause tests to 2692 # time out. 2693 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 2694 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 2695 "%s.") % (a.shape, b.shape, b) 2696 else: 2697 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 2698 b.shape) 2699 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 2700 2701 msgs = [msg] 2702 # np.allclose does not always work for our custom bfloat16 extension type 2703 # when type promotions are involved, so we first cast any bfloat16 arrays 2704 # to float32. 2705 a_dtype = a.dtype 2706 a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a 2707 b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b 2708 if not np.allclose(a, b, rtol=rtol, atol=atol): 2709 # Adds more details to np.testing.assert_allclose. 2710 # 2711 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 2712 # checks whether two arrays are element-wise equal within a 2713 # tolerance. The relative difference (rtol * abs(b)) and the 2714 # absolute difference atol are added together to compare against 2715 # the absolute difference between a and b. Here, we want to 2716 # tell user which elements violate such conditions. 2717 cond = np.logical_or( 2718 np.abs(a - b) > atol + rtol * np.abs(b), 2719 np.isnan(a) != np.isnan(b)) 2720 if a.ndim: 2721 x = a[np.where(cond)] 2722 y = b[np.where(cond)] 2723 msgs.append("not close where = {}".format(np.where(cond))) 2724 else: 2725 # np.where is broken for scalars 2726 x, y = a, b 2727 msgs.append("not close lhs = {}".format(x)) 2728 msgs.append("not close rhs = {}".format(y)) 2729 msgs.append("not close dif = {}".format(np.abs(x - y))) 2730 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 2731 msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape)) 2732 # TODO(xpan): There seems to be a bug: 2733 # tensorflow/compiler/tests:binary_ops_test pass with float32 2734 # nan even though the equal_nan is False by default internally. 2735 np.testing.assert_allclose( 2736 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 2737 2738 def _assertAllCloseRecursive(self, 2739 a, 2740 b, 2741 rtol=1e-6, 2742 atol=1e-6, 2743 path=None, 2744 msg=None): 2745 path = path or [] 2746 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "") 2747 msg = msg if msg else "" 2748 2749 # Check if a and/or b are namedtuples. 2750 if hasattr(a, "_asdict"): 2751 a = a._asdict() 2752 if hasattr(b, "_asdict"): 2753 b = b._asdict() 2754 a_is_dict = isinstance(a, collections_abc.Mapping) 2755 if a_is_dict != isinstance(b, collections_abc.Mapping): 2756 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 2757 (path_str, path_str, msg)) 2758 if a_is_dict: 2759 self.assertItemsEqual( 2760 a.keys(), 2761 b.keys(), 2762 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 2763 (path_str, a.keys(), path_str, b.keys(), msg)) 2764 for k in a: 2765 path.append(k) 2766 self._assertAllCloseRecursive( 2767 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 2768 del path[-1] 2769 elif isinstance(a, (list, tuple)): 2770 # Try to directly compare a, b as ndarrays; if not work, then traverse 2771 # through the sequence, which is more expensive. 2772 try: 2773 (a, b) = self.evaluate_if_both_tensors(a, b) 2774 a_as_ndarray = self._GetNdArray(a) 2775 b_as_ndarray = self._GetNdArray(b) 2776 self._assertArrayLikeAllClose( 2777 a_as_ndarray, 2778 b_as_ndarray, 2779 rtol=rtol, 2780 atol=atol, 2781 msg="Mismatched value: a%s is different from b%s. %s" % 2782 (path_str, path_str, msg)) 2783 except (ValueError, TypeError) as e: 2784 if len(a) != len(b): 2785 raise ValueError( 2786 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 2787 (path_str, len(a), path_str, len(b), msg)) 2788 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 2789 path.append(str(idx)) 2790 self._assertAllCloseRecursive( 2791 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 2792 del path[-1] 2793 # a and b are ndarray like objects 2794 else: 2795 try: 2796 self._assertArrayLikeAllClose( 2797 a, 2798 b, 2799 rtol=rtol, 2800 atol=atol, 2801 msg=("Mismatched value: a%s is different from b%s. %s" % 2802 (path_str, path_str, msg))) 2803 except TypeError as e: 2804 msg = ("Error: a%s has %s, but b%s has %s. %s" % 2805 (path_str, type(a), path_str, type(b), msg)) 2806 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 2807 raise 2808 2809 @py_func_if_in_function 2810 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2811 """Asserts that two structures of numpy arrays or Tensors, have near values. 2812 2813 `a` and `b` can be arbitrarily nested structures. A layer of a nested 2814 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 2815 2816 Note: the implementation follows 2817 [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html) 2818 (and numpy.testing.assert_allclose). It checks whether two arrays are 2819 element-wise equal within a tolerance. The relative difference 2820 (`rtol * abs(b)`) and the absolute difference `atol` are added together 2821 to compare against the absolute difference between `a` and `b`. 2822 2823 Args: 2824 a: The expected numpy `ndarray`, or anything that can be converted into a 2825 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2826 structure of these. 2827 b: The actual numpy `ndarray`, or anything that can be converted into a 2828 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2829 structure of these. 2830 rtol: relative tolerance. 2831 atol: absolute tolerance. 2832 msg: Optional message to report on failure. 2833 2834 Raises: 2835 ValueError: if only one of `a[p]` and `b[p]` is a dict or 2836 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 2837 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 2838 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 2839 """ 2840 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b): 2841 return self._assertRaggedClose(a, b, rtol, atol, msg) 2842 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 2843 2844 @py_func_if_in_function 2845 def assertAllCloseAccordingToType(self, 2846 a, 2847 b, 2848 rtol=1e-6, 2849 atol=1e-6, 2850 float_rtol=1e-6, 2851 float_atol=1e-6, 2852 half_rtol=1e-3, 2853 half_atol=1e-3, 2854 bfloat16_rtol=1e-2, 2855 bfloat16_atol=1e-2, 2856 msg=None): 2857 """Like assertAllClose, but also suitable for comparing fp16 arrays. 2858 2859 In particular, the tolerance is reduced to 1e-3 if at least 2860 one of the arguments is of type float16. 2861 2862 Args: 2863 a: the expected numpy ndarray or anything can be converted to one. 2864 b: the actual numpy ndarray or anything can be converted to one. 2865 rtol: relative tolerance. 2866 atol: absolute tolerance. 2867 float_rtol: relative tolerance for float32. 2868 float_atol: absolute tolerance for float32. 2869 half_rtol: relative tolerance for float16. 2870 half_atol: absolute tolerance for float16. 2871 bfloat16_rtol: relative tolerance for bfloat16. 2872 bfloat16_atol: absolute tolerance for bfloat16. 2873 msg: Optional message to report on failure. 2874 """ 2875 (a, b) = self.evaluate_if_both_tensors(a, b) 2876 a = self._GetNdArray(a) 2877 b = self._GetNdArray(b) 2878 # types with lower tol are put later to overwrite previous ones. 2879 if (a.dtype == np.float32 or b.dtype == np.float32 or 2880 a.dtype == np.complex64 or b.dtype == np.complex64): 2881 rtol = max(rtol, float_rtol) 2882 atol = max(atol, float_atol) 2883 if a.dtype == np.float16 or b.dtype == np.float16: 2884 rtol = max(rtol, half_rtol) 2885 atol = max(atol, half_atol) 2886 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 2887 b.dtype == dtypes.bfloat16.as_numpy_dtype): 2888 rtol = max(rtol, bfloat16_rtol) 2889 atol = max(atol, bfloat16_atol) 2890 2891 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2892 2893 @py_func_if_in_function 2894 def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2895 """Assert that two numpy arrays, or Tensors, do not have near values. 2896 2897 Args: 2898 a: The expected numpy `ndarray`, or anything that can be converted into a 2899 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2900 structure of these. 2901 b: The actual numpy `ndarray`, or anything that can be converted into a 2902 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2903 structure of these. 2904 rtol: relative tolerance. 2905 atol: absolute tolerance. 2906 msg: Optional message to report on failure. 2907 2908 Raises: 2909 AssertionError: If `a` and `b` are unexpectedly close at all elements. 2910 """ 2911 try: 2912 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2913 except AssertionError: 2914 return 2915 msg = msg or "" 2916 raise AssertionError("The two values are close at all elements. %s" % msg) 2917 2918 @py_func_if_in_function 2919 def assertAllEqual(self, a, b, msg=None): 2920 """Asserts that two numpy arrays or Tensors have the same values. 2921 2922 Args: 2923 a: the expected numpy ndarray or anything can be converted to one. 2924 b: the actual numpy ndarray or anything can be converted to one. 2925 msg: Optional message to report on failure. 2926 """ 2927 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)): 2928 return self._assertRaggedEqual(a, b, msg) 2929 msg = msg if msg else "" 2930 (a, b) = self.evaluate_if_both_tensors(a, b) 2931 a = self._GetNdArray(a) 2932 b = self._GetNdArray(b) 2933 # Arbitrary bounds so that we don't print giant tensors. 2934 if (b.ndim <= 3 or b.size < 500): 2935 self.assertEqual( 2936 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2937 " Contents: %r. \n%s." % (a.shape, b.shape, b, msg)) 2938 else: 2939 self.assertEqual( 2940 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2941 " %s" % (a.shape, b.shape, msg)) 2942 2943 same = (a == b) 2944 2945 if (a.dtype in [ 2946 np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype 2947 ]): 2948 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 2949 msgs = [msg] 2950 if not np.all(same): 2951 # Adds more details to np.testing.assert_array_equal. 2952 diff = np.logical_not(same) 2953 if a.ndim: 2954 x = a[np.where(diff)] 2955 y = b[np.where(diff)] 2956 msgs.append("not equal where = {}".format(np.where(diff))) 2957 else: 2958 # np.where is broken for scalars 2959 x, y = a, b 2960 msgs.append("not equal lhs = %r" % x) 2961 msgs.append("not equal rhs = %r" % y) 2962 2963 # Handle mixed string types as a result of PY2to3 migration. That is, the 2964 # mixing between bytes (b-prefix strings, PY2 default) and unicodes 2965 # (u-prefix strings, PY3 default). 2966 if six.PY3: 2967 if (a.dtype.kind != b.dtype.kind and 2968 {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})): 2969 a_list = [] 2970 b_list = [] 2971 # OK to flatten `a` and `b` because they are guaranteed to have the 2972 # same shape. 2973 for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]: 2974 for item in flat_arr: 2975 if isinstance(item, str): 2976 out_list.append(item.encode("utf-8")) 2977 else: 2978 out_list.append(item) 2979 a = np.array(a_list) 2980 b = np.array(b_list) 2981 2982 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 2983 2984 @py_func_if_in_function 2985 def assertNotAllEqual(self, a, b, msg=None): 2986 """Asserts that two numpy arrays or Tensors do not have the same values. 2987 2988 Args: 2989 a: the expected numpy ndarray or anything can be converted to one. 2990 b: the actual numpy ndarray or anything can be converted to one. 2991 msg: Optional message to report on failure. 2992 """ 2993 try: 2994 self.assertAllEqual(a, b) 2995 except AssertionError: 2996 return 2997 raise AssertionError("The two values are equal at all elements. %s" % msg) 2998 2999 @py_func_if_in_function 3000 def assertAllGreater(self, a, comparison_target): 3001 """Assert element values are all greater than a target value. 3002 3003 Args: 3004 a: The numpy `ndarray`, or anything that can be converted into a numpy 3005 `ndarray` (including Tensor). 3006 comparison_target: The target value of comparison. 3007 """ 3008 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3009 a = self._GetNdArray(a) 3010 self.assertGreater(np.min(a), comparison_target) 3011 3012 @py_func_if_in_function 3013 def assertAllLess(self, a, comparison_target): 3014 """Assert element values are all less than a target value. 3015 3016 Args: 3017 a: The numpy `ndarray`, or anything that can be converted into a numpy 3018 `ndarray` (including Tensor). 3019 comparison_target: The target value of comparison. 3020 """ 3021 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3022 a = self._GetNdArray(a) 3023 self.assertLess(np.max(a), comparison_target) 3024 3025 @py_func_if_in_function 3026 def assertAllGreaterEqual(self, a, comparison_target): 3027 """Assert element values are all greater than or equal to a target value. 3028 3029 Args: 3030 a: The numpy `ndarray`, or anything that can be converted into a numpy 3031 `ndarray` (including Tensor). 3032 comparison_target: The target value of comparison. 3033 """ 3034 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3035 a = self._GetNdArray(a) 3036 self.assertGreaterEqual(np.min(a), comparison_target) 3037 3038 @py_func_if_in_function 3039 def assertAllLessEqual(self, a, comparison_target): 3040 """Assert element values are all less than or equal to a target value. 3041 3042 Args: 3043 a: The numpy `ndarray`, or anything that can be converted into a numpy 3044 `ndarray` (including Tensor). 3045 comparison_target: The target value of comparison. 3046 """ 3047 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 3048 a = self._GetNdArray(a) 3049 self.assertLessEqual(np.max(a), comparison_target) 3050 3051 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 3052 """Generate a summary of ndarray subscripts as a list of str. 3053 3054 If limit == N, this method will print up to the first N subscripts on 3055 separate 3056 lines. A line of ellipses (...) will be appended at the end if the number of 3057 subscripts exceeds N. 3058 3059 Args: 3060 subscripts: The tensor (np.ndarray) subscripts, of the same format as 3061 np.where()'s return value, i.e., a tuple of arrays with each array 3062 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 3063 value: (np.ndarray) value of the tensor. 3064 limit: (int) The maximum number of indices to print. 3065 indent: (int) Number of characters to indent at the beginning of each 3066 line. 3067 3068 Returns: 3069 (list of str) the multi-line representation of the subscripts and values, 3070 potentially with omission at the end. 3071 """ 3072 lines = [] 3073 subscripts = np.transpose(subscripts) 3074 prefix = " " * indent 3075 if np.ndim(value) == 0: 3076 return [prefix + "[0] : " + str(value)] 3077 for subscript in itertools.islice(subscripts, limit): 3078 lines.append(prefix + str(subscript) + " : " + 3079 str(value[tuple(subscript)])) 3080 if len(subscripts) > limit: 3081 lines.append(prefix + "...") 3082 return lines 3083 3084 @py_func_if_in_function 3085 def assertAllInRange(self, 3086 target, 3087 lower_bound, 3088 upper_bound, 3089 open_lower_bound=False, 3090 open_upper_bound=False): 3091 """Assert that elements in a Tensor are all in a given range. 3092 3093 Args: 3094 target: The numpy `ndarray`, or anything that can be converted into a 3095 numpy `ndarray` (including Tensor). 3096 lower_bound: lower bound of the range 3097 upper_bound: upper bound of the range 3098 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 3099 than the default >=) 3100 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 3101 than the default <=) 3102 3103 Raises: 3104 AssertionError: 3105 if the value tensor does not have an ordered numeric type (float* or 3106 int*), or 3107 if there are nan values, or 3108 if any of the elements do not fall in the specified range. 3109 """ 3110 target = self._GetNdArray(target) 3111 if not (np.issubdtype(target.dtype, np.floating) or 3112 np.issubdtype(target.dtype, np.integer)): 3113 raise AssertionError( 3114 "The value of %s does not have an ordered numeric type, instead it " 3115 "has type: %s" % (target, target.dtype)) 3116 3117 nan_subscripts = np.where(np.isnan(target)) 3118 if np.size(nan_subscripts): 3119 raise AssertionError( 3120 "%d of the %d element(s) are NaN. " 3121 "Subscripts(s) and value(s) of the NaN element(s):\n" % 3122 (len(nan_subscripts[0]), np.size(target)) + 3123 "\n".join(self._format_subscripts(nan_subscripts, target))) 3124 3125 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 3126 str(upper_bound) + (")" if open_upper_bound else "]")) 3127 3128 violations = ( 3129 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 3130 target, lower_bound)) 3131 violations = np.logical_or( 3132 violations, 3133 np.greater_equal(target, upper_bound) 3134 if open_upper_bound else np.greater(target, upper_bound)) 3135 violation_subscripts = np.where(violations) 3136 if np.size(violation_subscripts): 3137 raise AssertionError( 3138 "%d of the %d element(s) are outside the range %s. " % 3139 (len(violation_subscripts[0]), np.size(target), range_str) + 3140 "Subscript(s) and value(s) of the offending elements:\n" + 3141 "\n".join(self._format_subscripts(violation_subscripts, target))) 3142 3143 @py_func_if_in_function 3144 def assertAllInSet(self, target, expected_set): 3145 """Assert that elements of a Tensor are all in a given closed set. 3146 3147 Args: 3148 target: The numpy `ndarray`, or anything that can be converted into a 3149 numpy `ndarray` (including Tensor). 3150 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 3151 of the value of `target` are expected to fall into. 3152 3153 Raises: 3154 AssertionError: 3155 if any of the elements do not fall into `expected_set`. 3156 """ 3157 target = self._GetNdArray(target) 3158 3159 # Elements in target that are not in expected_set. 3160 diff = np.setdiff1d(target.flatten(), list(expected_set)) 3161 if np.size(diff): 3162 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 3163 (np.size(diff), expected_set, diff)) 3164 3165 @py_func_if_in_function 3166 def assertDTypeEqual(self, target, expected_dtype): 3167 """Assert ndarray data type is equal to expected. 3168 3169 Args: 3170 target: The numpy `ndarray`, or anything that can be converted into a 3171 numpy `ndarray` (including Tensor). 3172 expected_dtype: Expected data type. 3173 """ 3174 target = self._GetNdArray(target) 3175 if not isinstance(target, list): 3176 arrays = [target] 3177 for arr in arrays: 3178 self.assertEqual(arr.dtype, expected_dtype) 3179 3180 # pylint: disable=g-doc-return-or-yield 3181 @contextlib.contextmanager 3182 def assertRaisesWithPredicateMatch(self, exception_type, 3183 expected_err_re_or_predicate): 3184 """Returns a context manager to enclose code expected to raise an exception. 3185 3186 If the exception is an OpError, the op stack is also included in the message 3187 predicate search. 3188 3189 Args: 3190 exception_type: The expected type of exception that should be raised. 3191 expected_err_re_or_predicate: If this is callable, it should be a function 3192 of one argument that inspects the passed-in exception and returns True 3193 (success) or False (please fail the test). Otherwise, the error message 3194 is expected to match this regular expression partially. 3195 3196 Returns: 3197 A context manager to surround code that is expected to raise an 3198 exception. 3199 """ 3200 if callable(expected_err_re_or_predicate): 3201 predicate = expected_err_re_or_predicate 3202 else: 3203 3204 def predicate(e): 3205 err_str = e.message if isinstance(e, errors.OpError) else str(e) 3206 op = e.op if isinstance(e, errors.OpError) else None 3207 while op is not None: 3208 err_str += "\nCaused by: " + op.name 3209 op = op._original_op # pylint: disable=protected-access 3210 logging.info("Searching within error strings: '%s' within '%s'", 3211 expected_err_re_or_predicate, err_str) 3212 return re.search(expected_err_re_or_predicate, err_str) 3213 3214 try: 3215 yield 3216 self.fail(exception_type.__name__ + " not raised") 3217 except Exception as e: # pylint: disable=broad-except 3218 if not isinstance(e, exception_type) or not predicate(e): 3219 raise AssertionError("Exception of type %s: %s" % 3220 (str(type(e)), str(e))) 3221 3222 # pylint: enable=g-doc-return-or-yield 3223 3224 def assertRaisesOpError(self, expected_err_re_or_predicate): 3225 return self.assertRaisesWithPredicateMatch(errors.OpError, 3226 expected_err_re_or_predicate) 3227 3228 def assertRaisesIncompatibleShapesError( 3229 self, exception_type=errors.InvalidArgumentError): 3230 return self.assertRaisesWithPredicateMatch( 3231 exception_type, r"Incompatible shapes|Dimensions must be equal|" 3232 r"required broadcastable shapes") 3233 3234 def assertShapeEqual(self, np_array, tf_tensor, msg=None): 3235 """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. 3236 3237 Args: 3238 np_array: A Numpy ndarray or Numpy scalar. 3239 tf_tensor: A Tensor. 3240 msg: Optional message to report on failure. 3241 3242 Raises: 3243 TypeError: If the arguments have the wrong type. 3244 """ 3245 if not isinstance(np_array, (np.ndarray, np.generic)): 3246 raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") 3247 if not isinstance(tf_tensor, ops.Tensor): 3248 raise TypeError("tf_tensor must be a Tensor") 3249 self.assertAllEqual( 3250 np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) 3251 3252 def assertDeviceEqual(self, device1, device2, msg=None): 3253 """Asserts that the two given devices are the same. 3254 3255 Args: 3256 device1: A string device name or TensorFlow `DeviceSpec` object. 3257 device2: A string device name or TensorFlow `DeviceSpec` object. 3258 msg: Optional message to report on failure. 3259 """ 3260 device1 = pydev.canonical_name(device1) 3261 device2 = pydev.canonical_name(device2) 3262 self.assertEqual( 3263 device1, device2, 3264 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 3265 3266 def _GetPyList(self, a): 3267 """Converts `a` to a nested python list.""" 3268 if isinstance(a, ragged_tensor.RaggedTensor): 3269 return self.evaluate(a).to_list() 3270 elif isinstance(a, ops.Tensor): 3271 a = self.evaluate(a) 3272 return a.tolist() if isinstance(a, np.ndarray) else a 3273 elif isinstance(a, np.ndarray): 3274 return a.tolist() 3275 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 3276 return a.to_list() 3277 else: 3278 return np.array(a).tolist() 3279 3280 def _assertRaggedEqual(self, a, b, msg): 3281 """Asserts that two ragged tensors are equal.""" 3282 a_list = self._GetPyList(a) 3283 b_list = self._GetPyList(b) 3284 self.assertEqual(a_list, b_list, msg) 3285 3286 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3287 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3288 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3289 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3290 3291 def _assertRaggedClose(self, a, b, rtol, atol, msg=None): 3292 a_list = self._GetPyList(a) 3293 b_list = self._GetPyList(b) 3294 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg) 3295 3296 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3297 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3298 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3299 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3300 3301 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): 3302 self.assertEqual(type(a), type(b)) 3303 if isinstance(a, (list, tuple)): 3304 self.assertLen(a, len(b), "Length differs for %s" % path) 3305 for i in range(len(a)): 3306 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg, 3307 "%s[%s]" % (path, i)) 3308 else: 3309 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) 3310 3311 # Fix Python 3+ compatibility issues 3312 if not six.PY2: 3313 # pylint: disable=invalid-name 3314 3315 # Silence a deprecation warning 3316 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 3317 3318 # assertItemsEqual is assertCountEqual as of 3.2. 3319 assertItemsEqual = googletest.TestCase.assertCountEqual 3320 3321 # pylint: enable=invalid-name 3322 3323 @contextlib.contextmanager 3324 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 3325 """Set the session and its graph to global default and constrain devices.""" 3326 if context.executing_eagerly(): 3327 yield None 3328 else: 3329 with sess.graph.as_default(), sess.as_default(): 3330 if force_gpu: 3331 # Use the name of an actual device if one is detected, or 3332 # '/device:GPU:0' otherwise 3333 gpu_name = gpu_device_name() 3334 if not gpu_name: 3335 gpu_name = "/device:GPU:0" 3336 with sess.graph.device(gpu_name): 3337 yield sess 3338 elif use_gpu: 3339 yield sess 3340 else: 3341 with sess.graph.device("/device:CPU:0"): 3342 yield sess 3343 3344 def _create_session(self, graph, config, force_gpu): 3345 """See session() for details.""" 3346 3347 def prepare_config(config): 3348 """Returns a config for sessions. 3349 3350 Args: 3351 config: An optional config_pb2.ConfigProto to use to configure the 3352 session. 3353 3354 Returns: 3355 A config_pb2.ConfigProto object. 3356 """ 3357 # TODO(b/114333779): Enforce allow_soft_placement=False when 3358 # use_gpu=False. Currently many tests rely on the fact that any device 3359 # will be used even when a specific device is supposed to be used. 3360 allow_soft_placement = not force_gpu 3361 if config is None: 3362 config = context.context().config 3363 config.allow_soft_placement = allow_soft_placement 3364 elif not allow_soft_placement and config.allow_soft_placement: 3365 config_copy = context.context().config 3366 config = config_copy 3367 config.allow_soft_placement = False 3368 # Don't perform optimizations for tests so we don't inadvertently run 3369 # gpu ops on cpu 3370 config.graph_options.optimizer_options.opt_level = -1 3371 # Disable Grappler constant folding since some tests & benchmarks 3372 # use constant input and become meaningless after constant folding. 3373 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 3374 # GRAPPLER TEAM. 3375 config.graph_options.rewrite_options.constant_folding = ( 3376 rewriter_config_pb2.RewriterConfig.OFF) 3377 config.graph_options.rewrite_options.pin_to_host_optimization = ( 3378 rewriter_config_pb2.RewriterConfig.OFF) 3379 return config 3380 3381 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 3382 3383 def _get_cached_session(self, 3384 graph=None, 3385 config=None, 3386 force_gpu=False, 3387 crash_if_inconsistent_args=True): 3388 """See cached_session() for documentation.""" 3389 if self._cached_session is None: 3390 sess = self._create_session( 3391 graph=graph, config=config, force_gpu=force_gpu) 3392 self._cached_session = sess 3393 self._cached_graph = graph 3394 self._cached_config = config 3395 self._cached_force_gpu = force_gpu 3396 return sess 3397 else: 3398 if crash_if_inconsistent_args and self._cached_graph is not graph: 3399 raise ValueError("The graph used to get the cached session is " 3400 "different than the one that was used to create the " 3401 "session. Maybe create a new session with " 3402 "self.session()") 3403 if crash_if_inconsistent_args and self._cached_config is not config: 3404 raise ValueError("The config used to get the cached session is " 3405 "different than the one that was used to create the " 3406 "session. Maybe create a new session with " 3407 "self.session()") 3408 if crash_if_inconsistent_args and (self._cached_force_gpu is 3409 not force_gpu): 3410 raise ValueError( 3411 "The force_gpu value used to get the cached session is " 3412 "different than the one that was used to create the " 3413 "session. Maybe create a new session with " 3414 "self.session()") 3415 return self._cached_session 3416 3417 3418@tf_export("test.create_local_cluster") 3419def create_local_cluster(num_workers, 3420 num_ps, 3421 protocol="grpc", 3422 worker_config=None, 3423 ps_config=None): 3424 """Create and start local servers and return the associated `Server` objects. 3425 3426 "PS" stands for "parameter server": a task responsible for storing and 3427 updating the model's parameters. Other tasks send updates to these parameters 3428 as they work on optimizing the parameters. This particular division of labor 3429 between tasks is not required, but is common for distributed training. 3430 3431 Read more at https://www.tensorflow.org/guide/extend/architecture 3432 3433  3434 3435 3436 Figure illustrates the interaction of these components. 3437 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services. 3438 3439 3440 Example: 3441 ```python 3442 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 3443 3444 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers] 3445 3446 with tf.device("/job:ps/task:0"): 3447 ... 3448 with tf.device("/job:ps/task:1"): 3449 ... 3450 with tf.device("/job:worker/task:0"): 3451 ... 3452 with tf.device("/job:worker/task:1"): 3453 ... 3454 3455 worker_sessions[0].run(...) 3456 ``` 3457 3458 Args: 3459 num_workers: Number of worker servers to start. 3460 num_ps: Number of PS servers to start. 3461 protocol: Communication protocol. Allowed values are documented in the 3462 documentation of `tf.distribute.Server`. 3463 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be 3464 used to instantiate multiple devices etc. 3465 ps_config: (optional) `tf.ConfigProto` to initialize PS servers. 3466 3467 Returns: 3468 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 3469 of `num_workers` objects of type `tf.distribute.Server` (all running 3470 locally); 3471 and `ps_servers` is a list of `num_ps` objects of similar type. 3472 3473 Raises: 3474 ImportError: if portpicker module was not found at load time 3475 """ 3476 import portpicker # pylint: disable=g-import-not-at-top 3477 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 3478 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 3479 cluster_dict = { 3480 "worker": ["localhost:%s" % port for port in worker_ports], 3481 "ps": ["localhost:%s" % port for port in ps_ports] 3482 } 3483 cs = server_lib.ClusterSpec(cluster_dict) 3484 3485 workers = [ 3486 server_lib.Server( 3487 cs, 3488 job_name="worker", 3489 protocol=protocol, 3490 task_index=ix, 3491 config=worker_config, 3492 start=True) for ix in range(num_workers) 3493 ] 3494 ps_servers = [ 3495 server_lib.Server( 3496 cs, 3497 job_name="ps", 3498 protocol=protocol, 3499 task_index=ix, 3500 config=ps_config, 3501 start=True) for ix in range(num_ps) 3502 ] 3503 3504 return workers, ps_servers 3505 3506 3507def get_node_def_from_graph(node_name, graph_def): 3508 """Returns the `NodeDef` instance for given node name in the graph def. 3509 3510 This method explores only the NodeDefs in `graph_def.node`. 3511 3512 Args: 3513 node_name: Name of the NodeDef to search for. 3514 graph_def: An instance of `GraphDef` proto. 3515 3516 Returns: 3517 the `NodeDef` instance whose name field matches the given node_name or None. 3518 """ 3519 for node_def in graph_def.node: 3520 if node_def.name == node_name: 3521 return node_def 3522 return None 3523 3524 3525def set_producer_version(graph, producer_version): 3526 """Sets graph.graph_def_versions.producer to `producer_version`.""" 3527 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 3528 # it via import_graph_def though. 3529 graph_def = graph_pb2.GraphDef() 3530 graph_def.versions.producer = producer_version 3531 with graph.as_default(): 3532 importer.import_graph_def(graph_def) 3533 assert graph.graph_def_versions.producer, producer_version 3534 3535 3536@contextlib.contextmanager 3537def _fake_gradient_tape_context_manager(): 3538 """tf.gradients(...) implemented as tf.GradientTape context manager interface. 3539 3540 This is useful to test tf.gradients() in tests that uses tf.GradientTape(). 3541 3542 Yields: 3543 gradient tape instance that's implemented by tf.gradients() underneath. 3544 """ 3545 try: 3546 class FakeGradientTape: 3547 3548 def watch(self, x): 3549 pass 3550 3551 def gradient(self, y, x, grad_ys=None): 3552 result = gradients_impl.gradients(y, x, grad_ys) 3553 3554 # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single 3555 # element. So unpack if needed to match `tape.gradient()` behavior. 3556 if not isinstance(x, (list, tuple)): 3557 assert len(result) == 1 3558 return result[0] 3559 3560 return result 3561 3562 yield FakeGradientTape() 3563 finally: 3564 pass 3565 3566 3567class AbstractGradientTape: 3568 """Abstract GradientTape context manager that has multiple implementations. 3569 3570 This is useful to test both tf.GradientTape() and tf.gradients() without 3571 duplicating tests. 3572 """ 3573 3574 def __init__(self, use_tape, persistent=False): 3575 self._use_tape = use_tape 3576 self._persistent = persistent 3577 3578 def __enter__(self): 3579 if self._use_tape: 3580 self._tape_impl = backprop.GradientTape(persistent=self._persistent) 3581 else: 3582 self._tape_impl = _fake_gradient_tape_context_manager() 3583 return self._tape_impl.__enter__() 3584 3585 def __exit__(self, exc_type, exc_val, exc_tb): 3586 self._tape_impl.__exit__(exc_type, exc_val, exc_tb) 3587 3588 3589@contextlib.contextmanager 3590def run_functions_eagerly(run_eagerly): 3591 """Runs functions eagerly if `run_eagerly` is true. 3592 3593 WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode 3594 *WILL NOT* make the tf.function to run eagerly because eager is disabled by 3595 default in V1. Instead, tf.function will run as a traced graph function. 3596 3597 Ensures that the state (for running functions eagerly) is back to the initial 3598 `def_function.RUN_FUNCTIONS_EAGERLY` state. 3599 3600 Args: 3601 run_eagerly: Boolean determining whether to run the function eagerly or not. 3602 3603 Raises: 3604 ValueError if `run_eagerly` is not a boolean. 3605 3606 Yields: 3607 Nothing. 3608 """ 3609 if not isinstance(run_eagerly, bool): 3610 raise ValueError( 3611 "Expected bool for `run_eagerly` but got {}".format(run_eagerly)) 3612 3613 is_eager = context.executing_eagerly() 3614 if not is_eager and run_eagerly: 3615 logging.warning( 3616 "Running tf.function eagerly in V1 graph mode is not supported. " 3617 "tf.function will be run as a traced graph function.") 3618 3619 initial_state = def_function.functions_run_eagerly() 3620 def_function.run_functions_eagerly(run_eagerly) 3621 try: 3622 yield 3623 finally: 3624 def_function.run_functions_eagerly(initial_state) 3625