1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow.""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23from collections import OrderedDict 24import contextlib 25import gc 26import itertools 27import math 28import os 29import random 30import re 31import tempfile 32import threading 33import unittest 34 35import numpy as np 36import six 37 38_portpicker_import_error = None 39try: 40 import portpicker # pylint: disable=g-import-not-at-top 41except ImportError as _error: 42 _portpicker_import_error = _error 43 portpicker = None 44 45# pylint: disable=g-import-not-at-top 46from google.protobuf import descriptor_pool 47from google.protobuf import text_format 48 49from tensorflow.core.framework import graph_pb2 50from tensorflow.core.protobuf import config_pb2 51from tensorflow.core.protobuf import rewriter_config_pb2 52from tensorflow.python import pywrap_tensorflow 53from tensorflow.python import tf2 54from tensorflow.python.client import device_lib 55from tensorflow.python.client import session 56from tensorflow.python.eager import context 57from tensorflow.python.eager import def_function 58from tensorflow.python.eager import tape 59from tensorflow.python.framework import device as pydev 60from tensorflow.python.framework import dtypes 61from tensorflow.python.framework import errors 62from tensorflow.python.framework import errors_impl 63from tensorflow.python.framework import importer 64from tensorflow.python.framework import ops 65from tensorflow.python.framework import random_seed 66from tensorflow.python.framework import sparse_tensor 67from tensorflow.python.framework import tensor_shape 68from tensorflow.python.framework import versions 69from tensorflow.python.ops import array_ops 70from tensorflow.python.ops import control_flow_util 71from tensorflow.python.ops import script_ops 72from tensorflow.python.ops import variables 73from tensorflow.python.platform import googletest 74from tensorflow.python.platform import tf_logging as logging 75from tensorflow.python.training import server_lib 76from tensorflow.python.util import compat 77from tensorflow.python.util import deprecation 78from tensorflow.python.util import nest 79from tensorflow.python.util import tf_decorator 80from tensorflow.python.util import tf_inspect 81from tensorflow.python.util.protobuf import compare 82from tensorflow.python.util.tf_export import tf_export 83 84 85# If the above import is made available through the BUILD rule, then this 86# function is overridden and will instead return True and cause Tensorflow 87# graphs to be compiled with XLA. 88def is_xla_enabled(): 89 return False 90 91 92try: 93 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top 94except: 95 pass 96 97 98@tf_export("test.gpu_device_name") 99def gpu_device_name(): 100 """Returns the name of a GPU device if available or the empty string.""" 101 for x in device_lib.list_local_devices(): 102 if x.device_type == "GPU" or x.device_type == "SYCL": 103 return compat.as_str(x.name) 104 return "" 105 106 107def assert_ops_in_graph(expected_ops, graph): 108 """Assert all expected operations are found. 109 110 Args: 111 expected_ops: `dict<string, string>` of op name to op type. 112 graph: Graph to check. 113 114 Returns: 115 `dict<string, node>` of node name to node. 116 117 Raises: 118 ValueError: If the expected ops are not present in the graph. 119 """ 120 actual_ops = {} 121 gd = graph.as_graph_def() 122 for node in gd.node: 123 if node.name in expected_ops: 124 if expected_ops[node.name] != node.op: 125 raise ValueError("Expected op for node %s is different. %s vs %s" % 126 (node.name, expected_ops[node.name], node.op)) 127 actual_ops[node.name] = node 128 if set(expected_ops.keys()) != set(actual_ops.keys()): 129 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 130 (expected_ops.keys(), actual_ops.keys())) 131 return actual_ops 132 133 134@tf_export("test.assert_equal_graph_def", v1=[]) 135def assert_equal_graph_def_v2(expected, actual): 136 """Asserts that two `GraphDef`s are (mostly) the same. 137 138 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 139 nodes, attrs, and control inputs. Node names are used to match up nodes 140 between the graphs, so the naming of nodes must be consistent. This function 141 ignores randomized attribute values that may appear in V2 checkpoints. 142 143 Args: 144 expected: The `GraphDef` we expected. 145 actual: The `GraphDef` we have. 146 147 Raises: 148 AssertionError: If the `GraphDef`s do not match. 149 TypeError: If either argument is not a `GraphDef`. 150 """ 151 assert_equal_graph_def(actual, expected, checkpoint_v2=True) 152 153 154@tf_export(v1=["test.assert_equal_graph_def"]) 155def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False): 156 """Asserts that two `GraphDef`s are (mostly) the same. 157 158 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 159 nodes, attrs, and control inputs. Node names are used to match up nodes 160 between the graphs, so the naming of nodes must be consistent. 161 162 Args: 163 actual: The `GraphDef` we have. 164 expected: The `GraphDef` we expected. 165 checkpoint_v2: boolean determining whether to ignore randomized attribute 166 values that appear in V2 checkpoints. 167 168 Raises: 169 AssertionError: If the `GraphDef`s do not match. 170 TypeError: If either argument is not a `GraphDef`. 171 """ 172 assert_equal_graph_def(actual, expected, checkpoint_v2) 173 174 175def assert_equal_graph_def(actual, expected, checkpoint_v2=False): 176 if not isinstance(actual, graph_pb2.GraphDef): 177 raise TypeError( 178 "Expected tf.GraphDef for actual, got %s" % type(actual).__name__) 179 if not isinstance(expected, graph_pb2.GraphDef): 180 raise TypeError( 181 "Expected tf.GraphDef for expected, got %s" % type(expected).__name__) 182 183 if checkpoint_v2: 184 _strip_checkpoint_v2_randomized(actual) 185 _strip_checkpoint_v2_randomized(expected) 186 187 diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(), 188 expected.SerializeToString()) 189 if diff: 190 raise AssertionError(compat.as_str(diff)) 191 192 193def assert_meta_graph_protos_equal(tester, a, b): 194 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 195 # Carefully check the collection_defs 196 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 197 collection_keys = a.collection_def.keys() 198 for k in collection_keys: 199 a_value = a.collection_def[k] 200 b_value = b.collection_def[k] 201 proto_type = ops.get_collection_proto_type(k) 202 if proto_type: 203 a_proto = proto_type() 204 b_proto = proto_type() 205 # Number of entries in the collections is the same 206 tester.assertEqual( 207 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 208 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 209 b_value.bytes_list.value): 210 a_proto.ParseFromString(a_value_item) 211 b_proto.ParseFromString(b_value_item) 212 tester.assertProtoEquals(a_proto, b_proto) 213 else: 214 tester.assertEquals(a_value, b_value) 215 # Compared the fields directly, remove their raw values from the 216 # proto comparison below. 217 a.ClearField("collection_def") 218 b.ClearField("collection_def") 219 220 # Check the graph_defs. 221 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 222 # Check graph_def versions (ignored by assert_equal_graph_def). 223 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 224 # Compared the fields directly, remove their raw values from the 225 # proto comparison below. 226 a.ClearField("graph_def") 227 b.ClearField("graph_def") 228 229 tester.assertProtoEquals(a, b) 230 231 232# Matches attributes named via _SHARDED_SUFFIX in 233# tensorflow/python/training/saver.py 234_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 235 236 237def _strip_checkpoint_v2_randomized(graph_def): 238 for node in graph_def.node: 239 delete_keys = [] 240 for attr_key in node.attr: 241 attr_tensor_value = node.attr[attr_key].tensor 242 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 243 attr_tensor_string_value = attr_tensor_value.string_val[0] 244 if (attr_tensor_string_value and 245 re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): 246 delete_keys.append(attr_key) 247 for attr_key in delete_keys: 248 del node.attr[attr_key] 249 250 251def IsGoogleCudaEnabled(): 252 return pywrap_tensorflow.IsGoogleCudaEnabled() 253 254 255def CudaSupportsHalfMatMulAndConv(): 256 return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv() 257 258 259def IsMklEnabled(): 260 return pywrap_tensorflow.IsMklEnabled() 261 262 263def InstallStackTraceHandler(): 264 pywrap_tensorflow.InstallStacktraceHandler() 265 266 267def NHWCToNCHW(input_tensor): 268 """Converts the input from the NHWC format to NCHW. 269 270 Args: 271 input_tensor: a 4- or 5-D tensor, or an array representing shape 272 273 Returns: 274 converted tensor or shape array 275 """ 276 # tensor dim -> new axis order 277 new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 278 if isinstance(input_tensor, ops.Tensor): 279 ndims = input_tensor.shape.ndims 280 return array_ops.transpose(input_tensor, new_axes[ndims]) 281 else: 282 ndims = len(input_tensor) 283 return [input_tensor[a] for a in new_axes[ndims]] 284 285 286def NHWCToNCHW_VECT_C(input_shape_or_tensor): 287 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 288 289 Note: Does not include quantization or type conversion steps, which should 290 be applied afterwards. 291 292 Args: 293 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 294 295 Returns: 296 tensor or shape array transformed into NCHW_VECT_C 297 298 Raises: 299 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 300 divisible by 4. 301 """ 302 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 303 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 304 temp_shape = ( 305 input_shape_or_tensor.shape.as_list() 306 if is_tensor else input_shape_or_tensor) 307 if temp_shape[-1] % 4 != 0: 308 raise ValueError( 309 "Last dimension of input must be evenly divisible by 4 to convert to " 310 "NCHW_VECT_C.") 311 temp_shape[-1] //= 4 312 temp_shape.append(4) 313 permutation = permutations[len(temp_shape)] 314 if is_tensor: 315 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 316 return array_ops.transpose(t, permutation) 317 else: 318 return [temp_shape[a] for a in permutation] 319 320 321def NCHW_VECT_CToNHWC(input_shape_or_tensor): 322 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 323 324 Note: Does not include de-quantization or type conversion steps, which should 325 be applied beforehand. 326 327 Args: 328 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 329 330 Returns: 331 tensor or shape array transformed into NHWC 332 333 Raises: 334 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 335 """ 336 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 337 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 338 input_shape = ( 339 input_shape_or_tensor.shape.as_list() 340 if is_tensor else input_shape_or_tensor) 341 if input_shape[-1] != 4: 342 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 343 permutation = permutations[len(input_shape)] 344 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 345 nhwc_shape[-1] *= input_shape[-1] 346 if is_tensor: 347 t = array_ops.transpose(input_shape_or_tensor, permutation) 348 return array_ops.reshape(t, nhwc_shape) 349 else: 350 return nhwc_shape 351 352 353def NCHWToNHWC(input_tensor): 354 """Converts the input from the NCHW format to NHWC. 355 356 Args: 357 input_tensor: a 4- or 5-D tensor, or an array representing shape 358 359 Returns: 360 converted tensor or shape array 361 """ 362 # tensor dim -> new axis order 363 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 364 if isinstance(input_tensor, ops.Tensor): 365 ndims = input_tensor.shape.ndims 366 return array_ops.transpose(input_tensor, new_axes[ndims]) 367 else: 368 ndims = len(input_tensor) 369 return [input_tensor[a] for a in new_axes[ndims]] 370 371 372def skip_if(condition): 373 """Skips the decorated function if condition is or evaluates to True. 374 375 Args: 376 condition: Either an expression that can be used in "if not condition" 377 statement, or a callable whose result should be a boolean. 378 379 Returns: 380 The wrapped function 381 """ 382 383 def real_skip_if(fn): 384 385 def wrapper(*args, **kwargs): 386 if callable(condition): 387 skip = condition() 388 else: 389 skip = condition 390 if not skip: 391 return fn(*args, **kwargs) 392 393 return wrapper 394 395 return real_skip_if 396 397 398def enable_c_shapes(fn): 399 """No-op. TODO(b/74620627): Remove this.""" 400 return fn 401 402 403def with_c_shapes(cls): 404 """No-op. TODO(b/74620627): Remove this.""" 405 return cls 406 407 408def enable_control_flow_v2(fn): 409 """Decorator for enabling CondV2 and WhileV2 on a test. 410 411 Note this enables using CondV2 and WhileV2 after running the test class's 412 setup/teardown methods. 413 414 In addition to this, callers must import the while_v2 module in order to set 415 the _while_v2 module in control_flow_ops. 416 417 Args: 418 fn: the function to be wrapped 419 420 Returns: 421 The wrapped function 422 """ 423 424 def wrapper(*args, **kwargs): 425 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 426 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 427 try: 428 return fn(*args, **kwargs) 429 finally: 430 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 431 432 return wrapper 433 434 435def with_control_flow_v2(cls): 436 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 437 438 Note this enables CondV2 and WhileV2 in new methods after running the test 439 class's setup method. 440 441 In addition to this, callers must import the while_v2 module in order to set 442 the _while_v2 module in control_flow_ops. 443 444 If a test function has _disable_control_flow_v2 attr set to True (using the 445 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 446 447 Example: 448 449 @test_util.with_control_flow_v2 450 class ControlFlowTest(test.TestCase): 451 452 def testEnabledForV2(self): 453 ... 454 455 @test_util.disable_control_flow_v2("b/xyzabc") 456 def testDisabledForV2(self): 457 ... 458 459 Generated class: 460 class ControlFlowTest(test.TestCase): 461 462 def testEnabledForV2(self): 463 ... 464 465 def testEnabledForV2WithControlFlowV2(self): 466 // Enable V2 flags. 467 testEnabledForV2(self) 468 // Restore V2 flags. 469 470 def testDisabledForV2(self): 471 ... 472 473 Args: 474 cls: class to decorate 475 476 Returns: 477 cls with new test methods added 478 """ 479 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 480 return cls 481 482 for name, value in cls.__dict__.copy().items(): 483 if (callable(value) and 484 name.startswith(unittest.TestLoader.testMethodPrefix) and 485 not getattr(value, "_disable_control_flow_v2", False)): 486 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 487 return cls 488 489 490def disable_control_flow_v2(unused_msg): 491 """Decorator for a function in a with_control_flow_v2 enabled test class. 492 493 Blocks the function from being run with v2 control flow ops. 494 495 Args: 496 unused_msg: Reason for disabling. 497 498 Returns: 499 The wrapped function with _disable_control_flow_v2 attr set to True. 500 """ 501 502 def wrapper(func): 503 func._disable_control_flow_v2 = True 504 return func 505 506 return wrapper 507 508 509def assert_no_new_pyobjects_executing_eagerly(f): 510 """Decorator for asserting that no new Python objects persist after a test. 511 512 Runs the test multiple times executing eagerly, first as a warmup and then to 513 let objects accumulate. The warmup helps ignore caches which do not grow as 514 the test is run repeatedly. 515 516 Useful for checking that there are no missing Py_DECREFs in the C exercised by 517 a bit of Python. 518 """ 519 520 def decorator(self, **kwargs): 521 """Warms up, gets an object count, runs the test, checks for new objects.""" 522 with context.eager_mode(): 523 gc.disable() 524 # Run the test 2 times as warmup, in an attempt to fill up caches, which 525 # should not grow as the test is run repeatedly below. 526 # 527 # TODO(b/117156879): Running warmup twice is black magic; we have seen 528 # tests that fail with 1 warmup run, and pass with 2, on various versions 529 # of python2.7.x. 530 for _ in range(2): 531 f(self, **kwargs) 532 gc.collect() 533 previous_count = len(gc.get_objects()) 534 if ops.has_default_graph(): 535 collection_sizes_before = { 536 collection: len(ops.get_collection(collection)) 537 for collection in ops.get_default_graph().collections 538 } 539 for _ in range(3): 540 f(self, **kwargs) 541 # Note that gc.get_objects misses anything that isn't subject to garbage 542 # collection (C types). Collections are a common source of leaks, so we 543 # test for collection sizes explicitly. 544 if ops.has_default_graph(): 545 for collection_key in ops.get_default_graph().collections: 546 collection = ops.get_collection(collection_key) 547 size_before = collection_sizes_before.get(collection_key, 0) 548 if len(collection) > size_before: 549 raise AssertionError( 550 ("Collection %s increased in size from " 551 "%d to %d (current items %s).") % 552 (collection_key, size_before, len(collection), collection)) 553 # Make sure our collection checks don't show up as leaked memory by 554 # removing references to temporary variables. 555 del collection 556 del collection_key 557 del size_before 558 del collection_sizes_before 559 gc.collect() 560 # There should be no new Python objects hanging around. 561 new_count = len(gc.get_objects()) 562 # In some cases (specifacally on MacOS), new_count is somehow 563 # smaller than previous_count. 564 # Using plain assert because not all classes using this decorator 565 # have assertLessEqual 566 assert new_count <= previous_count, ( 567 "new_count(%d) is not less than or equal to previous_count(%d)" % 568 (new_count, previous_count)) 569 gc.enable() 570 571 return decorator 572 573 574def assert_no_new_tensors(f): 575 """Decorator for asserting that no new Tensors persist after a test. 576 577 Mainly useful for checking that code using the Python C API has correctly 578 manipulated reference counts. 579 580 Clears the caches that it knows about, runs the garbage collector, then checks 581 that there are no Tensor or Tensor-like objects still around. This includes 582 Tensors to which something still has a reference (e.g. from missing 583 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 584 of the objects has __del__ defined). 585 586 Args: 587 f: The test case to run. 588 589 Returns: 590 The decorated test case. 591 """ 592 593 def decorator(self, **kwargs): 594 """Finds existing Tensors, runs the test, checks for new Tensors.""" 595 596 def _is_tensorflow_object(obj): 597 try: 598 return isinstance(obj, 599 (ops.Tensor, variables.Variable, 600 tensor_shape.Dimension, tensor_shape.TensorShape)) 601 except ReferenceError: 602 # If the object no longer exists, we don't care about it. 603 return False 604 605 tensors_before = set( 606 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 607 outside_executed_eagerly = context.executing_eagerly() 608 # Run the test in a new graph so that collections get cleared when it's 609 # done, but inherit the graph key so optimizers behave. 610 outside_graph_key = ops.get_default_graph()._graph_key 611 with ops.Graph().as_default(): 612 ops.get_default_graph()._graph_key = outside_graph_key 613 if outside_executed_eagerly: 614 with context.eager_mode(): 615 result = f(self, **kwargs) 616 else: 617 result = f(self, **kwargs) 618 # Make an effort to clear caches, which would otherwise look like leaked 619 # Tensors. 620 context.context()._clear_caches() # pylint: disable=protected-access 621 gc.collect() 622 tensors_after = [ 623 obj for obj in gc.get_objects() 624 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 625 ] 626 if tensors_after: 627 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 628 len(tensors_after), 629 str(tensors_after), 630 ))) 631 return result 632 633 return decorator 634 635 636def _find_reference_cycle(objects, idx): 637 638 def get_ignore_reason(obj, blacklist): 639 """Tests whether an object should be omitted from the dependency graph.""" 640 if len(blacklist) > 100: 641 return "<depth limit>" 642 if tf_inspect.isframe(obj): 643 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 644 return "<test code>" 645 for b in blacklist: 646 if b is obj: 647 return "<test code>" 648 if obj is blacklist: 649 return "<test code>" 650 return None 651 652 # Note: this function is meant to help with diagnostics. Its output is purely 653 # a human-readable representation, so you may freely modify it to suit your 654 # needs. 655 def describe(obj, blacklist, leaves_only=False): 656 """Returns a custom human-readable summary of obj. 657 658 Args: 659 obj: the value to describe. 660 blacklist: same as blacklist in get_ignore_reason. 661 leaves_only: boolean flag used when calling describe recursively. Useful 662 for summarizing collections. 663 """ 664 if get_ignore_reason(obj, blacklist): 665 return "{}{}".format(get_ignore_reason(obj, blacklist), type(obj)) 666 if tf_inspect.isframe(obj): 667 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 668 elif tf_inspect.ismodule(obj): 669 return "module: {}".format(obj.__name__) 670 else: 671 if leaves_only: 672 return "{}, {}".format(type(obj), id(obj)) 673 elif isinstance(obj, list): 674 return "list({}): {}".format( 675 id(obj), [describe(e, blacklist, leaves_only=True) for e in obj]) 676 elif isinstance(obj, tuple): 677 return "tuple({}): {}".format( 678 id(obj), [describe(e, blacklist, leaves_only=True) for e in obj]) 679 elif isinstance(obj, dict): 680 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 681 elif tf_inspect.isfunction(obj): 682 return "function({}) {}; globals ID: {}".format( 683 id(obj), obj.__name__, id(obj.__globals__)) 684 else: 685 return "{}, {}".format(type(obj), id(obj)) 686 687 def build_ref_graph(obj, graph, reprs, blacklist): 688 """Builds a reference graph as <referrer> -> <list of refferents>. 689 690 Args: 691 obj: The object to start from. The graph will be built by recursively 692 adding its referrers. 693 graph: Dict holding the graph to be built. To avoid creating extra 694 references, the graph holds object IDs rather than actual objects. 695 reprs: Auxiliary structure that maps object IDs to their human-readable 696 description. 697 blacklist: List of objects to ignore. 698 """ 699 referrers = gc.get_referrers(obj) 700 blacklist = blacklist + (referrers,) 701 702 obj_id = id(obj) 703 for r in referrers: 704 if get_ignore_reason(r, blacklist) is None: 705 r_id = id(r) 706 if r_id not in graph: 707 graph[r_id] = [] 708 if obj_id not in graph[r_id]: 709 graph[r_id].append(obj_id) 710 build_ref_graph(r, graph, reprs, blacklist) 711 reprs[r_id] = describe(r, blacklist) 712 713 def find_cycle(el, graph, reprs, path): 714 """Finds and prints a single cycle in the dependency graph.""" 715 if el not in graph: 716 return 717 for r in graph[el]: 718 if r in path: 719 logging.error("Reference cycle sample:") 720 for p in path + (r,): 721 logging.error(reprs.get(p, "unknown object " + str(p))) 722 return True 723 else: 724 if find_cycle(r, graph, reprs, path + (r,)): 725 return True 726 return False 727 728 obj = objects[idx] 729 graph = {} # referrer ID -> object ID 730 reprs = {} # object ID -> description 731 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 732 describe, build_ref_graph, find_cycle)) 733 for k in graph: 734 if find_cycle(k, graph, reprs, ()): 735 return True 736 return False 737 738 739def assert_no_garbage_created(f): 740 """Test method decorator to assert that no garbage has been created. 741 742 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 743 cannot be un-set (i.e. will disable garbage collection for any other unit 744 tests in the same file/shard). 745 746 Args: 747 f: The function to decorate. 748 749 Returns: 750 The decorated function. 751 """ 752 753 def decorator(self, **kwargs): 754 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 755 # Force-load `distribution_strategy_context` to prevent GC at 756 # test time when using eager. Remove once b/117329403 is resolved. 757 tape.distribution_strategy_context.get_strategy() 758 759 gc.disable() 760 previous_debug_flags = gc.get_debug() 761 gc.set_debug(gc.DEBUG_SAVEALL) 762 gc.collect() 763 previous_garbage = len(gc.garbage) 764 result = f(self, **kwargs) 765 gc.collect() 766 new_garbage = len(gc.garbage) 767 if new_garbage > previous_garbage: 768 logging.error( 769 "The decorated test created work for Python's garbage collector, " 770 "likely due to a reference cycle. New objects in cycle(s):") 771 for i, obj in enumerate(gc.garbage[previous_garbage:]): 772 try: 773 logging.error("Object %d of %d", i, 774 len(gc.garbage) - previous_garbage) 775 776 def _safe_object_str(obj): 777 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 778 779 logging.error(" Object type: %s", _safe_object_str(obj)) 780 logging.error( 781 " Referrer types: %s", ", ".join( 782 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 783 logging.error( 784 " Referent types: %s", ", ".join( 785 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 786 logging.error(" Object attribute names: %s", dir(obj)) 787 logging.error(" Object __str__:") 788 logging.error(obj) 789 logging.error(" Object __repr__:") 790 logging.error(repr(obj)) 791 except Exception: # pylint: disable=broad-except 792 logging.error("(Exception while printing object)") 793 794 # When garbage is created, this call can help identify reference cycles, 795 # which are typically the cause of such garbage. 796 if new_garbage > previous_garbage: 797 for i in range(previous_garbage, new_garbage): 798 if _find_reference_cycle(gc.garbage, i): 799 break 800 801 # This will fail if any garbage has been created, typically because of a 802 # reference cycle. 803 self.assertEqual(previous_garbage, new_garbage) 804 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 805 # be nice to be able to decorate arbitrary tests in a large test suite and 806 # not hold on to every object in other tests. 807 gc.set_debug(previous_debug_flags) 808 gc.enable() 809 return result 810 811 return decorator 812 813 814def _combine_named_parameters(**kwargs): 815 """Generate combinations based on its keyword arguments. 816 817 Two sets of returned combinations can be concatenated using +. Their product 818 can be computed using `times()`. 819 820 Args: 821 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 822 `option=the_only_possibility`. 823 824 Returns: 825 a list of dictionaries for each combination. Keys in the dictionaries are 826 the keyword argument names. Each key has one value - one of the 827 corresponding keyword argument values. 828 """ 829 if not kwargs: 830 return [OrderedDict()] 831 832 sort_by_key = lambda k: k[0][0] 833 kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) 834 first = list(kwargs.items())[0] 835 836 rest = dict(list(kwargs.items())[1:]) 837 rest_combined = _combine_named_parameters(**rest) 838 839 key = first[0] 840 values = first[1] 841 if not isinstance(values, list): 842 values = [values] 843 844 combinations = [ 845 OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) 846 for v in values 847 for combined in rest_combined 848 ] 849 return combinations 850 851 852def generate_combinations_with_testcase_name(**kwargs): 853 """Generate combinations based on its keyword arguments using combine(). 854 855 This function calls combine() and appends a testcase name to the list of 856 dictionaries returned. The 'testcase_name' key is a required for named 857 parameterized tests. 858 859 Args: 860 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 861 `option=the_only_possibility`. 862 863 Returns: 864 a list of dictionaries for each combination. Keys in the dictionaries are 865 the keyword argument names. Each key has one value - one of the 866 corresponding keyword argument values. 867 """ 868 combinations = _combine_named_parameters(**kwargs) 869 named_combinations = [] 870 for combination in combinations: 871 assert isinstance(combination, OrderedDict) 872 name = "".join([ 873 "_{}_{}".format("".join(filter(str.isalnum, key)), "".join( 874 filter(str.isalnum, str(value)))) 875 for key, value in combination.items() 876 ]) 877 named_combinations.append( 878 OrderedDict( 879 list(combination.items()) + [("testcase_name", 880 "_test{}".format(name))])) 881 882 return named_combinations 883 884 885def run_all_in_graph_and_eager_modes(cls): 886 """Execute all test methods in the given class with and without eager.""" 887 base_decorator = run_in_graph_and_eager_modes 888 for name, value in cls.__dict__.copy().items(): 889 if callable(value) and name.startswith( 890 unittest.TestLoader.testMethodPrefix) and not ( 891 name.startswith("testSkipEager") or 892 name.startswith("test_skip_eager") or name == "test_session"): 893 setattr(cls, name, base_decorator(value)) 894 return cls 895 896 897def run_in_graph_and_eager_modes(func=None, 898 config=None, 899 use_gpu=True, 900 reset_test=True, 901 assert_no_eager_garbage=False): 902 """Execute the decorated test with and without enabling eager execution. 903 904 This function returns a decorator intended to be applied to test methods in 905 a `tf.test.TestCase` class. Doing so will cause the contents of the test 906 method to be executed twice - once normally, and once with eager execution 907 enabled. This allows unittests to confirm the equivalence between eager 908 and graph execution (see `tf.enable_eager_execution`). 909 910 For example, consider the following unittest: 911 912 ```python 913 class MyTests(tf.test.TestCase): 914 915 @run_in_graph_and_eager_modes 916 def test_foo(self): 917 x = tf.constant([1, 2]) 918 y = tf.constant([3, 4]) 919 z = tf.add(x, y) 920 self.assertAllEqual([4, 6], self.evaluate(z)) 921 922 if __name__ == "__main__": 923 tf.test.main() 924 ``` 925 926 This test validates that `tf.add()` has the same behavior when computed with 927 eager execution enabled as it does when constructing a TensorFlow graph and 928 executing the `z` tensor in a session. 929 930 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 931 `run_in_graph_and_eager_modes` are available decorators for different 932 v1/v2/eager/graph combinations. 933 934 935 Args: 936 func: function to be annotated. If `func` is None, this method returns a 937 decorator the can be applied to a function. If `func` is not None this 938 returns the decorator applied to `func`. 939 config: An optional config_pb2.ConfigProto to use to configure the session 940 when executing graphs. 941 use_gpu: If True, attempt to run as many operations as possible on GPU. 942 reset_test: If True, tearDown and SetUp the test case between the two 943 executions of the test (once with and once without eager execution). 944 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 945 collector and asserts that no extra garbage has been created when running 946 the test with eager execution enabled. This will fail if there are 947 reference cycles (e.g. a = []; a.append(a)). Off by default because some 948 tests may create garbage for legitimate reasons (e.g. they define a class 949 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 950 Python interpreters (meaning that tests which rely on objects being 951 collected elsewhere in the unit test file will not work). Additionally, 952 checks that nothing still has a reference to Tensors that the test 953 allocated. 954 955 Returns: 956 Returns a decorator that will run the decorated test method twice: 957 once by constructing and executing a graph in a session and once with 958 eager execution enabled. 959 """ 960 961 def decorator(f): 962 if tf_inspect.isclass(f): 963 raise ValueError( 964 "`run_in_graph_and_eager_modes` only supports test methods. " 965 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 966 967 def decorated(self, *args, **kwargs): 968 try: 969 with context.graph_mode(): 970 with self.test_session(use_gpu=use_gpu, config=config): 971 f(self, *args, **kwargs) 972 except unittest.case.SkipTest: 973 pass 974 975 def run_eagerly(self, **kwargs): 976 if not use_gpu: 977 with ops.device("/device:CPU:0"): 978 f(self, *args, **kwargs) 979 else: 980 f(self, *args, **kwargs) 981 982 if assert_no_eager_garbage: 983 ops.reset_default_graph() 984 run_eagerly = assert_no_new_tensors( 985 assert_no_garbage_created(run_eagerly)) 986 987 if reset_test: 988 # This decorator runs the wrapped test twice. 989 # Reset the test environment between runs. 990 self.tearDown() 991 self._tempdir = None 992 # Create a new graph for the eagerly executed version of this test for 993 # better isolation. 994 graph_for_eager_test = ops.Graph() 995 with graph_for_eager_test.as_default(), context.eager_mode(): 996 if reset_test: 997 self.setUp() 998 run_eagerly(self, **kwargs) 999 ops.dismantle_graph(graph_for_eager_test) 1000 1001 return decorated 1002 1003 if func is not None: 1004 return decorator(func) 1005 1006 return decorator 1007 1008 1009def py_func_if_in_function(f): 1010 1011 def decorated(*args, **kwds): 1012 if not ops.get_default_graph()._building_function: 1013 return f(*args, **kwds) 1014 1015 tensor_args = [] 1016 tensor_indices = [] 1017 for i, arg in enumerate(args): 1018 if isinstance(arg, (ops.Tensor, variables.Variable)): 1019 tensor_args.append(arg) 1020 tensor_indices.append(i) 1021 1022 def inner_f(*inner_tensor_args): 1023 my_args = list(args) 1024 for i, n in zip(tensor_indices, inner_tensor_args): 1025 my_args[i] = n 1026 return f(*my_args, **kwds) 1027 1028 return script_ops.py_func(inner_f, tensor_args, []) 1029 1030 return tf_decorator.make_decorator(f, decorated) 1031 1032 1033def also_run_as_tf_function(f): 1034 """Runs the decorated test twice--once as is, once inside a tf.function. 1035 1036 This allows you to run a test both in eager execution and inside a 1037 tf.function, exercising the two execution modes supported in tf 2.0. The test 1038 assertions are automatically done inside tf.py_funcs, and tf.function ensures 1039 that they run in the proper order and with the proper side effects. 1040 1041 Currently variable creation is not supported in tests annotated with this 1042 decorator since it's tricky to ensure the variable doesn't get repeatedly 1043 created when retracing the tf.function. 1044 1045 Args: 1046 f: the test method to be decorated 1047 1048 Returns: 1049 The decorated test method, which will run both in eager and inside a 1050 tf.function. 1051 """ 1052 1053 def decorated(*args, **kwds): 1054 def bound_f(): 1055 f(*args, **kwds) 1056 with context.eager_mode(): 1057 # Running in eager mode 1058 bound_f() 1059 # Running as TF function 1060 # TODO(b/121143941): Remove the autograph override. 1061 def_function.function(bound_f, autograph=False)() 1062 1063 return decorated 1064 1065 1066def deprecated_graph_mode_only(func=None): 1067 """Execute the decorated test in graph mode. 1068 1069 This function returns a decorator intended to be applied to tests that are not 1070 compatible with eager mode. When this decorator is applied, the test body will 1071 be run in an environment where API calls construct graphs instead of executing 1072 eagerly. 1073 1074 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1075 `run_in_graph_and_eager_modes` are available decorators for different 1076 v1/v2/eager/graph combinations. 1077 1078 Args: 1079 func: function to be annotated. If `func` is None, this method returns a 1080 decorator the can be applied to a function. If `func` is not None this 1081 returns the decorator applied to `func`. 1082 1083 Returns: 1084 Returns a decorator that will run the decorated test method in graph mode. 1085 """ 1086 1087 def decorator(f): 1088 if tf_inspect.isclass(f): 1089 setup = f.__dict__.get("setUp") 1090 if setup is not None: 1091 setattr(f, "setUp", decorator(setup)) 1092 1093 for name, value in f.__dict__.copy().items(): 1094 if (callable(value) and 1095 name.startswith(unittest.TestLoader.testMethodPrefix)): 1096 setattr(f, name, decorator(value)) 1097 1098 return f 1099 1100 def decorated(self, *args, **kwargs): 1101 if tf2.enabled(): 1102 with context.graph_mode(): 1103 return f(self, *args, **kwargs) 1104 else: 1105 return f(self, *args, **kwargs) 1106 1107 return decorated 1108 1109 if func is not None: 1110 return decorator(func) 1111 1112 return decorator 1113 1114 1115run_deprecated_v1 = deprecated_graph_mode_only 1116 1117 1118def run_v1_only(reason, func=None): 1119 """Execute the decorated test only if running in v1 mode. 1120 1121 This function is intended to be applied to tests that exercise v1 only 1122 functionality. If the test is run in v2 mode it will simply be skipped. 1123 1124 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1125 `run_in_graph_and_eager_modes` are available decorators for different 1126 v1/v2/eager/graph combinations. 1127 1128 Args: 1129 reason: string giving a reason for limiting the test to v1 only. 1130 func: function to be annotated. If `func` is None, this method returns a 1131 decorator the can be applied to a function. If `func` is not None this 1132 returns the decorator applied to `func`. 1133 1134 Returns: 1135 Returns a decorator that will conditionally skip the decorated test method. 1136 """ 1137 1138 def decorator(f): 1139 if tf_inspect.isclass(f): 1140 setup = f.__dict__.get("setUp") 1141 if setup is not None: 1142 setattr(f, "setUp", decorator(setup)) 1143 1144 for name, value in f.__dict__.copy().items(): 1145 if (callable(value) and 1146 name.startswith(unittest.TestLoader.testMethodPrefix)): 1147 setattr(f, name, decorator(value)) 1148 1149 return f 1150 1151 def decorated(self, *args, **kwargs): 1152 if tf2.enabled(): 1153 self.skipTest(reason) 1154 1155 return f(self, *args, **kwargs) 1156 1157 return decorated 1158 1159 if func is not None: 1160 return decorator(func) 1161 1162 return decorator 1163 1164 1165def run_v2_only(func=None): 1166 """Execute the decorated test only if running in v2 mode. 1167 1168 This function is intended to be applied to tests that exercise v2 only 1169 functionality. If the test is run in v1 mode it will simply be skipped. 1170 1171 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 1172 `run_in_graph_and_eager_modes` are available decorators for different 1173 v1/v2/eager/graph combinations. 1174 1175 Args: 1176 func: function to be annotated. If `func` is None, this method returns a 1177 decorator the can be applied to a function. If `func` is not None this 1178 returns the decorator applied to `func`. 1179 1180 Returns: 1181 Returns a decorator that will conditionally skip the decorated test method. 1182 """ 1183 1184 def decorator(f): 1185 if tf_inspect.isclass(f): 1186 raise ValueError("`run_v2_only` only supports test methods.") 1187 1188 def decorated(self, *args, **kwargs): 1189 if not tf2.enabled(): 1190 self.skipTest("Test is only comptaible in v2") 1191 1192 return f(self, *args, **kwargs) 1193 1194 return decorated 1195 1196 if func is not None: 1197 return decorator(func) 1198 1199 return decorator 1200 1201 1202def run_gpu_only(func=None): 1203 """Execute the decorated test only if a GPU is available. 1204 1205 This function is intended to be applied to tests that require the presence 1206 of a GPU. If a GPU is absent, it will simply be skipped. 1207 1208 Args: 1209 func: function to be annotated. If `func` is None, this method returns a 1210 decorator the can be applied to a function. If `func` is not None this 1211 returns the decorator applied to `func`. 1212 1213 Returns: 1214 Returns a decorator that will conditionally skip the decorated test method. 1215 """ 1216 1217 def decorator(f): 1218 if tf_inspect.isclass(f): 1219 raise ValueError("`run_gpu_only` only supports test methods.") 1220 1221 def decorated(self, *args, **kwargs): 1222 if not is_gpu_available(): 1223 self.skipTest("Test requires GPU") 1224 1225 return f(self, *args, **kwargs) 1226 1227 return decorated 1228 1229 if func is not None: 1230 return decorator(func) 1231 1232 return decorator 1233 1234 1235def run_cuda_only(func=None): 1236 """Execute the decorated test only if a GPU is available. 1237 1238 This function is intended to be applied to tests that require the precense 1239 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 1240 1241 Args: 1242 func: function to be annotated. If `func` is None, this method returns a 1243 decorator the can be applied to a function. If `func` is not None this 1244 returns the decorator applied to `func`. 1245 1246 Returns: 1247 Returns a decorator that will conditionally skip the decorated test method. 1248 """ 1249 1250 def decorator(f): 1251 if tf_inspect.isclass(f): 1252 raise ValueError("`run_cuda_only` only supports test methods.") 1253 1254 def decorated(self, *args, **kwargs): 1255 if not is_gpu_available(cuda_only=True): 1256 self.skipTest("Test requires CUDA GPU") 1257 1258 return f(self, *args, **kwargs) 1259 1260 return decorated 1261 1262 if func is not None: 1263 return decorator(func) 1264 1265 return decorator 1266 1267 1268@tf_export("test.is_gpu_available") 1269def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 1270 """Returns whether TensorFlow can access a GPU. 1271 1272 Args: 1273 cuda_only: limit the search to CUDA gpus. 1274 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 1275 CUDA compute capability required, or None if no requirement. 1276 1277 Returns: 1278 True if a gpu device of the requested kind is available. 1279 """ 1280 1281 def compute_capability_from_device_desc(device_desc): 1282 # TODO(jingyue): The device description generator has to be in sync with 1283 # this file. Another option is to put compute capability in 1284 # DeviceAttributes, but I avoided that to keep DeviceAttributes 1285 # target-independent. Reconsider this option when we have more things like 1286 # this to keep in sync. 1287 # LINT.IfChange 1288 match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc) 1289 # LINT.ThenChange(//tensorflow/core/\ 1290 # common_runtime/gpu/gpu_device.cc) 1291 if not match: 1292 return 0, 0 1293 return int(match.group(1)), int(match.group(2)) 1294 1295 try: 1296 for local_device in device_lib.list_local_devices(): 1297 if local_device.device_type == "GPU": 1298 if (min_cuda_compute_capability is None or 1299 compute_capability_from_device_desc( 1300 local_device.physical_device_desc) >= 1301 min_cuda_compute_capability): 1302 return True 1303 if local_device.device_type == "SYCL" and not cuda_only: 1304 return True 1305 return False 1306 except errors_impl.NotFoundError as e: 1307 if not all(x in str(e) for x in ["CUDA", "not find"]): 1308 raise e 1309 else: 1310 logging.error(str(e)) 1311 return False 1312 1313 1314@contextlib.contextmanager 1315def device(use_gpu): 1316 """Uses gpu when requested and available.""" 1317 if use_gpu and is_gpu_available(): 1318 dev = "/device:GPU:0" 1319 else: 1320 dev = "/device:CPU:0" 1321 with ops.device(dev): 1322 yield 1323 1324 1325@contextlib.contextmanager 1326def use_gpu(): 1327 """Uses gpu when requested and available.""" 1328 with device(use_gpu=True): 1329 yield 1330 1331 1332@contextlib.contextmanager 1333def force_gpu(): 1334 """Force the gpu to be used.""" 1335 with ops.device("/device:GPU:0"): 1336 yield 1337 1338 1339@contextlib.contextmanager 1340def force_cpu(): 1341 """Force the cpu to be used.""" 1342 with ops.device("/device:CPU:0"): 1343 yield 1344 1345 1346class CapturedWrites(object): 1347 """A utility class to load the captured writes made to a stream.""" 1348 1349 def __init__(self, capture_location): 1350 self.capture_location = capture_location 1351 1352 def contents(self): 1353 """Get the captured writes as a single string.""" 1354 with open(self.capture_location) as tmp_file: 1355 output_data = "".join(tmp_file.readlines()) 1356 return output_data 1357 1358 1359class FakeEagerSession(object): 1360 """Fake session so tests that conditionally use placeholders can use eager. 1361 1362 There are a number of tests that conditionally use placeholders for shape 1363 inference. The pattern is demonstrated here: 1364 1365 ```python 1366 with self.cached_session() as sess: 1367 if static_shape: 1368 y = math_ops.matmul(x, ...) 1369 feed_dict = {} 1370 else: 1371 x_ph = array_ops.placeholder(...) 1372 y = math_ops.matmul(x_ph, ...) 1373 feed_dict = {x_ph: x} 1374 val = sess.run(y, feed_dict=feed_dict) 1375 ``` 1376 1377 Since the feed_dict is empty when not using placeholders we should be able to 1378 call self.evaluate(), however this requires rewriting the test case. 1379 This class should be considered a stop-gap solution to get tests running with 1380 eager with minimal changes to the actual test. 1381 """ 1382 1383 def __init__(self, test_case): 1384 self._test_case = test_case 1385 1386 def run(self, fetches, *args, **kwargs): 1387 """Evalaute `fetches`. 1388 1389 Fail if additional args are specified. 1390 1391 Args: 1392 fetches: A Tensor or a nested list/tuple of Tensors. 1393 *args: Positional arguments 1394 **kwargs: Keyword arguments 1395 1396 Raises: 1397 RuntimeError: If args or kwargs are specified. 1398 1399 Returns: 1400 Tensors as numpy values. 1401 """ 1402 feed_dict = kwargs.pop("feed_dict", {}) 1403 if feed_dict: 1404 raise RuntimeError( 1405 "feed_dict is not supported when eager execution is enabled " 1406 "(in this case, sess.run(t) is shorthand for t.numpy()") 1407 1408 if args or kwargs: 1409 raise RuntimeError( 1410 "Optional args are not supported when eager execution is enabled " 1411 "(in this case, sess.run(t) is shorthand for t.numpy()") 1412 1413 return self._test_case.evaluate(fetches) 1414 1415 1416class ErrorLoggingSession(session.Session): 1417 """Wrapper around a Session that logs errors in run().""" 1418 1419 def run(self, *args, **kwargs): 1420 try: 1421 return super(ErrorLoggingSession, self).run(*args, **kwargs) 1422 except Exception as e: # pylint: disable=broad-except 1423 # Note: disable the logging for OutOfRangeError, which makes the output 1424 # of tf.data tests hard to read, because OutOfRangeError is used as the 1425 # signal completion 1426 if not isinstance(e, errors.OutOfRangeError): 1427 logging.error(str(e)) 1428 raise 1429 1430 1431def use_deterministic_cudnn(func): 1432 """Disable autotuning during the call to this function. 1433 1434 Some tests want to base assertions on a graph being isomorphic with a copy. 1435 To ensure this, this decorator disables autotuning. 1436 1437 Args: 1438 func: Function to run with CUDNN autotuning turned off. 1439 1440 Returns: 1441 Decorated function. 1442 """ 1443 1444 def decorator(f): 1445 1446 def decorated(self, *args, **kwargs): 1447 original_var = os.environ.get("TF_CUDNN_DETERMINISTIC", "") 1448 os.environ["TF_CUDNN_DETERMINISTIC"] = "true" 1449 result = f(self, *args, **kwargs) 1450 os.environ["TF_CUDNN_DETERMINISTIC"] = original_var 1451 return result 1452 1453 return decorated 1454 1455 if func is not None: 1456 return decorator(func) 1457 1458 return decorator 1459 1460 1461# The description is just for documentation purposes. 1462def disable_xla(description): 1463 1464 def disable_xla_impl(func): 1465 """Execute the test method only if xla is not enabled.""" 1466 1467 def decorator(func): 1468 1469 def decorated(self, *args, **kwargs): 1470 if is_xla_enabled(): 1471 return 1472 else: 1473 return func(self, *args, **kwargs) 1474 1475 return decorated 1476 1477 if func is not None: 1478 return decorator(func) 1479 1480 return decorator 1481 1482 return disable_xla_impl 1483 1484 1485# The description is just for documentation purposes. 1486def disable_all_xla(description): 1487 1488 def disable_all_impl(cls): 1489 """Execute all test methods in this class only if xla is not enabled.""" 1490 base_decorator = disable_xla 1491 for name in dir(cls): 1492 value = getattr(cls, name) 1493 if callable(value) and name.startswith( 1494 "test") and not name == "test_session": 1495 setattr(cls, name, base_decorator(description)(value)) 1496 return cls 1497 1498 return disable_all_impl 1499 1500 1501class EagerSessionWarner(object): 1502 1503 def __getattr__(self, attr): 1504 raise AttributeError( 1505 "Trying to access properties or call methods on the result of " 1506 "self.session(), self.cached_session(), etc while eager execution " 1507 "is enabled. If you're porting this test case to TF 2.0, either " 1508 "adapt the test to work with eager execution or insert a call to " 1509 "tf.disable_eager_execution() in the main() function of this test " 1510 "file.") 1511 1512 1513@tf_export("test.TestCase") 1514class TensorFlowTestCase(googletest.TestCase): 1515 """Base class for tests that need to test TensorFlow.""" 1516 1517 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 1518 super(TensorFlowTestCase, self).__init__(methodName) 1519 if is_xla_enabled(): 1520 os.putenv( 1521 "TF_XLA_FLAGS", "--tf_xla_auto_jit=2 --tf_xla_min_cluster_size=1 " 1522 "--tf_xla_enable_lazy_compilation=false " + 1523 os.getenv("TF_XLA_FLAGS", "")) 1524 self._threads = [] 1525 self._tempdir = None 1526 self._cached_session = None 1527 1528 def setUp(self): 1529 self._ClearCachedSession() 1530 random.seed(random_seed.DEFAULT_GRAPH_SEED) 1531 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 1532 # Note: The following line is necessary because some test methods may error 1533 # out from within nested graph contexts (e.g., via assertRaises and 1534 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 1535 # under certain versions of Python. That would cause 1536 # ops.reset_default_graph() to throw an exception if the stack were not 1537 # cleared first. 1538 ops._default_graph_stack.reset() # pylint: disable=protected-access 1539 ops.reset_default_graph() 1540 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 1541 1542 # Avoiding calling setUp() for the poorly named test_session method. 1543 if self.id().endswith(".test_session"): 1544 self.skipTest("Not a test.") 1545 1546 def tearDown(self): 1547 for thread in self._threads: 1548 thread.check_termination() 1549 1550 self._ClearCachedSession() 1551 1552 def _ClearCachedSession(self): 1553 if self._cached_session is not None: 1554 self._cached_session.close() 1555 self._cached_session = None 1556 1557 def get_temp_dir(self): 1558 """Returns a unique temporary directory for the test to use. 1559 1560 If you call this method multiple times during in a test, it will return the 1561 same folder. However, across different runs the directories will be 1562 different. This will ensure that across different runs tests will not be 1563 able to pollute each others environment. 1564 If you need multiple unique directories within a single test, you should 1565 use tempfile.mkdtemp as follows: 1566 tempfile.mkdtemp(dir=self.get_temp_dir()): 1567 1568 Returns: 1569 string, the path to the unique temporary directory created for this test. 1570 """ 1571 if not self._tempdir: 1572 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 1573 return self._tempdir 1574 1575 @contextlib.contextmanager 1576 def captureWritesToStream(self, stream): 1577 """A context manager that captures the writes to a given stream. 1578 1579 This context manager captures all writes to a given stream inside of a 1580 `CapturedWrites` object. When this context manager is created, it yields 1581 the `CapturedWrites` object. The captured contents can be accessed by 1582 calling `.contents()` on the `CapturedWrites`. 1583 1584 For this function to work, the stream must have a file descriptor that 1585 can be modified using `os.dup` and `os.dup2`, and the stream must support 1586 a `.flush()` method. The default python sys.stdout and sys.stderr are 1587 examples of this. Note that this does not work in Colab or Jupyter 1588 notebooks, because those use alternate stdout streams. 1589 1590 Example: 1591 ```python 1592 class MyOperatorTest(test_util.TensorFlowTestCase): 1593 def testMyOperator(self): 1594 input = [1.0, 2.0, 3.0, 4.0, 5.0] 1595 with self.captureWritesToStream(sys.stdout) as captured: 1596 result = MyOperator(input).eval() 1597 self.assertStartsWith(captured.contents(), "This was printed.") 1598 ``` 1599 1600 Args: 1601 stream: The stream whose writes should be captured. This stream must have 1602 a file descriptor, support writing via using that file descriptor, and 1603 must have a `.flush()` method. 1604 1605 Yields: 1606 A `CapturedWrites` object that contains all writes to the specified stream 1607 made during this context. 1608 """ 1609 stream.flush() 1610 fd = stream.fileno() 1611 tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir()) 1612 tmp_file = open(tmp_file_path, "w") 1613 orig_fd = os.dup(fd) 1614 os.dup2(tmp_file.fileno(), fd) 1615 try: 1616 yield CapturedWrites(tmp_file_path) 1617 finally: 1618 tmp_file.close() 1619 os.dup2(orig_fd, fd) 1620 1621 def _AssertProtoEquals(self, a, b, msg=None): 1622 """Asserts that a and b are the same proto. 1623 1624 Uses ProtoEq() first, as it returns correct results 1625 for floating point attributes, and then use assertProtoEqual() 1626 in case of failure as it provides good error messages. 1627 1628 Args: 1629 a: a proto. 1630 b: another proto. 1631 msg: Optional message to report on failure. 1632 """ 1633 if not compare.ProtoEq(a, b): 1634 compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg) 1635 1636 def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None): 1637 """Asserts that message is same as parsed expected_message_ascii. 1638 1639 Creates another prototype of message, reads the ascii message into it and 1640 then compares them using self._AssertProtoEqual(). 1641 1642 Args: 1643 expected_message_maybe_ascii: proto message in original or ascii form. 1644 message: the message to validate. 1645 msg: Optional message to report on failure. 1646 """ 1647 msg = msg if msg else "" 1648 if isinstance(expected_message_maybe_ascii, type(message)): 1649 expected_message = expected_message_maybe_ascii 1650 self._AssertProtoEquals(expected_message, message) 1651 elif isinstance(expected_message_maybe_ascii, str): 1652 expected_message = type(message)() 1653 text_format.Merge( 1654 expected_message_maybe_ascii, 1655 expected_message, 1656 descriptor_pool=descriptor_pool.Default()) 1657 self._AssertProtoEquals(expected_message, message, msg=msg) 1658 else: 1659 assert False, ("Can't compare protos of type %s and %s. %s" % 1660 (type(expected_message_maybe_ascii), type(message), msg)) 1661 1662 def assertProtoEqualsVersion( 1663 self, 1664 expected, 1665 actual, 1666 producer=versions.GRAPH_DEF_VERSION, 1667 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 1668 msg=None): 1669 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 1670 producer, min_consumer, expected) 1671 self.assertProtoEquals(expected, actual, msg=msg) 1672 1673 def assertStartsWith(self, actual, expected_start, msg=None): 1674 """Assert that actual.startswith(expected_start) is True. 1675 1676 Args: 1677 actual: str 1678 expected_start: str 1679 msg: Optional message to report on failure. 1680 """ 1681 if not actual.startswith(expected_start): 1682 fail_msg = "%r does not start with %r" % (actual, expected_start) 1683 fail_msg += " : %r" % (msg) if msg else "" 1684 self.fail(fail_msg) 1685 1686 def _eval_tensor(self, tensor): 1687 if tensor is None: 1688 return None 1689 elif callable(tensor): 1690 return self._eval_helper(tensor()) 1691 else: 1692 try: 1693 if sparse_tensor.is_sparse(tensor): 1694 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 1695 tensor.values.numpy(), 1696 tensor.dense_shape.numpy()) 1697 elif isinstance(tensor, ops.IndexedSlices): 1698 return ops.IndexedSlicesValue(values=tensor.values.numpy(), 1699 indices=tensor.indices.numpy(), 1700 dense_shape=tensor.dense_shape.numpy()) 1701 return tensor.numpy() 1702 except AttributeError as e: 1703 six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) 1704 1705 def _eval_helper(self, tensors): 1706 if tensors is None: 1707 return None 1708 return nest.map_structure(self._eval_tensor, tensors) 1709 1710 def evaluate(self, tensors): 1711 """Evaluates tensors and returns numpy values. 1712 1713 Args: 1714 tensors: A Tensor or a nested list/tuple of Tensors. 1715 1716 Returns: 1717 tensors numpy values. 1718 """ 1719 if context.executing_eagerly(): 1720 return self._eval_helper(tensors) 1721 else: 1722 sess = ops.get_default_session() 1723 if sess is None: 1724 with self.test_session() as sess: 1725 return sess.run(tensors) 1726 else: 1727 return sess.run(tensors) 1728 1729 # pylint: disable=g-doc-return-or-yield 1730 @contextlib.contextmanager 1731 def session(self, graph=None, config=None, use_gpu=False, force_gpu=False): 1732 """Returns a TensorFlow Session for use in executing tests. 1733 1734 Note that this will set this session and the graph as global defaults. 1735 1736 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 1737 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 1738 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 1739 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 1740 the CPU. 1741 1742 Example: 1743 ```python 1744 class MyOperatorTest(test_util.TensorFlowTestCase): 1745 def testMyOperator(self): 1746 with self.session(use_gpu=True): 1747 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 1748 result = MyOperator(valid_input).eval() 1749 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 1750 invalid_input = [-1.0, 2.0, 7.0] 1751 with self.assertRaisesOpError("negative input not supported"): 1752 MyOperator(invalid_input).eval() 1753 ``` 1754 1755 Args: 1756 graph: Optional graph to use during the returned session. 1757 config: An optional config_pb2.ConfigProto to use to configure the 1758 session. 1759 use_gpu: If True, attempt to run as many ops as possible on GPU. 1760 force_gpu: If True, pin all ops to `/device:GPU:0`. 1761 1762 Yields: 1763 A Session object that should be used as a context manager to surround 1764 the graph building and execution code in a test case. 1765 """ 1766 if context.executing_eagerly(): 1767 yield EagerSessionWarner() 1768 else: 1769 with self._create_session(graph, config, force_gpu) as sess: 1770 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 1771 yield sess 1772 1773 @contextlib.contextmanager 1774 def cached_session(self, 1775 graph=None, 1776 config=None, 1777 use_gpu=False, 1778 force_gpu=False): 1779 """Returns a TensorFlow Session for use in executing tests. 1780 1781 This method behaves differently than self.session(): for performance reasons 1782 `cached_session` will by default reuse the same session within the same 1783 test. The session returned by this function will only be closed at the end 1784 of the test (in the TearDown function). 1785 1786 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 1787 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 1788 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 1789 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 1790 the CPU. 1791 1792 Example: 1793 ```python 1794 class MyOperatorTest(test_util.TensorFlowTestCase): 1795 def testMyOperator(self): 1796 with self.cached_session(use_gpu=True) as sess: 1797 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 1798 result = MyOperator(valid_input).eval() 1799 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 1800 invalid_input = [-1.0, 2.0, 7.0] 1801 with self.assertRaisesOpError("negative input not supported"): 1802 MyOperator(invalid_input).eval() 1803 ``` 1804 1805 Args: 1806 graph: Optional graph to use during the returned session. 1807 config: An optional config_pb2.ConfigProto to use to configure the 1808 session. 1809 use_gpu: If True, attempt to run as many ops as possible on GPU. 1810 force_gpu: If True, pin all ops to `/device:GPU:0`. 1811 1812 Yields: 1813 A Session object that should be used as a context manager to surround 1814 the graph building and execution code in a test case. 1815 """ 1816 if context.executing_eagerly(): 1817 yield FakeEagerSession(self) 1818 else: 1819 sess = self._get_cached_session( 1820 graph, config, force_gpu, crash_if_inconsistent_args=True) 1821 with self._constrain_devices_and_set_default(sess, use_gpu, 1822 force_gpu) as cached: 1823 yield cached 1824 1825 @contextlib.contextmanager 1826 @deprecation.deprecated(None, "Use `self.session()` or " 1827 "`self.cached_session()` instead.") 1828 def test_session(self, 1829 graph=None, 1830 config=None, 1831 use_gpu=False, 1832 force_gpu=False): 1833 """Use cached_session instead.""" 1834 if self.id().endswith(".test_session"): 1835 self.skipTest("Not a test.") 1836 if context.executing_eagerly(): 1837 yield None 1838 else: 1839 if graph is None: 1840 sess = self._get_cached_session( 1841 graph, config, force_gpu, crash_if_inconsistent_args=False) 1842 with self._constrain_devices_and_set_default(sess, use_gpu, 1843 force_gpu) as cached: 1844 yield cached 1845 else: 1846 with self.session(graph, config, use_gpu, force_gpu) as sess: 1847 yield sess 1848 1849 # pylint: enable=g-doc-return-or-yield 1850 1851 class _CheckedThread(object): 1852 """A wrapper class for Thread that asserts successful completion. 1853 1854 This class should be created using the TensorFlowTestCase.checkedThread() 1855 method. 1856 """ 1857 1858 def __init__(self, testcase, target, args=None, kwargs=None): 1859 """Constructs a new instance of _CheckedThread. 1860 1861 Args: 1862 testcase: The TensorFlowTestCase for which this thread is being created. 1863 target: A callable object representing the code to be executed in the 1864 thread. 1865 args: A tuple of positional arguments that will be passed to target. 1866 kwargs: A dictionary of keyword arguments that will be passed to target. 1867 """ 1868 self._testcase = testcase 1869 self._target = target 1870 self._args = () if args is None else args 1871 self._kwargs = {} if kwargs is None else kwargs 1872 self._thread = threading.Thread(target=self._protected_run) 1873 self._exception = None 1874 1875 self._is_thread_joined = False 1876 1877 def _protected_run(self): 1878 """Target for the wrapper thread. Sets self._exception on failure.""" 1879 try: 1880 self._target(*self._args, **self._kwargs) 1881 except Exception as e: # pylint: disable=broad-except 1882 self._exception = e 1883 1884 def start(self): 1885 """Starts the thread's activity. 1886 1887 This must be called at most once per _CheckedThread object. It arranges 1888 for the object's target to be invoked in a separate thread of control. 1889 """ 1890 self._thread.start() 1891 1892 def join(self): 1893 """Blocks until the thread terminates. 1894 1895 Raises: 1896 self._testcase.failureException: If the thread terminates with due to 1897 an exception. 1898 """ 1899 self._is_thread_joined = True 1900 self._thread.join() 1901 if self._exception is not None: 1902 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 1903 1904 def is_alive(self): 1905 """Returns whether the thread is alive. 1906 1907 This method returns True just before the run() method starts 1908 until just after the run() method terminates. 1909 1910 Returns: 1911 True if the thread is alive, otherwise False. 1912 """ 1913 return self._thread.is_alive() 1914 1915 def check_termination(self): 1916 """Returns whether the checked thread was properly used and did terminate. 1917 1918 Every checked thread should be "join"ed after starting, and before the 1919 test tears down. If it is not joined, it is possible the thread will hang 1920 and cause flaky failures in tests. 1921 1922 Raises: 1923 self._testcase.failureException: If check_termination was called before 1924 thread was joined. 1925 1926 RuntimeError: If the thread is not terminated. This means thread was not 1927 joined with the main thread. 1928 """ 1929 if self._is_thread_joined: 1930 if self.is_alive(): 1931 raise RuntimeError( 1932 "Thread was not joined with main thread, and is still running " 1933 "when the test finished.") 1934 else: 1935 self._testcase.fail("A checked thread was not joined.") 1936 1937 def checkedThread(self, target, args=None, kwargs=None): 1938 """Returns a Thread wrapper that asserts 'target' completes successfully. 1939 1940 This method should be used to create all threads in test cases, as 1941 otherwise there is a risk that a thread will silently fail, and/or 1942 assertions made in the thread will not be respected. 1943 1944 Args: 1945 target: A callable object to be executed in the thread. 1946 args: The argument tuple for the target invocation. Defaults to (). 1947 kwargs: A dictionary of keyword arguments for the target invocation. 1948 Defaults to {}. 1949 1950 Returns: 1951 A wrapper for threading.Thread that supports start() and join() methods. 1952 """ 1953 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 1954 self._threads.append(ret) 1955 return ret 1956 1957 # pylint: enable=invalid-name 1958 @py_func_if_in_function 1959 def assertNear(self, f1, f2, err, msg=None): 1960 """Asserts that two floats are near each other. 1961 1962 Checks that |f1 - f2| < err and asserts a test failure 1963 if not. 1964 1965 Args: 1966 f1: A float value. 1967 f2: A float value. 1968 err: A float value. 1969 msg: An optional string message to append to the failure message. 1970 """ 1971 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 1972 self.assertTrue( 1973 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 1974 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 1975 1976 @py_func_if_in_function 1977 def assertArrayNear(self, farray1, farray2, err, msg=None): 1978 """Asserts that two float arrays are near each other. 1979 1980 Checks that for all elements of farray1 and farray2 1981 |f1 - f2| < err. Asserts a test failure if not. 1982 1983 Args: 1984 farray1: a list of float values. 1985 farray2: a list of float values. 1986 err: a float value. 1987 msg: Optional message to report on failure. 1988 """ 1989 self.assertEqual(len(farray1), len(farray2), msg=msg) 1990 for f1, f2 in zip(farray1, farray2): 1991 self.assertNear(float(f1), float(f2), err, msg=msg) 1992 1993 def _NDArrayNear(self, ndarray1, ndarray2, err): 1994 return np.linalg.norm(ndarray1 - ndarray2) < err 1995 1996 @py_func_if_in_function 1997 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 1998 """Asserts that two numpy arrays have near values. 1999 2000 Args: 2001 ndarray1: a numpy ndarray. 2002 ndarray2: a numpy ndarray. 2003 err: a float. The maximum absolute difference allowed. 2004 msg: Optional message to report on failure. 2005 """ 2006 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 2007 2008 def _GetNdArray(self, a): 2009 # If a is a tensor then convert it to ndarray 2010 if isinstance(a, ops.Tensor): 2011 if isinstance(a, ops._EagerTensorBase): 2012 a = a.numpy() 2013 else: 2014 a = self.evaluate(a) 2015 if not isinstance(a, np.ndarray): 2016 return np.array(a) 2017 return a 2018 2019 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2020 a = self._GetNdArray(a) 2021 b = self._GetNdArray(b) 2022 # When the array rank is small, print its contents. Numpy array printing is 2023 # implemented using inefficient recursion so prints can cause tests to 2024 # time out. 2025 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 2026 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 2027 "%s.") % (a.shape, b.shape, b) 2028 else: 2029 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 2030 b.shape) 2031 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 2032 2033 msgs = [msg] 2034 if not np.allclose(a, b, rtol=rtol, atol=atol): 2035 # Adds more details to np.testing.assert_allclose. 2036 # 2037 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 2038 # checks whether two arrays are element-wise equal within a 2039 # tolerance. The relative difference (rtol * abs(b)) and the 2040 # absolute difference atol are added together to compare against 2041 # the absolute difference between a and b. Here, we want to 2042 # tell user which elements violate such conditions. 2043 cond = np.logical_or( 2044 np.abs(a - b) > atol + rtol * np.abs(b), 2045 np.isnan(a) != np.isnan(b)) 2046 if a.ndim: 2047 x = a[np.where(cond)] 2048 y = b[np.where(cond)] 2049 msgs.append("not close where = {}".format(np.where(cond))) 2050 else: 2051 # np.where is broken for scalars 2052 x, y = a, b 2053 msgs.append("not close lhs = {}".format(x)) 2054 msgs.append("not close rhs = {}".format(y)) 2055 msgs.append("not close dif = {}".format(np.abs(x - y))) 2056 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 2057 msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape)) 2058 # TODO(xpan): There seems to be a bug: 2059 # tensorflow/compiler/tests:binary_ops_test pass with float32 2060 # nan even though the equal_nan is False by default internally. 2061 np.testing.assert_allclose( 2062 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 2063 2064 def _assertAllCloseRecursive(self, 2065 a, 2066 b, 2067 rtol=1e-6, 2068 atol=1e-6, 2069 path=None, 2070 msg=None): 2071 path = path or [] 2072 path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") 2073 msg = msg if msg else "" 2074 2075 # Check if a and/or b are namedtuples. 2076 if hasattr(a, "_asdict"): 2077 a = a._asdict() 2078 if hasattr(b, "_asdict"): 2079 b = b._asdict() 2080 a_is_dict = isinstance(a, collections.Mapping) 2081 if a_is_dict != isinstance(b, collections.Mapping): 2082 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 2083 (path_str, path_str, msg)) 2084 if a_is_dict: 2085 self.assertItemsEqual( 2086 a.keys(), 2087 b.keys(), 2088 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 2089 (path_str, a.keys(), path_str, b.keys(), msg)) 2090 for k in a: 2091 path.append(k) 2092 self._assertAllCloseRecursive( 2093 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 2094 del path[-1] 2095 elif isinstance(a, (list, tuple)): 2096 # Try to directly compare a, b as ndarrays; if not work, then traverse 2097 # through the sequence, which is more expensive. 2098 try: 2099 a_as_ndarray = self._GetNdArray(a) 2100 b_as_ndarray = self._GetNdArray(b) 2101 self._assertArrayLikeAllClose( 2102 a_as_ndarray, 2103 b_as_ndarray, 2104 rtol=rtol, 2105 atol=atol, 2106 msg="Mismatched value: a%s is different from b%s. %s" % 2107 (path_str, path_str, msg)) 2108 except (ValueError, TypeError) as e: 2109 if len(a) != len(b): 2110 raise ValueError( 2111 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 2112 (path_str, len(a), path_str, len(b), msg)) 2113 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 2114 path.append(str(idx)) 2115 self._assertAllCloseRecursive( 2116 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 2117 del path[-1] 2118 # a and b are ndarray like objects 2119 else: 2120 try: 2121 self._assertArrayLikeAllClose( 2122 a, 2123 b, 2124 rtol=rtol, 2125 atol=atol, 2126 msg=("Mismatched value: a%s is different from b%s. %s" % 2127 (path_str, path_str, msg))) 2128 except TypeError as e: 2129 msg = ("Error: a%s has %s, but b%s has %s. %s" % 2130 (path_str, type(a), path_str, type(b), msg)) 2131 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 2132 raise 2133 2134 @py_func_if_in_function 2135 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 2136 """Asserts that two structures of numpy arrays or Tensors, have near values. 2137 2138 `a` and `b` can be arbitrarily nested structures. A layer of a nested 2139 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 2140 2141 Args: 2142 a: The expected numpy `ndarray`, or anything that can be converted into a 2143 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2144 structure of these. 2145 b: The actual numpy `ndarray`, or anything that can be converted into a 2146 numpy `ndarray` (including Tensor), or any arbitrarily nested of 2147 structure of these. 2148 rtol: relative tolerance. 2149 atol: absolute tolerance. 2150 msg: Optional message to report on failure. 2151 2152 Raises: 2153 ValueError: if only one of `a[p]` and `b[p]` is a dict or 2154 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 2155 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 2156 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 2157 """ 2158 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 2159 2160 @py_func_if_in_function 2161 def assertAllCloseAccordingToType(self, 2162 a, 2163 b, 2164 rtol=1e-6, 2165 atol=1e-6, 2166 float_rtol=1e-6, 2167 float_atol=1e-6, 2168 half_rtol=1e-3, 2169 half_atol=1e-3, 2170 bfloat16_rtol=1e-2, 2171 bfloat16_atol=1e-2, 2172 msg=None): 2173 """Like assertAllClose, but also suitable for comparing fp16 arrays. 2174 2175 In particular, the tolerance is reduced to 1e-3 if at least 2176 one of the arguments is of type float16. 2177 2178 Args: 2179 a: the expected numpy ndarray or anything can be converted to one. 2180 b: the actual numpy ndarray or anything can be converted to one. 2181 rtol: relative tolerance. 2182 atol: absolute tolerance. 2183 float_rtol: relative tolerance for float32. 2184 float_atol: absolute tolerance for float32. 2185 half_rtol: relative tolerance for float16. 2186 half_atol: absolute tolerance for float16. 2187 bfloat16_rtol: relative tolerance for bfloat16. 2188 bfloat16_atol: absolute tolerance for bfloat16. 2189 msg: Optional message to report on failure. 2190 """ 2191 a = self._GetNdArray(a) 2192 b = self._GetNdArray(b) 2193 # types with lower tol are put later to overwrite previous ones. 2194 if (a.dtype == np.float32 or b.dtype == np.float32 or 2195 a.dtype == np.complex64 or b.dtype == np.complex64): 2196 rtol = max(rtol, float_rtol) 2197 atol = max(atol, float_atol) 2198 if a.dtype == np.float16 or b.dtype == np.float16: 2199 rtol = max(rtol, half_rtol) 2200 atol = max(atol, half_atol) 2201 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 2202 b.dtype == dtypes.bfloat16.as_numpy_dtype): 2203 rtol = max(rtol, bfloat16_rtol) 2204 atol = max(atol, bfloat16_atol) 2205 2206 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 2207 2208 @py_func_if_in_function 2209 def assertNotAllClose(self, a, b, **kwargs): 2210 """Assert that two numpy arrays, or Tensors, do not have near values. 2211 2212 Args: 2213 a: the first value to compare. 2214 b: the second value to compare. 2215 **kwargs: additional keyword arguments to be passed to the underlying 2216 `assertAllClose` call. 2217 2218 Raises: 2219 AssertionError: If `a` and `b` are unexpectedly close at all elements. 2220 """ 2221 try: 2222 self.assertAllClose(a, b, **kwargs) 2223 except AssertionError: 2224 return 2225 raise AssertionError("The two values are close at all elements") 2226 2227 @py_func_if_in_function 2228 def assertAllEqual(self, a, b, msg=None): 2229 """Asserts that two numpy arrays or Tensors have the same values. 2230 2231 Args: 2232 a: the expected numpy ndarray or anything can be converted to one. 2233 b: the actual numpy ndarray or anything can be converted to one. 2234 msg: Optional message to report on failure. 2235 """ 2236 msg = msg if msg else "" 2237 a = self._GetNdArray(a) 2238 b = self._GetNdArray(b) 2239 # Arbitrary bounds so that we don't print giant tensors. 2240 if (b.ndim <= 3 or b.size < 500): 2241 self.assertEqual( 2242 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2243 " Contents: %s. \n%s." % (a.shape, b.shape, b, msg)) 2244 else: 2245 self.assertEqual( 2246 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 2247 " %s" % (a.shape, b.shape, msg)) 2248 2249 same = (a == b) 2250 2251 if (a.dtype in [ 2252 np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype 2253 ]): 2254 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 2255 msgs = [msg] 2256 if not np.all(same): 2257 # Adds more details to np.testing.assert_array_equal. 2258 diff = np.logical_not(same) 2259 if a.ndim: 2260 x = a[np.where(diff)] 2261 y = b[np.where(diff)] 2262 msgs.append("not equal where = {}".format(np.where(diff))) 2263 else: 2264 # np.where is broken for scalars 2265 x, y = a, b 2266 msgs.append("not equal lhs = {}".format(x)) 2267 msgs.append("not equal rhs = {}".format(y)) 2268 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 2269 2270 @py_func_if_in_function 2271 def assertAllGreater(self, a, comparison_target): 2272 """Assert element values are all greater than a target value. 2273 2274 Args: 2275 a: The numpy `ndarray`, or anything that can be converted into a numpy 2276 `ndarray` (including Tensor). 2277 comparison_target: The target value of comparison. 2278 """ 2279 a = self._GetNdArray(a) 2280 self.assertGreater(np.min(a), comparison_target) 2281 2282 @py_func_if_in_function 2283 def assertAllLess(self, a, comparison_target): 2284 """Assert element values are all less than a target value. 2285 2286 Args: 2287 a: The numpy `ndarray`, or anything that can be converted into a numpy 2288 `ndarray` (including Tensor). 2289 comparison_target: The target value of comparison. 2290 """ 2291 a = self._GetNdArray(a) 2292 self.assertLess(np.max(a), comparison_target) 2293 2294 @py_func_if_in_function 2295 def assertAllGreaterEqual(self, a, comparison_target): 2296 """Assert element values are all greater than or equal to a target value. 2297 2298 Args: 2299 a: The numpy `ndarray`, or anything that can be converted into a numpy 2300 `ndarray` (including Tensor). 2301 comparison_target: The target value of comparison. 2302 """ 2303 a = self._GetNdArray(a) 2304 self.assertGreaterEqual(np.min(a), comparison_target) 2305 2306 @py_func_if_in_function 2307 def assertAllLessEqual(self, a, comparison_target): 2308 """Assert element values are all less than or equal to a target value. 2309 2310 Args: 2311 a: The numpy `ndarray`, or anything that can be converted into a numpy 2312 `ndarray` (including Tensor). 2313 comparison_target: The target value of comparison. 2314 """ 2315 a = self._GetNdArray(a) 2316 self.assertLessEqual(np.max(a), comparison_target) 2317 2318 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 2319 """Generate a summary of ndarray subscripts as a list of str. 2320 2321 If limit == N, this method will print up to the first N subscripts on 2322 separate 2323 lines. A line of ellipses (...) will be appended at the end if the number of 2324 subscripts exceeds N. 2325 2326 Args: 2327 subscripts: The tensor (np.ndarray) subscripts, of the same format as 2328 np.where()'s return value, i.e., a tuple of arrays with each array 2329 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 2330 value: (np.ndarray) value of the tensor. 2331 limit: (int) The maximum number of indices to print. 2332 indent: (int) Number of characters to indent at the beginning of each 2333 line. 2334 2335 Returns: 2336 (list of str) the multi-line representation of the subscripts and values, 2337 potentially with omission at the end. 2338 """ 2339 lines = [] 2340 subscripts = np.transpose(subscripts) 2341 prefix = " " * indent 2342 for subscript in itertools.islice(subscripts, limit): 2343 lines.append(prefix + str(subscript) + " : " + 2344 str(value[tuple(subscript)])) 2345 if len(subscripts) > limit: 2346 lines.append(prefix + "...") 2347 return lines 2348 2349 @py_func_if_in_function 2350 def assertAllInRange(self, 2351 target, 2352 lower_bound, 2353 upper_bound, 2354 open_lower_bound=False, 2355 open_upper_bound=False): 2356 """Assert that elements in a Tensor are all in a given range. 2357 2358 Args: 2359 target: The numpy `ndarray`, or anything that can be converted into a 2360 numpy `ndarray` (including Tensor). 2361 lower_bound: lower bound of the range 2362 upper_bound: upper bound of the range 2363 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 2364 than the default >=) 2365 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 2366 than the default <=) 2367 2368 Raises: 2369 AssertionError: 2370 if the value tensor does not have an ordered numeric type (float* or 2371 int*), or 2372 if there are nan values, or 2373 if any of the elements do not fall in the specified range. 2374 """ 2375 target = self._GetNdArray(target) 2376 if not (np.issubdtype(target.dtype, np.floating) or 2377 np.issubdtype(target.dtype, np.integer)): 2378 raise AssertionError( 2379 "The value of %s does not have an ordered numeric type, instead it " 2380 "has type: %s" % (target, target.dtype)) 2381 2382 nan_subscripts = np.where(np.isnan(target)) 2383 if np.size(nan_subscripts): 2384 raise AssertionError( 2385 "%d of the %d element(s) are NaN. " 2386 "Subscripts(s) and value(s) of the NaN element(s):\n" % 2387 (len(nan_subscripts[0]), np.size(target)) + 2388 "\n".join(self._format_subscripts(nan_subscripts, target))) 2389 2390 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 2391 str(upper_bound) + (")" if open_upper_bound else "]")) 2392 2393 violations = ( 2394 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 2395 target, lower_bound)) 2396 violations = np.logical_or( 2397 violations, 2398 np.greater_equal(target, upper_bound) 2399 if open_upper_bound else np.greater(target, upper_bound)) 2400 violation_subscripts = np.where(violations) 2401 if np.size(violation_subscripts): 2402 raise AssertionError( 2403 "%d of the %d element(s) are outside the range %s. " % 2404 (len(violation_subscripts[0]), np.size(target), range_str) + 2405 "Subscript(s) and value(s) of the offending elements:\n" + 2406 "\n".join(self._format_subscripts(violation_subscripts, target))) 2407 2408 @py_func_if_in_function 2409 def assertAllInSet(self, target, expected_set): 2410 """Assert that elements of a Tensor are all in a given closed set. 2411 2412 Args: 2413 target: The numpy `ndarray`, or anything that can be converted into a 2414 numpy `ndarray` (including Tensor). 2415 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 2416 of the value of `target` are expected to fall into. 2417 2418 Raises: 2419 AssertionError: 2420 if any of the elements do not fall into `expected_set`. 2421 """ 2422 target = self._GetNdArray(target) 2423 2424 # Elements in target that are not in expected_set. 2425 diff = np.setdiff1d(target.flatten(), list(expected_set)) 2426 if np.size(diff): 2427 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 2428 (np.size(diff), expected_set, diff)) 2429 2430 @py_func_if_in_function 2431 def assertDTypeEqual(self, target, expected_dtype): 2432 """Assert ndarray data type is equal to expected. 2433 2434 Args: 2435 target: The numpy `ndarray`, or anything that can be converted into a 2436 numpy `ndarray` (including Tensor). 2437 expected_dtype: Expected data type. 2438 """ 2439 target = self._GetNdArray(target) 2440 if not isinstance(target, list): 2441 arrays = [target] 2442 for arr in arrays: 2443 self.assertEqual(arr.dtype, expected_dtype) 2444 2445 # pylint: disable=g-doc-return-or-yield 2446 @contextlib.contextmanager 2447 def assertRaisesWithPredicateMatch(self, exception_type, 2448 expected_err_re_or_predicate): 2449 """Returns a context manager to enclose code expected to raise an exception. 2450 2451 If the exception is an OpError, the op stack is also included in the message 2452 predicate search. 2453 2454 Args: 2455 exception_type: The expected type of exception that should be raised. 2456 expected_err_re_or_predicate: If this is callable, it should be a function 2457 of one argument that inspects the passed-in exception and returns True 2458 (success) or False (please fail the test). Otherwise, the error message 2459 is expected to match this regular expression partially. 2460 2461 Returns: 2462 A context manager to surround code that is expected to raise an 2463 exception. 2464 """ 2465 if callable(expected_err_re_or_predicate): 2466 predicate = expected_err_re_or_predicate 2467 else: 2468 2469 def predicate(e): 2470 err_str = e.message if isinstance(e, errors.OpError) else str(e) 2471 op = e.op if isinstance(e, errors.OpError) else None 2472 while op is not None: 2473 err_str += "\nCaused by: " + op.name 2474 op = op._original_op # pylint: disable=protected-access 2475 logging.info("Searching within error strings: '%s' within '%s'", 2476 expected_err_re_or_predicate, err_str) 2477 return re.search(expected_err_re_or_predicate, err_str) 2478 2479 try: 2480 yield 2481 self.fail(exception_type.__name__ + " not raised") 2482 except Exception as e: # pylint: disable=broad-except 2483 if not isinstance(e, exception_type) or not predicate(e): 2484 raise AssertionError( 2485 "Exception of type %s: %s" % (str(type(e)), str(e))) 2486 2487 # pylint: enable=g-doc-return-or-yield 2488 2489 def assertRaisesOpError(self, expected_err_re_or_predicate): 2490 return self.assertRaisesWithPredicateMatch(errors.OpError, 2491 expected_err_re_or_predicate) 2492 2493 def assertShapeEqual(self, np_array, tf_tensor, msg=None): 2494 """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape. 2495 2496 Args: 2497 np_array: A Numpy ndarray or Numpy scalar. 2498 tf_tensor: A Tensor. 2499 msg: Optional message to report on failure. 2500 2501 Raises: 2502 TypeError: If the arguments have the wrong type. 2503 """ 2504 if not isinstance(np_array, (np.ndarray, np.generic)): 2505 raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") 2506 if not isinstance(tf_tensor, ops.Tensor): 2507 raise TypeError("tf_tensor must be a Tensor") 2508 self.assertAllEqual( 2509 np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) 2510 2511 def assertDeviceEqual(self, device1, device2, msg=None): 2512 """Asserts that the two given devices are the same. 2513 2514 Args: 2515 device1: A string device name or TensorFlow `DeviceSpec` object. 2516 device2: A string device name or TensorFlow `DeviceSpec` object. 2517 msg: Optional message to report on failure. 2518 """ 2519 device1 = pydev.canonical_name(device1) 2520 device2 = pydev.canonical_name(device2) 2521 self.assertEqual( 2522 device1, device2, 2523 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 2524 2525 # Fix Python 3 compatibility issues 2526 if six.PY3: 2527 # pylint: disable=invalid-name 2528 2529 # Silence a deprecation warning 2530 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 2531 2532 # assertItemsEqual is assertCountEqual as of 3.2. 2533 assertItemsEqual = googletest.TestCase.assertCountEqual 2534 2535 # pylint: enable=invalid-name 2536 2537 @contextlib.contextmanager 2538 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 2539 """Set the session and its graph to global default and constrain devices.""" 2540 if context.executing_eagerly(): 2541 yield None 2542 else: 2543 with sess.graph.as_default(), sess.as_default(): 2544 if force_gpu: 2545 # Use the name of an actual device if one is detected, or 2546 # '/device:GPU:0' otherwise 2547 gpu_name = gpu_device_name() 2548 if not gpu_name: 2549 gpu_name = "/device:GPU:0" 2550 with sess.graph.device(gpu_name): 2551 yield sess 2552 elif use_gpu: 2553 yield sess 2554 else: 2555 with sess.graph.device("/device:CPU:0"): 2556 yield sess 2557 2558 def _create_session(self, graph, config, force_gpu): 2559 """See session() for details.""" 2560 2561 def prepare_config(config): 2562 """Returns a config for sessions. 2563 2564 Args: 2565 config: An optional config_pb2.ConfigProto to use to configure the 2566 session. 2567 2568 Returns: 2569 A config_pb2.ConfigProto object. 2570 """ 2571 # TODO(b/114333779): Enforce allow_soft_placement=False when 2572 # use_gpu=False. Currently many tests rely on the fact that any device 2573 # will be used even when a specific device is supposed to be used. 2574 allow_soft_placement = not force_gpu 2575 if config is None: 2576 config = config_pb2.ConfigProto() 2577 config.allow_soft_placement = allow_soft_placement 2578 config.gpu_options.per_process_gpu_memory_fraction = 0.3 2579 elif not allow_soft_placement and config.allow_soft_placement: 2580 config_copy = config_pb2.ConfigProto() 2581 config_copy.CopyFrom(config) 2582 config = config_copy 2583 config.allow_soft_placement = False 2584 # Don't perform optimizations for tests so we don't inadvertently run 2585 # gpu ops on cpu 2586 config.graph_options.optimizer_options.opt_level = -1 2587 # Disable Grappler constant folding since some tests & benchmarks 2588 # use constant input and become meaningless after constant folding. 2589 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 2590 # GRAPPLER TEAM. 2591 config.graph_options.rewrite_options.constant_folding = ( 2592 rewriter_config_pb2.RewriterConfig.OFF) 2593 config.graph_options.rewrite_options.pin_to_host_optimization = ( 2594 rewriter_config_pb2.RewriterConfig.OFF) 2595 return config 2596 2597 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 2598 2599 def _get_cached_session(self, 2600 graph=None, 2601 config=None, 2602 force_gpu=False, 2603 crash_if_inconsistent_args=True): 2604 """See cached_session() for documentation.""" 2605 if self._cached_session is None: 2606 sess = self._create_session( 2607 graph=graph, config=config, force_gpu=force_gpu) 2608 self._cached_session = sess 2609 self._cached_graph = graph 2610 self._cached_config = config 2611 self._cached_force_gpu = force_gpu 2612 return sess 2613 else: 2614 if crash_if_inconsistent_args and self._cached_graph is not graph: 2615 raise ValueError("The graph used to get the cached session is " 2616 "different than the one that was used to create the " 2617 "session. Maybe create a new session with " 2618 "self.session()") 2619 if crash_if_inconsistent_args and self._cached_config is not config: 2620 raise ValueError("The config used to get the cached session is " 2621 "different than the one that was used to create the " 2622 "session. Maybe create a new session with " 2623 "self.session()") 2624 if crash_if_inconsistent_args and (self._cached_force_gpu is 2625 not force_gpu): 2626 raise ValueError( 2627 "The force_gpu value used to get the cached session is " 2628 "different than the one that was used to create the " 2629 "session. Maybe create a new session with " 2630 "self.session()") 2631 return self._cached_session 2632 2633 2634@tf_export("test.create_local_cluster") 2635def create_local_cluster(num_workers, 2636 num_ps, 2637 protocol="grpc", 2638 worker_config=None, 2639 ps_config=None): 2640 """Create and start local servers and return the associated `Server` objects. 2641 2642 Example: 2643 ```python 2644 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 2645 2646 worker_sessions = [tf.Session(w.target) for w in workers] 2647 2648 with tf.device("/job:ps/task:0"): 2649 ... 2650 with tf.device("/job:ps/task:1"): 2651 ... 2652 with tf.device("/job:worker/task:0"): 2653 ... 2654 with tf.device("/job:worker/task:1"): 2655 ... 2656 2657 worker_sessions[0].run(...) 2658 ``` 2659 2660 Args: 2661 num_workers: Number of worker servers to start. 2662 num_ps: Number of PS servers to start. 2663 protocol: Communication protocol. Allowed values are documented in the 2664 documentation of `tf.train.Server`. 2665 worker_config: (optional) ConfigProto to initialize workers. Can be used to 2666 instantiate multiple devices etc. 2667 ps_config: (optional) ConfigProto to initialize PS servers. 2668 2669 Returns: 2670 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 2671 of `num_workers` objects of type `tf.train.Server` (all running locally); 2672 and `ps_servers` is a list of `num_ps` objects of similar type. 2673 2674 Raises: 2675 ImportError: if portpicker module was not found at load time 2676 """ 2677 if _portpicker_import_error: 2678 raise _portpicker_import_error # pylint: disable=raising-bad-type 2679 worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] 2680 ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] 2681 cluster_dict = { 2682 "worker": ["localhost:%s" % port for port in worker_ports], 2683 "ps": ["localhost:%s" % port for port in ps_ports] 2684 } 2685 cs = server_lib.ClusterSpec(cluster_dict) 2686 2687 workers = [ 2688 server_lib.Server( 2689 cs, 2690 job_name="worker", 2691 protocol=protocol, 2692 task_index=ix, 2693 config=worker_config, 2694 start=True) for ix in range(num_workers) 2695 ] 2696 ps_servers = [ 2697 server_lib.Server( 2698 cs, 2699 job_name="ps", 2700 protocol=protocol, 2701 task_index=ix, 2702 config=ps_config, 2703 start=True) for ix in range(num_ps) 2704 ] 2705 2706 return workers, ps_servers 2707 2708 2709def get_node_def_from_graph(node_name, graph_def): 2710 """Returns the `NodeDef` instance for given node name in the graph def. 2711 2712 This method explores only the NodeDefs in `graph_def.node`. 2713 2714 Args: 2715 node_name: Name of the NodeDef to search for. 2716 graph_def: An instance of `GraphDef` proto. 2717 2718 Returns: 2719 the `NodeDef` instance whose name field matches the given node_name or None. 2720 """ 2721 for node_def in graph_def.node: 2722 if node_def.name == node_name: 2723 return node_def 2724 return None 2725 2726 2727def set_producer_version(graph, producer_version): 2728 """Sets graph.graph_def_versions.producer to `producer_version`.""" 2729 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 2730 # it via import_graph_def though. 2731 graph_def = graph_pb2.GraphDef() 2732 graph_def.versions.producer = producer_version 2733 with graph.as_default(): 2734 importer.import_graph_def(graph_def) 2735 assert graph.graph_def_versions.producer, producer_version 2736