1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23from collections import OrderedDict 24import contextlib 25import functools 26import gc 27import itertools 28import math 29import os 30import random 31import re 32import tempfile 33import threading 34import time 35import unittest 36 37from absl.testing import parameterized 38import numpy as np 39import six 40 41from google.protobuf import descriptor_pool 42from google.protobuf import text_format 43 44from tensorflow.core.framework import graph_pb2 45from tensorflow.core.protobuf import rewriter_config_pb2 46from tensorflow.python import tf2 47from tensorflow.python.client import device_lib 48from tensorflow.python.client import pywrap_tf_session 49from tensorflow.python.client import session 50from tensorflow.python.compat.compat import forward_compatibility_horizon 51from tensorflow.python.eager import backprop 52from tensorflow.python.eager import context 53from tensorflow.python.eager import def_function 54from tensorflow.python.eager import tape 55from tensorflow.python.framework import config 56from tensorflow.python.framework import device as pydev 57from tensorflow.python.framework import dtypes 58from tensorflow.python.framework import errors 59from tensorflow.python.framework import errors_impl 60from tensorflow.python.framework import gpu_util 61from tensorflow.python.framework import importer 62from tensorflow.python.framework import ops 63from tensorflow.python.framework import random_seed 64from tensorflow.python.framework import sparse_tensor 65from tensorflow.python.framework import tensor_shape 66from tensorflow.python.framework import tensor_util 67from tensorflow.python.framework import tfrt_utils 68from tensorflow.python.framework import versions 69from tensorflow.python.ops import array_ops 70from tensorflow.python.ops import control_flow_util 71from tensorflow.python.ops import control_flow_util_v2 72from tensorflow.python.ops import gradients_impl 73from tensorflow.python.ops import math_ops 74from tensorflow.python.ops import script_ops 75from tensorflow.python.ops import summary_ops_v2 76from tensorflow.python.ops import variables 77from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import 78from tensorflow.python.ops.ragged import ragged_tensor 79from tensorflow.python.ops.ragged import ragged_tensor_value 80from tensorflow.python.platform import _pywrap_stacktrace_handler 81from tensorflow.python.platform import googletest 82from tensorflow.python.platform import tf_logging as logging 83from tensorflow.python.training import server_lib 84from tensorflow.python.util import _pywrap_util_port 85from tensorflow.python.util import compat 86from tensorflow.python.util import deprecation 87from tensorflow.python.util import nest 88from tensorflow.python.util import tf_decorator 89from tensorflow.python.util import tf_inspect 90from tensorflow.python.util.compat import collections_abc 91from tensorflow.python.util.protobuf import compare 92from tensorflow.python.util.tf_export import tf_export 93 94 95# If the below import is made available through the BUILD rule, then this 96# function is overridden and will instead return True and cause Tensorflow 97# graphs to be compiled with XLA. 98def is_xla_enabled(): 99 return False 100 101 102try: 103 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import 104except Exception: # pylint: disable=broad-except 105 pass 106 107 108# Uses the same mechanism as above to selectively enable/disable MLIR 109# compilation. 110def is_mlir_bridge_enabled(): 111 return None 112 113 114try: 115 from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 116except ImportError: 117 try: 118 from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 119 except ImportError: 120 pass 121 122 123def _get_object_count_by_type(exclude=()): 124 return ( 125 collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - 126 collections.Counter([type(obj).__name__ for obj in exclude])) 127 128 129@tf_export("test.gpu_device_name") 130def gpu_device_name(): 131 """Returns the name of a GPU device if available or the empty string.""" 132 for x in device_lib.list_local_devices(): 133 if x.device_type == "GPU": 134 return compat.as_str(x.name) 135 return "" 136 137 138def assert_ops_in_graph(expected_ops, graph): 139 """Assert all expected operations are found. 140 141 Args: 142 expected_ops: `dict<string, string>` of op name to op type. 143 graph: Graph to check. 144 145 Returns: 146 `dict<string, node>` of node name to node. 147 148 Raises: 149 ValueError: If the expected ops are not present in the graph. 150 """ 151 actual_ops = {} 152 gd = graph.as_graph_def() 153 for node in gd.node: 154 if node.name in expected_ops: 155 if expected_ops[node.name] != node.op: 156 raise ValueError("Expected op for node %s is different. %s vs %s" % 157 (node.name, expected_ops[node.name], node.op)) 158 actual_ops[node.name] = node 159 if set(expected_ops.keys()) != set(actual_ops.keys()): 160 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 161 (expected_ops.keys(), actual_ops.keys())) 162 return actual_ops 163 164 165@tf_export("test.assert_equal_graph_def", v1=[]) 166def assert_equal_graph_def_v2(expected, actual): 167 """Asserts that two `GraphDef`s are (mostly) the same. 168 169 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 170 nodes, attrs, and control inputs. Node names are used to match up nodes 171 between the graphs, so the naming of nodes must be consistent. This function 172 ignores randomized attribute values that may appear in V2 checkpoints. 173 174 Args: 175 expected: The `GraphDef` we expected. 176 actual: The `GraphDef` we have. 177 178 Raises: 179 AssertionError: If the `GraphDef`s do not match. 180 TypeError: If either argument is not a `GraphDef`. 181 """ 182 assert_equal_graph_def(actual, expected, checkpoint_v2=True, 183 hash_table_shared_name=True) 184 185 186@tf_export(v1=["test.assert_equal_graph_def"]) 187def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, 188 hash_table_shared_name=False): 189 """Asserts that two `GraphDef`s are (mostly) the same. 190 191 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 192 nodes, attrs, and control inputs. Node names are used to match up nodes 193 between the graphs, so the naming of nodes must be consistent. 194 195 Args: 196 actual: The `GraphDef` we have. 197 expected: The `GraphDef` we expected. 198 checkpoint_v2: boolean determining whether to ignore randomized attribute 199 values that appear in V2 checkpoints. 200 hash_table_shared_name: boolean determining whether to ignore randomized 201 shared_names that appear in HashTableV2 op defs. 202 203 Raises: 204 AssertionError: If the `GraphDef`s do not match. 205 TypeError: If either argument is not a `GraphDef`. 206 """ 207 assert_equal_graph_def(actual, expected, checkpoint_v2, 208 hash_table_shared_name) 209 210 211def assert_equal_graph_def(actual, expected, checkpoint_v2=False, 212 hash_table_shared_name=False): 213 if not isinstance(actual, graph_pb2.GraphDef): 214 raise TypeError("Expected tf.GraphDef for actual, got %s" % 215 type(actual).__name__) 216 if not isinstance(expected, graph_pb2.GraphDef): 217 raise TypeError("Expected tf.GraphDef for expected, got %s" % 218 type(expected).__name__) 219 220 if checkpoint_v2: 221 _strip_checkpoint_v2_randomized(actual) 222 _strip_checkpoint_v2_randomized(expected) 223 224 if hash_table_shared_name: 225 _strip_hash_table_shared_name(actual) 226 _strip_hash_table_shared_name(expected) 227 228 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(), 229 expected.SerializeToString()) 230 if diff: 231 raise AssertionError(compat.as_str(diff)) 232 233 234def assert_meta_graph_protos_equal(tester, a, b): 235 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 236 # Carefully check the collection_defs 237 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 238 collection_keys = a.collection_def.keys() 239 for k in collection_keys: 240 a_value = a.collection_def[k] 241 b_value = b.collection_def[k] 242 proto_type = ops.get_collection_proto_type(k) 243 if proto_type: 244 a_proto = proto_type() 245 b_proto = proto_type() 246 # Number of entries in the collections is the same 247 tester.assertEqual( 248 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 249 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 250 b_value.bytes_list.value): 251 a_proto.ParseFromString(a_value_item) 252 b_proto.ParseFromString(b_value_item) 253 tester.assertProtoEquals(a_proto, b_proto) 254 else: 255 tester.assertEquals(a_value, b_value) 256 # Compared the fields directly, remove their raw values from the 257 # proto comparison below. 258 a.ClearField("collection_def") 259 b.ClearField("collection_def") 260 261 # Check the graph_defs. 262 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 263 # Check graph_def versions (ignored by assert_equal_graph_def). 264 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 265 # Compared the fields directly, remove their raw values from the 266 # proto comparison below. 267 a.ClearField("graph_def") 268 b.ClearField("graph_def") 269 270 tester.assertProtoEquals(a, b) 271 272 273# Matches attributes named via _SHARDED_SUFFIX in 274# tensorflow/python/training/saver.py 275_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 276 277 278def _strip_checkpoint_v2_randomized(graph_def): 279 for node in graph_def.node: 280 delete_keys = [] 281 for attr_key in node.attr: 282 attr_tensor_value = node.attr[attr_key].tensor 283 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 284 attr_tensor_string_value = attr_tensor_value.string_val[0] 285 if (attr_tensor_string_value and 286 re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN), 287 attr_tensor_string_value)): 288 delete_keys.append(attr_key) 289 for attr_key in delete_keys: 290 del node.attr[attr_key] 291 292 293_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" 294 295 296def _strip_hash_table_shared_name(graph_def): 297 for node in graph_def.node: 298 delete_keys = [] 299 if node.op == "HashTableV2" and "shared_name" in node.attr: 300 if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), 301 node.attr["shared_name"].s): 302 delete_keys.append("shared_name") 303 for attr_key in delete_keys: 304 del node.attr[attr_key] 305 306 307def IsGoogleCudaEnabled(): 308 return _pywrap_util_port.IsGoogleCudaEnabled() 309 310 311def IsBuiltWithROCm(): 312 return _pywrap_util_port.IsBuiltWithROCm() 313 314 315def IsBuiltWithXLA(): 316 return _pywrap_util_port.IsBuiltWithXLA() 317 318 319def IsBuiltWithNvcc(): 320 return _pywrap_util_port.IsBuiltWithNvcc() 321 322 323def GpuSupportsHalfMatMulAndConv(): 324 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() 325 326 327def IsMklEnabled(): 328 return _pywrap_util_port.IsMklEnabled() 329 330 331def InstallStackTraceHandler(): 332 _pywrap_stacktrace_handler.InstallStacktraceHandler() 333 334 335def NHWCToNCHW(input_tensor): 336 """Converts the input from the NHWC format to NCHW. 337 338 Args: 339 input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape 340 341 Returns: 342 converted tensor or shape array 343 """ 344 # tensor dim -> new axis order 345 new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 346 if isinstance(input_tensor, ops.Tensor): 347 ndims = input_tensor.shape.ndims 348 return array_ops.transpose(input_tensor, new_axes[ndims]) 349 else: 350 ndims = len(input_tensor) 351 return [input_tensor[a] for a in new_axes[ndims]] 352 353 354def NHWCToNCHW_VECT_C(input_shape_or_tensor): 355 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 356 357 Note: Does not include quantization or type conversion steps, which should 358 be applied afterwards. 359 360 Args: 361 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 362 363 Returns: 364 tensor or shape array transformed into NCHW_VECT_C 365 366 Raises: 367 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 368 divisible by 4. 369 """ 370 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 371 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 372 temp_shape = ( 373 input_shape_or_tensor.shape.as_list() 374 if is_tensor else input_shape_or_tensor) 375 if temp_shape[-1] % 4 != 0: 376 raise ValueError( 377 "Last dimension of input must be evenly divisible by 4 to convert to " 378 "NCHW_VECT_C.") 379 temp_shape[-1] //= 4 380 temp_shape.append(4) 381 permutation = permutations[len(temp_shape)] 382 if is_tensor: 383 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 384 return array_ops.transpose(t, permutation) 385 else: 386 return [temp_shape[a] for a in permutation] 387 388 389def NCHW_VECT_CToNHWC(input_shape_or_tensor): 390 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 391 392 Note: Does not include de-quantization or type conversion steps, which should 393 be applied beforehand. 394 395 Args: 396 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 397 398 Returns: 399 tensor or shape array transformed into NHWC 400 401 Raises: 402 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 403 """ 404 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 405 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 406 input_shape = ( 407 input_shape_or_tensor.shape.as_list() 408 if is_tensor else input_shape_or_tensor) 409 if input_shape[-1] != 4: 410 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 411 permutation = permutations[len(input_shape)] 412 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 413 nhwc_shape[-1] *= input_shape[-1] 414 if is_tensor: 415 t = array_ops.transpose(input_shape_or_tensor, permutation) 416 return array_ops.reshape(t, nhwc_shape) 417 else: 418 return nhwc_shape 419 420 421def NCHWToNHWC(input_tensor): 422 """Converts the input from the NCHW format to NHWC. 423 424 Args: 425 input_tensor: a 4- or 5-D tensor, or an array representing shape 426 427 Returns: 428 converted tensor or shape array 429 """ 430 # tensor dim -> new axis order 431 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 432 if isinstance(input_tensor, ops.Tensor): 433 ndims = input_tensor.shape.ndims 434 return array_ops.transpose(input_tensor, new_axes[ndims]) 435 else: 436 ndims = len(input_tensor) 437 return [input_tensor[a] for a in new_axes[ndims]] 438 439 440def skip_if(condition): 441 """Skips the decorated function if condition is or evaluates to True. 442 443 Args: 444 condition: Either an expression that can be used in "if not condition" 445 statement, or a callable whose result should be a boolean. 446 447 Returns: 448 The wrapped function 449 """ 450 451 def real_skip_if(fn): 452 453 def wrapper(*args, **kwargs): 454 if callable(condition): 455 skip = condition() 456 else: 457 skip = condition 458 if not skip: 459 return fn(*args, **kwargs) 460 461 return wrapper 462 463 return real_skip_if 464 465 466@contextlib.contextmanager 467def skip_if_error(test_obj, error_type, messages=None): 468 """Context manager to skip cases not considered failures by the tests. 469 470 Note that this does not work if used in setUpClass/tearDownClass. 471 Usage in setUp/tearDown works fine just like regular test methods. 472 473 Args: 474 test_obj: A test object provided as `self` in the test methods; this object 475 is usually an instance of `unittest.TestCase`'s subclass and should have 476 `skipTest` method. 477 error_type: The error type to skip. Note that if `messages` are given, both 478 `error_type` and `messages` need to match for the test to be skipped. 479 messages: Optional, a string or list of strings. If `None`, the test will be 480 skipped if `error_type` matches what is raised; otherwise, the test is 481 skipped if any of the `messages` is contained in the message of the error 482 raised, and `error_type` matches the error raised. 483 484 Yields: 485 Nothing. 486 """ 487 if messages: 488 messages = nest.flatten(messages) 489 try: 490 yield 491 except error_type as e: 492 if not messages or any(message in str(e) for message in messages): 493 test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e))) 494 else: 495 raise 496 497 498def enable_c_shapes(fn): 499 """No-op. TODO(b/74620627): Remove this.""" 500 return fn 501 502 503def with_c_shapes(cls): 504 """No-op. TODO(b/74620627): Remove this.""" 505 return cls 506 507 508def enable_control_flow_v2(fn): 509 """Decorator for enabling CondV2 and WhileV2 on a test. 510 511 Note this enables using CondV2 and WhileV2 after running the test class's 512 setup/teardown methods. 513 514 In addition to this, callers must import the while_v2 module in order to set 515 the _while_v2 module in control_flow_ops. 516 517 Args: 518 fn: the function to be wrapped 519 520 Returns: 521 The wrapped function 522 """ 523 524 def wrapper(*args, **kwargs): 525 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 526 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 527 try: 528 return fn(*args, **kwargs) 529 finally: 530 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 531 532 return wrapper 533 534 535def with_control_flow_v2(cls): 536 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 537 538 Note this enables CondV2 and WhileV2 in new methods after running the test 539 class's setup method. 540 541 In addition to this, callers must import the while_v2 module in order to set 542 the _while_v2 module in control_flow_ops. 543 544 If a test function has _disable_control_flow_v2 attr set to True (using the 545 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 546 547 Example: 548 549 @test_util.with_control_flow_v2 550 class ControlFlowTest(test.TestCase): 551 552 def testEnabledForV2(self): 553 ... 554 555 @test_util.disable_control_flow_v2("b/xyzabc") 556 def testDisabledForV2(self): 557 ... 558 559 Generated class: 560 class ControlFlowTest(test.TestCase): 561 562 def testEnabledForV2(self): 563 ... 564 565 def testEnabledForV2WithControlFlowV2(self): 566 // Enable V2 flags. 567 testEnabledForV2(self) 568 // Restore V2 flags. 569 570 def testDisabledForV2(self): 571 ... 572 573 Args: 574 cls: class to decorate 575 576 Returns: 577 cls with new test methods added 578 """ 579 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 580 return cls 581 582 for name, value in cls.__dict__.copy().items(): 583 if (callable(value) and 584 name.startswith(unittest.TestLoader.testMethodPrefix) and 585 not getattr(value, "_disable_control_flow_v2", False)): 586 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 587 return cls 588 589 590def disable_control_flow_v2(unused_msg): 591 """Decorator for a function in a with_control_flow_v2 enabled test class. 592 593 Blocks the function from being run with v2 control flow ops. 594 595 Args: 596 unused_msg: Reason for disabling. 597 598 Returns: 599 The wrapped function with _disable_control_flow_v2 attr set to True. 600 """ 601 602 def wrapper(func): 603 func._disable_control_flow_v2 = True 604 return func 605 606 return wrapper 607 608 609def enable_output_all_intermediates(fn): 610 """Force-enable outputing all intermediates from functional control flow ops. 611 612 Args: 613 fn: the function to be wrapped 614 615 Returns: 616 The wrapped function 617 """ 618 619 def wrapper(*args, **kwargs): 620 output_all_intermediates_old = \ 621 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 622 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True 623 try: 624 return fn(*args, **kwargs) 625 finally: 626 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ 627 output_all_intermediates_old 628 629 return wrapper 630 631 632def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): 633 """Decorator for asserting that no new Python objects persist after a test. 634 635 Runs the test multiple times executing eagerly, first as a warmup and then to 636 let objects accumulate. The warmup helps ignore caches which do not grow as 637 the test is run repeatedly. 638 639 Useful for checking that there are no missing Py_DECREFs in the C exercised by 640 a bit of Python. 641 642 Args: 643 func: The function to test. 644 warmup_iters: The numer of warmup iterations, excluded from measuring. 645 646 Returns: 647 The wrapped function performing the test. 648 """ 649 650 def wrap_f(f): 651 def decorator(self, *args, **kwargs): 652 """Warms up, gets object counts, runs the test, checks for new objects.""" 653 with context.eager_mode(): 654 gc.disable() 655 # Run the test 2 times as warmup, in an attempt to fill up caches, which 656 # should not grow as the test is run repeatedly below. 657 # 658 # TODO(b/117156879): Running warmup twice is black magic; we have seen 659 # tests that fail with 1 warmup run, and pass with 2, on various 660 # versions of python2.7.x. 661 for _ in range(warmup_iters): 662 f(self, *args, **kwargs) 663 # Since we aren't in the normal test lifecycle, we need to manually run 664 # cleanups to clear out their object references. 665 self.doCleanups() 666 667 # Some objects are newly created by _get_object_count_by_type(). So 668 # create and save as a dummy variable to include it as a baseline. 669 obj_count_by_type = _get_object_count_by_type() 670 gc.collect() 671 672 # Make sure any registered functions are cleaned up in the C++ runtime. 673 registered_function_names = context.context().list_function_names() 674 675 # unittest.doCleanups adds to self._outcome with each unwound call. 676 # These objects are retained across gc collections so we exclude them 677 # from the object count calculation. 678 obj_count_by_type = _get_object_count_by_type( 679 exclude=gc.get_referents(self._outcome.errors, 680 self._outcome.skipped)) 681 682 if ops.has_default_graph(): 683 collection_sizes_before = { 684 collection: len(ops.get_collection(collection)) 685 for collection in ops.get_default_graph().collections 686 } 687 for _ in range(3): 688 f(self, *args, **kwargs) 689 # Since we aren't in the normal test lifecycle, we need to manually run 690 # cleanups to clear out their object references. 691 self.doCleanups() 692 # Note that gc.get_objects misses anything that isn't subject to garbage 693 # collection (C types). Collections are a common source of leaks, so we 694 # test for collection sizes explicitly. 695 if ops.has_default_graph(): 696 for collection_key in ops.get_default_graph().collections: 697 collection = ops.get_collection(collection_key) 698 size_before = collection_sizes_before.get(collection_key, 0) 699 if len(collection) > size_before: 700 raise AssertionError( 701 ("Collection %s increased in size from " 702 "%d to %d (current items %s).") % 703 (collection_key, size_before, len(collection), collection)) 704 # Make sure our collection checks don't show up as leaked memory by 705 # removing references to temporary variables. 706 del collection 707 del collection_key 708 del size_before 709 del collection_sizes_before 710 gc.collect() 711 712 # There should be no new Python objects hanging around. 713 obj_count_by_type = ( 714 _get_object_count_by_type( 715 exclude=gc.get_referents(self._outcome.errors, 716 self._outcome.skipped)) - 717 obj_count_by_type) 718 719 # There should be no newly registered functions hanging around. 720 leftover_functions = ( 721 context.context().list_function_names() - registered_function_names) 722 assert not leftover_functions, ( 723 "The following functions were newly created: %s" % 724 leftover_functions) 725 726 # In some cases (specifically on MacOS), new_count is somehow 727 # smaller than previous_count. 728 # Using plain assert because not all classes using this decorator 729 # have assertLessEqual 730 assert not obj_count_by_type, ( 731 "The following objects were newly created: %s" % 732 str(obj_count_by_type)) 733 gc.enable() 734 return decorator 735 736 if func is None: 737 return wrap_f 738 else: 739 return wrap_f(func) 740 741 742def assert_no_new_tensors(f): 743 """Decorator for asserting that no new Tensors persist after a test. 744 745 Mainly useful for checking that code using the Python C API has correctly 746 manipulated reference counts. 747 748 Clears the caches that it knows about, runs the garbage collector, then checks 749 that there are no Tensor or Tensor-like objects still around. This includes 750 Tensors to which something still has a reference (e.g. from missing 751 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 752 of the objects has __del__ defined). 753 754 Args: 755 f: The test case to run. 756 757 Returns: 758 The decorated test case. 759 """ 760 761 def decorator(self, **kwargs): 762 """Finds existing Tensors, runs the test, checks for new Tensors.""" 763 764 def _is_tensorflow_object(obj): 765 try: 766 return isinstance(obj, 767 (ops.Tensor, variables.Variable, 768 tensor_shape.Dimension, tensor_shape.TensorShape)) 769 except (ReferenceError, AttributeError): 770 # If the object no longer exists, we don't care about it. 771 return False 772 773 tensors_before = set( 774 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 775 outside_executed_eagerly = context.executing_eagerly() 776 # Run the test in a new graph so that collections get cleared when it's 777 # done, but inherit the graph key so optimizers behave. 778 outside_graph_key = ops.get_default_graph()._graph_key 779 with ops.Graph().as_default(): 780 ops.get_default_graph()._graph_key = outside_graph_key 781 if outside_executed_eagerly: 782 with context.eager_mode(): 783 result = f(self, **kwargs) 784 else: 785 result = f(self, **kwargs) 786 # Make an effort to clear caches, which would otherwise look like leaked 787 # Tensors. 788 context.context()._clear_caches() # pylint: disable=protected-access 789 gc.collect() 790 tensors_after = [ 791 obj for obj in gc.get_objects() 792 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 793 ] 794 if tensors_after: 795 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 796 len(tensors_after), 797 str(tensors_after), 798 ))) 799 return result 800 801 return decorator 802 803 804def _find_reference_cycle(objects, idx): 805 806 def get_ignore_reason(obj, denylist): 807 """Tests whether an object should be omitted from the dependency graph.""" 808 if len(denylist) > 100: 809 return "<depth limit>" 810 if tf_inspect.isframe(obj): 811 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 812 return "<test code>" 813 for b in denylist: 814 if b is obj: 815 return "<test code>" 816 if obj is denylist: 817 return "<test code>" 818 return None 819 820 # Note: this function is meant to help with diagnostics. Its output is purely 821 # a human-readable representation, so you may freely modify it to suit your 822 # needs. 823 def describe(obj, denylist, leaves_only=False): 824 """Returns a custom human-readable summary of obj. 825 826 Args: 827 obj: the value to describe. 828 denylist: same as denylist in get_ignore_reason. 829 leaves_only: boolean flag used when calling describe recursively. Useful 830 for summarizing collections. 831 """ 832 if get_ignore_reason(obj, denylist): 833 return "{}{}".format(get_ignore_reason(obj, denylist), type(obj)) 834 if tf_inspect.isframe(obj): 835 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 836 elif tf_inspect.ismodule(obj): 837 return "module: {}".format(obj.__name__) 838 else: 839 if leaves_only: 840 return "{}, {}".format(type(obj), id(obj)) 841 elif isinstance(obj, list): 842 return "list({}): {}".format( 843 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 844 elif isinstance(obj, tuple): 845 return "tuple({}): {}".format( 846 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 847 elif isinstance(obj, dict): 848 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 849 elif tf_inspect.isfunction(obj): 850 return "function({}) {}; globals ID: {}".format( 851 id(obj), obj.__name__, id(obj.__globals__)) 852 else: 853 return "{}, {}".format(type(obj), id(obj)) 854 855 def build_ref_graph(obj, graph, reprs, denylist): 856 """Builds a reference graph as <referrer> -> <list of referents>. 857 858 Args: 859 obj: The object to start from. The graph will be built by recursively 860 adding its referrers. 861 graph: Dict holding the graph to be built. To avoid creating extra 862 references, the graph holds object IDs rather than actual objects. 863 reprs: Auxiliary structure that maps object IDs to their human-readable 864 description. 865 denylist: List of objects to ignore. 866 """ 867 referrers = gc.get_referrers(obj) 868 denylist = denylist + (referrers,) 869 870 obj_id = id(obj) 871 for r in referrers: 872 if get_ignore_reason(r, denylist) is None: 873 r_id = id(r) 874 if r_id not in graph: 875 graph[r_id] = [] 876 if obj_id not in graph[r_id]: 877 graph[r_id].append(obj_id) 878 build_ref_graph(r, graph, reprs, denylist) 879 reprs[r_id] = describe(r, denylist) 880 881 def find_cycle(el, graph, reprs, path): 882 """Finds and prints a single cycle in the dependency graph.""" 883 if el not in graph: 884 return 885 for r in graph[el]: 886 if r in path: 887 logging.error("Reference cycle sample:") 888 for p in path + (r,): 889 logging.error(reprs.get(p, "unknown object " + str(p))) 890 return True 891 else: 892 if find_cycle(r, graph, reprs, path + (r,)): 893 return True 894 return False 895 896 obj = objects[idx] 897 graph = {} # referrer ID -> object ID 898 reprs = {} # object ID -> description 899 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 900 describe, build_ref_graph, find_cycle)) 901 for k in graph: 902 if find_cycle(k, graph, reprs, ()): 903 return True 904 return False 905 906 907def assert_no_garbage_created(f): 908 """Test method decorator to assert that no garbage has been created. 909 910 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 911 cannot be un-set (i.e. will disable garbage collection for any other unit 912 tests in the same file/shard). 913 914 Args: 915 f: The function to decorate. 916 917 Returns: 918 The decorated function. 919 """ 920 921 def decorator(self, **kwargs): 922 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 923 # Force-load `distribution_strategy_context` to prevent GC at 924 # test time when using eager. Remove once b/117329403 is resolved. 925 tape.distribution_strategy_context.get_strategy() 926 927 gc.disable() 928 previous_debug_flags = gc.get_debug() 929 gc.set_debug(gc.DEBUG_SAVEALL) 930 gc.collect() 931 previous_garbage = len(gc.garbage) 932 result = f(self, **kwargs) 933 gc.collect() 934 new_garbage = len(gc.garbage) 935 if new_garbage > previous_garbage: 936 logging.error( 937 "The decorated test created work for Python's garbage collector, " 938 "likely due to a reference cycle. New objects in cycle(s):") 939 for i, obj in enumerate(gc.garbage[previous_garbage:]): 940 try: 941 logging.error("Object %d of %d", i, 942 len(gc.garbage) - previous_garbage) 943 944 def _safe_object_str(obj): 945 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 946 947 logging.error(" Object type: %s", _safe_object_str(obj)) 948 logging.error( 949 " Referrer types: %s", ", ".join( 950 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 951 logging.error( 952 " Referent types: %s", ", ".join( 953 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 954 logging.error(" Object attribute names: %s", dir(obj)) 955 logging.error(" Object __str__:") 956 logging.error(obj) 957 logging.error(" Object __repr__:") 958 logging.error(repr(obj)) 959 except Exception: # pylint: disable=broad-except 960 logging.error("(Exception while printing object)") 961 962 # When garbage is created, this call can help identify reference cycles, 963 # which are typically the cause of such garbage. 964 if new_garbage > previous_garbage: 965 for i in range(previous_garbage, new_garbage): 966 if _find_reference_cycle(gc.garbage, i): 967 break 968 969 # This will fail if any garbage has been created, typically because of a 970 # reference cycle. 971 self.assertEqual(previous_garbage, new_garbage) 972 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 973 # be nice to be able to decorate arbitrary tests in a large test suite and 974 # not hold on to every object in other tests. 975 gc.set_debug(previous_debug_flags) 976 gc.enable() 977 return result 978 979 return decorator 980 981 982def _combine_named_parameters(**kwargs): 983 """Generate combinations based on its keyword arguments. 984 985 Two sets of returned combinations can be concatenated using +. Their product 986 can be computed using `times()`. 987 988 Args: 989 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 990 `option=the_only_possibility`. 991 992 Returns: 993 a list of dictionaries for each combination. Keys in the dictionaries are 994 the keyword argument names. Each key has one value - one of the 995 corresponding keyword argument values. 996 """ 997 sort_by_key = lambda k: k[0] 998 combinations = [] 999 for key, values in sorted(kwargs.items(), key=sort_by_key): 1000 if not isinstance(values, list): 1001 values = [values] 1002 combinations.append([(key, value) for value in values]) 1003 1004 return [OrderedDict(result) for result in itertools.product(*combinations)] 1005 1006 1007def generate_combinations_with_testcase_name(**kwargs): 1008 """Generate combinations based on its keyword arguments using combine(). 1009 1010 This function calls combine() and appends a testcase name to the list of 1011 dictionaries returned. The 'testcase_name' key is a required for named 1012 parameterized tests. 1013 1014 Args: 1015 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1016 `option=the_only_possibility`. 1017 1018 Returns: 1019 a list of dictionaries for each combination. Keys in the dictionaries are 1020 the keyword argument names. Each key has one value - one of the 1021 corresponding keyword argument values. 1022 """ 1023 combinations = _combine_named_parameters(**kwargs) 1024 named_combinations = [] 1025 for combination in combinations: 1026 assert isinstance(combination, OrderedDict) 1027 name = "".join([ 1028 "_{}_{}".format("".join(filter(str.isalnum, key)), 1029 "".join(filter(str.isalnum, str(value)))) 1030 for key, value in combination.items() 1031 ]) 1032 named_combinations.append( 1033 OrderedDict( 1034 list(combination.items()) + 1035 [("testcase_name", "_test{}".format(name))])) 1036 1037 return named_combinations 1038 1039 1040def run_all_in_graph_and_eager_modes(cls): 1041 """Execute all test methods in the given class with and without eager.""" 1042 base_decorator = run_in_graph_and_eager_modes 1043 for name in dir(cls): 1044 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1045 name.startswith("testSkipEager") or 1046 name.startswith("test_skip_eager") or 1047 name == "test_session"): 1048 continue 1049 value = getattr(cls, name, None) 1050 if callable(value): 1051 setattr(cls, name, base_decorator(value)) 1052 return cls 1053 1054 1055def build_as_function_and_v1_graph(func=None): 1056 """Run a test case in v1 graph mode and inside tf.function in eager mode. 1057 1058 WARNING: This decorator can only be used in test cases that statically checks 1059 generated graph. Attempting to evaluate graph or function results via. 1060 session.run() or self.evaluate() will fail. 1061 1062 WARNING: This decorator can only be used for test cases that inherit from 1063 absl.testing.parameterized.TestCase. 1064 1065 Args: 1066 func: Test case function to be decorated. 1067 1068 Returns: 1069 Decorated test case function. 1070 """ 1071 1072 def decorator(f): 1073 if tf_inspect.isclass(f): 1074 raise ValueError( 1075 "`run_in_graph_mode_and_function` only supports test methods.") 1076 1077 @parameterized.named_parameters(("_v1_graph", "v1_graph"), 1078 ("_function", "function")) 1079 @functools.wraps(f) 1080 def decorated(self, run_mode, *args, **kwargs): 1081 if run_mode == "v1_graph": 1082 with ops.Graph().as_default(): 1083 f(self, *args, **kwargs) 1084 elif run_mode == "function": 1085 1086 @def_function.function 1087 def function_in_eager(): 1088 f(self, *args, **kwargs) 1089 1090 # Create a new graph for the eagerly executed version of this test for 1091 # better isolation. 1092 graph_for_eager_test = ops.Graph() 1093 with graph_for_eager_test.as_default(), context.eager_mode(): 1094 function_in_eager() 1095 ops.dismantle_graph(graph_for_eager_test) 1096 else: 1097 return ValueError("Unknown run mode %s" % run_mode) 1098 1099 return decorated 1100 1101 if func is not None: 1102 return decorator(func) 1103 1104 return decorator 1105 1106 1107def run_in_async_and_sync_mode(f): 1108 """Execute the test in async mode and sync mode.""" 1109 1110 @parameterized.named_parameters([("Async", True), ("", False)]) 1111 @functools.wraps(f) 1112 def decorator(self, async_mode, *args, **kwargs): 1113 if async_mode: 1114 with context.execution_mode(context.ASYNC): 1115 f(self, *args, **kwargs) 1116 else: 1117 with context.execution_mode(context.SYNC): 1118 f(self, *args, **kwargs) 1119 return decorator 1120 1121 1122def run_in_graph_and_eager_modes(func=None, 1123 config=None, 1124 use_gpu=True, 1125 assert_no_eager_garbage=False): 1126 """Execute the decorated test with and without enabling eager execution. 1127 1128 This function returns a decorator intended to be applied to test methods in 1129 a `tf.test.TestCase` class. Doing so will cause the contents of the test 1130 method to be executed twice - once normally, and once with eager execution 1131 enabled. This allows unittests to confirm the equivalence between eager 1132 and graph execution (see `tf.compat.v1.enable_eager_execution`). 1133 1134 For example, consider the following unittest: 1135 1136 ```python 1137 class MyTests(tf.test.TestCase): 1138 1139 @run_in_graph_and_eager_modes 1140 def test_foo(self): 1141 x = tf.constant([1, 2]) 1142 y = tf.constant([3, 4]) 1143 z = tf.add(x, y) 1144 self.assertAllEqual([4, 6], self.evaluate(z)) 1145 1146 if __name__ == "__main__": 1147 tf.test.main() 1148 ``` 1149 1150 This test validates that `tf.add()` has the same behavior when computed with 1151 eager execution enabled as it does when constructing a TensorFlow graph and 1152 executing the `z` tensor in a session. 1153 1154 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1155 `run_in_graph_and_eager_modes` are available decorators for different 1156 v1/v2/eager/graph combinations. 1157 1158 1159 Args: 1160 func: function to be annotated. If `func` is None, this method returns a 1161 decorator the can be applied to a function. If `func` is not None this 1162 returns the decorator applied to `func`. 1163 config: An optional config_pb2.ConfigProto to use to configure the session 1164 when executing graphs. 1165 use_gpu: If True, attempt to run as many operations as possible on GPU. 1166 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 1167 collector and asserts that no extra garbage has been created when running 1168 the test with eager execution enabled. This will fail if there are 1169 reference cycles (e.g. a = []; a.append(a)). Off by default because some 1170 tests may create garbage for legitimate reasons (e.g. they define a class 1171 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 1172 Python interpreters (meaning that tests which rely on objects being 1173 collected elsewhere in the unit test file will not work). Additionally, 1174 checks that nothing still has a reference to Tensors that the test 1175 allocated. 1176 1177 Returns: 1178 Returns a decorator that will run the decorated test method twice: 1179 once by constructing and executing a graph in a session and once with 1180 eager execution enabled. 1181 """ 1182 1183 def decorator(f): 1184 if tf_inspect.isclass(f): 1185 raise ValueError( 1186 "`run_in_graph_and_eager_modes` only supports test methods. " 1187 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 1188 1189 def decorated(self, *args, **kwargs): 1190 try: 1191 with context.graph_mode(): 1192 with self.test_session(use_gpu=use_gpu, config=config): 1193 f(self, *args, **kwargs) 1194 except unittest.case.SkipTest: 1195 pass 1196 1197 def run_eagerly(self, **kwargs): 1198 if not use_gpu: 1199 with ops.device("/device:CPU:0"): 1200 f(self, *args, **kwargs) 1201 else: 1202 f(self, *args, **kwargs) 1203 1204 if assert_no_eager_garbage: 1205 ops.reset_default_graph() 1206 run_eagerly = assert_no_new_tensors( 1207 assert_no_garbage_created(run_eagerly)) 1208 1209 # This decorator runs the wrapped test twice. 1210 # Reset the test environment between runs. 1211 self.tearDown() 1212 self._tempdir = None 1213 # Create a new graph for the eagerly executed version of this test for 1214 # better isolation. 1215 graph_for_eager_test = ops.Graph() 1216 with graph_for_eager_test.as_default(), context.eager_mode(): 1217 self.setUp() 1218 run_eagerly(self, **kwargs) 1219 ops.dismantle_graph(graph_for_eager_test) 1220 1221 return tf_decorator.make_decorator(f, decorated) 1222 1223 if func is not None: 1224 return decorator(func) 1225 1226 return decorator 1227 1228 1229def py_func_if_in_function(f): 1230 1231 def decorated(*args, **kwds): 1232 if not ops.inside_function(): 1233 return f(*args, **kwds) 1234 1235 tensor_args = [] 1236 tensor_indices = [] 1237 for i, arg in enumerate(args): 1238 if isinstance(arg, (ops.Tensor, variables.Variable)): 1239 tensor_args.append(arg) 1240 tensor_indices.append(i) 1241 1242 def inner_f(*inner_tensor_args): 1243 my_args = list(args) 1244 for i, n in zip(tensor_indices, inner_tensor_args): 1245 my_args[i] = n 1246 return f(*my_args, **kwds) 1247 1248 return script_ops.py_func(inner_f, tensor_args, []) 1249 1250 return tf_decorator.make_decorator(f, decorated) 1251 1252 1253def also_run_as_tf_function(f): 1254 """Runs the decorated test twice--once as is, once inside a tf.function. 1255 1256 This allows you to run a test both in eager execution and inside a 1257 tf.function, exercising the two execution modes supported in tf 2.0. The test 1258 assertions are automatically done inside tf.py_funcs, and tf.function ensures 1259 that they run in the proper order and with the proper side effects. 1260 1261 Currently variable creation is not supported in tests annotated with this 1262 decorator since it's tricky to ensure the variable doesn't get repeatedly 1263 created when retracing the tf.function. 1264 1265 Args: 1266 f: the test method to be decorated 1267 1268 Returns: 1269 The decorated test method, which will run both in eager and inside a 1270 tf.function. 1271 """ 1272 1273 def decorated(*args, **kwds): 1274 1275 def bound_f(): 1276 f(*args, **kwds) 1277 1278 with context.eager_mode(): 1279 # Running in eager mode 1280 bound_f() 1281 # Running as TF function 1282 # TODO(b/121143941): Remove the autograph override. 1283 def_function.function(bound_f, autograph=False)() 1284 1285 return decorated 1286 1287 1288def deprecated_graph_mode_only(func=None): 1289 """Execute the decorated test in graph mode. 1290 1291 This function returns a decorator intended to be applied to tests that are not 1292 compatible with eager mode. When this decorator is applied, the test body will 1293 be run in an environment where API calls construct graphs instead of executing 1294 eagerly. 1295 1296 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1297 `run_in_graph_and_eager_modes` are available decorators for different 1298 v1/v2/eager/graph combinations. 1299 1300 Args: 1301 func: function to be annotated. If `func` is None, this method returns a 1302 decorator the can be applied to a function. If `func` is not None this 1303 returns the decorator applied to `func`. 1304 1305 Returns: 1306 Returns a decorator that will run the decorated test method in graph mode. 1307 """ 1308 1309 def decorator(f): 1310 if tf_inspect.isclass(f): 1311 setup = f.__dict__.get("setUp") 1312 if setup is not None: 1313 setattr(f, "setUp", decorator(setup)) 1314 1315 for name, value in f.__dict__.copy().items(): 1316 if (callable(value) and 1317 name.startswith(unittest.TestLoader.testMethodPrefix)): 1318 setattr(f, name, decorator(value)) 1319 1320 return f 1321 1322 def decorated(self, *args, **kwargs): 1323 if context.executing_eagerly(): 1324 with context.graph_mode(): 1325 return f(self, *args, **kwargs) 1326 else: 1327 return f(self, *args, **kwargs) 1328 1329 return decorated 1330 1331 if func is not None: 1332 return decorator(func) 1333 1334 return decorator 1335 1336 1337run_deprecated_v1 = deprecated_graph_mode_only 1338 1339 1340def run_all_in_deprecated_graph_mode_only(cls): 1341 """Execute all tests in a class in graph mode.""" 1342 base_decorator = deprecated_graph_mode_only 1343 for name in dir(cls): 1344 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1345 name == "test_session"): 1346 continue 1347 value = getattr(cls, name, None) 1348 if callable(value): 1349 setattr(cls, name, base_decorator(value)) 1350 return cls 1351 1352 1353def run_v1_only(reason, func=None): 1354 """Execute the decorated test only if running in v1 mode. 1355 1356 This function is intended to be applied to tests that exercise v1 only 1357 functionality. If the test is run in v2 mode it will simply be skipped. 1358 1359 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1360 `run_in_graph_and_eager_modes` are available decorators for different 1361 v1/v2/eager/graph combinations. 1362 1363 Args: 1364 reason: string giving a reason for limiting the test to v1 only. 1365 func: function to be annotated. If `func` is None, this method returns a 1366 decorator the can be applied to a function. If `func` is not None this 1367 returns the decorator applied to `func`. 1368 1369 Returns: 1370 Returns a decorator that will conditionally skip the decorated test method. 1371 """ 1372 if not isinstance(reason, str): 1373 raise ValueError("'reason' should be string, got {}".format(type(reason))) 1374 1375 def decorator(f): 1376 if tf_inspect.isclass(f): 1377 # To skip an entire test suite class, we only decorate the setUp method 1378 # to skip all tests. There are cases when setUp is not defined (not 1379 # overridden in subclasses of TestCase, so not available in f.__dict__ 1380 # below). For those cases, we walk the method resolution order list and 1381 # pick the first setUp method we find (usually this should be the one in 1382 # the parent class since that's the TestCase class). 1383 for cls in type.mro(f): 1384 setup = cls.__dict__.get("setUp") 1385 if setup is not None: 1386 setattr(f, "setUp", decorator(setup)) 1387 break 1388 1389 return f 1390 else: 1391 # If f is just a function, just create a decorator for it and return it 1392 def decorated(self, *args, **kwargs): 1393 if tf2.enabled(): 1394 self.skipTest(reason) 1395 1396 return f(self, *args, **kwargs) 1397 1398 return decorated 1399 1400 if func is not None: 1401 return decorator(func) 1402 1403 return decorator 1404 1405 1406def run_v2_only(func=None): 1407 """Execute the decorated test only if running in v2 mode. 1408 1409 This function is intended to be applied to tests that exercise v2 only 1410 functionality. If the test is run in v1 mode it will simply be skipped. 1411 1412 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1413 `run_in_graph_and_eager_modes` are available decorators for different 1414 v1/v2/eager/graph combinations. 1415 1416 Args: 1417 func: function to be annotated. If `func` is None, this method returns a 1418 decorator the can be applied to a function. If `func` is not None this 1419 returns the decorator applied to `func`. 1420 1421 Returns: 1422 Returns a decorator that will conditionally skip the decorated test method. 1423 """ 1424 1425 def decorator(f): 1426 if tf_inspect.isclass(f): 1427 raise ValueError("`run_v2_only` only supports test methods.") 1428 1429 def decorated(self, *args, **kwargs): 1430 if not tf2.enabled(): 1431 self.skipTest("Test is only compatible with v2") 1432 1433 return f(self, *args, **kwargs) 1434 1435 return decorated 1436 1437 if func is not None: 1438 return decorator(func) 1439 1440 return decorator 1441 1442 1443def run_gpu_only(func=None): 1444 """Execute the decorated test only if a GPU is available. 1445 1446 This function is intended to be applied to tests that require the presence 1447 of a GPU. If a GPU is absent, it will simply be skipped. 1448 1449 Args: 1450 func: function to be annotated. If `func` is None, this method returns a 1451 decorator the can be applied to a function. If `func` is not None this 1452 returns the decorator applied to `func`. 1453 1454 Returns: 1455 Returns a decorator that will conditionally skip the decorated test method. 1456 """ 1457 1458 def decorator(f): 1459 if tf_inspect.isclass(f): 1460 raise ValueError("`run_gpu_only` only supports test methods.") 1461 1462 def decorated(self, *args, **kwargs): 1463 if not is_gpu_available(): 1464 self.skipTest("Test requires GPU") 1465 1466 return f(self, *args, **kwargs) 1467 1468 return decorated 1469 1470 if func is not None: 1471 return decorator(func) 1472 1473 return decorator 1474 1475 1476def run_cuda_only(func=None): 1477 """Execute the decorated test only if a GPU is available. 1478 1479 This function is intended to be applied to tests that require the presence 1480 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 1481 1482 Args: 1483 func: function to be annotated. If `func` is None, this method returns a 1484 decorator the can be applied to a function. If `func` is not None this 1485 returns the decorator applied to `func`. 1486 1487 Returns: 1488 Returns a decorator that will conditionally skip the decorated test method. 1489 """ 1490 1491 def decorator(f): 1492 if tf_inspect.isclass(f): 1493 raise ValueError("`run_cuda_only` only supports test methods.") 1494 1495 def decorated(self, *args, **kwargs): 1496 if not is_gpu_available(cuda_only=True): 1497 self.skipTest("Test requires CUDA GPU") 1498 1499 return f(self, *args, **kwargs) 1500 1501 return decorated 1502 1503 if func is not None: 1504 return decorator(func) 1505 1506 return decorator 1507 1508 1509def with_forward_compatibility_horizons(*horizons): 1510 """Executes the decorated test with the specified forward-compat horizons. 1511 1512 Args: 1513 *horizons: A list of (year, month, day) tuples. If the list includes 1514 `None`, then the test will also be run with no forward-compatibility 1515 horizon set. 1516 1517 Returns: 1518 A decorator that will execute the test with the specified horizons. 1519 """ 1520 if not horizons: 1521 raise ValueError("Expected at least one horizon.") 1522 for horizon in horizons: 1523 if not ((horizon is None) or 1524 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): 1525 raise ValueError("Bad horizon value: %r" % horizon) 1526 1527 def decorator(f): 1528 if tf_inspect.isclass(f): 1529 raise ValueError("`with_forward_compatibility_horizons` only " 1530 "supports test methods.") 1531 def decorated(self, *args, **kwargs): 1532 for horizon in horizons: 1533 if horizon is None: 1534 f(self, *args, **kwargs) 1535 else: 1536 (year, month, day) = horizon 1537 with forward_compatibility_horizon(year, month, day): 1538 f(self, *args, **kwargs) 1539 return decorated 1540 1541 return decorator 1542 1543 1544@deprecation.deprecated(None, 1545 "Use `tf.config.list_physical_devices('GPU')` instead.") 1546@tf_export("test.is_gpu_available") 1547def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 1548 """Returns whether TensorFlow can access a GPU. 1549 1550 Warning: if a non-GPU version of the package is installed, the function would 1551 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow 1552 was build with CUDA support. 1553 1554 For example, 1555 >>> gpu_available = tf.test.is_gpu_available() 1556 >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True) 1557 >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0)) 1558 1559 Args: 1560 cuda_only: limit the search to CUDA GPUs. 1561 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 1562 CUDA compute capability required, or None if no requirement. 1563 1564 Note that the keyword arg name "cuda_only" is misleading (since routine will 1565 return true when a GPU device is available irrespective of whether TF was 1566 built with CUDA support or ROCm support. However no changes here because 1567 1568 ++ Changing the name "cuda_only" to something more generic would break 1569 backward compatibility 1570 1571 ++ Adding an equivalent "rocm_only" would require the implementation check 1572 the build type. This in turn would require doing the same for CUDA and thus 1573 potentially break backward compatibility 1574 1575 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, 1576 but would require most (if not all) callers to update the call to use 1577 "cuda_or_rocm_only" instead of "cuda_only" 1578 1579 Returns: 1580 True if a GPU device of the requested kind is available. 1581 """ 1582 1583 # This was needed earlier when we had support for SYCL in TensorFlow. 1584 del cuda_only 1585 1586 try: 1587 for local_device in device_lib.list_local_devices(): 1588 if local_device.device_type == "GPU": 1589 gpu_info = gpu_util.compute_capability_from_device_desc(local_device) 1590 cc = gpu_info.compute_capability or (0, 0) 1591 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability: 1592 return True 1593 return False 1594 except errors_impl.NotFoundError as e: 1595 if not all(x in str(e) for x in ["CUDA", "not find"]): 1596 raise e 1597 else: 1598 logging.error(str(e)) 1599 return False 1600 1601 1602@contextlib.contextmanager 1603def device(use_gpu): 1604 """Uses gpu when requested and available.""" 1605 if use_gpu and is_gpu_available(): 1606 dev = "/device:GPU:0" 1607 else: 1608 dev = "/device:CPU:0" 1609 with ops.device(dev): 1610 yield 1611 1612 1613@contextlib.contextmanager 1614def use_gpu(): 1615 """Uses gpu when requested and available.""" 1616 with device(use_gpu=True): 1617 yield 1618 1619 1620@contextlib.contextmanager 1621def force_gpu(): 1622 """Force the gpu to be used.""" 1623 with ops.device("/device:GPU:0"): 1624 yield 1625 1626 1627@contextlib.contextmanager 1628def force_cpu(): 1629 """Force the cpu to be used.""" 1630 with ops.device("/device:CPU:0"): 1631 yield 1632 1633 1634class CapturedWrites(object): 1635 """A utility class to load the captured writes made to a stream.""" 1636 1637 def __init__(self, capture_location): 1638 self.capture_location = capture_location 1639 1640 def contents(self): 1641 """Get the captured writes as a single string.""" 1642 with open(self.capture_location) as tmp_file: 1643 output_data = "".join(tmp_file.readlines()) 1644 return output_data 1645 1646 1647class FakeEagerSession(object): 1648 """Fake session so tests that conditionally use placeholders can use eager. 1649 1650 There are a number of tests that conditionally use placeholders for shape 1651 inference. The pattern is demonstrated here: 1652 1653 ```python 1654 with self.cached_session() as sess: 1655 if static_shape: 1656 y = math_ops.matmul(x, ...) 1657 feed_dict = {} 1658 else: 1659 x_ph = array_ops.placeholder(...) 1660 y = math_ops.matmul(x_ph, ...) 1661 feed_dict = {x_ph: x} 1662 val = sess.run(y, feed_dict=feed_dict) 1663 ``` 1664 1665 Since the feed_dict is empty when not using placeholders we should be able to 1666 call self.evaluate(), however this requires rewriting the test case. 1667 This class should be considered a stop-gap solution to get tests running with 1668 eager with minimal changes to the actual test. 1669 """ 1670 1671 def __init__(self, test_case): 1672 self._test_case = test_case 1673 1674 def run(self, fetches, *args, **kwargs): 1675 """Evaluate `fetches`. 1676 1677 Fail if additional args are specified. 1678 1679 Args: 1680 fetches: A Tensor or a nested list/tuple of Tensors. 1681 *args: Positional arguments 1682 **kwargs: Keyword arguments 1683 1684 Raises: 1685 RuntimeError: If args or kwargs are specified. 1686 1687 Returns: 1688 Tensors as numpy values. 1689 """ 1690 feed_dict = kwargs.pop("feed_dict", {}) 1691 if feed_dict: 1692 raise RuntimeError( 1693 "feed_dict is not supported when eager execution is enabled " 1694 "(in this case, sess.run(t) is shorthand for t.numpy()") 1695 1696 if args or kwargs: 1697 raise RuntimeError( 1698 "Optional args are not supported when eager execution is enabled " 1699 "(in this case, sess.run(t) is shorthand for t.numpy()") 1700 1701 return self._test_case.evaluate(fetches) 1702 1703 1704class ErrorLoggingSession(session.Session): 1705 """Wrapper around a Session that logs errors in run().""" 1706 1707 def run(self, *args, **kwargs): 1708 try: 1709 return super(ErrorLoggingSession, self).run(*args, **kwargs) 1710 except Exception as e: # pylint: disable=broad-except 1711 # Note: disable the logging for OutOfRangeError, which makes the output 1712 # of tf.data tests hard to read, because OutOfRangeError is used as the 1713 # signal completion 1714 if not isinstance(e, errors.OutOfRangeError): 1715 logging.error(str(e)) 1716 raise 1717 1718 1719def disable_cudnn_autotune(func): 1720 """Disable autotuning during the call to this function. 1721 1722 Some tests want to base assertions on a graph being isomorphic with a copy. 1723 To ensure this, this decorator disables autotuning. 1724 1725 Args: 1726 func: Function to run with CuDNN autotuning turned off. 1727 1728 Returns: 1729 Decorated function. 1730 """ 1731 1732 def decorator(f): 1733 1734 def decorated(self, *args, **kwargs): 1735 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") 1736 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" 1737 original_xla_flags = os.environ.get("XLA_FLAGS") 1738 new_xla_flags = "--xla_gpu_autotune_level=0" 1739 if original_xla_flags: 1740 new_xla_flags = original_xla_flags + " " + new_xla_flags 1741 os.environ["XLA_FLAGS"] = new_xla_flags 1742 1743 result = f(self, *args, **kwargs) 1744 1745 if (original_tf_cudnn_use_autotune is None): 1746 del os.environ["TF_CUDNN_USE_AUTOTUNE"] 1747 else: 1748 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune 1749 if (original_xla_flags is None): 1750 del os.environ["XLA_FLAGS"] 1751 else: 1752 os.environ["XLA_FLAGS"] = original_xla_flags 1753 1754 return result 1755 1756 return decorated 1757 1758 if func is not None: 1759 return decorator(func) 1760 1761 return decorator 1762 1763 1764# The description is just for documentation purposes. 1765def enable_tf_xla_constant_folding(description): 1766 1767 if not isinstance(description, str): 1768 raise ValueError("'description' should be string, got {}".format( 1769 type(description))) 1770 1771 def enable_tf_xla_constant_folding_impl(func): 1772 """Enable constant folding during the call to this function. 1773 1774 Some tests fail without constant folding. 1775 1776 Args: 1777 func: Function to run with constant folding turned on. 1778 1779 Returns: 1780 Decorated function. 1781 """ 1782 1783 def decorator(f): 1784 1785 def decorated(self, *args, **kwargs): 1786 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() 1787 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) 1788 result = f(self, *args, **kwargs) 1789 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) 1790 return result 1791 1792 return decorated 1793 1794 if func is not None: 1795 return decorator(func) 1796 1797 return decorator 1798 1799 return enable_tf_xla_constant_folding_impl 1800 1801 1802# Updates test function by selectively disabling it. 1803def _disable_test(execute_func): 1804 1805 def disable_test_impl(func): 1806 1807 def decorator(func): 1808 1809 def decorated(self, *args, **kwargs): 1810 if execute_func: 1811 return func(self, *args, **kwargs) 1812 1813 return decorated 1814 1815 if func is not None: 1816 return decorator(func) 1817 1818 return decorator 1819 1820 return disable_test_impl 1821 1822 1823# The description is just for documentation purposes. 1824def disable_xla(description): # pylint: disable=unused-argument 1825 """Execute the test method only if xla is not enabled.""" 1826 execute_func = not is_xla_enabled() 1827 return _disable_test(execute_func) 1828 1829 1830# The description is just for documentation purposes. 1831def disable_mlir_bridge(description): # pylint: disable=unused-argument 1832 """Execute the test method only if MLIR bridge is not enabled.""" 1833 execute_func = not is_mlir_bridge_enabled() 1834 return _disable_test(execute_func) 1835 1836 1837# The description is just for documentation purposes. 1838def disable_tfrt(unused_description): 1839 1840 def disable_tfrt_impl(cls_or_func): 1841 """Execute the test only if tfrt is not enabled.""" 1842 1843 if tf_inspect.isclass(cls_or_func): 1844 if tfrt_utils.enabled(): 1845 return None 1846 else: 1847 return cls_or_func 1848 else: 1849 def decorator(func): 1850 1851 def decorated(self, *args, **kwargs): 1852 if tfrt_utils.enabled(): 1853 return 1854 else: 1855 return func(self, *args, **kwargs) 1856 1857 return decorated 1858 1859 if cls_or_func is not None: 1860 return decorator(cls_or_func) 1861 1862 return decorator 1863 1864 return disable_tfrt_impl 1865 1866 1867def for_all_test_methods(decorator, *args, **kwargs): 1868 """Generate class-level decorator from given method-level decorator. 1869 1870 It is expected for the given decorator to take some arguments and return 1871 a method that is then called on the test method to produce a decorated 1872 method. 1873 1874 Args: 1875 decorator: The decorator to apply. 1876 *args: Positional arguments 1877 **kwargs: Keyword arguments 1878 Returns: Function that will decorate a given classes test methods with the 1879 decorator. 1880 """ 1881 1882 def all_test_methods_impl(cls): 1883 """Apply decorator to all test methods in class.""" 1884 for name in dir(cls): 1885 value = getattr(cls, name) 1886 if callable(value) and name.startswith( 1887 "test") and (name != "test_session"): 1888 setattr(cls, name, decorator(*args, **kwargs)(value)) 1889 return cls 1890 1891 return all_test_methods_impl 1892 1893 1894# The description is just for documentation purposes. 1895def no_xla_auto_jit(description): # pylint: disable=unused-argument 1896 """This test is not intended to be run with XLA auto jit enabled.""" 1897 execute_func = not is_xla_enabled() 1898 return _disable_test(execute_func) 1899 1900 1901# The description is just for documentation purposes. 1902def xla_allow_fallback(description): # pylint: disable=unused-argument 1903 1904 def xla_allow_fallback_impl(func): 1905 """Allow fallback to TF even though testing xla.""" 1906 1907 def decorator(func): 1908 1909 def decorated(self, *args, **kwargs): 1910 if is_xla_enabled(): 1911 # Update the global XLABuildOpsPassFlags to enable lazy compilation, 1912 # which allows the compiler to fall back to TF classic. Remember the 1913 # old value so that we can reset it. 1914 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) 1915 result = func(self, *args, **kwargs) 1916 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) 1917 return result 1918 else: 1919 return func(self, *args, **kwargs) 1920 1921 return decorated 1922 1923 if func is not None: 1924 return decorator(func) 1925 1926 return decorator 1927 1928 return xla_allow_fallback_impl 1929 1930 1931# The description is just for documentation purposes. 1932def run_without_tensor_float_32(description): # pylint: disable=unused-argument 1933 """Execute test with TensorFloat-32 disabled. 1934 1935 While almost every real-world deep learning model runs fine with 1936 TensorFloat-32, many tests use assertAllClose or similar methods. 1937 TensorFloat-32 matmuls typically will cause such methods to fail with the 1938 default tolerances. 1939 1940 Args: 1941 description: A description used for documentation purposes, describing why 1942 the test requires TensorFloat-32 to be disabled. 1943 1944 Returns: 1945 Decorator which runs a test with TensorFloat-32 disabled. 1946 """ 1947 1948 def decorator(f): 1949 1950 @functools.wraps(f) 1951 def decorated(self, *args, **kwargs): 1952 allowed = config.tensor_float_32_execution_enabled() 1953 try: 1954 config.enable_tensor_float_32_execution(False) 1955 f(self, *args, **kwargs) 1956 finally: 1957 config.enable_tensor_float_32_execution(allowed) 1958 1959 return decorated 1960 1961 return decorator 1962 1963 1964# The description is just for documentation purposes. 1965def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 1966 """Execute all tests in a class with TensorFloat-32 disabled.""" 1967 return for_all_test_methods(run_without_tensor_float_32, description) 1968 1969 1970def matmul_without_tf32(a, b, *args, **kwargs): 1971 """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled. 1972 1973 This effectively runs matmul without TensorFloat-32. It should only be used in 1974 tests when verifying some other op or functions works correctly, e.g. to test 1975 `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In 1976 such cases, the matmul itself is not being tested so it's OK to run it with 1977 higher precision. 1978 1979 If a matmul itself is being tested, or some other op which uses matmul, use 1980 `run_without_tensor_float_32` instead. 1981 1982 This also casts complex64 inputs to complex128, since TensorFloat-32 can also 1983 be used with complex64 1984 1985 Args: 1986 a: First input to tf.linalg.matmul 1987 b: Second input to tf.linalg.matmul 1988 args: Other positional arguments to tf.linalg.matmul 1989 **kwargs: Other keyword arguments to tf.linalg.matmul 1990 1991 Returns: 1992 A tensor with the same type as `a`. 1993 """ 1994 if config.tensor_float_32_execution_enabled() and a.dtype == "float32": 1995 a = math_ops.cast(a, "float64") 1996 b = math_ops.cast(b, "float64") 1997 ret = math_ops.matmul(a, b, *args, **kwargs) 1998 return math_ops.cast(ret, a.dtype) 1999 elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": 2000 a = math_ops.cast(a, "complex128") 2001 b = math_ops.cast(b, "complex128") 2002 ret = math_ops.matmul(a, b, *args, **kwargs) 2003 return math_ops.cast(ret, a.dtype) 2004 else: 2005 return math_ops.matmul(a, b, *args, **kwargs) 2006 2007 2008class EagerSessionWarner(object): 2009 2010 def __getattr__(self, attr): 2011 raise AttributeError( 2012 "Trying to access properties or call methods on the result of " 2013 "self.session(), self.cached_session(), etc while eager execution " 2014 "is enabled. If you're porting this test case to TF 2.0, either " 2015 "adapt the test to work with eager execution or insert a call to " 2016 "tf.disable_eager_execution() in the main() function of this test " 2017 "file.") 2018 2019 2020@tf_export("test.TestCase") 2021class TensorFlowTestCase(googletest.TestCase): 2022 """Base class for tests that need to test TensorFlow.""" 2023 2024 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 2025 super(TensorFlowTestCase, self).__init__(methodName) 2026 if is_xla_enabled(): 2027 pywrap_tf_session.TF_SetXlaAutoJitMode("2") 2028 pywrap_tf_session.TF_SetXlaMinClusterSize(1) 2029 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False) 2030 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True) 2031 # Constant folding secretly runs code on TF:Classic CPU, so we also 2032 # disable it here. 2033 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) 2034 2035 # Check if the mlir bridge has been explicitly enabled or disabled. If 2036 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 2037 # or disable the bridge so do not update enable_mlir_bridge. 2038 if is_mlir_bridge_enabled(): 2039 context.context().enable_mlir_bridge = True 2040 elif is_mlir_bridge_enabled() is not None: 2041 context.context().enable_mlir_bridge = False 2042 2043 self._threads = [] 2044 self._tempdir = None 2045 self._cached_session = None 2046 self._test_start_time = None 2047 # This flag provides the ability to control whether the graph mode gets 2048 # initialized for TF1 or not. Initializing for TF1, which is what was 2049 # happening earlier, was preventing enablement of 'eager mode' in the test. 2050 self._set_default_seed = True 2051 2052 def setUp(self): 2053 super(TensorFlowTestCase, self).setUp() 2054 self._ClearCachedSession() 2055 random.seed(random_seed.DEFAULT_GRAPH_SEED) 2056 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 2057 # Note: The following line is necessary because some test methods may error 2058 # out from within nested graph contexts (e.g., via assertRaises and 2059 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 2060 # under certain versions of Python. That would cause 2061 # ops.reset_default_graph() to throw an exception if the stack were not 2062 # cleared first. 2063 ops._default_graph_stack.reset() # pylint: disable=protected-access 2064 ops.reset_default_graph() 2065 if self._set_default_seed: 2066 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 2067 # Reset summary writer in case another test used set_as_default() with their 2068 # summary writer. 2069 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 2070 summary_state.writer = None 2071 2072 # Avoiding calling setUp() for the poorly named test_session method. 2073 if self.id().endswith(".test_session"): 2074 self.skipTest("Not a test.") 2075 2076 self._test_start_time = time.time() 2077 2078 def tearDown(self): 2079 # If a subclass overrides setUp and doesn't call the parent class's setUp, 2080 # then we may not have set the start time. 2081 if self._test_start_time is not None: 2082 logging.info("time(%s): %ss", self.id(), 2083 round(time.time() - self._test_start_time, 2)) 2084 2085 for thread in self._threads: 2086 thread.check_termination() 2087 2088 self._ClearCachedSession() 2089 super(TensorFlowTestCase, self).tearDown() 2090 2091 def _ClearCachedSession(self): 2092 if self._cached_session is not None: 2093 self._cached_session.close() 2094 self._cached_session = None 2095 2096 def get_temp_dir(self): 2097 """Returns a unique temporary directory for the test to use. 2098 2099 If you call this method multiple times during in a test, it will return the 2100 same folder. However, across different runs the directories will be 2101 different. This will ensure that across different runs tests will not be 2102 able to pollute each others environment. 2103 If you need multiple unique directories within a single test, you should 2104 use tempfile.mkdtemp as follows: 2105 tempfile.mkdtemp(dir=self.get_temp_dir()): 2106 2107 Returns: 2108 string, the path to the unique temporary directory created for this test. 2109 """ 2110 if not self._tempdir: 2111 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 2112 return self._tempdir 2113 2114 @contextlib.contextmanager 2115 def captureWritesToStream(self, stream): 2116 """A context manager that captures the writes to a given stream. 2117 2118 This context manager captures all writes to a given stream inside of a 2119 `CapturedWrites` object. When this context manager is created, it yields 2120 the `CapturedWrites` object. The captured contents can be accessed by 2121 calling `.contents()` on the `CapturedWrites`. 2122 2123 For this function to work, the stream must have a file descriptor that 2124 can be modified using `os.dup` and `os.dup2`, and the stream must support 2125 a `.flush()` method. The default python sys.stdout and sys.stderr are 2126 examples of this. Note that this does not work in Colab or Jupyter 2127 notebooks, because those use alternate stdout streams. 2128 2129 Example: 2130 ```python 2131 class MyOperatorTest(test_util.TensorFlowTestCase): 2132 def testMyOperator(self): 2133 input = [1.0, 2.0, 3.0, 4.0, 5.0] 2134 with self.captureWritesToStream(sys.stdout) as captured: 2135 result = MyOperator(input).eval() 2136 self.assertStartsWith(captured.contents(), "This was printed.") 2137 ``` 2138 2139 Args: 2140 stream: The stream whose writes should be captured. This stream must have 2141 a file descriptor, support writing via using that file descriptor, and 2142 must have a `.flush()` method. 2143 2144 Yields: 2145 A `CapturedWrites` object that contains all writes to the specified stream 2146 made during this context. 2147 """ 2148 stream.flush() 2149 fd = stream.fileno() 2150 tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir()) 2151 tmp_file = open(tmp_file_path, "w") 2152 orig_fd = os.dup(fd) 2153 os.dup2(tmp_file.fileno(), fd) 2154 try: 2155 yield CapturedWrites(tmp_file_path) 2156 finally: 2157 tmp_file.close() 2158 os.dup2(orig_fd, fd) 2159 2160 def _AssertProtoEquals(self, a, b, msg=None): 2161 """Asserts that a and b are the same proto. 2162 2163 Uses ProtoEq() first, as it returns correct results 2164 for floating point attributes, and then use assertProtoEqual() 2165 in case of failure as it provides good error messages. 2166 2167 Args: 2168 a: a proto. 2169 b: another proto. 2170 msg: Optional message to report on failure. 2171 """ 2172 if not compare.ProtoEq(a, b): 2173 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 2174 2175 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 2176 """Asserts that message is same as parsed expected_message_ascii. 2177 2178 Creates another prototype of message, reads the ascii message into it and 2179 then compares them using self._AssertProtoEqual(). 2180 2181 Args: 2182 expected_message_maybe_ascii: proto message in original or ascii form. 2183 message: the message to validate. 2184 msg: Optional message to report on failure. 2185 """ 2186 if isinstance(expected_message_maybe_ascii, type(message)): 2187 expected_message = expected_message_maybe_ascii 2188 self._AssertProtoEquals(expected_message, message, msg=msg) 2189 elif isinstance(expected_message_maybe_ascii, (str, bytes)): 2190 expected_message = type(message)() 2191 text_format.Merge( 2192 expected_message_maybe_ascii, 2193 expected_message, 2194 descriptor_pool=descriptor_pool.Default()) 2195 self._AssertProtoEquals(expected_message, message, msg=msg) 2196 else: 2197 assert False, ("Can't compare protos of type %s and %s." % 2198 (type(expected_message_maybe_ascii), type(message))) 2199 2200 def assertProtoEqualsVersion( 2201 self, 2202 expected, 2203 actual, 2204 producer=versions.GRAPH_DEF_VERSION, 2205 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 2206 msg=None): 2207 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 2208 producer, min_consumer, expected) 2209 self.assertProtoEquals(expected, actual, msg=msg) 2210 2211 def assertStartsWith(self, actual, expected_start, msg=None): 2212 """Assert that actual.startswith(expected_start) is True. 2213 2214 Args: 2215 actual: str 2216 expected_start: str 2217 msg: Optional message to report on failure. 2218 """ 2219 if not actual.startswith(expected_start): 2220 fail_msg = "%r does not start with %r" % (actual, expected_start) 2221 fail_msg += " : %r" % (msg) if msg else "" 2222 self.fail(fail_msg) 2223 2224 def _eval_tensor(self, tensor): 2225 if tensor is None: 2226 return None 2227 elif callable(tensor): 2228 return self._eval_helper(tensor()) 2229 else: 2230 try: 2231 if sparse_tensor.is_sparse(tensor): 2232 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 2233 tensor.values.numpy(), 2234 tensor.dense_shape.numpy()) 2235 elif ragged_tensor.is_ragged(tensor): 2236 return ragged_tensor_value.RaggedTensorValue( 2237 self._eval_tensor(tensor.values), 2238 self._eval_tensor(tensor.row_splits)) 2239 elif isinstance(tensor, ops.IndexedSlices): 2240 return ops.IndexedSlicesValue( 2241 values=tensor.values.numpy(), 2242 indices=tensor.indices.numpy(), 2243 dense_shape=tensor.dense_shape.numpy()) 2244 # Convert tensors and composite tensors to numpy arrays. 2245 return nest.map_structure(lambda t: t.numpy(), tensor, 2246 expand_composites=True) 2247 except AttributeError as e: 2248 six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) 2249 2250 def _eval_helper(self, tensors): 2251 if tensors is None: 2252 return None 2253 return nest.map_structure(self._eval_tensor, tensors) 2254 2255 def evaluate(self, tensors): 2256 """Evaluates tensors and returns numpy values. 2257 2258 Args: 2259 tensors: A Tensor or a nested list/tuple of Tensors. 2260 2261 Returns: 2262 tensors numpy values. 2263 """ 2264 if context.executing_eagerly(): 2265 return self._eval_helper(tensors) 2266 else: 2267 sess = ops.get_default_session() 2268 if sess is None: 2269 with self.test_session() as sess: 2270 return sess.run(tensors) 2271 else: 2272 return sess.run(tensors) 2273 2274 # pylint: disable=g-doc-return-or-yield 2275 @contextlib.contextmanager 2276 def session(self, graph=None, config=None, use_gpu=True, force_gpu=False): 2277 """A context manager for a TensorFlow Session for use in executing tests. 2278 2279 Note that this will set this session and the graph as global defaults. 2280 2281 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2282 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2283 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2284 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2285 the CPU. 2286 2287 Example: 2288 2289 ``` python 2290 class MyOperatorTest(test_util.TensorFlowTestCase): 2291 def testMyOperator(self): 2292 with self.session(): 2293 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2294 result = MyOperator(valid_input).eval() 2295 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2296 invalid_input = [-1.0, 2.0, 7.0] 2297 with self.assertRaisesOpError("negative input not supported"): 2298 MyOperator(invalid_input).eval() 2299 ``` 2300 2301 Args: 2302 graph: Optional graph to use during the returned session. 2303 config: An optional config_pb2.ConfigProto to use to configure the 2304 session. 2305 use_gpu: If True, attempt to run as many ops as possible on GPU. 2306 force_gpu: If True, pin all ops to `/device:GPU:0`. 2307 2308 Yields: 2309 A Session object that should be used as a context manager to surround 2310 the graph building and execution code in a test case. 2311 """ 2312 if context.executing_eagerly(): 2313 yield EagerSessionWarner() 2314 else: 2315 with self._create_session(graph, config, force_gpu) as sess: 2316 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 2317 yield sess 2318 2319 @contextlib.contextmanager 2320 def cached_session(self, 2321 graph=None, 2322 config=None, 2323 use_gpu=True, 2324 force_gpu=False): 2325 """Returns a TensorFlow Session for use in executing tests. 2326 2327 This method behaves differently than self.session(): for performance reasons 2328 `cached_session` will by default reuse the same session within the same 2329 test. The session returned by this function will only be closed at the end 2330 of the test (in the TearDown function). 2331 2332 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2333 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2334 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2335 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2336 the CPU. 2337 2338 Example: 2339 ```python 2340 class MyOperatorTest(test_util.TensorFlowTestCase): 2341 def testMyOperator(self): 2342 with self.cached_session() as sess: 2343 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2344 result = MyOperator(valid_input).eval() 2345 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2346 invalid_input = [-1.0, 2.0, 7.0] 2347 with self.assertRaisesOpError("negative input not supported"): 2348 MyOperator(invalid_input).eval() 2349 ``` 2350 2351 Args: 2352 graph: Optional graph to use during the returned session. 2353 config: An optional config_pb2.ConfigProto to use to configure the 2354 session. 2355 use_gpu: If True, attempt to run as many ops as possible on GPU. 2356 force_gpu: If True, pin all ops to `/device:GPU:0`. 2357 2358 Yields: 2359 A Session object that should be used as a context manager to surround 2360 the graph building and execution code in a test case. 2361 """ 2362 if context.executing_eagerly(): 2363 yield FakeEagerSession(self) 2364 else: 2365 sess = self._get_cached_session( 2366 graph, config, force_gpu, crash_if_inconsistent_args=True) 2367 with self._constrain_devices_and_set_default(sess, use_gpu, 2368 force_gpu) as cached: 2369 yield cached 2370 2371 @contextlib.contextmanager 2372 @deprecation.deprecated(None, "Use `self.session()` or " 2373 "`self.cached_session()` instead.") 2374 def test_session(self, 2375 graph=None, 2376 config=None, 2377 use_gpu=True, 2378 force_gpu=False): 2379 """Use cached_session instead.""" 2380 if self.id().endswith(".test_session"): 2381 self.skipTest( 2382 "Tests that have the name \"test_session\" are automatically skipped " 2383 "by TensorFlow test fixture, as the name is reserved for creating " 2384 "sessions within tests. Please rename your test if you have a test " 2385 "with this name.") 2386 if context.executing_eagerly(): 2387 yield None 2388 else: 2389 if graph is None: 2390 sess = self._get_cached_session( 2391 graph, config, force_gpu, crash_if_inconsistent_args=False) 2392 with self._constrain_devices_and_set_default(sess, use_gpu, 2393 force_gpu) as cached: 2394 yield cached 2395 else: 2396 with self.session(graph, config, use_gpu, force_gpu) as sess: 2397 yield sess 2398 2399 # pylint: enable=g-doc-return-or-yield 2400 2401 class _CheckedThread(object): 2402 """A wrapper class for Thread that asserts successful completion. 2403 2404 This class should be created using the TensorFlowTestCase.checkedThread() 2405 method. 2406 """ 2407 2408 def __init__(self, testcase, target, args=None, kwargs=None): 2409 """Constructs a new instance of _CheckedThread. 2410 2411 Args: 2412 testcase: The TensorFlowTestCase for which this thread is being created. 2413 target: A callable object representing the code to be executed in the 2414 thread. 2415 args: A tuple of positional arguments that will be passed to target. 2416 kwargs: A dictionary of keyword arguments that will be passed to target. 2417 """ 2418 self._testcase = testcase 2419 self._target = target 2420 self._args = () if args is None else args 2421 self._kwargs = {} if kwargs is None else kwargs 2422 self._thread = threading.Thread(target=self._protected_run) 2423 self._exception = None 2424 2425 self._is_thread_joined = False 2426 2427 def _protected_run(self): 2428 """Target for the wrapper thread. Sets self._exception on failure.""" 2429 try: 2430 self._target(*self._args, **self._kwargs) 2431 except Exception as e: # pylint: disable=broad-except 2432 self._exception = e 2433 2434 def start(self): 2435 """Starts the thread's activity. 2436 2437 This must be called at most once per _CheckedThread object. It arranges 2438 for the object's target to be invoked in a separate thread of control. 2439 """ 2440 self._thread.start() 2441 2442 def join(self): 2443 """Blocks until the thread terminates. 2444 2445 Raises: 2446 self._testcase.failureException: If the thread terminates with due to 2447 an exception. 2448 """ 2449 self._is_thread_joined = True 2450 self._thread.join() 2451 if self._exception is not None: 2452 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 2453 2454 def is_alive(self): 2455 """Returns whether the thread is alive. 2456 2457 This method returns True just before the run() method starts 2458 until just after the run() method terminates. 2459 2460 Returns: 2461 True if the thread is alive, otherwise False. 2462 """ 2463 return self._thread.is_alive() 2464 2465 def check_termination(self): 2466 """Returns whether the checked thread was properly used and did terminate. 2467 2468 Every checked thread should be "join"ed after starting, and before the 2469 test tears down. If it is not joined, it is possible the thread will hang 2470 and cause flaky failures in tests. 2471 2472 Raises: 2473 self._testcase.failureException: If check_termination was called before 2474 thread was joined. 2475 2476 RuntimeError: If the thread is not terminated. This means thread was not 2477 joined with the main thread. 2478 """ 2479 if self._is_thread_joined: 2480 if self.is_alive(): 2481 raise RuntimeError( 2482 "Thread was not joined with main thread, and is still running " 2483 "when the test finished.") 2484 else: 2485 self._testcase.fail("A checked thread was not joined.") 2486 2487 def checkedThread(self, target, args=None, kwargs=None): 2488 """Returns a Thread wrapper that asserts 'target' completes successfully. 2489 2490 This method should be used to create all threads in test cases, as 2491 otherwise there is a risk that a thread will silently fail, and/or 2492 assertions made in the thread will not be respected. 2493 2494 Args: 2495 target: A callable object to be executed in the thread. 2496 args: The argument tuple for the target invocation. Defaults to (). 2497 kwargs: A dictionary of keyword arguments for the target invocation. 2498 Defaults to {}. 2499 2500 Returns: 2501 A wrapper for threading.Thread that supports start() and join() methods. 2502 """ 2503 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 2504 self._threads.append(ret) 2505 return ret 2506 2507 # pylint: enable=invalid-name 2508 @py_func_if_in_function 2509 def assertNear(self, f1, f2, err, msg=None): 2510 """Asserts that two floats are near each other. 2511 2512 Checks that |f1 - f2| < err and asserts a test failure 2513 if not. 2514 2515 Args: 2516 f1: A float value. 2517 f2: A float value. 2518 err: A float value. 2519 msg: An optional string message to append to the failure message. 2520 """ 2521 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 2522 self.assertTrue( 2523 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 2524 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 2525 2526 @py_func_if_in_function 2527 def assertArrayNear(self, farray1, farray2, err, msg=None): 2528 """Asserts that two float arrays are near each other. 2529 2530 Checks that for all elements of farray1 and farray2 2531 |f1 - f2| < err. Asserts a test failure if not. 2532 2533 Args: 2534 farray1: a list of float values. 2535 farray2: a list of float values. 2536 err: a float value. 2537 msg: Optional message to report on failure. 2538 """ 2539 self.assertEqual(len(farray1), len(farray2), msg=msg) 2540 for f1, f2 in zip(farray1, farray2): 2541 self.assertNear(float(f1), float(f2), err, msg=msg) 2542 2543 def _NDArrayNear(self, ndarray1, ndarray2, err): 2544 return np.linalg.norm(ndarray1 - ndarray2) < err 2545 2546 @py_func_if_in_function 2547 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 2548 """Asserts that two numpy arrays have near values. 2549 2550 Args: 2551 ndarray1: a numpy ndarray. 2552 ndarray2: a numpy ndarray. 2553 err: a float. The maximum absolute difference allowed. 2554 msg: Optional message to report on failure. 2555 """ 2556 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 2557 2558 def _GetNdArray(self, a): 2559 # If a is tensor-like then convert it to ndarray 2560 if tensor_util.is_tf_type(a): 2561 if isinstance(a, ops._EagerTensorBase): 2562 a = a.numpy() 2563 else: 2564 a = self.evaluate(a) 2565 if not isinstance(a, np.ndarray): 2566 return np.array(a) 2567 return a 2568 2569 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2570 a = self._GetNdArray(a) 2571 b = self._GetNdArray(b) 2572 # When the array rank is small, print its contents. Numpy array printing is 2573 # implemented using inefficient recursion so prints can cause tests to 2574 # time out. 2575 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 2576 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 2577 "%s.") % (a.shape, b.shape, b) 2578 else: 2579 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 2580 b.shape) 2581 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 2582 2583 msgs = [msg] 2584 # np.allclose does not always work for our custom bfloat16 extension type 2585 # when type promotions are involved, so we first cast any bfloat16 arrays 2586 # to float32. 2587 a_dtype = a.dtype 2588 a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a 2589 b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b 2590 if not np.allclose(a, b, rtol=rtol, atol=atol): 2591 # Adds more details to np.testing.assert_allclose. 2592 # 2593 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 2594 # checks whether two arrays are element-wise equal within a 2595 # tolerance. The relative difference (rtol * abs(b)) and the 2596 # absolute difference atol are added together to compare against 2597 # the absolute difference between a and b. Here, we want to 2598 # tell user which elements violate such conditions. 2599 cond = np.logical_or( 2600 np.abs(a - b) > atol + rtol * np.abs(b), 2601 np.isnan(a) != np.isnan(b)) 2602 if a.ndim: 2603 x = a[np.where(cond)] 2604 y = b[np.where(cond)] 2605 msgs.append("not close where = {}".format(np.where(cond))) 2606 else: 2607 # np.where is broken for scalars 2608 x, y = a, b 2609 msgs.append("not close lhs = {}".format(x)) 2610 msgs.append("not close rhs = {}".format(y)) 2611 msgs.append("not close dif = {}".format(np.abs(x - y))) 2612 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 2613 msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape)) 2614 # TODO(xpan): There seems to be a bug: 2615 # tensorflow/compiler/tests:binary_ops_test pass with float32 2616 # nan even though the equal_nan is False by default internally. 2617 np.testing.assert_allclose( 2618 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 2619 2620 def _assertAllCloseRecursive(self, 2621 a, 2622 b, 2623 rtol=1e-6, 2624 atol=1e-6, 2625 path=None, 2626 msg=None): 2627 path = path or [] 2628 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "") 2629 msg = msg if msg else "" 2630 2631 # Check if a and/or b are namedtuples. 2632 if hasattr(a, "_asdict"): 2633 a = a._asdict() 2634 if hasattr(b, "_asdict"): 2635 b = b._asdict() 2636 a_is_dict = isinstance(a, collections_abc.Mapping) 2637 if a_is_dict != isinstance(b, collections_abc.Mapping): 2638 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 2639 (path_str, path_str, msg)) 2640 if a_is_dict: 2641 self.assertItemsEqual( 2642 a.keys(), 2643 b.keys(), 2644 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 2645 (path_str, a.keys(), path_str, b.keys(), msg)) 2646 for k in a: 2647 path.append(k) 2648 self._assertAllCloseRecursive( 2649 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 2650 del path[-1] 2651 elif isinstance(a, (list, tuple)): 2652 # Try to directly compare a, b as ndarrays; if not work, then traverse 2653 # through the sequence, which is more expensive. 2654 try: 2655 a_as_ndarray = self._GetNdArray(a) 2656 b_as_ndarray = self._GetNdArray(b) 2657 self._assertArrayLikeAllClose( 2658 a_as_ndarray, 2659 b_as_ndarray, 2660 rtol=rtol, 2661 atol=atol, 2662 msg="Mismatched value: a%s is different from b%s. %s" % 2663 (path_str, path_str, msg)) 2664 except (ValueError, TypeError) as e: 2665 if len(a) != len(b): 2666 raise ValueError( 2667 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 2668 (path_str, len(a), path_str, len(b), msg)) 2669 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 2670 path.append(str(idx)) 2671 self._assertAllCloseRecursive( 2672 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 2673 del path[-1] 2674 # a and b are ndarray like objects 2675 else: 2676 try: 2677 self._assertArrayLikeAllClose( 2678 a, 2679 b, 2680 rtol=rtol, 2681 atol=atol, 2682 msg=("Mismatched value: a%s is different from b%s. %s" % 2683 (path_str, path_str, msg))) 2684 except TypeError as e: 2685 msg = ("Error: a%s has %s, but b%s has %s. %s" % 2686 (path_str, type(a), path_str, type(b), msg)) 2687 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 2688 raise 2689 2690 @py_func_if_in_function 2691 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2692 """Asserts that two structures of numpy arrays or Tensors, have near values. 2693 2694 `a` and `b` can be arbitrarily nested structures. A layer of a nested 2695 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 2696 2697 Note: the implementation follows 2698 [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html) 2699 (and numpy.testing.assert_allclose). It checks whether two arrays are 2700 element-wise equal within a tolerance. The relative difference 2701 (`rtol * abs(b)`) and the absolute difference `atol` are added together 2702 to compare against the absolute difference between `a` and `b`. 2703 2704 Args: 2705 a: The expected numpy `ndarray`, or anything that can be converted into a 2706 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2707 structure of these. 2708 b: The actual numpy `ndarray`, or anything that can be converted into a 2709 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2710 structure of these. 2711 rtol: relative tolerance. 2712 atol: absolute tolerance. 2713 msg: Optional message to report on failure. 2714 2715 Raises: 2716 ValueError: if only one of `a[p]` and `b[p]` is a dict or 2717 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 2718 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 2719 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 2720 """ 2721 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b): 2722 return self._assertRaggedClose(a, b, rtol, atol, msg) 2723 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 2724 2725 @py_func_if_in_function 2726 def assertAllCloseAccordingToType(self, 2727 a, 2728 b, 2729 rtol=1e-6, 2730 atol=1e-6, 2731 float_rtol=1e-6, 2732 float_atol=1e-6, 2733 half_rtol=1e-3, 2734 half_atol=1e-3, 2735 bfloat16_rtol=1e-2, 2736 bfloat16_atol=1e-2, 2737 msg=None): 2738 """Like assertAllClose, but also suitable for comparing fp16 arrays. 2739 2740 In particular, the tolerance is reduced to 1e-3 if at least 2741 one of the arguments is of type float16. 2742 2743 Args: 2744 a: the expected numpy ndarray or anything can be converted to one. 2745 b: the actual numpy ndarray or anything can be converted to one. 2746 rtol: relative tolerance. 2747 atol: absolute tolerance. 2748 float_rtol: relative tolerance for float32. 2749 float_atol: absolute tolerance for float32. 2750 half_rtol: relative tolerance for float16. 2751 half_atol: absolute tolerance for float16. 2752 bfloat16_rtol: relative tolerance for bfloat16. 2753 bfloat16_atol: absolute tolerance for bfloat16. 2754 msg: Optional message to report on failure. 2755 """ 2756 a = self._GetNdArray(a) 2757 b = self._GetNdArray(b) 2758 # types with lower tol are put later to overwrite previous ones. 2759 if (a.dtype == np.float32 or b.dtype == np.float32 or 2760 a.dtype == np.complex64 or b.dtype == np.complex64): 2761 rtol = max(rtol, float_rtol) 2762 atol = max(atol, float_atol) 2763 if a.dtype == np.float16 or b.dtype == np.float16: 2764 rtol = max(rtol, half_rtol) 2765 atol = max(atol, half_atol) 2766 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 2767 b.dtype == dtypes.bfloat16.as_numpy_dtype): 2768 rtol = max(rtol, bfloat16_rtol) 2769 atol = max(atol, bfloat16_atol) 2770 2771 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2772 2773 @py_func_if_in_function 2774 def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2775 """Assert that two numpy arrays, or Tensors, do not have near values. 2776 2777 Args: 2778 a: The expected numpy `ndarray`, or anything that can be converted into a 2779 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2780 structure of these. 2781 b: The actual numpy `ndarray`, or anything that can be converted into a 2782 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2783 structure of these. 2784 rtol: relative tolerance. 2785 atol: absolute tolerance. 2786 msg: Optional message to report on failure. 2787 2788 Raises: 2789 AssertionError: If `a` and `b` are unexpectedly close at all elements. 2790 """ 2791 try: 2792 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2793 except AssertionError: 2794 return 2795 msg = msg or "" 2796 raise AssertionError("The two values are close at all elements. %s" % msg) 2797 2798 @py_func_if_in_function 2799 def assertAllEqual(self, a, b, msg=None): 2800 """Asserts that two numpy arrays or Tensors have the same values. 2801 2802 Args: 2803 a: the expected numpy ndarray or anything can be converted to one. 2804 b: the actual numpy ndarray or anything can be converted to one. 2805 msg: Optional message to report on failure. 2806 """ 2807 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)): 2808 return self._assertRaggedEqual(a, b, msg) 2809 msg = msg if msg else "" 2810 a = self._GetNdArray(a) 2811 b = self._GetNdArray(b) 2812 # Arbitrary bounds so that we don't print giant tensors. 2813 if (b.ndim <= 3 or b.size < 500): 2814 self.assertEqual( 2815 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2816 " Contents: %r. \n%s." % (a.shape, b.shape, b, msg)) 2817 else: 2818 self.assertEqual( 2819 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2820 " %s" % (a.shape, b.shape, msg)) 2821 2822 same = (a == b) 2823 2824 if (a.dtype in [ 2825 np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype 2826 ]): 2827 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 2828 msgs = [msg] 2829 if not np.all(same): 2830 # Adds more details to np.testing.assert_array_equal. 2831 diff = np.logical_not(same) 2832 if a.ndim: 2833 x = a[np.where(diff)] 2834 y = b[np.where(diff)] 2835 msgs.append("not equal where = {}".format(np.where(diff))) 2836 else: 2837 # np.where is broken for scalars 2838 x, y = a, b 2839 msgs.append("not equal lhs = %r" % x) 2840 msgs.append("not equal rhs = %r" % y) 2841 2842 # Handle mixed string types as a result of PY2to3 migration. That is, the 2843 # mixing between bytes (b-prefix strings, PY2 default) and unicodes 2844 # (u-prefix strings, PY3 default). 2845 if six.PY3: 2846 if (a.dtype.kind != b.dtype.kind and 2847 {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})): 2848 a_list = [] 2849 b_list = [] 2850 # OK to flatten `a` and `b` because they are guaranteed to have the 2851 # same shape. 2852 for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]: 2853 for item in flat_arr: 2854 if isinstance(item, str): 2855 out_list.append(item.encode("utf-8")) 2856 else: 2857 out_list.append(item) 2858 a = np.array(a_list) 2859 b = np.array(b_list) 2860 2861 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 2862 2863 @py_func_if_in_function 2864 def assertNotAllEqual(self, a, b, msg=None): 2865 """Asserts that two numpy arrays or Tensors do not have the same values. 2866 2867 Args: 2868 a: the expected numpy ndarray or anything can be converted to one. 2869 b: the actual numpy ndarray or anything can be converted to one. 2870 msg: Optional message to report on failure. 2871 """ 2872 try: 2873 self.assertAllEqual(a, b) 2874 except AssertionError: 2875 return 2876 raise AssertionError("The two values are equal at all elements. %s" % msg) 2877 2878 @py_func_if_in_function 2879 def assertAllGreater(self, a, comparison_target): 2880 """Assert element values are all greater than a target value. 2881 2882 Args: 2883 a: The numpy `ndarray`, or anything that can be converted into a numpy 2884 `ndarray` (including Tensor). 2885 comparison_target: The target value of comparison. 2886 """ 2887 a = self._GetNdArray(a) 2888 self.assertGreater(np.min(a), comparison_target) 2889 2890 @py_func_if_in_function 2891 def assertAllLess(self, a, comparison_target): 2892 """Assert element values are all less than a target value. 2893 2894 Args: 2895 a: The numpy `ndarray`, or anything that can be converted into a numpy 2896 `ndarray` (including Tensor). 2897 comparison_target: The target value of comparison. 2898 """ 2899 a = self._GetNdArray(a) 2900 self.assertLess(np.max(a), comparison_target) 2901 2902 @py_func_if_in_function 2903 def assertAllGreaterEqual(self, a, comparison_target): 2904 """Assert element values are all greater than or equal to a target value. 2905 2906 Args: 2907 a: The numpy `ndarray`, or anything that can be converted into a numpy 2908 `ndarray` (including Tensor). 2909 comparison_target: The target value of comparison. 2910 """ 2911 a = self._GetNdArray(a) 2912 self.assertGreaterEqual(np.min(a), comparison_target) 2913 2914 @py_func_if_in_function 2915 def assertAllLessEqual(self, a, comparison_target): 2916 """Assert element values are all less than or equal to a target value. 2917 2918 Args: 2919 a: The numpy `ndarray`, or anything that can be converted into a numpy 2920 `ndarray` (including Tensor). 2921 comparison_target: The target value of comparison. 2922 """ 2923 a = self._GetNdArray(a) 2924 self.assertLessEqual(np.max(a), comparison_target) 2925 2926 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 2927 """Generate a summary of ndarray subscripts as a list of str. 2928 2929 If limit == N, this method will print up to the first N subscripts on 2930 separate 2931 lines. A line of ellipses (...) will be appended at the end if the number of 2932 subscripts exceeds N. 2933 2934 Args: 2935 subscripts: The tensor (np.ndarray) subscripts, of the same format as 2936 np.where()'s return value, i.e., a tuple of arrays with each array 2937 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 2938 value: (np.ndarray) value of the tensor. 2939 limit: (int) The maximum number of indices to print. 2940 indent: (int) Number of characters to indent at the beginning of each 2941 line. 2942 2943 Returns: 2944 (list of str) the multi-line representation of the subscripts and values, 2945 potentially with omission at the end. 2946 """ 2947 lines = [] 2948 subscripts = np.transpose(subscripts) 2949 prefix = " " * indent 2950 if np.ndim(value) == 0: 2951 return [prefix + "[0] : " + str(value)] 2952 for subscript in itertools.islice(subscripts, limit): 2953 lines.append(prefix + str(subscript) + " : " + 2954 str(value[tuple(subscript)])) 2955 if len(subscripts) > limit: 2956 lines.append(prefix + "...") 2957 return lines 2958 2959 @py_func_if_in_function 2960 def assertAllInRange(self, 2961 target, 2962 lower_bound, 2963 upper_bound, 2964 open_lower_bound=False, 2965 open_upper_bound=False): 2966 """Assert that elements in a Tensor are all in a given range. 2967 2968 Args: 2969 target: The numpy `ndarray`, or anything that can be converted into a 2970 numpy `ndarray` (including Tensor). 2971 lower_bound: lower bound of the range 2972 upper_bound: upper bound of the range 2973 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 2974 than the default >=) 2975 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 2976 than the default <=) 2977 2978 Raises: 2979 AssertionError: 2980 if the value tensor does not have an ordered numeric type (float* or 2981 int*), or 2982 if there are nan values, or 2983 if any of the elements do not fall in the specified range. 2984 """ 2985 target = self._GetNdArray(target) 2986 if not (np.issubdtype(target.dtype, np.floating) or 2987 np.issubdtype(target.dtype, np.integer)): 2988 raise AssertionError( 2989 "The value of %s does not have an ordered numeric type, instead it " 2990 "has type: %s" % (target, target.dtype)) 2991 2992 nan_subscripts = np.where(np.isnan(target)) 2993 if np.size(nan_subscripts): 2994 raise AssertionError( 2995 "%d of the %d element(s) are NaN. " 2996 "Subscripts(s) and value(s) of the NaN element(s):\n" % 2997 (len(nan_subscripts[0]), np.size(target)) + 2998 "\n".join(self._format_subscripts(nan_subscripts, target))) 2999 3000 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 3001 str(upper_bound) + (")" if open_upper_bound else "]")) 3002 3003 violations = ( 3004 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 3005 target, lower_bound)) 3006 violations = np.logical_or( 3007 violations, 3008 np.greater_equal(target, upper_bound) 3009 if open_upper_bound else np.greater(target, upper_bound)) 3010 violation_subscripts = np.where(violations) 3011 if np.size(violation_subscripts): 3012 raise AssertionError( 3013 "%d of the %d element(s) are outside the range %s. " % 3014 (len(violation_subscripts[0]), np.size(target), range_str) + 3015 "Subscript(s) and value(s) of the offending elements:\n" + 3016 "\n".join(self._format_subscripts(violation_subscripts, target))) 3017 3018 @py_func_if_in_function 3019 def assertAllInSet(self, target, expected_set): 3020 """Assert that elements of a Tensor are all in a given closed set. 3021 3022 Args: 3023 target: The numpy `ndarray`, or anything that can be converted into a 3024 numpy `ndarray` (including Tensor). 3025 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 3026 of the value of `target` are expected to fall into. 3027 3028 Raises: 3029 AssertionError: 3030 if any of the elements do not fall into `expected_set`. 3031 """ 3032 target = self._GetNdArray(target) 3033 3034 # Elements in target that are not in expected_set. 3035 diff = np.setdiff1d(target.flatten(), list(expected_set)) 3036 if np.size(diff): 3037 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 3038 (np.size(diff), expected_set, diff)) 3039 3040 @py_func_if_in_function 3041 def assertDTypeEqual(self, target, expected_dtype): 3042 """Assert ndarray data type is equal to expected. 3043 3044 Args: 3045 target: The numpy `ndarray`, or anything that can be converted into a 3046 numpy `ndarray` (including Tensor). 3047 expected_dtype: Expected data type. 3048 """ 3049 target = self._GetNdArray(target) 3050 if not isinstance(target, list): 3051 arrays = [target] 3052 for arr in arrays: 3053 self.assertEqual(arr.dtype, expected_dtype) 3054 3055 # pylint: disable=g-doc-return-or-yield 3056 @contextlib.contextmanager 3057 def assertRaisesWithPredicateMatch(self, exception_type, 3058 expected_err_re_or_predicate): 3059 """Returns a context manager to enclose code expected to raise an exception. 3060 3061 If the exception is an OpError, the op stack is also included in the message 3062 predicate search. 3063 3064 Args: 3065 exception_type: The expected type of exception that should be raised. 3066 expected_err_re_or_predicate: If this is callable, it should be a function 3067 of one argument that inspects the passed-in exception and returns True 3068 (success) or False (please fail the test). Otherwise, the error message 3069 is expected to match this regular expression partially. 3070 3071 Returns: 3072 A context manager to surround code that is expected to raise an 3073 exception. 3074 """ 3075 if callable(expected_err_re_or_predicate): 3076 predicate = expected_err_re_or_predicate 3077 else: 3078 3079 def predicate(e): 3080 err_str = e.message if isinstance(e, errors.OpError) else str(e) 3081 op = e.op if isinstance(e, errors.OpError) else None 3082 while op is not None: 3083 err_str += "\nCaused by: " + op.name 3084 op = op._original_op # pylint: disable=protected-access 3085 logging.info("Searching within error strings: '%s' within '%s'", 3086 expected_err_re_or_predicate, err_str) 3087 return re.search(expected_err_re_or_predicate, err_str) 3088 3089 try: 3090 yield 3091 self.fail(exception_type.__name__ + " not raised") 3092 except Exception as e: # pylint: disable=broad-except 3093 if not isinstance(e, exception_type) or not predicate(e): 3094 raise AssertionError("Exception of type %s: %s" % 3095 (str(type(e)), str(e))) 3096 3097 # pylint: enable=g-doc-return-or-yield 3098 3099 def assertRaisesOpError(self, expected_err_re_or_predicate): 3100 return self.assertRaisesWithPredicateMatch(errors.OpError, 3101 expected_err_re_or_predicate) 3102 3103 def assertShapeEqual(self, np_array, tf_tensor, msg=None): 3104 """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. 3105 3106 Args: 3107 np_array: A Numpy ndarray or Numpy scalar. 3108 tf_tensor: A Tensor. 3109 msg: Optional message to report on failure. 3110 3111 Raises: 3112 TypeError: If the arguments have the wrong type. 3113 """ 3114 if not isinstance(np_array, (np.ndarray, np.generic)): 3115 raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") 3116 if not isinstance(tf_tensor, ops.Tensor): 3117 raise TypeError("tf_tensor must be a Tensor") 3118 self.assertAllEqual( 3119 np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) 3120 3121 def assertDeviceEqual(self, device1, device2, msg=None): 3122 """Asserts that the two given devices are the same. 3123 3124 Args: 3125 device1: A string device name or TensorFlow `DeviceSpec` object. 3126 device2: A string device name or TensorFlow `DeviceSpec` object. 3127 msg: Optional message to report on failure. 3128 """ 3129 device1 = pydev.canonical_name(device1) 3130 device2 = pydev.canonical_name(device2) 3131 self.assertEqual( 3132 device1, device2, 3133 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 3134 3135 def _GetPyList(self, a): 3136 """Converts `a` to a nested python list.""" 3137 if isinstance(a, ragged_tensor.RaggedTensor): 3138 return self.evaluate(a).to_list() 3139 elif isinstance(a, ops.Tensor): 3140 a = self.evaluate(a) 3141 return a.tolist() if isinstance(a, np.ndarray) else a 3142 elif isinstance(a, np.ndarray): 3143 return a.tolist() 3144 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 3145 return a.to_list() 3146 else: 3147 return np.array(a).tolist() 3148 3149 def _assertRaggedEqual(self, a, b, msg): 3150 """Asserts that two ragged tensors are equal.""" 3151 a_list = self._GetPyList(a) 3152 b_list = self._GetPyList(b) 3153 self.assertEqual(a_list, b_list, msg) 3154 3155 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3156 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3157 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3158 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3159 3160 def _assertRaggedClose(self, a, b, rtol, atol, msg=None): 3161 a_list = self._GetPyList(a) 3162 b_list = self._GetPyList(b) 3163 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg) 3164 3165 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 3166 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 3167 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 3168 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 3169 3170 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): 3171 self.assertEqual(type(a), type(b)) 3172 if isinstance(a, (list, tuple)): 3173 self.assertLen(a, len(b), "Length differs for %s" % path) 3174 for i in range(len(a)): 3175 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg, 3176 "%s[%s]" % (path, i)) 3177 else: 3178 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) 3179 3180 # Fix Python 3+ compatibility issues 3181 if not six.PY2: 3182 # pylint: disable=invalid-name 3183 3184 # Silence a deprecation warning 3185 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 3186 3187 # assertItemsEqual is assertCountEqual as of 3.2. 3188 assertItemsEqual = googletest.TestCase.assertCountEqual 3189 3190 # pylint: enable=invalid-name 3191 3192 @contextlib.contextmanager 3193 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 3194 """Set the session and its graph to global default and constrain devices.""" 3195 if context.executing_eagerly(): 3196 yield None 3197 else: 3198 with sess.graph.as_default(), sess.as_default(): 3199 if force_gpu: 3200 # Use the name of an actual device if one is detected, or 3201 # '/device:GPU:0' otherwise 3202 gpu_name = gpu_device_name() 3203 if not gpu_name: 3204 gpu_name = "/device:GPU:0" 3205 with sess.graph.device(gpu_name): 3206 yield sess 3207 elif use_gpu: 3208 yield sess 3209 else: 3210 with sess.graph.device("/device:CPU:0"): 3211 yield sess 3212 3213 def _create_session(self, graph, config, force_gpu): 3214 """See session() for details.""" 3215 3216 def prepare_config(config): 3217 """Returns a config for sessions. 3218 3219 Args: 3220 config: An optional config_pb2.ConfigProto to use to configure the 3221 session. 3222 3223 Returns: 3224 A config_pb2.ConfigProto object. 3225 """ 3226 # TODO(b/114333779): Enforce allow_soft_placement=False when 3227 # use_gpu=False. Currently many tests rely on the fact that any device 3228 # will be used even when a specific device is supposed to be used. 3229 allow_soft_placement = not force_gpu 3230 if config is None: 3231 config = context.context().config 3232 config.allow_soft_placement = allow_soft_placement 3233 elif not allow_soft_placement and config.allow_soft_placement: 3234 config_copy = context.context().config 3235 config = config_copy 3236 config.allow_soft_placement = False 3237 # Don't perform optimizations for tests so we don't inadvertently run 3238 # gpu ops on cpu 3239 config.graph_options.optimizer_options.opt_level = -1 3240 # Disable Grappler constant folding since some tests & benchmarks 3241 # use constant input and become meaningless after constant folding. 3242 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 3243 # GRAPPLER TEAM. 3244 config.graph_options.rewrite_options.constant_folding = ( 3245 rewriter_config_pb2.RewriterConfig.OFF) 3246 config.graph_options.rewrite_options.pin_to_host_optimization = ( 3247 rewriter_config_pb2.RewriterConfig.OFF) 3248 return config 3249 3250 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 3251 3252 def _get_cached_session(self, 3253 graph=None, 3254 config=None, 3255 force_gpu=False, 3256 crash_if_inconsistent_args=True): 3257 """See cached_session() for documentation.""" 3258 if self._cached_session is None: 3259 sess = self._create_session( 3260 graph=graph, config=config, force_gpu=force_gpu) 3261 self._cached_session = sess 3262 self._cached_graph = graph 3263 self._cached_config = config 3264 self._cached_force_gpu = force_gpu 3265 return sess 3266 else: 3267 if crash_if_inconsistent_args and self._cached_graph is not graph: 3268 raise ValueError("The graph used to get the cached session is " 3269 "different than the one that was used to create the " 3270 "session. Maybe create a new session with " 3271 "self.session()") 3272 if crash_if_inconsistent_args and self._cached_config is not config: 3273 raise ValueError("The config used to get the cached session is " 3274 "different than the one that was used to create the " 3275 "session. Maybe create a new session with " 3276 "self.session()") 3277 if crash_if_inconsistent_args and (self._cached_force_gpu is 3278 not force_gpu): 3279 raise ValueError( 3280 "The force_gpu value used to get the cached session is " 3281 "different than the one that was used to create the " 3282 "session. Maybe create a new session with " 3283 "self.session()") 3284 return self._cached_session 3285 3286 3287@tf_export("test.create_local_cluster") 3288def create_local_cluster(num_workers, 3289 num_ps, 3290 protocol="grpc", 3291 worker_config=None, 3292 ps_config=None): 3293 """Create and start local servers and return the associated `Server` objects. 3294 3295 "PS" stands for "parameter server": a task responsible for storing and 3296 updating the model's parameters. Other tasks send updates to these parameters 3297 as they work on optimizing the parameters. This particular division of labor 3298 between tasks is not required, but is common for distributed training. 3299 3300 Read more at https://www.tensorflow.org/guide/extend/architecture 3301 3302 ![components](https://www.tensorflow.org/images/diag1.svg "components") 3303 3304 3305 Figure illustrates the interaction of these components. 3306 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services. 3307 3308 3309 Example: 3310 ```python 3311 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 3312 3313 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers] 3314 3315 with tf.device("/job:ps/task:0"): 3316 ... 3317 with tf.device("/job:ps/task:1"): 3318 ... 3319 with tf.device("/job:worker/task:0"): 3320 ... 3321 with tf.device("/job:worker/task:1"): 3322 ... 3323 3324 worker_sessions[0].run(...) 3325 ``` 3326 3327 Args: 3328 num_workers: Number of worker servers to start. 3329 num_ps: Number of PS servers to start. 3330 protocol: Communication protocol. Allowed values are documented in the 3331 documentation of `tf.distribute.Server`. 3332 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be 3333 used to instantiate multiple devices etc. 3334 ps_config: (optional) `tf.ConfigProto` to initialize PS servers. 3335 3336 Returns: 3337 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 3338 of `num_workers` objects of type `tf.distribute.Server` (all running 3339 locally); 3340 and `ps_servers` is a list of `num_ps` objects of similar type. 3341 3342 Raises: 3343 ImportError: if portpicker module was not found at load time 3344 """ 3345 import portpicker # pylint: disable=g-import-not-at-top 3346 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 3347 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 3348 cluster_dict = { 3349 "worker": ["localhost:%s" % port for port in worker_ports], 3350 "ps": ["localhost:%s" % port for port in ps_ports] 3351 } 3352 cs = server_lib.ClusterSpec(cluster_dict) 3353 3354 workers = [ 3355 server_lib.Server( 3356 cs, 3357 job_name="worker", 3358 protocol=protocol, 3359 task_index=ix, 3360 config=worker_config, 3361 start=True) for ix in range(num_workers) 3362 ] 3363 ps_servers = [ 3364 server_lib.Server( 3365 cs, 3366 job_name="ps", 3367 protocol=protocol, 3368 task_index=ix, 3369 config=ps_config, 3370 start=True) for ix in range(num_ps) 3371 ] 3372 3373 return workers, ps_servers 3374 3375 3376def get_node_def_from_graph(node_name, graph_def): 3377 """Returns the `NodeDef` instance for given node name in the graph def. 3378 3379 This method explores only the NodeDefs in `graph_def.node`. 3380 3381 Args: 3382 node_name: Name of the NodeDef to search for. 3383 graph_def: An instance of `GraphDef` proto. 3384 3385 Returns: 3386 the `NodeDef` instance whose name field matches the given node_name or None. 3387 """ 3388 for node_def in graph_def.node: 3389 if node_def.name == node_name: 3390 return node_def 3391 return None 3392 3393 3394def set_producer_version(graph, producer_version): 3395 """Sets graph.graph_def_versions.producer to `producer_version`.""" 3396 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 3397 # it via import_graph_def though. 3398 graph_def = graph_pb2.GraphDef() 3399 graph_def.versions.producer = producer_version 3400 with graph.as_default(): 3401 importer.import_graph_def(graph_def) 3402 assert graph.graph_def_versions.producer, producer_version 3403 3404 3405@contextlib.contextmanager 3406def _fake_gradient_tape_context_manager(): 3407 """tf.gradients(...) implemented as tf.GradientTape context manager interface. 3408 3409 This is useful to test tf.gradients() in tests that uses tf.GradientTape(). 3410 3411 Yields: 3412 gradient tape instance that's implemented by tf.gradients() underneath. 3413 """ 3414 try: 3415 class FakeGradientTape: 3416 3417 def watch(self, x): 3418 pass 3419 3420 def gradient(self, y, x, grad_ys=None): 3421 result = gradients_impl.gradients(y, x, grad_ys) 3422 3423 # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single 3424 # element. So unpack if needed to match `tape.gradient()` behavior. 3425 if not isinstance(x, (list, tuple)): 3426 assert len(result) == 1 3427 return result[0] 3428 3429 return result 3430 3431 yield FakeGradientTape() 3432 finally: 3433 pass 3434 3435 3436class AbstractGradientTape: 3437 """Abstract GradientTape context manager that has multiple implementations. 3438 3439 This is useful to test both tf.GradientTape() and tf.gradients() without 3440 duplicating tests. 3441 """ 3442 3443 def __init__(self, use_tape, persistent=False): 3444 self._use_tape = use_tape 3445 self._persistent = persistent 3446 3447 def __enter__(self): 3448 if self._use_tape: 3449 self._tape_impl = backprop.GradientTape(persistent=self._persistent) 3450 else: 3451 self._tape_impl = _fake_gradient_tape_context_manager() 3452 return self._tape_impl.__enter__() 3453 3454 def __exit__(self, exc_type, exc_val, exc_tb): 3455 self._tape_impl.__exit__(exc_type, exc_val, exc_tb) 3456 3457 3458@contextlib.contextmanager 3459def run_functions_eagerly(run_eagerly): 3460 """Runs functions eagerly if `run_eagerly` is true. 3461 3462 WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode 3463 *WILL NOT* make the tf.function to run eagerly because eager is disabled by 3464 default in V1. Instead, tf.function will run as a traced graph function. 3465 3466 Ensures that the state (for running functions eagerly) is back to the initial 3467 `def_function.RUN_FUNCTIONS_EAGERLY` state. 3468 3469 Args: 3470 run_eagerly: Boolean determining whether to run the function eagerly or not. 3471 3472 Raises: 3473 ValueError if `run_eagerly` is not a boolean. 3474 3475 Yields: 3476 Nothing. 3477 """ 3478 if not isinstance(run_eagerly, bool): 3479 raise ValueError( 3480 "Expected bool for `run_eagerly` but got {}".format(run_eagerly)) 3481 3482 is_eager = context.executing_eagerly() 3483 if not is_eager and run_eagerly: 3484 logging.warning( 3485 "Running tf.function eagerly in V1 graph mode is not supported. " 3486 "tf.function will be run as a traced graph function.") 3487 3488 initial_state = def_function.functions_run_eagerly() 3489 def_function.run_functions_eagerly(run_eagerly) 3490 try: 3491 yield 3492 finally: 3493 def_function.run_functions_eagerly(initial_state) 3494