1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23from collections import OrderedDict 24import contextlib 25import functools 26import gc 27import itertools 28import math 29import os 30import random 31import re 32import tempfile 33import threading 34import unittest 35 36from absl.testing import parameterized 37import numpy as np 38import six 39 40from google.protobuf import descriptor_pool 41from google.protobuf import text_format 42 43from tensorflow.core.framework import graph_pb2 44from tensorflow.core.protobuf import rewriter_config_pb2 45from tensorflow.python import _pywrap_stacktrace_handler 46from tensorflow.python import _pywrap_util_port 47from tensorflow.python import tf2 48from tensorflow.python.client import device_lib 49from tensorflow.python.client import pywrap_tf_session 50from tensorflow.python.client import session 51from tensorflow.python.compat.compat import forward_compatibility_horizon 52from tensorflow.python.eager import context 53from tensorflow.python.eager import def_function 54from tensorflow.python.eager import tape 55from tensorflow.python.framework import device as pydev 56from tensorflow.python.framework import dtypes 57from tensorflow.python.framework import errors 58from tensorflow.python.framework import errors_impl 59from tensorflow.python.framework import gpu_util 60from tensorflow.python.framework import importer 61from tensorflow.python.framework import ops 62from tensorflow.python.framework import random_seed 63from tensorflow.python.framework import sparse_tensor 64from tensorflow.python.framework import tensor_shape 65from tensorflow.python.framework import tensor_util 66from tensorflow.python.framework import versions 67from tensorflow.python.ops import array_ops 68from tensorflow.python.ops import control_flow_util 69from tensorflow.python.ops import control_flow_util_v2 70from tensorflow.python.ops.ragged import ragged_tensor 71from tensorflow.python.ops.ragged import ragged_tensor_value 72from tensorflow.python.ops import script_ops 73from tensorflow.python.ops import summary_ops_v2 74from tensorflow.python.ops import variables 75from tensorflow.python.platform import googletest 76from tensorflow.python.platform import tf_logging as logging 77from tensorflow.python.training import server_lib 78from tensorflow.python.util import compat 79from tensorflow.python.util import deprecation 80from tensorflow.python.util import nest 81from tensorflow.python.util import tf_decorator 82from tensorflow.python.util import tf_inspect 83from tensorflow.python.util.protobuf import compare 84from tensorflow.python.util.tf_export import tf_export 85from tensorflow.python.util.compat import collections_abc 86 87 88# If the below import is made available through the BUILD rule, then this 89# function is overridden and will instead return True and cause Tensorflow 90# graphs to be compiled with XLA. 91def is_xla_enabled(): 92 return False 93 94 95try: 96 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top 97except: 98 pass 99 100def _get_object_count_by_type(): 101 return collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) 102 103@tf_export("test.gpu_device_name") 104def gpu_device_name(): 105 """Returns the name of a GPU device if available or the empty string.""" 106 for x in device_lib.list_local_devices(): 107 if x.device_type == "GPU" or x.device_type == "SYCL": 108 return compat.as_str(x.name) 109 return "" 110 111 112def assert_ops_in_graph(expected_ops, graph): 113 """Assert all expected operations are found. 114 115 Args: 116 expected_ops: `dict<string, string>` of op name to op type. 117 graph: Graph to check. 118 119 Returns: 120 `dict<string, node>` of node name to node. 121 122 Raises: 123 ValueError: If the expected ops are not present in the graph. 124 """ 125 actual_ops = {} 126 gd = graph.as_graph_def() 127 for node in gd.node: 128 if node.name in expected_ops: 129 if expected_ops[node.name] != node.op: 130 raise ValueError("Expected op for node %s is different. %s vs %s" % 131 (node.name, expected_ops[node.name], node.op)) 132 actual_ops[node.name] = node 133 if set(expected_ops.keys()) != set(actual_ops.keys()): 134 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 135 (expected_ops.keys(), actual_ops.keys())) 136 return actual_ops 137 138 139@tf_export("test.assert_equal_graph_def", v1=[]) 140def assert_equal_graph_def_v2(expected, actual): 141 """Asserts that two `GraphDef`s are (mostly) the same. 142 143 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 144 nodes, attrs, and control inputs. Node names are used to match up nodes 145 between the graphs, so the naming of nodes must be consistent. This function 146 ignores randomized attribute values that may appear in V2 checkpoints. 147 148 Args: 149 expected: The `GraphDef` we expected. 150 actual: The `GraphDef` we have. 151 152 Raises: 153 AssertionError: If the `GraphDef`s do not match. 154 TypeError: If either argument is not a `GraphDef`. 155 """ 156 assert_equal_graph_def(actual, expected, checkpoint_v2=True, 157 hash_table_shared_name=True) 158 159 160@tf_export(v1=["test.assert_equal_graph_def"]) 161def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, 162 hash_table_shared_name=False): 163 """Asserts that two `GraphDef`s are (mostly) the same. 164 165 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 166 nodes, attrs, and control inputs. Node names are used to match up nodes 167 between the graphs, so the naming of nodes must be consistent. 168 169 Args: 170 actual: The `GraphDef` we have. 171 expected: The `GraphDef` we expected. 172 checkpoint_v2: boolean determining whether to ignore randomized attribute 173 values that appear in V2 checkpoints. 174 hash_table_shared_name: boolean determining whether to ignore randomized 175 shared_names that appear in HashTableV2 op defs. 176 177 Raises: 178 AssertionError: If the `GraphDef`s do not match. 179 TypeError: If either argument is not a `GraphDef`. 180 """ 181 assert_equal_graph_def(actual, expected, checkpoint_v2, 182 hash_table_shared_name) 183 184 185def assert_equal_graph_def(actual, expected, checkpoint_v2=False, 186 hash_table_shared_name=False): 187 if not isinstance(actual, graph_pb2.GraphDef): 188 raise TypeError("Expected tf.GraphDef for actual, got %s" % 189 type(actual).__name__) 190 if not isinstance(expected, graph_pb2.GraphDef): 191 raise TypeError("Expected tf.GraphDef for expected, got %s" % 192 type(expected).__name__) 193 194 if checkpoint_v2: 195 _strip_checkpoint_v2_randomized(actual) 196 _strip_checkpoint_v2_randomized(expected) 197 198 if hash_table_shared_name: 199 _strip_hash_table_shared_name(actual) 200 _strip_hash_table_shared_name(expected) 201 202 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(), 203 expected.SerializeToString()) 204 if diff: 205 raise AssertionError(compat.as_str(diff)) 206 207 208def assert_meta_graph_protos_equal(tester, a, b): 209 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 210 # Carefully check the collection_defs 211 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 212 collection_keys = a.collection_def.keys() 213 for k in collection_keys: 214 a_value = a.collection_def[k] 215 b_value = b.collection_def[k] 216 proto_type = ops.get_collection_proto_type(k) 217 if proto_type: 218 a_proto = proto_type() 219 b_proto = proto_type() 220 # Number of entries in the collections is the same 221 tester.assertEqual( 222 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 223 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 224 b_value.bytes_list.value): 225 a_proto.ParseFromString(a_value_item) 226 b_proto.ParseFromString(b_value_item) 227 tester.assertProtoEquals(a_proto, b_proto) 228 else: 229 tester.assertEquals(a_value, b_value) 230 # Compared the fields directly, remove their raw values from the 231 # proto comparison below. 232 a.ClearField("collection_def") 233 b.ClearField("collection_def") 234 235 # Check the graph_defs. 236 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 237 # Check graph_def versions (ignored by assert_equal_graph_def). 238 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 239 # Compared the fields directly, remove their raw values from the 240 # proto comparison below. 241 a.ClearField("graph_def") 242 b.ClearField("graph_def") 243 244 tester.assertProtoEquals(a, b) 245 246 247# Matches attributes named via _SHARDED_SUFFIX in 248# tensorflow/python/training/saver.py 249_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 250 251 252def _strip_checkpoint_v2_randomized(graph_def): 253 for node in graph_def.node: 254 delete_keys = [] 255 for attr_key in node.attr: 256 attr_tensor_value = node.attr[attr_key].tensor 257 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 258 attr_tensor_string_value = attr_tensor_value.string_val[0] 259 if (attr_tensor_string_value and 260 re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): 261 delete_keys.append(attr_key) 262 for attr_key in delete_keys: 263 del node.attr[attr_key] 264 265 266_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" 267 268 269def _strip_hash_table_shared_name(graph_def): 270 for node in graph_def.node: 271 delete_keys = [] 272 if node.op == "HashTableV2" and "shared_name" in node.attr: 273 if re.match(_TABLE_SHARED_NAME_PATTERN, str(node.attr["shared_name"].s)): 274 delete_keys.append("shared_name") 275 for attr_key in delete_keys: 276 del node.attr[attr_key] 277 278 279def IsGoogleCudaEnabled(): 280 return _pywrap_util_port.IsGoogleCudaEnabled() 281 282 283def IsBuiltWithROCm(): 284 return _pywrap_util_port.IsBuiltWithROCm() 285 286 287def IsBuiltWithXLA(): 288 return _pywrap_util_port.IsBuiltWithXLA() 289 290 291def IsBuiltWithNvcc(): 292 return _pywrap_util_port.IsBuiltWithNvcc() 293 294 295def GpuSupportsHalfMatMulAndConv(): 296 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() 297 298 299def IsMklEnabled(): 300 return _pywrap_util_port.IsMklEnabled() 301 302 303def InstallStackTraceHandler(): 304 _pywrap_stacktrace_handler.InstallStacktraceHandler() 305 306 307def NHWCToNCHW(input_tensor): 308 """Converts the input from the NHWC format to NCHW. 309 310 Args: 311 input_tensor: a 4- or 5-D tensor, or an array representing shape 312 313 Returns: 314 converted tensor or shape array 315 """ 316 # tensor dim -> new axis order 317 new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 318 if isinstance(input_tensor, ops.Tensor): 319 ndims = input_tensor.shape.ndims 320 return array_ops.transpose(input_tensor, new_axes[ndims]) 321 else: 322 ndims = len(input_tensor) 323 return [input_tensor[a] for a in new_axes[ndims]] 324 325 326def NHWCToNCHW_VECT_C(input_shape_or_tensor): 327 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 328 329 Note: Does not include quantization or type conversion steps, which should 330 be applied afterwards. 331 332 Args: 333 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 334 335 Returns: 336 tensor or shape array transformed into NCHW_VECT_C 337 338 Raises: 339 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 340 divisible by 4. 341 """ 342 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 343 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 344 temp_shape = ( 345 input_shape_or_tensor.shape.as_list() 346 if is_tensor else input_shape_or_tensor) 347 if temp_shape[-1] % 4 != 0: 348 raise ValueError( 349 "Last dimension of input must be evenly divisible by 4 to convert to " 350 "NCHW_VECT_C.") 351 temp_shape[-1] //= 4 352 temp_shape.append(4) 353 permutation = permutations[len(temp_shape)] 354 if is_tensor: 355 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 356 return array_ops.transpose(t, permutation) 357 else: 358 return [temp_shape[a] for a in permutation] 359 360 361def NCHW_VECT_CToNHWC(input_shape_or_tensor): 362 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 363 364 Note: Does not include de-quantization or type conversion steps, which should 365 be applied beforehand. 366 367 Args: 368 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 369 370 Returns: 371 tensor or shape array transformed into NHWC 372 373 Raises: 374 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 375 """ 376 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 377 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 378 input_shape = ( 379 input_shape_or_tensor.shape.as_list() 380 if is_tensor else input_shape_or_tensor) 381 if input_shape[-1] != 4: 382 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 383 permutation = permutations[len(input_shape)] 384 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 385 nhwc_shape[-1] *= input_shape[-1] 386 if is_tensor: 387 t = array_ops.transpose(input_shape_or_tensor, permutation) 388 return array_ops.reshape(t, nhwc_shape) 389 else: 390 return nhwc_shape 391 392 393def NCHWToNHWC(input_tensor): 394 """Converts the input from the NCHW format to NHWC. 395 396 Args: 397 input_tensor: a 4- or 5-D tensor, or an array representing shape 398 399 Returns: 400 converted tensor or shape array 401 """ 402 # tensor dim -> new axis order 403 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 404 if isinstance(input_tensor, ops.Tensor): 405 ndims = input_tensor.shape.ndims 406 return array_ops.transpose(input_tensor, new_axes[ndims]) 407 else: 408 ndims = len(input_tensor) 409 return [input_tensor[a] for a in new_axes[ndims]] 410 411 412def skip_if(condition): 413 """Skips the decorated function if condition is or evaluates to True. 414 415 Args: 416 condition: Either an expression that can be used in "if not condition" 417 statement, or a callable whose result should be a boolean. 418 419 Returns: 420 The wrapped function 421 """ 422 423 def real_skip_if(fn): 424 425 def wrapper(*args, **kwargs): 426 if callable(condition): 427 skip = condition() 428 else: 429 skip = condition 430 if not skip: 431 return fn(*args, **kwargs) 432 433 return wrapper 434 435 return real_skip_if 436 437 438def enable_c_shapes(fn): 439 """No-op. TODO(b/74620627): Remove this.""" 440 return fn 441 442 443def with_c_shapes(cls): 444 """No-op. TODO(b/74620627): Remove this.""" 445 return cls 446 447 448def enable_control_flow_v2(fn): 449 """Decorator for enabling CondV2 and WhileV2 on a test. 450 451 Note this enables using CondV2 and WhileV2 after running the test class's 452 setup/teardown methods. 453 454 In addition to this, callers must import the while_v2 module in order to set 455 the _while_v2 module in control_flow_ops. 456 457 Args: 458 fn: the function to be wrapped 459 460 Returns: 461 The wrapped function 462 """ 463 464 def wrapper(*args, **kwargs): 465 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 466 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 467 try: 468 return fn(*args, **kwargs) 469 finally: 470 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 471 472 return wrapper 473 474 475def with_control_flow_v2(cls): 476 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 477 478 Note this enables CondV2 and WhileV2 in new methods after running the test 479 class's setup method. 480 481 In addition to this, callers must import the while_v2 module in order to set 482 the _while_v2 module in control_flow_ops. 483 484 If a test function has _disable_control_flow_v2 attr set to True (using the 485 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 486 487 Example: 488 489 @test_util.with_control_flow_v2 490 class ControlFlowTest(test.TestCase): 491 492 def testEnabledForV2(self): 493 ... 494 495 @test_util.disable_control_flow_v2("b/xyzabc") 496 def testDisabledForV2(self): 497 ... 498 499 Generated class: 500 class ControlFlowTest(test.TestCase): 501 502 def testEnabledForV2(self): 503 ... 504 505 def testEnabledForV2WithControlFlowV2(self): 506 // Enable V2 flags. 507 testEnabledForV2(self) 508 // Restore V2 flags. 509 510 def testDisabledForV2(self): 511 ... 512 513 Args: 514 cls: class to decorate 515 516 Returns: 517 cls with new test methods added 518 """ 519 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 520 return cls 521 522 for name, value in cls.__dict__.copy().items(): 523 if (callable(value) and 524 name.startswith(unittest.TestLoader.testMethodPrefix) and 525 not getattr(value, "_disable_control_flow_v2", False)): 526 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 527 return cls 528 529 530def disable_control_flow_v2(unused_msg): 531 """Decorator for a function in a with_control_flow_v2 enabled test class. 532 533 Blocks the function from being run with v2 control flow ops. 534 535 Args: 536 unused_msg: Reason for disabling. 537 538 Returns: 539 The wrapped function with _disable_control_flow_v2 attr set to True. 540 """ 541 542 def wrapper(func): 543 func._disable_control_flow_v2 = True 544 return func 545 546 return wrapper 547 548 549def enable_output_all_intermediates(fn): 550 """Force-enable outputing all intermediates from functional control flow ops. 551 552 Args: 553 fn: the function to be wrapped 554 555 Returns: 556 The wrapped function 557 """ 558 559 def wrapper(*args, **kwargs): 560 output_all_intermediates_old = \ 561 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 562 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True 563 try: 564 return fn(*args, **kwargs) 565 finally: 566 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ 567 output_all_intermediates_old 568 569 return wrapper 570 571 572def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): 573 """Decorator for asserting that no new Python objects persist after a test. 574 575 Runs the test multiple times executing eagerly, first as a warmup and then to 576 let objects accumulate. The warmup helps ignore caches which do not grow as 577 the test is run repeatedly. 578 579 Useful for checking that there are no missing Py_DECREFs in the C exercised by 580 a bit of Python. 581 582 Args: 583 func: The function to test. 584 warmup_iters: The numer of warmup iterations, excluded from measuring. 585 586 Returns: 587 The wrapped function performing the test. 588 """ 589 590 def wrap_f(f): 591 def decorator(self, *args, **kwargs): 592 """Warms up, gets object counts, runs the test, checks for new objects.""" 593 with context.eager_mode(): 594 gc.disable() 595 # Run the test 2 times as warmup, in an attempt to fill up caches, which 596 # should not grow as the test is run repeatedly below. 597 # 598 # TODO(b/117156879): Running warmup twice is black magic; we have seen 599 # tests that fail with 1 warmup run, and pass with 2, on various 600 # versions of python2.7.x. 601 for _ in range(warmup_iters): 602 f(self, *args, **kwargs) 603 604 # Some objects are newly created by _get_object_count_by_type(). So 605 # create and save as a dummy variable to include it as a baseline. 606 obj_count_by_type = _get_object_count_by_type() 607 gc.collect() 608 obj_count_by_type = _get_object_count_by_type() 609 610 if ops.has_default_graph(): 611 collection_sizes_before = { 612 collection: len(ops.get_collection(collection)) 613 for collection in ops.get_default_graph().collections 614 } 615 for _ in range(3): 616 f(self, *args, **kwargs) 617 # Note that gc.get_objects misses anything that isn't subject to garbage 618 # collection (C types). Collections are a common source of leaks, so we 619 # test for collection sizes explicitly. 620 if ops.has_default_graph(): 621 for collection_key in ops.get_default_graph().collections: 622 collection = ops.get_collection(collection_key) 623 size_before = collection_sizes_before.get(collection_key, 0) 624 if len(collection) > size_before: 625 raise AssertionError( 626 ("Collection %s increased in size from " 627 "%d to %d (current items %s).") % 628 (collection_key, size_before, len(collection), collection)) 629 # Make sure our collection checks don't show up as leaked memory by 630 # removing references to temporary variables. 631 del collection 632 del collection_key 633 del size_before 634 del collection_sizes_before 635 gc.collect() 636 637 # There should be no new Python objects hanging around. 638 obj_count_by_type = _get_object_count_by_type() - obj_count_by_type 639 # In some cases (specifacally on MacOS), new_count is somehow 640 # smaller than previous_count. 641 # Using plain assert because not all classes using this decorator 642 # have assertLessEqual 643 assert not obj_count_by_type, ( 644 "The following objects were newly created: %s" % 645 str(obj_count_by_type)) 646 gc.enable() 647 return decorator 648 649 if func is None: 650 return wrap_f 651 else: 652 return wrap_f(func) 653 654 655def assert_no_new_tensors(f): 656 """Decorator for asserting that no new Tensors persist after a test. 657 658 Mainly useful for checking that code using the Python C API has correctly 659 manipulated reference counts. 660 661 Clears the caches that it knows about, runs the garbage collector, then checks 662 that there are no Tensor or Tensor-like objects still around. This includes 663 Tensors to which something still has a reference (e.g. from missing 664 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 665 of the objects has __del__ defined). 666 667 Args: 668 f: The test case to run. 669 670 Returns: 671 The decorated test case. 672 """ 673 674 def decorator(self, **kwargs): 675 """Finds existing Tensors, runs the test, checks for new Tensors.""" 676 677 def _is_tensorflow_object(obj): 678 try: 679 return isinstance(obj, 680 (ops.Tensor, variables.Variable, 681 tensor_shape.Dimension, tensor_shape.TensorShape)) 682 except ReferenceError: 683 # If the object no longer exists, we don't care about it. 684 return False 685 686 tensors_before = set( 687 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 688 outside_executed_eagerly = context.executing_eagerly() 689 # Run the test in a new graph so that collections get cleared when it's 690 # done, but inherit the graph key so optimizers behave. 691 outside_graph_key = ops.get_default_graph()._graph_key 692 with ops.Graph().as_default(): 693 ops.get_default_graph()._graph_key = outside_graph_key 694 if outside_executed_eagerly: 695 with context.eager_mode(): 696 result = f(self, **kwargs) 697 else: 698 result = f(self, **kwargs) 699 # Make an effort to clear caches, which would otherwise look like leaked 700 # Tensors. 701 context.context()._clear_caches() # pylint: disable=protected-access 702 gc.collect() 703 tensors_after = [ 704 obj for obj in gc.get_objects() 705 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 706 ] 707 if tensors_after: 708 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 709 len(tensors_after), 710 str(tensors_after), 711 ))) 712 return result 713 714 return decorator 715 716 717def _find_reference_cycle(objects, idx): 718 719 def get_ignore_reason(obj, blacklist): 720 """Tests whether an object should be omitted from the dependency graph.""" 721 if len(blacklist) > 100: 722 return "<depth limit>" 723 if tf_inspect.isframe(obj): 724 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 725 return "<test code>" 726 for b in blacklist: 727 if b is obj: 728 return "<test code>" 729 if obj is blacklist: 730 return "<test code>" 731 return None 732 733 # Note: this function is meant to help with diagnostics. Its output is purely 734 # a human-readable representation, so you may freely modify it to suit your 735 # needs. 736 def describe(obj, blacklist, leaves_only=False): 737 """Returns a custom human-readable summary of obj. 738 739 Args: 740 obj: the value to describe. 741 blacklist: same as blacklist in get_ignore_reason. 742 leaves_only: boolean flag used when calling describe recursively. Useful 743 for summarizing collections. 744 """ 745 if get_ignore_reason(obj, blacklist): 746 return "{}{}".format(get_ignore_reason(obj, blacklist), type(obj)) 747 if tf_inspect.isframe(obj): 748 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 749 elif tf_inspect.ismodule(obj): 750 return "module: {}".format(obj.__name__) 751 else: 752 if leaves_only: 753 return "{}, {}".format(type(obj), id(obj)) 754 elif isinstance(obj, list): 755 return "list({}): {}".format( 756 id(obj), [describe(e, blacklist, leaves_only=True) for e in obj]) 757 elif isinstance(obj, tuple): 758 return "tuple({}): {}".format( 759 id(obj), [describe(e, blacklist, leaves_only=True) for e in obj]) 760 elif isinstance(obj, dict): 761 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 762 elif tf_inspect.isfunction(obj): 763 return "function({}) {}; globals ID: {}".format( 764 id(obj), obj.__name__, id(obj.__globals__)) 765 else: 766 return "{}, {}".format(type(obj), id(obj)) 767 768 def build_ref_graph(obj, graph, reprs, blacklist): 769 """Builds a reference graph as <referrer> -> <list of refferents>. 770 771 Args: 772 obj: The object to start from. The graph will be built by recursively 773 adding its referrers. 774 graph: Dict holding the graph to be built. To avoid creating extra 775 references, the graph holds object IDs rather than actual objects. 776 reprs: Auxiliary structure that maps object IDs to their human-readable 777 description. 778 blacklist: List of objects to ignore. 779 """ 780 referrers = gc.get_referrers(obj) 781 blacklist = blacklist + (referrers,) 782 783 obj_id = id(obj) 784 for r in referrers: 785 if get_ignore_reason(r, blacklist) is None: 786 r_id = id(r) 787 if r_id not in graph: 788 graph[r_id] = [] 789 if obj_id not in graph[r_id]: 790 graph[r_id].append(obj_id) 791 build_ref_graph(r, graph, reprs, blacklist) 792 reprs[r_id] = describe(r, blacklist) 793 794 def find_cycle(el, graph, reprs, path): 795 """Finds and prints a single cycle in the dependency graph.""" 796 if el not in graph: 797 return 798 for r in graph[el]: 799 if r in path: 800 logging.error("Reference cycle sample:") 801 for p in path + (r,): 802 logging.error(reprs.get(p, "unknown object " + str(p))) 803 return True 804 else: 805 if find_cycle(r, graph, reprs, path + (r,)): 806 return True 807 return False 808 809 obj = objects[idx] 810 graph = {} # referrer ID -> object ID 811 reprs = {} # object ID -> description 812 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 813 describe, build_ref_graph, find_cycle)) 814 for k in graph: 815 if find_cycle(k, graph, reprs, ()): 816 return True 817 return False 818 819 820def assert_no_garbage_created(f): 821 """Test method decorator to assert that no garbage has been created. 822 823 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 824 cannot be un-set (i.e. will disable garbage collection for any other unit 825 tests in the same file/shard). 826 827 Args: 828 f: The function to decorate. 829 830 Returns: 831 The decorated function. 832 """ 833 834 def decorator(self, **kwargs): 835 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 836 # Force-load `distribution_strategy_context` to prevent GC at 837 # test time when using eager. Remove once b/117329403 is resolved. 838 tape.distribution_strategy_context.get_strategy() 839 840 gc.disable() 841 previous_debug_flags = gc.get_debug() 842 gc.set_debug(gc.DEBUG_SAVEALL) 843 gc.collect() 844 previous_garbage = len(gc.garbage) 845 result = f(self, **kwargs) 846 gc.collect() 847 new_garbage = len(gc.garbage) 848 if new_garbage > previous_garbage: 849 logging.error( 850 "The decorated test created work for Python's garbage collector, " 851 "likely due to a reference cycle. New objects in cycle(s):") 852 for i, obj in enumerate(gc.garbage[previous_garbage:]): 853 try: 854 logging.error("Object %d of %d", i, 855 len(gc.garbage) - previous_garbage) 856 857 def _safe_object_str(obj): 858 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 859 860 logging.error(" Object type: %s", _safe_object_str(obj)) 861 logging.error( 862 " Referrer types: %s", ", ".join( 863 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 864 logging.error( 865 " Referent types: %s", ", ".join( 866 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 867 logging.error(" Object attribute names: %s", dir(obj)) 868 logging.error(" Object __str__:") 869 logging.error(obj) 870 logging.error(" Object __repr__:") 871 logging.error(repr(obj)) 872 except Exception: # pylint: disable=broad-except 873 logging.error("(Exception while printing object)") 874 875 # When garbage is created, this call can help identify reference cycles, 876 # which are typically the cause of such garbage. 877 if new_garbage > previous_garbage: 878 for i in range(previous_garbage, new_garbage): 879 if _find_reference_cycle(gc.garbage, i): 880 break 881 882 # This will fail if any garbage has been created, typically because of a 883 # reference cycle. 884 self.assertEqual(previous_garbage, new_garbage) 885 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 886 # be nice to be able to decorate arbitrary tests in a large test suite and 887 # not hold on to every object in other tests. 888 gc.set_debug(previous_debug_flags) 889 gc.enable() 890 return result 891 892 return decorator 893 894 895def _combine_named_parameters(**kwargs): 896 """Generate combinations based on its keyword arguments. 897 898 Two sets of returned combinations can be concatenated using +. Their product 899 can be computed using `times()`. 900 901 Args: 902 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 903 `option=the_only_possibility`. 904 905 Returns: 906 a list of dictionaries for each combination. Keys in the dictionaries are 907 the keyword argument names. Each key has one value - one of the 908 corresponding keyword argument values. 909 """ 910 sort_by_key = lambda k: k[0] 911 combinations = [] 912 for key, values in sorted(kwargs.items(), key=sort_by_key): 913 if not isinstance(values, list): 914 values = [values] 915 combinations.append([(key, value) for value in values]) 916 917 return [OrderedDict(result) for result in itertools.product(*combinations)] 918 919 920def generate_combinations_with_testcase_name(**kwargs): 921 """Generate combinations based on its keyword arguments using combine(). 922 923 This function calls combine() and appends a testcase name to the list of 924 dictionaries returned. The 'testcase_name' key is a required for named 925 parameterized tests. 926 927 Args: 928 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 929 `option=the_only_possibility`. 930 931 Returns: 932 a list of dictionaries for each combination. Keys in the dictionaries are 933 the keyword argument names. Each key has one value - one of the 934 corresponding keyword argument values. 935 """ 936 combinations = _combine_named_parameters(**kwargs) 937 named_combinations = [] 938 for combination in combinations: 939 assert isinstance(combination, OrderedDict) 940 name = "".join([ 941 "_{}_{}".format("".join(filter(str.isalnum, key)), 942 "".join(filter(str.isalnum, str(value)))) 943 for key, value in combination.items() 944 ]) 945 named_combinations.append( 946 OrderedDict( 947 list(combination.items()) + 948 [("testcase_name", "_test{}".format(name))])) 949 950 return named_combinations 951 952 953def run_all_in_graph_and_eager_modes(cls): 954 """Execute all test methods in the given class with and without eager.""" 955 base_decorator = run_in_graph_and_eager_modes 956 for name in dir(cls): 957 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 958 name.startswith("testSkipEager") or 959 name.startswith("test_skip_eager") or 960 name == "test_session"): 961 continue 962 value = getattr(cls, name, None) 963 if callable(value): 964 setattr(cls, name, base_decorator(value)) 965 return cls 966 967 968def build_as_function_and_v1_graph(func=None): 969 """Run a test case in v1 graph mode and inside tf.function in eager mode. 970 971 WARNING: This decorator can only be used in test cases that statically checks 972 generated graph. Attempting to evaluate graph or function results via. 973 session.run() or self.evaluate() will fail. 974 975 WARNING: This decorator can only be used for test cases that inherit from 976 absl.testing.parameterized.TestCase. 977 978 Args: 979 func: Test case function to be decorated. 980 981 Returns: 982 Decorated test case function. 983 """ 984 985 def decorator(f): 986 if tf_inspect.isclass(f): 987 raise ValueError( 988 "`run_in_graph_mode_and_function` only supports test methods.") 989 990 @parameterized.named_parameters(("_v1_graph", "v1_graph"), 991 ("_function", "function")) 992 @functools.wraps(f) 993 def decorated(self, run_mode, *args, **kwargs): 994 if run_mode == "v1_graph": 995 with ops.Graph().as_default(): 996 f(self, *args, **kwargs) 997 elif run_mode == "function": 998 999 @def_function.function 1000 def function_in_eager(): 1001 f(self, *args, **kwargs) 1002 1003 # Create a new graph for the eagerly executed version of this test for 1004 # better isolation. 1005 graph_for_eager_test = ops.Graph() 1006 with graph_for_eager_test.as_default(), context.eager_mode(): 1007 function_in_eager() 1008 ops.dismantle_graph(graph_for_eager_test) 1009 else: 1010 return ValueError("Unknown run mode %s" % run_mode) 1011 1012 return decorated 1013 1014 if func is not None: 1015 return decorator(func) 1016 1017 return decorator 1018 1019 1020def eager_lazy_remote_copy_on_and_off(f): 1021 """Execute the test method w/o lazy tensor copy for function remote inputs.""" 1022 1023 @parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)]) 1024 @functools.wraps(f) 1025 def decorator(self, lazily_remote_copy, *args, **kwargs): 1026 if lazily_remote_copy: 1027 context.context().lazy_remote_inputs_copy = True 1028 else: 1029 context.context().lazy_remote_inputs_copy = False 1030 f(self, *args, **kwargs) 1031 1032 return decorator 1033 1034 1035def run_in_graph_and_eager_modes(func=None, 1036 config=None, 1037 use_gpu=True, 1038 reset_test=True, 1039 assert_no_eager_garbage=False): 1040 """Execute the decorated test with and without enabling eager execution. 1041 1042 This function returns a decorator intended to be applied to test methods in 1043 a `tf.test.TestCase` class. Doing so will cause the contents of the test 1044 method to be executed twice - once normally, and once with eager execution 1045 enabled. This allows unittests to confirm the equivalence between eager 1046 and graph execution (see `tf.compat.v1.enable_eager_execution`). 1047 1048 For example, consider the following unittest: 1049 1050 ```python 1051 class MyTests(tf.test.TestCase): 1052 1053 @run_in_graph_and_eager_modes 1054 def test_foo(self): 1055 x = tf.constant([1, 2]) 1056 y = tf.constant([3, 4]) 1057 z = tf.add(x, y) 1058 self.assertAllEqual([4, 6], self.evaluate(z)) 1059 1060 if __name__ == "__main__": 1061 tf.test.main() 1062 ``` 1063 1064 This test validates that `tf.add()` has the same behavior when computed with 1065 eager execution enabled as it does when constructing a TensorFlow graph and 1066 executing the `z` tensor in a session. 1067 1068 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1069 `run_in_graph_and_eager_modes` are available decorators for different 1070 v1/v2/eager/graph combinations. 1071 1072 1073 Args: 1074 func: function to be annotated. If `func` is None, this method returns a 1075 decorator the can be applied to a function. If `func` is not None this 1076 returns the decorator applied to `func`. 1077 config: An optional config_pb2.ConfigProto to use to configure the session 1078 when executing graphs. 1079 use_gpu: If True, attempt to run as many operations as possible on GPU. 1080 reset_test: If True, tearDown and SetUp the test case between the two 1081 executions of the test (once with and once without eager execution). 1082 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 1083 collector and asserts that no extra garbage has been created when running 1084 the test with eager execution enabled. This will fail if there are 1085 reference cycles (e.g. a = []; a.append(a)). Off by default because some 1086 tests may create garbage for legitimate reasons (e.g. they define a class 1087 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 1088 Python interpreters (meaning that tests which rely on objects being 1089 collected elsewhere in the unit test file will not work). Additionally, 1090 checks that nothing still has a reference to Tensors that the test 1091 allocated. 1092 1093 Returns: 1094 Returns a decorator that will run the decorated test method twice: 1095 once by constructing and executing a graph in a session and once with 1096 eager execution enabled. 1097 """ 1098 1099 def decorator(f): 1100 if tf_inspect.isclass(f): 1101 raise ValueError( 1102 "`run_in_graph_and_eager_modes` only supports test methods. " 1103 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 1104 1105 def decorated(self, *args, **kwargs): 1106 try: 1107 with context.graph_mode(): 1108 with self.test_session(use_gpu=use_gpu, config=config): 1109 f(self, *args, **kwargs) 1110 except unittest.case.SkipTest: 1111 pass 1112 1113 def run_eagerly(self, **kwargs): 1114 if not use_gpu: 1115 with ops.device("/device:CPU:0"): 1116 f(self, *args, **kwargs) 1117 else: 1118 f(self, *args, **kwargs) 1119 1120 if assert_no_eager_garbage: 1121 ops.reset_default_graph() 1122 run_eagerly = assert_no_new_tensors( 1123 assert_no_garbage_created(run_eagerly)) 1124 1125 if reset_test: 1126 # This decorator runs the wrapped test twice. 1127 # Reset the test environment between runs. 1128 self.tearDown() 1129 self._tempdir = None 1130 # Create a new graph for the eagerly executed version of this test for 1131 # better isolation. 1132 graph_for_eager_test = ops.Graph() 1133 with graph_for_eager_test.as_default(), context.eager_mode(): 1134 if reset_test: 1135 self.setUp() 1136 run_eagerly(self, **kwargs) 1137 ops.dismantle_graph(graph_for_eager_test) 1138 1139 return decorated 1140 1141 if func is not None: 1142 return decorator(func) 1143 1144 return decorator 1145 1146 1147def py_func_if_in_function(f): 1148 1149 def decorated(*args, **kwds): 1150 if not ops.get_default_graph()._building_function: 1151 return f(*args, **kwds) 1152 1153 tensor_args = [] 1154 tensor_indices = [] 1155 for i, arg in enumerate(args): 1156 if isinstance(arg, (ops.Tensor, variables.Variable)): 1157 tensor_args.append(arg) 1158 tensor_indices.append(i) 1159 1160 def inner_f(*inner_tensor_args): 1161 my_args = list(args) 1162 for i, n in zip(tensor_indices, inner_tensor_args): 1163 my_args[i] = n 1164 return f(*my_args, **kwds) 1165 1166 return script_ops.py_func(inner_f, tensor_args, []) 1167 1168 return tf_decorator.make_decorator(f, decorated) 1169 1170 1171def also_run_as_tf_function(f): 1172 """Runs the decorated test twice--once as is, once inside a tf.function. 1173 1174 This allows you to run a test both in eager execution and inside a 1175 tf.function, exercising the two execution modes supported in tf 2.0. The test 1176 assertions are automatically done inside tf.py_funcs, and tf.function ensures 1177 that they run in the proper order and with the proper side effects. 1178 1179 Currently variable creation is not supported in tests annotated with this 1180 decorator since it's tricky to ensure the variable doesn't get repeatedly 1181 created when retracing the tf.function. 1182 1183 Args: 1184 f: the test method to be decorated 1185 1186 Returns: 1187 The decorated test method, which will run both in eager and inside a 1188 tf.function. 1189 """ 1190 1191 def decorated(*args, **kwds): 1192 1193 def bound_f(): 1194 f(*args, **kwds) 1195 1196 with context.eager_mode(): 1197 # Running in eager mode 1198 bound_f() 1199 # Running as TF function 1200 # TODO(b/121143941): Remove the autograph override. 1201 def_function.function(bound_f, autograph=False)() 1202 1203 return decorated 1204 1205 1206def deprecated_graph_mode_only(func=None): 1207 """Execute the decorated test in graph mode. 1208 1209 This function returns a decorator intended to be applied to tests that are not 1210 compatible with eager mode. When this decorator is applied, the test body will 1211 be run in an environment where API calls construct graphs instead of executing 1212 eagerly. 1213 1214 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1215 `run_in_graph_and_eager_modes` are available decorators for different 1216 v1/v2/eager/graph combinations. 1217 1218 Args: 1219 func: function to be annotated. If `func` is None, this method returns a 1220 decorator the can be applied to a function. If `func` is not None this 1221 returns the decorator applied to `func`. 1222 1223 Returns: 1224 Returns a decorator that will run the decorated test method in graph mode. 1225 """ 1226 1227 def decorator(f): 1228 if tf_inspect.isclass(f): 1229 setup = f.__dict__.get("setUp") 1230 if setup is not None: 1231 setattr(f, "setUp", decorator(setup)) 1232 1233 for name, value in f.__dict__.copy().items(): 1234 if (callable(value) and 1235 name.startswith(unittest.TestLoader.testMethodPrefix)): 1236 setattr(f, name, decorator(value)) 1237 1238 return f 1239 1240 def decorated(self, *args, **kwargs): 1241 if context.executing_eagerly(): 1242 with context.graph_mode(): 1243 return f(self, *args, **kwargs) 1244 else: 1245 return f(self, *args, **kwargs) 1246 1247 return decorated 1248 1249 if func is not None: 1250 return decorator(func) 1251 1252 return decorator 1253 1254 1255run_deprecated_v1 = deprecated_graph_mode_only 1256 1257 1258def run_all_in_deprecated_graph_mode_only(cls): 1259 """Execute all tests in a class in graph mode.""" 1260 base_decorator = deprecated_graph_mode_only 1261 for name in dir(cls): 1262 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 1263 name == "test_session"): 1264 continue 1265 value = getattr(cls, name, None) 1266 if callable(value): 1267 setattr(cls, name, base_decorator(value)) 1268 return cls 1269 1270 1271def run_v1_only(reason, func=None): 1272 """Execute the decorated test only if running in v1 mode. 1273 1274 This function is intended to be applied to tests that exercise v1 only 1275 functionality. If the test is run in v2 mode it will simply be skipped. 1276 1277 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1278 `run_in_graph_and_eager_modes` are available decorators for different 1279 v1/v2/eager/graph combinations. 1280 1281 Args: 1282 reason: string giving a reason for limiting the test to v1 only. 1283 func: function to be annotated. If `func` is None, this method returns a 1284 decorator the can be applied to a function. If `func` is not None this 1285 returns the decorator applied to `func`. 1286 1287 Returns: 1288 Returns a decorator that will conditionally skip the decorated test method. 1289 """ 1290 if not isinstance(reason, str): 1291 raise ValueError("'reason' should be string, got {}".format(type(reason))) 1292 1293 def decorator(f): 1294 if tf_inspect.isclass(f): 1295 # To skip an entire test suite class, we only decorate the setUp method 1296 # to skip all tests. There are cases when setUp is not defined (not 1297 # overridden in subclasses of TestCase, so not available in f.__dict__ 1298 # below). For those cases, we walk the method resolution order list and 1299 # pick the first setUp method we find (usually this should be the one in 1300 # the parent class since that's the TestCase class). 1301 for cls in type.mro(f): 1302 setup = cls.__dict__.get("setUp") 1303 if setup is not None: 1304 setattr(f, "setUp", decorator(setup)) 1305 break 1306 1307 return f 1308 else: 1309 # If f is just a function, just create a decorator for it and return it 1310 def decorated(self, *args, **kwargs): 1311 if tf2.enabled(): 1312 self.skipTest(reason) 1313 1314 return f(self, *args, **kwargs) 1315 1316 return decorated 1317 1318 if func is not None: 1319 return decorator(func) 1320 1321 return decorator 1322 1323 1324def run_v2_only(func=None): 1325 """Execute the decorated test only if running in v2 mode. 1326 1327 This function is intended to be applied to tests that exercise v2 only 1328 functionality. If the test is run in v1 mode it will simply be skipped. 1329 1330 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1331 `run_in_graph_and_eager_modes` are available decorators for different 1332 v1/v2/eager/graph combinations. 1333 1334 Args: 1335 func: function to be annotated. If `func` is None, this method returns a 1336 decorator the can be applied to a function. If `func` is not None this 1337 returns the decorator applied to `func`. 1338 1339 Returns: 1340 Returns a decorator that will conditionally skip the decorated test method. 1341 """ 1342 1343 def decorator(f): 1344 if tf_inspect.isclass(f): 1345 raise ValueError("`run_v2_only` only supports test methods.") 1346 1347 def decorated(self, *args, **kwargs): 1348 if not tf2.enabled(): 1349 self.skipTest("Test is only compatible with v2") 1350 1351 return f(self, *args, **kwargs) 1352 1353 return decorated 1354 1355 if func is not None: 1356 return decorator(func) 1357 1358 return decorator 1359 1360 1361def run_gpu_only(func=None): 1362 """Execute the decorated test only if a GPU is available. 1363 1364 This function is intended to be applied to tests that require the presence 1365 of a GPU. If a GPU is absent, it will simply be skipped. 1366 1367 Args: 1368 func: function to be annotated. If `func` is None, this method returns a 1369 decorator the can be applied to a function. If `func` is not None this 1370 returns the decorator applied to `func`. 1371 1372 Returns: 1373 Returns a decorator that will conditionally skip the decorated test method. 1374 """ 1375 1376 def decorator(f): 1377 if tf_inspect.isclass(f): 1378 raise ValueError("`run_gpu_only` only supports test methods.") 1379 1380 def decorated(self, *args, **kwargs): 1381 if not is_gpu_available(): 1382 self.skipTest("Test requires GPU") 1383 1384 return f(self, *args, **kwargs) 1385 1386 return decorated 1387 1388 if func is not None: 1389 return decorator(func) 1390 1391 return decorator 1392 1393 1394def run_cuda_only(func=None): 1395 """Execute the decorated test only if a GPU is available. 1396 1397 This function is intended to be applied to tests that require the precense 1398 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 1399 1400 Args: 1401 func: function to be annotated. If `func` is None, this method returns a 1402 decorator the can be applied to a function. If `func` is not None this 1403 returns the decorator applied to `func`. 1404 1405 Returns: 1406 Returns a decorator that will conditionally skip the decorated test method. 1407 """ 1408 1409 def decorator(f): 1410 if tf_inspect.isclass(f): 1411 raise ValueError("`run_cuda_only` only supports test methods.") 1412 1413 def decorated(self, *args, **kwargs): 1414 if not is_gpu_available(cuda_only=True): 1415 self.skipTest("Test requires CUDA GPU") 1416 1417 return f(self, *args, **kwargs) 1418 1419 return decorated 1420 1421 if func is not None: 1422 return decorator(func) 1423 1424 return decorator 1425 1426 1427def with_forward_compatibility_horizons(*horizons): 1428 """Executes the decorated test with the specified forward-compat horizons. 1429 1430 Args: 1431 *horizons: A list of (year, month, day) tuples. If the list includes 1432 `None`, then the test will also be run with no forward-compatibility 1433 horizon set. 1434 1435 Returns: 1436 A decorator that will execute the test with the specified horizons. 1437 """ 1438 if not horizons: 1439 raise ValueError("Expected at least one horizon.") 1440 for horizon in horizons: 1441 if not ((horizon is None) or 1442 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): 1443 raise ValueError("Bad horizon value: %r" % horizon) 1444 1445 def decorator(f): 1446 if tf_inspect.isclass(f): 1447 raise ValueError("`with_forward_compatibility_horizons` only " 1448 "supports test methods.") 1449 def decorated(self, *args, **kwargs): 1450 for horizon in horizons: 1451 if horizon is None: 1452 f(self, *args, **kwargs) 1453 else: 1454 (year, month, day) = horizon 1455 with forward_compatibility_horizon(year, month, day): 1456 f(self, *args, **kwargs) 1457 return decorated 1458 1459 return decorator 1460 1461 1462@deprecation.deprecated(None, 1463 "Use `tf.config.list_physical_devices('GPU')` instead.") 1464@tf_export("test.is_gpu_available") 1465def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 1466 """Returns whether TensorFlow can access a GPU. 1467 1468 Warning: if a non-GPU version of the package is installed, the function would 1469 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow 1470 was build with CUDA support. 1471 1472 Args: 1473 cuda_only: limit the search to CUDA GPUs. 1474 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 1475 CUDA compute capability required, or None if no requirement. 1476 1477 Note that the keyword arg name "cuda_only" is misleading (since routine will 1478 return true when a GPU device is available irrespective of whether TF was 1479 built with CUDA support or ROCm support. However no changes here because 1480 1481 ++ Changing the name "cuda_only" to something more generic would break 1482 backward compatibility 1483 1484 ++ Adding an equivalent "rocm_only" would require the implementation check 1485 the build type. This in turn would require doing the same for CUDA and thus 1486 potentially break backward compatibility 1487 1488 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, 1489 but would require most (if not all) callers to update the call to use 1490 "cuda_or_rocm_only" instead of "cuda_only" 1491 1492 Returns: 1493 True if a GPU device of the requested kind is available. 1494 """ 1495 try: 1496 for local_device in device_lib.list_local_devices(): 1497 if local_device.device_type == "GPU": 1498 gpu_info = gpu_util.compute_capability_from_device_desc(local_device) 1499 cc = gpu_info.compute_capability or (0, 0) 1500 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability: 1501 return True 1502 if local_device.device_type == "SYCL" and not cuda_only: 1503 return True 1504 return False 1505 except errors_impl.NotFoundError as e: 1506 if not all(x in str(e) for x in ["CUDA", "not find"]): 1507 raise e 1508 else: 1509 logging.error(str(e)) 1510 return False 1511 1512 1513@contextlib.contextmanager 1514def device(use_gpu): 1515 """Uses gpu when requested and available.""" 1516 if use_gpu and is_gpu_available(): 1517 dev = "/device:GPU:0" 1518 else: 1519 dev = "/device:CPU:0" 1520 with ops.device(dev): 1521 yield 1522 1523 1524@contextlib.contextmanager 1525def use_gpu(): 1526 """Uses gpu when requested and available.""" 1527 with device(use_gpu=True): 1528 yield 1529 1530 1531@contextlib.contextmanager 1532def force_gpu(): 1533 """Force the gpu to be used.""" 1534 with ops.device("/device:GPU:0"): 1535 yield 1536 1537 1538@contextlib.contextmanager 1539def force_cpu(): 1540 """Force the cpu to be used.""" 1541 with ops.device("/device:CPU:0"): 1542 yield 1543 1544 1545class CapturedWrites(object): 1546 """A utility class to load the captured writes made to a stream.""" 1547 1548 def __init__(self, capture_location): 1549 self.capture_location = capture_location 1550 1551 def contents(self): 1552 """Get the captured writes as a single string.""" 1553 with open(self.capture_location) as tmp_file: 1554 output_data = "".join(tmp_file.readlines()) 1555 return output_data 1556 1557 1558class FakeEagerSession(object): 1559 """Fake session so tests that conditionally use placeholders can use eager. 1560 1561 There are a number of tests that conditionally use placeholders for shape 1562 inference. The pattern is demonstrated here: 1563 1564 ```python 1565 with self.cached_session() as sess: 1566 if static_shape: 1567 y = math_ops.matmul(x, ...) 1568 feed_dict = {} 1569 else: 1570 x_ph = array_ops.placeholder(...) 1571 y = math_ops.matmul(x_ph, ...) 1572 feed_dict = {x_ph: x} 1573 val = sess.run(y, feed_dict=feed_dict) 1574 ``` 1575 1576 Since the feed_dict is empty when not using placeholders we should be able to 1577 call self.evaluate(), however this requires rewriting the test case. 1578 This class should be considered a stop-gap solution to get tests running with 1579 eager with minimal changes to the actual test. 1580 """ 1581 1582 def __init__(self, test_case): 1583 self._test_case = test_case 1584 1585 def run(self, fetches, *args, **kwargs): 1586 """Evalaute `fetches`. 1587 1588 Fail if additional args are specified. 1589 1590 Args: 1591 fetches: A Tensor or a nested list/tuple of Tensors. 1592 *args: Positional arguments 1593 **kwargs: Keyword arguments 1594 1595 Raises: 1596 RuntimeError: If args or kwargs are specified. 1597 1598 Returns: 1599 Tensors as numpy values. 1600 """ 1601 feed_dict = kwargs.pop("feed_dict", {}) 1602 if feed_dict: 1603 raise RuntimeError( 1604 "feed_dict is not supported when eager execution is enabled " 1605 "(in this case, sess.run(t) is shorthand for t.numpy()") 1606 1607 if args or kwargs: 1608 raise RuntimeError( 1609 "Optional args are not supported when eager execution is enabled " 1610 "(in this case, sess.run(t) is shorthand for t.numpy()") 1611 1612 return self._test_case.evaluate(fetches) 1613 1614 1615class ErrorLoggingSession(session.Session): 1616 """Wrapper around a Session that logs errors in run().""" 1617 1618 def run(self, *args, **kwargs): 1619 try: 1620 return super(ErrorLoggingSession, self).run(*args, **kwargs) 1621 except Exception as e: # pylint: disable=broad-except 1622 # Note: disable the logging for OutOfRangeError, which makes the output 1623 # of tf.data tests hard to read, because OutOfRangeError is used as the 1624 # signal completion 1625 if not isinstance(e, errors.OutOfRangeError): 1626 logging.error(str(e)) 1627 raise 1628 1629 1630def disable_cudnn_autotune(func): 1631 """Disable autotuning during the call to this function. 1632 1633 Some tests want to base assertions on a graph being isomorphic with a copy. 1634 To ensure this, this decorator disables autotuning. 1635 1636 Args: 1637 func: Function to run with CuDNN autotuning turned off. 1638 1639 Returns: 1640 Decorated function. 1641 """ 1642 1643 def decorator(f): 1644 1645 def decorated(self, *args, **kwargs): 1646 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") 1647 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" 1648 original_xla_flags = os.environ.get("XLA_FLAGS") 1649 new_xla_flags = "--xla_gpu_autotune_level=0" 1650 if original_xla_flags: 1651 new_xla_flags = original_xla_flags + " " + new_xla_flags 1652 os.environ["XLA_FLAGS"] = new_xla_flags 1653 1654 result = f(self, *args, **kwargs) 1655 1656 if (original_tf_cudnn_use_autotune is None): 1657 del os.environ["TF_CUDNN_USE_AUTOTUNE"] 1658 else: 1659 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune 1660 if (original_xla_flags is None): 1661 del os.environ["XLA_FLAGS"] 1662 else: 1663 os.environ["XLA_FLAGS"] = original_xla_flags 1664 1665 return result 1666 1667 return decorated 1668 1669 if func is not None: 1670 return decorator(func) 1671 1672 return decorator 1673 1674 1675# The description is just for documentation purposes. 1676def enable_tf_xla_constant_folding(description): 1677 1678 if not isinstance(description, str): 1679 raise ValueError("'description' should be string, got {}".format( 1680 type(description))) 1681 1682 def enable_tf_xla_constant_folding_impl(func): 1683 """Enable constant folding during the call to this function. 1684 1685 Some tests fail without constant folding. 1686 1687 Args: 1688 func: Function to run with constant folding turned on. 1689 1690 Returns: 1691 Decorated function. 1692 """ 1693 1694 def decorator(f): 1695 1696 def decorated(self, *args, **kwargs): 1697 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() 1698 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) 1699 result = f(self, *args, **kwargs) 1700 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) 1701 return result 1702 1703 return decorated 1704 1705 if func is not None: 1706 return decorator(func) 1707 1708 return decorator 1709 1710 return enable_tf_xla_constant_folding_impl 1711 1712 1713# The description is just for documentation purposes. 1714def disable_xla(description): 1715 1716 def disable_xla_impl(func): 1717 """Execute the test method only if xla is not enabled.""" 1718 1719 def decorator(func): 1720 1721 def decorated(self, *args, **kwargs): 1722 if is_xla_enabled(): 1723 return 1724 else: 1725 return func(self, *args, **kwargs) 1726 1727 return decorated 1728 1729 if func is not None: 1730 return decorator(func) 1731 1732 return decorator 1733 1734 return disable_xla_impl 1735 1736 1737def for_all_test_methods(decorator, *args, **kwargs): 1738 """Generate class-level decorator from given method-level decorator. 1739 1740 It is expected for the given decorator to take some arguments and return 1741 a method that is then called on the test method to produce a decorated 1742 method. 1743 1744 Args: 1745 decorator: The decorator to apply. 1746 *args: Positional arguments 1747 **kwargs: Keyword arguments 1748 Returns: Function that will decorate a given classes test methods with the 1749 decorator. 1750 """ 1751 1752 def all_test_methods_impl(cls): 1753 """Apply decorator to all test methods in class.""" 1754 for name in dir(cls): 1755 value = getattr(cls, name) 1756 if callable(value) and name.startswith( 1757 "test") and (name != "test_session"): 1758 setattr(cls, name, decorator(*args, **kwargs)(value)) 1759 return cls 1760 1761 return all_test_methods_impl 1762 1763 1764# The description is just for documentation purposes. 1765def no_xla_auto_jit(description): # pylint: disable=unused-argument 1766 1767 def no_xla_auto_jit_impl(func): 1768 """This test is not intended to be run with XLA auto jit enabled.""" 1769 1770 def decorator(func): 1771 1772 def decorated(self, *args, **kwargs): 1773 if is_xla_enabled(): 1774 # Skip test if using XLA is forced. 1775 return 1776 else: 1777 return func(self, *args, **kwargs) 1778 1779 return decorated 1780 1781 if func is not None: 1782 return decorator(func) 1783 1784 return decorator 1785 1786 return no_xla_auto_jit_impl 1787 1788 1789# The description is just for documentation purposes. 1790def xla_allow_fallback(description): # pylint: disable=unused-argument 1791 1792 def xla_allow_fallback_impl(func): 1793 """Allow fallback to TF even though testing xla.""" 1794 1795 def decorator(func): 1796 1797 def decorated(self, *args, **kwargs): 1798 if is_xla_enabled(): 1799 # Update the global XLABuildOpsPassFlags to enable lazy compilation, 1800 # which allows the compiler to fall back to TF classic. Remember the 1801 # old value so that we can reset it. 1802 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) 1803 result = func(self, *args, **kwargs) 1804 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) 1805 return result 1806 else: 1807 return func(self, *args, **kwargs) 1808 1809 return decorated 1810 1811 if func is not None: 1812 return decorator(func) 1813 1814 return decorator 1815 1816 return xla_allow_fallback_impl 1817 1818 1819class EagerSessionWarner(object): 1820 1821 def __getattr__(self, attr): 1822 raise AttributeError( 1823 "Trying to access properties or call methods on the result of " 1824 "self.session(), self.cached_session(), etc while eager execution " 1825 "is enabled. If you're porting this test case to TF 2.0, either " 1826 "adapt the test to work with eager execution or insert a call to " 1827 "tf.disable_eager_execution() in the main() function of this test " 1828 "file.") 1829 1830 1831@tf_export("test.TestCase") 1832class TensorFlowTestCase(googletest.TestCase): 1833 """Base class for tests that need to test TensorFlow.""" 1834 1835 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 1836 super(TensorFlowTestCase, self).__init__(methodName) 1837 if is_xla_enabled(): 1838 pywrap_tf_session.TF_SetXlaAutoJitMode("2") 1839 pywrap_tf_session.TF_SetXlaMinClusterSize(1) 1840 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False) 1841 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True) 1842 # Constant folding secretly runs code on TF:Classic CPU, so we also 1843 # disable it here. 1844 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) 1845 1846 self._threads = [] 1847 self._tempdir = None 1848 self._cached_session = None 1849 1850 def setUp(self): 1851 self._ClearCachedSession() 1852 random.seed(random_seed.DEFAULT_GRAPH_SEED) 1853 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 1854 # Note: The following line is necessary because some test methods may error 1855 # out from within nested graph contexts (e.g., via assertRaises and 1856 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 1857 # under certain versions of Python. That would cause 1858 # ops.reset_default_graph() to throw an exception if the stack were not 1859 # cleared first. 1860 ops._default_graph_stack.reset() # pylint: disable=protected-access 1861 ops.reset_default_graph() 1862 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 1863 # Reset summary writer in case another test used set_as_default() with their 1864 # summary writer. 1865 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1866 summary_state.writer = None 1867 1868 # Avoiding calling setUp() for the poorly named test_session method. 1869 if self.id().endswith(".test_session"): 1870 self.skipTest("Not a test.") 1871 1872 def tearDown(self): 1873 for thread in self._threads: 1874 thread.check_termination() 1875 1876 self._ClearCachedSession() 1877 1878 def _ClearCachedSession(self): 1879 if self._cached_session is not None: 1880 self._cached_session.close() 1881 self._cached_session = None 1882 1883 def get_temp_dir(self): 1884 """Returns a unique temporary directory for the test to use. 1885 1886 If you call this method multiple times during in a test, it will return the 1887 same folder. However, across different runs the directories will be 1888 different. This will ensure that across different runs tests will not be 1889 able to pollute each others environment. 1890 If you need multiple unique directories within a single test, you should 1891 use tempfile.mkdtemp as follows: 1892 tempfile.mkdtemp(dir=self.get_temp_dir()): 1893 1894 Returns: 1895 string, the path to the unique temporary directory created for this test. 1896 """ 1897 if not self._tempdir: 1898 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 1899 return self._tempdir 1900 1901 @contextlib.contextmanager 1902 def captureWritesToStream(self, stream): 1903 """A context manager that captures the writes to a given stream. 1904 1905 This context manager captures all writes to a given stream inside of a 1906 `CapturedWrites` object. When this context manager is created, it yields 1907 the `CapturedWrites` object. The captured contents can be accessed by 1908 calling `.contents()` on the `CapturedWrites`. 1909 1910 For this function to work, the stream must have a file descriptor that 1911 can be modified using `os.dup` and `os.dup2`, and the stream must support 1912 a `.flush()` method. The default python sys.stdout and sys.stderr are 1913 examples of this. Note that this does not work in Colab or Jupyter 1914 notebooks, because those use alternate stdout streams. 1915 1916 Example: 1917 ```python 1918 class MyOperatorTest(test_util.TensorFlowTestCase): 1919 def testMyOperator(self): 1920 input = [1.0, 2.0, 3.0, 4.0, 5.0] 1921 with self.captureWritesToStream(sys.stdout) as captured: 1922 result = MyOperator(input).eval() 1923 self.assertStartsWith(captured.contents(), "This was printed.") 1924 ``` 1925 1926 Args: 1927 stream: The stream whose writes should be captured. This stream must have 1928 a file descriptor, support writing via using that file descriptor, and 1929 must have a `.flush()` method. 1930 1931 Yields: 1932 A `CapturedWrites` object that contains all writes to the specified stream 1933 made during this context. 1934 """ 1935 stream.flush() 1936 fd = stream.fileno() 1937 tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir()) 1938 tmp_file = open(tmp_file_path, "w") 1939 orig_fd = os.dup(fd) 1940 os.dup2(tmp_file.fileno(), fd) 1941 try: 1942 yield CapturedWrites(tmp_file_path) 1943 finally: 1944 tmp_file.close() 1945 os.dup2(orig_fd, fd) 1946 1947 def _AssertProtoEquals(self, a, b, msg=None): 1948 """Asserts that a and b are the same proto. 1949 1950 Uses ProtoEq() first, as it returns correct results 1951 for floating point attributes, and then use assertProtoEqual() 1952 in case of failure as it provides good error messages. 1953 1954 Args: 1955 a: a proto. 1956 b: another proto. 1957 msg: Optional message to report on failure. 1958 """ 1959 if not compare.ProtoEq(a, b): 1960 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 1961 1962 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 1963 """Asserts that message is same as parsed expected_message_ascii. 1964 1965 Creates another prototype of message, reads the ascii message into it and 1966 then compares them using self._AssertProtoEqual(). 1967 1968 Args: 1969 expected_message_maybe_ascii: proto message in original or ascii form. 1970 message: the message to validate. 1971 msg: Optional message to report on failure. 1972 """ 1973 msg = msg if msg else "" 1974 if isinstance(expected_message_maybe_ascii, type(message)): 1975 expected_message = expected_message_maybe_ascii 1976 self._AssertProtoEquals(expected_message, message) 1977 elif isinstance(expected_message_maybe_ascii, str): 1978 expected_message = type(message)() 1979 text_format.Merge( 1980 expected_message_maybe_ascii, 1981 expected_message, 1982 descriptor_pool=descriptor_pool.Default()) 1983 self._AssertProtoEquals(expected_message, message, msg=msg) 1984 else: 1985 assert False, ("Can't compare protos of type %s and %s. %s" % 1986 (type(expected_message_maybe_ascii), type(message), msg)) 1987 1988 def assertProtoEqualsVersion( 1989 self, 1990 expected, 1991 actual, 1992 producer=versions.GRAPH_DEF_VERSION, 1993 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 1994 msg=None): 1995 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 1996 producer, min_consumer, expected) 1997 self.assertProtoEquals(expected, actual, msg=msg) 1998 1999 def assertStartsWith(self, actual, expected_start, msg=None): 2000 """Assert that actual.startswith(expected_start) is True. 2001 2002 Args: 2003 actual: str 2004 expected_start: str 2005 msg: Optional message to report on failure. 2006 """ 2007 if not actual.startswith(expected_start): 2008 fail_msg = "%r does not start with %r" % (actual, expected_start) 2009 fail_msg += " : %r" % (msg) if msg else "" 2010 self.fail(fail_msg) 2011 2012 def _eval_tensor(self, tensor): 2013 if tensor is None: 2014 return None 2015 elif callable(tensor): 2016 return self._eval_helper(tensor()) 2017 else: 2018 try: 2019 if sparse_tensor.is_sparse(tensor): 2020 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 2021 tensor.values.numpy(), 2022 tensor.dense_shape.numpy()) 2023 elif ragged_tensor.is_ragged(tensor): 2024 return ragged_tensor_value.RaggedTensorValue( 2025 self._eval_tensor(tensor.values), 2026 self._eval_tensor(tensor.row_splits)) 2027 elif isinstance(tensor, ops.IndexedSlices): 2028 return ops.IndexedSlicesValue( 2029 values=tensor.values.numpy(), 2030 indices=tensor.indices.numpy(), 2031 dense_shape=tensor.dense_shape.numpy()) 2032 return tensor.numpy() 2033 except AttributeError as e: 2034 six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) 2035 2036 def _eval_helper(self, tensors): 2037 if tensors is None: 2038 return None 2039 return nest.map_structure(self._eval_tensor, tensors) 2040 2041 def evaluate(self, tensors): 2042 """Evaluates tensors and returns numpy values. 2043 2044 Args: 2045 tensors: A Tensor or a nested list/tuple of Tensors. 2046 2047 Returns: 2048 tensors numpy values. 2049 """ 2050 if context.executing_eagerly(): 2051 return self._eval_helper(tensors) 2052 else: 2053 sess = ops.get_default_session() 2054 if sess is None: 2055 with self.test_session() as sess: 2056 return sess.run(tensors) 2057 else: 2058 return sess.run(tensors) 2059 2060 # pylint: disable=g-doc-return-or-yield 2061 @contextlib.contextmanager 2062 def session(self, graph=None, config=None, use_gpu=False, force_gpu=False): 2063 """A context manager for a TensorFlow Session for use in executing tests. 2064 2065 Note that this will set this session and the graph as global defaults. 2066 2067 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2068 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2069 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2070 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2071 the CPU. 2072 2073 Example: 2074 2075 ``` python 2076 class MyOperatorTest(test_util.TensorFlowTestCase): 2077 def testMyOperator(self): 2078 with self.session(use_gpu=True): 2079 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2080 result = MyOperator(valid_input).eval() 2081 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2082 invalid_input = [-1.0, 2.0, 7.0] 2083 with self.assertRaisesOpError("negative input not supported"): 2084 MyOperator(invalid_input).eval() 2085 ``` 2086 2087 Args: 2088 graph: Optional graph to use during the returned session. 2089 config: An optional config_pb2.ConfigProto to use to configure the 2090 session. 2091 use_gpu: If True, attempt to run as many ops as possible on GPU. 2092 force_gpu: If True, pin all ops to `/device:GPU:0`. 2093 2094 Yields: 2095 A Session object that should be used as a context manager to surround 2096 the graph building and execution code in a test case. 2097 """ 2098 if context.executing_eagerly(): 2099 yield EagerSessionWarner() 2100 else: 2101 with self._create_session(graph, config, force_gpu) as sess: 2102 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 2103 yield sess 2104 2105 @contextlib.contextmanager 2106 def cached_session(self, 2107 graph=None, 2108 config=None, 2109 use_gpu=False, 2110 force_gpu=False): 2111 """Returns a TensorFlow Session for use in executing tests. 2112 2113 This method behaves differently than self.session(): for performance reasons 2114 `cached_session` will by default reuse the same session within the same 2115 test. The session returned by this function will only be closed at the end 2116 of the test (in the TearDown function). 2117 2118 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 2119 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 2120 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 2121 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 2122 the CPU. 2123 2124 Example: 2125 ```python 2126 class MyOperatorTest(test_util.TensorFlowTestCase): 2127 def testMyOperator(self): 2128 with self.cached_session(use_gpu=True) as sess: 2129 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 2130 result = MyOperator(valid_input).eval() 2131 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 2132 invalid_input = [-1.0, 2.0, 7.0] 2133 with self.assertRaisesOpError("negative input not supported"): 2134 MyOperator(invalid_input).eval() 2135 ``` 2136 2137 Args: 2138 graph: Optional graph to use during the returned session. 2139 config: An optional config_pb2.ConfigProto to use to configure the 2140 session. 2141 use_gpu: If True, attempt to run as many ops as possible on GPU. 2142 force_gpu: If True, pin all ops to `/device:GPU:0`. 2143 2144 Yields: 2145 A Session object that should be used as a context manager to surround 2146 the graph building and execution code in a test case. 2147 """ 2148 if context.executing_eagerly(): 2149 yield FakeEagerSession(self) 2150 else: 2151 sess = self._get_cached_session( 2152 graph, config, force_gpu, crash_if_inconsistent_args=True) 2153 with self._constrain_devices_and_set_default(sess, use_gpu, 2154 force_gpu) as cached: 2155 yield cached 2156 2157 @contextlib.contextmanager 2158 @deprecation.deprecated(None, "Use `self.session()` or " 2159 "`self.cached_session()` instead.") 2160 def test_session(self, 2161 graph=None, 2162 config=None, 2163 use_gpu=False, 2164 force_gpu=False): 2165 """Use cached_session instead.""" 2166 if self.id().endswith(".test_session"): 2167 self.skipTest( 2168 "Tests that have the name \"test_session\" are automatically skipped " 2169 "by TensorFlow test fixture, as the name is reserved for creating " 2170 "sessions within tests. Please rename your test if you have a test " 2171 "with this name.") 2172 if context.executing_eagerly(): 2173 yield None 2174 else: 2175 if graph is None: 2176 sess = self._get_cached_session( 2177 graph, config, force_gpu, crash_if_inconsistent_args=False) 2178 with self._constrain_devices_and_set_default(sess, use_gpu, 2179 force_gpu) as cached: 2180 yield cached 2181 else: 2182 with self.session(graph, config, use_gpu, force_gpu) as sess: 2183 yield sess 2184 2185 # pylint: enable=g-doc-return-or-yield 2186 2187 class _CheckedThread(object): 2188 """A wrapper class for Thread that asserts successful completion. 2189 2190 This class should be created using the TensorFlowTestCase.checkedThread() 2191 method. 2192 """ 2193 2194 def __init__(self, testcase, target, args=None, kwargs=None): 2195 """Constructs a new instance of _CheckedThread. 2196 2197 Args: 2198 testcase: The TensorFlowTestCase for which this thread is being created. 2199 target: A callable object representing the code to be executed in the 2200 thread. 2201 args: A tuple of positional arguments that will be passed to target. 2202 kwargs: A dictionary of keyword arguments that will be passed to target. 2203 """ 2204 self._testcase = testcase 2205 self._target = target 2206 self._args = () if args is None else args 2207 self._kwargs = {} if kwargs is None else kwargs 2208 self._thread = threading.Thread(target=self._protected_run) 2209 self._exception = None 2210 2211 self._is_thread_joined = False 2212 2213 def _protected_run(self): 2214 """Target for the wrapper thread. Sets self._exception on failure.""" 2215 try: 2216 self._target(*self._args, **self._kwargs) 2217 except Exception as e: # pylint: disable=broad-except 2218 self._exception = e 2219 2220 def start(self): 2221 """Starts the thread's activity. 2222 2223 This must be called at most once per _CheckedThread object. It arranges 2224 for the object's target to be invoked in a separate thread of control. 2225 """ 2226 self._thread.start() 2227 2228 def join(self): 2229 """Blocks until the thread terminates. 2230 2231 Raises: 2232 self._testcase.failureException: If the thread terminates with due to 2233 an exception. 2234 """ 2235 self._is_thread_joined = True 2236 self._thread.join() 2237 if self._exception is not None: 2238 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 2239 2240 def is_alive(self): 2241 """Returns whether the thread is alive. 2242 2243 This method returns True just before the run() method starts 2244 until just after the run() method terminates. 2245 2246 Returns: 2247 True if the thread is alive, otherwise False. 2248 """ 2249 return self._thread.is_alive() 2250 2251 def check_termination(self): 2252 """Returns whether the checked thread was properly used and did terminate. 2253 2254 Every checked thread should be "join"ed after starting, and before the 2255 test tears down. If it is not joined, it is possible the thread will hang 2256 and cause flaky failures in tests. 2257 2258 Raises: 2259 self._testcase.failureException: If check_termination was called before 2260 thread was joined. 2261 2262 RuntimeError: If the thread is not terminated. This means thread was not 2263 joined with the main thread. 2264 """ 2265 if self._is_thread_joined: 2266 if self.is_alive(): 2267 raise RuntimeError( 2268 "Thread was not joined with main thread, and is still running " 2269 "when the test finished.") 2270 else: 2271 self._testcase.fail("A checked thread was not joined.") 2272 2273 def checkedThread(self, target, args=None, kwargs=None): 2274 """Returns a Thread wrapper that asserts 'target' completes successfully. 2275 2276 This method should be used to create all threads in test cases, as 2277 otherwise there is a risk that a thread will silently fail, and/or 2278 assertions made in the thread will not be respected. 2279 2280 Args: 2281 target: A callable object to be executed in the thread. 2282 args: The argument tuple for the target invocation. Defaults to (). 2283 kwargs: A dictionary of keyword arguments for the target invocation. 2284 Defaults to {}. 2285 2286 Returns: 2287 A wrapper for threading.Thread that supports start() and join() methods. 2288 """ 2289 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 2290 self._threads.append(ret) 2291 return ret 2292 2293 # pylint: enable=invalid-name 2294 @py_func_if_in_function 2295 def assertNear(self, f1, f2, err, msg=None): 2296 """Asserts that two floats are near each other. 2297 2298 Checks that |f1 - f2| < err and asserts a test failure 2299 if not. 2300 2301 Args: 2302 f1: A float value. 2303 f2: A float value. 2304 err: A float value. 2305 msg: An optional string message to append to the failure message. 2306 """ 2307 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 2308 self.assertTrue( 2309 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 2310 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 2311 2312 @py_func_if_in_function 2313 def assertArrayNear(self, farray1, farray2, err, msg=None): 2314 """Asserts that two float arrays are near each other. 2315 2316 Checks that for all elements of farray1 and farray2 2317 |f1 - f2| < err. Asserts a test failure if not. 2318 2319 Args: 2320 farray1: a list of float values. 2321 farray2: a list of float values. 2322 err: a float value. 2323 msg: Optional message to report on failure. 2324 """ 2325 self.assertEqual(len(farray1), len(farray2), msg=msg) 2326 for f1, f2 in zip(farray1, farray2): 2327 self.assertNear(float(f1), float(f2), err, msg=msg) 2328 2329 def _NDArrayNear(self, ndarray1, ndarray2, err): 2330 return np.linalg.norm(ndarray1 - ndarray2) < err 2331 2332 @py_func_if_in_function 2333 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 2334 """Asserts that two numpy arrays have near values. 2335 2336 Args: 2337 ndarray1: a numpy ndarray. 2338 ndarray2: a numpy ndarray. 2339 err: a float. The maximum absolute difference allowed. 2340 msg: Optional message to report on failure. 2341 """ 2342 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 2343 2344 def _GetNdArray(self, a): 2345 # If a is tensor-like then convert it to ndarray 2346 if tensor_util.is_tensor(a): 2347 if isinstance(a, ops._EagerTensorBase): 2348 a = a.numpy() 2349 else: 2350 a = self.evaluate(a) 2351 if not isinstance(a, np.ndarray): 2352 return np.array(a) 2353 return a 2354 2355 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2356 a = self._GetNdArray(a) 2357 b = self._GetNdArray(b) 2358 # When the array rank is small, print its contents. Numpy array printing is 2359 # implemented using inefficient recursion so prints can cause tests to 2360 # time out. 2361 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 2362 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 2363 "%s.") % (a.shape, b.shape, b) 2364 else: 2365 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 2366 b.shape) 2367 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 2368 2369 msgs = [msg] 2370 if not np.allclose(a, b, rtol=rtol, atol=atol): 2371 # Adds more details to np.testing.assert_allclose. 2372 # 2373 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 2374 # checks whether two arrays are element-wise equal within a 2375 # tolerance. The relative difference (rtol * abs(b)) and the 2376 # absolute difference atol are added together to compare against 2377 # the absolute difference between a and b. Here, we want to 2378 # tell user which elements violate such conditions. 2379 cond = np.logical_or( 2380 np.abs(a - b) > atol + rtol * np.abs(b), 2381 np.isnan(a) != np.isnan(b)) 2382 if a.ndim: 2383 x = a[np.where(cond)] 2384 y = b[np.where(cond)] 2385 msgs.append("not close where = {}".format(np.where(cond))) 2386 else: 2387 # np.where is broken for scalars 2388 x, y = a, b 2389 msgs.append("not close lhs = {}".format(x)) 2390 msgs.append("not close rhs = {}".format(y)) 2391 msgs.append("not close dif = {}".format(np.abs(x - y))) 2392 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 2393 msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape)) 2394 # TODO(xpan): There seems to be a bug: 2395 # tensorflow/compiler/tests:binary_ops_test pass with float32 2396 # nan even though the equal_nan is False by default internally. 2397 np.testing.assert_allclose( 2398 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 2399 2400 def _assertAllCloseRecursive(self, 2401 a, 2402 b, 2403 rtol=1e-6, 2404 atol=1e-6, 2405 path=None, 2406 msg=None): 2407 path = path or [] 2408 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "") 2409 msg = msg if msg else "" 2410 2411 # Check if a and/or b are namedtuples. 2412 if hasattr(a, "_asdict"): 2413 a = a._asdict() 2414 if hasattr(b, "_asdict"): 2415 b = b._asdict() 2416 a_is_dict = isinstance(a, collections_abc.Mapping) 2417 if a_is_dict != isinstance(b, collections_abc.Mapping): 2418 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 2419 (path_str, path_str, msg)) 2420 if a_is_dict: 2421 self.assertItemsEqual( 2422 a.keys(), 2423 b.keys(), 2424 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 2425 (path_str, a.keys(), path_str, b.keys(), msg)) 2426 for k in a: 2427 path.append(k) 2428 self._assertAllCloseRecursive( 2429 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 2430 del path[-1] 2431 elif isinstance(a, (list, tuple)): 2432 # Try to directly compare a, b as ndarrays; if not work, then traverse 2433 # through the sequence, which is more expensive. 2434 try: 2435 a_as_ndarray = self._GetNdArray(a) 2436 b_as_ndarray = self._GetNdArray(b) 2437 self._assertArrayLikeAllClose( 2438 a_as_ndarray, 2439 b_as_ndarray, 2440 rtol=rtol, 2441 atol=atol, 2442 msg="Mismatched value: a%s is different from b%s. %s" % 2443 (path_str, path_str, msg)) 2444 except (ValueError, TypeError) as e: 2445 if len(a) != len(b): 2446 raise ValueError( 2447 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 2448 (path_str, len(a), path_str, len(b), msg)) 2449 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 2450 path.append(str(idx)) 2451 self._assertAllCloseRecursive( 2452 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 2453 del path[-1] 2454 # a and b are ndarray like objects 2455 else: 2456 try: 2457 self._assertArrayLikeAllClose( 2458 a, 2459 b, 2460 rtol=rtol, 2461 atol=atol, 2462 msg=("Mismatched value: a%s is different from b%s. %s" % 2463 (path_str, path_str, msg))) 2464 except TypeError as e: 2465 msg = ("Error: a%s has %s, but b%s has %s. %s" % 2466 (path_str, type(a), path_str, type(b), msg)) 2467 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 2468 raise 2469 2470 @py_func_if_in_function 2471 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2472 """Asserts that two structures of numpy arrays or Tensors, have near values. 2473 2474 `a` and `b` can be arbitrarily nested structures. A layer of a nested 2475 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 2476 2477 Args: 2478 a: The expected numpy `ndarray`, or anything that can be converted into a 2479 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2480 structure of these. 2481 b: The actual numpy `ndarray`, or anything that can be converted into a 2482 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2483 structure of these. 2484 rtol: relative tolerance. 2485 atol: absolute tolerance. 2486 msg: Optional message to report on failure. 2487 2488 Raises: 2489 ValueError: if only one of `a[p]` and `b[p]` is a dict or 2490 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 2491 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 2492 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 2493 """ 2494 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b): 2495 return self._assertRaggedClose(a, b, rtol, atol, msg) 2496 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 2497 2498 @py_func_if_in_function 2499 def assertAllCloseAccordingToType(self, 2500 a, 2501 b, 2502 rtol=1e-6, 2503 atol=1e-6, 2504 float_rtol=1e-6, 2505 float_atol=1e-6, 2506 half_rtol=1e-3, 2507 half_atol=1e-3, 2508 bfloat16_rtol=1e-2, 2509 bfloat16_atol=1e-2, 2510 msg=None): 2511 """Like assertAllClose, but also suitable for comparing fp16 arrays. 2512 2513 In particular, the tolerance is reduced to 1e-3 if at least 2514 one of the arguments is of type float16. 2515 2516 Args: 2517 a: the expected numpy ndarray or anything can be converted to one. 2518 b: the actual numpy ndarray or anything can be converted to one. 2519 rtol: relative tolerance. 2520 atol: absolute tolerance. 2521 float_rtol: relative tolerance for float32. 2522 float_atol: absolute tolerance for float32. 2523 half_rtol: relative tolerance for float16. 2524 half_atol: absolute tolerance for float16. 2525 bfloat16_rtol: relative tolerance for bfloat16. 2526 bfloat16_atol: absolute tolerance for bfloat16. 2527 msg: Optional message to report on failure. 2528 """ 2529 a = self._GetNdArray(a) 2530 b = self._GetNdArray(b) 2531 # types with lower tol are put later to overwrite previous ones. 2532 if (a.dtype == np.float32 or b.dtype == np.float32 or 2533 a.dtype == np.complex64 or b.dtype == np.complex64): 2534 rtol = max(rtol, float_rtol) 2535 atol = max(atol, float_atol) 2536 if a.dtype == np.float16 or b.dtype == np.float16: 2537 rtol = max(rtol, half_rtol) 2538 atol = max(atol, half_atol) 2539 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 2540 b.dtype == dtypes.bfloat16.as_numpy_dtype): 2541 rtol = max(rtol, bfloat16_rtol) 2542 atol = max(atol, bfloat16_atol) 2543 2544 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2545 2546 @py_func_if_in_function 2547 def assertNotAllClose(self, a, b, **kwargs): 2548 """Assert that two numpy arrays, or Tensors, do not have near values. 2549 2550 Args: 2551 a: the first value to compare. 2552 b: the second value to compare. 2553 **kwargs: additional keyword arguments to be passed to the underlying 2554 `assertAllClose` call. 2555 2556 Raises: 2557 AssertionError: If `a` and `b` are unexpectedly close at all elements. 2558 """ 2559 try: 2560 self.assertAllClose(a, b, **kwargs) 2561 except AssertionError: 2562 return 2563 raise AssertionError("The two values are close at all elements") 2564 2565 @py_func_if_in_function 2566 def assertAllEqual(self, a, b, msg=None): 2567 """Asserts that two numpy arrays or Tensors have the same values. 2568 2569 Args: 2570 a: the expected numpy ndarray or anything can be converted to one. 2571 b: the actual numpy ndarray or anything can be converted to one. 2572 msg: Optional message to report on failure. 2573 """ 2574 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)): 2575 return self._assertRaggedEqual(a, b, msg) 2576 msg = msg if msg else "" 2577 a = self._GetNdArray(a) 2578 b = self._GetNdArray(b) 2579 # Arbitrary bounds so that we don't print giant tensors. 2580 if (b.ndim <= 3 or b.size < 500): 2581 self.assertEqual( 2582 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2583 " Contents: %s. \n%s." % (a.shape, b.shape, b, msg)) 2584 else: 2585 self.assertEqual( 2586 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2587 " %s" % (a.shape, b.shape, msg)) 2588 2589 same = (a == b) 2590 2591 if (a.dtype in [ 2592 np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype 2593 ]): 2594 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 2595 msgs = [msg] 2596 if not np.all(same): 2597 # Adds more details to np.testing.assert_array_equal. 2598 diff = np.logical_not(same) 2599 if a.ndim: 2600 x = a[np.where(diff)] 2601 y = b[np.where(diff)] 2602 msgs.append("not equal where = {}".format(np.where(diff))) 2603 else: 2604 # np.where is broken for scalars 2605 x, y = a, b 2606 msgs.append("not equal lhs = {}".format(x)) 2607 msgs.append("not equal rhs = {}".format(y)) 2608 # With Python 3, we need to make sure the dtype matches between a and b. 2609 b = b.astype(a.dtype) 2610 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 2611 2612 @py_func_if_in_function 2613 def assertNotAllEqual(self, a, b, msg=None): 2614 """Asserts that two numpy arrays or Tensors do not have the same values. 2615 2616 Args: 2617 a: the expected numpy ndarray or anything can be converted to one. 2618 b: the actual numpy ndarray or anything can be converted to one. 2619 msg: Optional message to report on failure. 2620 """ 2621 try: 2622 self.assertAllEqual(a, b, msg) 2623 except AssertionError: 2624 return 2625 raise AssertionError("The two values are equal at all elements") 2626 2627 @py_func_if_in_function 2628 def assertAllGreater(self, a, comparison_target): 2629 """Assert element values are all greater than a target value. 2630 2631 Args: 2632 a: The numpy `ndarray`, or anything that can be converted into a numpy 2633 `ndarray` (including Tensor). 2634 comparison_target: The target value of comparison. 2635 """ 2636 a = self._GetNdArray(a) 2637 self.assertGreater(np.min(a), comparison_target) 2638 2639 @py_func_if_in_function 2640 def assertAllLess(self, a, comparison_target): 2641 """Assert element values are all less than a target value. 2642 2643 Args: 2644 a: The numpy `ndarray`, or anything that can be converted into a numpy 2645 `ndarray` (including Tensor). 2646 comparison_target: The target value of comparison. 2647 """ 2648 a = self._GetNdArray(a) 2649 self.assertLess(np.max(a), comparison_target) 2650 2651 @py_func_if_in_function 2652 def assertAllGreaterEqual(self, a, comparison_target): 2653 """Assert element values are all greater than or equal to a target value. 2654 2655 Args: 2656 a: The numpy `ndarray`, or anything that can be converted into a numpy 2657 `ndarray` (including Tensor). 2658 comparison_target: The target value of comparison. 2659 """ 2660 a = self._GetNdArray(a) 2661 self.assertGreaterEqual(np.min(a), comparison_target) 2662 2663 @py_func_if_in_function 2664 def assertAllLessEqual(self, a, comparison_target): 2665 """Assert element values are all less than or equal to a target value. 2666 2667 Args: 2668 a: The numpy `ndarray`, or anything that can be converted into a numpy 2669 `ndarray` (including Tensor). 2670 comparison_target: The target value of comparison. 2671 """ 2672 a = self._GetNdArray(a) 2673 self.assertLessEqual(np.max(a), comparison_target) 2674 2675 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 2676 """Generate a summary of ndarray subscripts as a list of str. 2677 2678 If limit == N, this method will print up to the first N subscripts on 2679 separate 2680 lines. A line of ellipses (...) will be appended at the end if the number of 2681 subscripts exceeds N. 2682 2683 Args: 2684 subscripts: The tensor (np.ndarray) subscripts, of the same format as 2685 np.where()'s return value, i.e., a tuple of arrays with each array 2686 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 2687 value: (np.ndarray) value of the tensor. 2688 limit: (int) The maximum number of indices to print. 2689 indent: (int) Number of characters to indent at the beginning of each 2690 line. 2691 2692 Returns: 2693 (list of str) the multi-line representation of the subscripts and values, 2694 potentially with omission at the end. 2695 """ 2696 lines = [] 2697 subscripts = np.transpose(subscripts) 2698 prefix = " " * indent 2699 for subscript in itertools.islice(subscripts, limit): 2700 lines.append(prefix + str(subscript) + " : " + 2701 str(value[tuple(subscript)])) 2702 if len(subscripts) > limit: 2703 lines.append(prefix + "...") 2704 return lines 2705 2706 @py_func_if_in_function 2707 def assertAllInRange(self, 2708 target, 2709 lower_bound, 2710 upper_bound, 2711 open_lower_bound=False, 2712 open_upper_bound=False): 2713 """Assert that elements in a Tensor are all in a given range. 2714 2715 Args: 2716 target: The numpy `ndarray`, or anything that can be converted into a 2717 numpy `ndarray` (including Tensor). 2718 lower_bound: lower bound of the range 2719 upper_bound: upper bound of the range 2720 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 2721 than the default >=) 2722 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 2723 than the default <=) 2724 2725 Raises: 2726 AssertionError: 2727 if the value tensor does not have an ordered numeric type (float* or 2728 int*), or 2729 if there are nan values, or 2730 if any of the elements do not fall in the specified range. 2731 """ 2732 target = self._GetNdArray(target) 2733 if not (np.issubdtype(target.dtype, np.floating) or 2734 np.issubdtype(target.dtype, np.integer)): 2735 raise AssertionError( 2736 "The value of %s does not have an ordered numeric type, instead it " 2737 "has type: %s" % (target, target.dtype)) 2738 2739 nan_subscripts = np.where(np.isnan(target)) 2740 if np.size(nan_subscripts): 2741 raise AssertionError( 2742 "%d of the %d element(s) are NaN. " 2743 "Subscripts(s) and value(s) of the NaN element(s):\n" % 2744 (len(nan_subscripts[0]), np.size(target)) + 2745 "\n".join(self._format_subscripts(nan_subscripts, target))) 2746 2747 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 2748 str(upper_bound) + (")" if open_upper_bound else "]")) 2749 2750 violations = ( 2751 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 2752 target, lower_bound)) 2753 violations = np.logical_or( 2754 violations, 2755 np.greater_equal(target, upper_bound) 2756 if open_upper_bound else np.greater(target, upper_bound)) 2757 violation_subscripts = np.where(violations) 2758 if np.size(violation_subscripts): 2759 raise AssertionError( 2760 "%d of the %d element(s) are outside the range %s. " % 2761 (len(violation_subscripts[0]), np.size(target), range_str) + 2762 "Subscript(s) and value(s) of the offending elements:\n" + 2763 "\n".join(self._format_subscripts(violation_subscripts, target))) 2764 2765 @py_func_if_in_function 2766 def assertAllInSet(self, target, expected_set): 2767 """Assert that elements of a Tensor are all in a given closed set. 2768 2769 Args: 2770 target: The numpy `ndarray`, or anything that can be converted into a 2771 numpy `ndarray` (including Tensor). 2772 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 2773 of the value of `target` are expected to fall into. 2774 2775 Raises: 2776 AssertionError: 2777 if any of the elements do not fall into `expected_set`. 2778 """ 2779 target = self._GetNdArray(target) 2780 2781 # Elements in target that are not in expected_set. 2782 diff = np.setdiff1d(target.flatten(), list(expected_set)) 2783 if np.size(diff): 2784 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 2785 (np.size(diff), expected_set, diff)) 2786 2787 @py_func_if_in_function 2788 def assertDTypeEqual(self, target, expected_dtype): 2789 """Assert ndarray data type is equal to expected. 2790 2791 Args: 2792 target: The numpy `ndarray`, or anything that can be converted into a 2793 numpy `ndarray` (including Tensor). 2794 expected_dtype: Expected data type. 2795 """ 2796 target = self._GetNdArray(target) 2797 if not isinstance(target, list): 2798 arrays = [target] 2799 for arr in arrays: 2800 self.assertEqual(arr.dtype, expected_dtype) 2801 2802 # pylint: disable=g-doc-return-or-yield 2803 @contextlib.contextmanager 2804 def assertRaisesWithPredicateMatch(self, exception_type, 2805 expected_err_re_or_predicate): 2806 """Returns a context manager to enclose code expected to raise an exception. 2807 2808 If the exception is an OpError, the op stack is also included in the message 2809 predicate search. 2810 2811 Args: 2812 exception_type: The expected type of exception that should be raised. 2813 expected_err_re_or_predicate: If this is callable, it should be a function 2814 of one argument that inspects the passed-in exception and returns True 2815 (success) or False (please fail the test). Otherwise, the error message 2816 is expected to match this regular expression partially. 2817 2818 Returns: 2819 A context manager to surround code that is expected to raise an 2820 exception. 2821 """ 2822 if callable(expected_err_re_or_predicate): 2823 predicate = expected_err_re_or_predicate 2824 else: 2825 2826 def predicate(e): 2827 err_str = e.message if isinstance(e, errors.OpError) else str(e) 2828 op = e.op if isinstance(e, errors.OpError) else None 2829 while op is not None: 2830 err_str += "\nCaused by: " + op.name 2831 op = op._original_op # pylint: disable=protected-access 2832 logging.info("Searching within error strings: '%s' within '%s'", 2833 expected_err_re_or_predicate, err_str) 2834 return re.search(expected_err_re_or_predicate, err_str) 2835 2836 try: 2837 yield 2838 self.fail(exception_type.__name__ + " not raised") 2839 except Exception as e: # pylint: disable=broad-except 2840 if not isinstance(e, exception_type) or not predicate(e): 2841 raise AssertionError("Exception of type %s: %s" % 2842 (str(type(e)), str(e))) 2843 2844 # pylint: enable=g-doc-return-or-yield 2845 2846 def assertRaisesOpError(self, expected_err_re_or_predicate): 2847 return self.assertRaisesWithPredicateMatch(errors.OpError, 2848 expected_err_re_or_predicate) 2849 2850 def assertShapeEqual(self, np_array, tf_tensor, msg=None): 2851 """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. 2852 2853 Args: 2854 np_array: A Numpy ndarray or Numpy scalar. 2855 tf_tensor: A Tensor. 2856 msg: Optional message to report on failure. 2857 2858 Raises: 2859 TypeError: If the arguments have the wrong type. 2860 """ 2861 if not isinstance(np_array, (np.ndarray, np.generic)): 2862 raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") 2863 if not isinstance(tf_tensor, ops.Tensor): 2864 raise TypeError("tf_tensor must be a Tensor") 2865 self.assertAllEqual( 2866 np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) 2867 2868 def assertDeviceEqual(self, device1, device2, msg=None): 2869 """Asserts that the two given devices are the same. 2870 2871 Args: 2872 device1: A string device name or TensorFlow `DeviceSpec` object. 2873 device2: A string device name or TensorFlow `DeviceSpec` object. 2874 msg: Optional message to report on failure. 2875 """ 2876 device1 = pydev.canonical_name(device1) 2877 device2 = pydev.canonical_name(device2) 2878 self.assertEqual( 2879 device1, device2, 2880 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 2881 2882 def _GetPyList(self, a): 2883 """Converts `a` to a nested python list.""" 2884 if isinstance(a, ragged_tensor.RaggedTensor): 2885 return self.evaluate(a).to_list() 2886 elif isinstance(a, ops.Tensor): 2887 a = self.evaluate(a) 2888 return a.tolist() if isinstance(a, np.ndarray) else a 2889 elif isinstance(a, np.ndarray): 2890 return a.tolist() 2891 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 2892 return a.to_list() 2893 else: 2894 return np.array(a).tolist() 2895 2896 def _assertRaggedEqual(self, a, b, msg): 2897 """Asserts that two ragged tensors are equal.""" 2898 a_list = self._GetPyList(a) 2899 b_list = self._GetPyList(b) 2900 self.assertEqual(a_list, b_list, msg) 2901 2902 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 2903 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 2904 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 2905 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 2906 2907 def _assertRaggedClose(self, a, b, rtol, atol, msg=None): 2908 a_list = self._GetPyList(a) 2909 b_list = self._GetPyList(b) 2910 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg) 2911 2912 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 2913 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 2914 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 2915 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 2916 2917 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): 2918 self.assertEqual(type(a), type(b)) 2919 if isinstance(a, (list, tuple)): 2920 self.assertLen(a, len(b), "Length differs for %s" % path) 2921 for i in range(len(a)): 2922 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg, 2923 "%s[%s]" % (path, i)) 2924 else: 2925 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) 2926 2927 # Fix Python 3+ compatibility issues 2928 if not six.PY2: 2929 # pylint: disable=invalid-name 2930 2931 # Silence a deprecation warning 2932 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 2933 2934 # assertItemsEqual is assertCountEqual as of 3.2. 2935 assertItemsEqual = googletest.TestCase.assertCountEqual 2936 2937 # pylint: enable=invalid-name 2938 2939 @contextlib.contextmanager 2940 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 2941 """Set the session and its graph to global default and constrain devices.""" 2942 if context.executing_eagerly(): 2943 yield None 2944 else: 2945 with sess.graph.as_default(), sess.as_default(): 2946 if force_gpu: 2947 # Use the name of an actual device if one is detected, or 2948 # '/device:GPU:0' otherwise 2949 gpu_name = gpu_device_name() 2950 if not gpu_name: 2951 gpu_name = "/device:GPU:0" 2952 with sess.graph.device(gpu_name): 2953 yield sess 2954 elif use_gpu: 2955 yield sess 2956 else: 2957 with sess.graph.device("/device:CPU:0"): 2958 yield sess 2959 2960 def _create_session(self, graph, config, force_gpu): 2961 """See session() for details.""" 2962 2963 def prepare_config(config): 2964 """Returns a config for sessions. 2965 2966 Args: 2967 config: An optional config_pb2.ConfigProto to use to configure the 2968 session. 2969 2970 Returns: 2971 A config_pb2.ConfigProto object. 2972 """ 2973 # TODO(b/114333779): Enforce allow_soft_placement=False when 2974 # use_gpu=False. Currently many tests rely on the fact that any device 2975 # will be used even when a specific device is supposed to be used. 2976 allow_soft_placement = not force_gpu 2977 if config is None: 2978 config = context.context().config 2979 config.allow_soft_placement = allow_soft_placement 2980 elif not allow_soft_placement and config.allow_soft_placement: 2981 config_copy = context.context().config 2982 config = config_copy 2983 config.allow_soft_placement = False 2984 # Don't perform optimizations for tests so we don't inadvertently run 2985 # gpu ops on cpu 2986 config.graph_options.optimizer_options.opt_level = -1 2987 # Disable Grappler constant folding since some tests & benchmarks 2988 # use constant input and become meaningless after constant folding. 2989 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 2990 # GRAPPLER TEAM. 2991 config.graph_options.rewrite_options.constant_folding = ( 2992 rewriter_config_pb2.RewriterConfig.OFF) 2993 config.graph_options.rewrite_options.pin_to_host_optimization = ( 2994 rewriter_config_pb2.RewriterConfig.OFF) 2995 return config 2996 2997 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 2998 2999 def _get_cached_session(self, 3000 graph=None, 3001 config=None, 3002 force_gpu=False, 3003 crash_if_inconsistent_args=True): 3004 """See cached_session() for documentation.""" 3005 if self._cached_session is None: 3006 sess = self._create_session( 3007 graph=graph, config=config, force_gpu=force_gpu) 3008 self._cached_session = sess 3009 self._cached_graph = graph 3010 self._cached_config = config 3011 self._cached_force_gpu = force_gpu 3012 return sess 3013 else: 3014 if crash_if_inconsistent_args and self._cached_graph is not graph: 3015 raise ValueError("The graph used to get the cached session is " 3016 "different than the one that was used to create the " 3017 "session. Maybe create a new session with " 3018 "self.session()") 3019 if crash_if_inconsistent_args and self._cached_config is not config: 3020 raise ValueError("The config used to get the cached session is " 3021 "different than the one that was used to create the " 3022 "session. Maybe create a new session with " 3023 "self.session()") 3024 if crash_if_inconsistent_args and (self._cached_force_gpu is 3025 not force_gpu): 3026 raise ValueError( 3027 "The force_gpu value used to get the cached session is " 3028 "different than the one that was used to create the " 3029 "session. Maybe create a new session with " 3030 "self.session()") 3031 return self._cached_session 3032 3033 3034@tf_export("test.create_local_cluster") 3035def create_local_cluster(num_workers, 3036 num_ps, 3037 protocol="grpc", 3038 worker_config=None, 3039 ps_config=None): 3040 """Create and start local servers and return the associated `Server` objects. 3041 3042 "PS" stands for "parameter server": a task responsible for storing and 3043 updating the model's parameters. Other tasks send updates to these parameters 3044 as they work on optimizing the parameters. This particular division of labor 3045 between tasks is not required, but is common for distributed training. 3046 3047 Read more at https://www.tensorflow.org/guide/extend/architecture 3048 3049  3050 3051 3052 Figure illustrates the interaction of these components. 3053 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services. 3054 3055 3056 Example: 3057 ```python 3058 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 3059 3060 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers] 3061 3062 with tf.device("/job:ps/task:0"): 3063 ... 3064 with tf.device("/job:ps/task:1"): 3065 ... 3066 with tf.device("/job:worker/task:0"): 3067 ... 3068 with tf.device("/job:worker/task:1"): 3069 ... 3070 3071 worker_sessions[0].run(...) 3072 ``` 3073 3074 Args: 3075 num_workers: Number of worker servers to start. 3076 num_ps: Number of PS servers to start. 3077 protocol: Communication protocol. Allowed values are documented in the 3078 documentation of `tf.distribute.Server`. 3079 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be 3080 used to instantiate multiple devices etc. 3081 ps_config: (optional) `tf.ConfigProto` to initialize PS servers. 3082 3083 Returns: 3084 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 3085 of `num_workers` objects of type `tf.distribute.Server` (all running 3086 locally); 3087 and `ps_servers` is a list of `num_ps` objects of similar type. 3088 3089 Raises: 3090 ImportError: if portpicker module was not found at load time 3091 """ 3092 import portpicker # pylint: disable=g-import-not-at-top 3093 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 3094 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 3095 cluster_dict = { 3096 "worker": ["localhost:%s" % port for port in worker_ports], 3097 "ps": ["localhost:%s" % port for port in ps_ports] 3098 } 3099 cs = server_lib.ClusterSpec(cluster_dict) 3100 3101 workers = [ 3102 server_lib.Server( 3103 cs, 3104 job_name="worker", 3105 protocol=protocol, 3106 task_index=ix, 3107 config=worker_config, 3108 start=True) for ix in range(num_workers) 3109 ] 3110 ps_servers = [ 3111 server_lib.Server( 3112 cs, 3113 job_name="ps", 3114 protocol=protocol, 3115 task_index=ix, 3116 config=ps_config, 3117 start=True) for ix in range(num_ps) 3118 ] 3119 3120 return workers, ps_servers 3121 3122 3123def get_node_def_from_graph(node_name, graph_def): 3124 """Returns the `NodeDef` instance for given node name in the graph def. 3125 3126 This method explores only the NodeDefs in `graph_def.node`. 3127 3128 Args: 3129 node_name: Name of the NodeDef to search for. 3130 graph_def: An instance of `GraphDef` proto. 3131 3132 Returns: 3133 the `NodeDef` instance whose name field matches the given node_name or None. 3134 """ 3135 for node_def in graph_def.node: 3136 if node_def.name == node_name: 3137 return node_def 3138 return None 3139 3140 3141def set_producer_version(graph, producer_version): 3142 """Sets graph.graph_def_versions.producer to `producer_version`.""" 3143 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 3144 # it via import_graph_def though. 3145 graph_def = graph_pb2.GraphDef() 3146 graph_def.versions.producer = producer_version 3147 with graph.as_default(): 3148 importer.import_graph_def(graph_def) 3149 assert graph.graph_def_versions.producer, producer_version 3150