1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""State management for eager execution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import gc 25import os 26import random 27import threading 28 29from absl import logging 30import numpy as np 31import six 32 33from tensorflow.core.framework import function_pb2 34from tensorflow.core.protobuf import config_pb2 35from tensorflow.core.protobuf import rewriter_config_pb2 36from tensorflow.python import pywrap_tfe 37from tensorflow.python import tf2 38from tensorflow.python.client import pywrap_tf_session 39from tensorflow.python.eager import executor 40from tensorflow.python.eager import monitoring 41from tensorflow.python.framework import c_api_util 42from tensorflow.python.framework import device as pydev 43from tensorflow.python.framework import tfrt_utils 44from tensorflow.python.util import compat 45from tensorflow.python.util import is_in_graph_mode 46from tensorflow.python.util import tf_contextlib 47from tensorflow.python.util.deprecation import deprecated 48from tensorflow.python.util.tf_export import tf_export 49 50GRAPH_MODE = 0 51EAGER_MODE = 1 52 53default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 54 55# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 56# new_device_spec). 57# Note that we do not protect this with a lock and instead rely on python's GIL 58# and the idempotent nature of writes to provide thread safety. 59_device_parsing_cache = {} 60_starting_device_spec = pydev.DeviceSpec.from_string("") 61 62_MAXINT32 = 2**31 - 1 63 64DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT 65DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN 66DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT 67DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 68 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 69 70SYNC = 0 71ASYNC = 1 72 73_KEEP_ALIVE_SECS = 600 74 75_python_eager_context_create_counter = monitoring.Counter( 76 "/tensorflow/api/python/eager_context_create_counter", 77 "Counter for number of eager contexts created in Python.") 78 79# Re-exporting through context. 80is_tfrt_enabled = tfrt_utils.enabled 81 82_RUN_EAGER_OP_AS_FUNCTION_ENABLED = False 83 84 85def enable_run_eager_op_as_function(): 86 """Execute elementary eager ops (non-function) wrapped in a call op. 87 88 This should be functionally equivalent to running the eager op's kernel 89 directly (the default) but reduces the number of codepaths for executing 90 TF2 programs in the runtime, thereby improving consistency (in terms of 91 optimizations and rewrites for instance) and maintainability. 92 """ 93 # Must be called before context is actually built. 94 global _RUN_EAGER_OP_AS_FUNCTION_ENABLED 95 _RUN_EAGER_OP_AS_FUNCTION_ENABLED = True 96 97 98def run_eager_op_as_function_enabled(): 99 return _RUN_EAGER_OP_AS_FUNCTION_ENABLED 100 101 102# Expose it as internally public APIs for Keras use cases in b/171080602. 103tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) 104 105 106class _EagerTensorCache(object): 107 """Simple cache which evicts items based on length in a FIFO manner.""" 108 109 __slots__ = ["_data", "_max_items", "_max_tensor_size"] 110 111 def __init__(self, max_items=256, max_tensor_size=10000): 112 self._data = collections.OrderedDict() 113 self._max_items = max_items 114 self._max_tensor_size = max_tensor_size 115 116 def put(self, key, value): 117 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 118 return 119 120 self._data[key] = value 121 122 if len(self._data) > self._max_items: 123 self._data.popitem(last=False) 124 125 def get(self, key): 126 return self._data.get(key, None) 127 128 def flush(self): 129 self._data.clear() 130 131 132class FunctionCallOptions(object): 133 """Options applied at call sites of eager functions. 134 135 Eager functions are functions decorated with tf.contrib.eager.defun. 136 """ 137 138 __slots__ = ["_config_proto_serialized", "_executor_type"] 139 140 def __init__(self, executor_type=None, config_proto=None): 141 """Constructor. 142 143 Args: 144 executor_type: (optional) name of the executor to be used to execute the 145 eager function. If None or an empty string, the default Tensorflow 146 executor will be used. 147 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized 148 string of that proto. The config used by Grappler when optimizing the 149 function graph. Each concrete function is optimized the first time is 150 called. Changing config_proto after the first call has no effect. If 151 config_proto is None, an empty RewriterConfig will be used. 152 """ 153 self.config_proto_serialized = config_proto 154 self.executor_type = executor_type 155 156 @property 157 def executor_type(self): 158 return self._executor_type 159 160 @executor_type.setter 161 def executor_type(self, executor_type): 162 self._executor_type = executor_type 163 164 @property 165 def config_proto_serialized(self): 166 return self._config_proto_serialized 167 168 @config_proto_serialized.setter 169 def config_proto_serialized(self, config): 170 if isinstance(config, config_pb2.ConfigProto): 171 self._config_proto_serialized = config.SerializeToString( 172 deterministic=True) 173 elif isinstance(config, str): 174 self._config_proto_serialized = config 175 elif config is None: 176 self._config_proto_serialized = ( 177 config_pb2.ConfigProto().SerializeToString()) 178 else: 179 raise ValueError("the rewriter config must be either a " 180 "config_pb2.ConfigProto, or a serialized string of that " 181 "proto or None. got: {}".format(type(config))) 182 183 184# Map from context_id (an int) to _TensorCaches. 185# Dicts are thread safe in CPython. 186# TODO(iga): Remove this once TensorCaches are moved to C++. 187_tensor_caches_map = {} 188 189 190class _TensorCaches(threading.local): 191 """Thread local tensor caches.""" 192 193 __slots__ = ["_ones_rank_cache", "_zeros_cache"] 194 195 def __init__(self): 196 super(_TensorCaches, self).__init__() 197 self._ones_rank_cache = None 198 self._zeros_cache = None 199 200 @property 201 def ones_rank_cache(self): 202 if not self._ones_rank_cache: 203 self._ones_rank_cache = _EagerTensorCache() 204 return self._ones_rank_cache 205 206 @property 207 def zeros_cache(self): 208 if not self._zeros_cache: 209 self._zeros_cache = _EagerTensorCache() 210 return self._zeros_cache 211 212 213ContextSwitch = collections.namedtuple( 214 "ContextSwitch", 215 ["is_building_function", "enter_context_fn", "device_stack"]) 216 217 218# `_ContextSwitchStack` is a `threading.local` to match the semantics of 219# ``DefaultGraphStack`, which is also a `threading.local`. 220class _ContextSwitchStack(threading.local): 221 """A thread-local stack of context switches.""" 222 223 def __init__(self, eager): 224 super(_ContextSwitchStack, self).__init__() 225 self.stack = [] 226 if eager: 227 # Initialize the stack with a pointer to enter the eager context; this 228 # ensures that the fact that eager execution was enabled is propagated 229 # across threads, since (1) `enable_eager_execution` modifies a 230 # process-level flag (`default_execution_mode`) and (2) `__init__` is 231 # called each time a threading.local object is used in a separate thread. 232 self.push( 233 is_building_function=False, 234 enter_context_fn=eager_mode, 235 device_stack=None) 236 237 def push(self, is_building_function, enter_context_fn, device_stack): 238 """Push metadata about a context switch onto the stack. 239 240 A context switch can take any one of the two forms: installing a graph as 241 the default graph, or entering the eager context. For each context switch, 242 we record whether or not the entered context is building a function. 243 244 Args: 245 is_building_function: (bool.) Whether the context is building a function. 246 enter_context_fn: (function.) A callable that executes the context switch. 247 For example, `graph.as_default` or `eager_mode`. 248 device_stack: If applicable, the device function stack for this graph. 249 When breaking out of graphs in init_scope, the innermost nonempty device 250 stack is used. Eager contexts put `None` here and the value is never 251 used. 252 """ 253 254 self.stack.append( 255 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 256 257 def pop(self): 258 """Pop the stack.""" 259 260 self.stack.pop() 261 262 263@tf_export("config.LogicalDevice") 264class LogicalDevice( 265 collections.namedtuple("LogicalDevice", ["name", "device_type"])): 266 """Abstraction for a logical device initialized by the runtime. 267 268 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a 269 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors 270 and operations can be placed on a specific logical device by calling 271 `tf.device` with a specified `tf.config.LogicalDevice`. 272 273 Fields: 274 name: The fully qualified name of the device. Can be used for Op or function 275 placement. 276 device_type: String declaring the type of device such as "CPU" or "GPU". 277 """ 278 pass 279 280 281@tf_export("config.LogicalDeviceConfiguration", 282 "config.experimental.VirtualDeviceConfiguration") 283class LogicalDeviceConfiguration( 284 collections.namedtuple("LogicalDeviceConfiguration", 285 ["memory_limit", "experimental_priority"])): 286 """Configuration class for a logical devices. 287 288 The class specifies the parameters to configure a `tf.config.PhysicalDevice` 289 as it is initialized to a `tf.config.LogicalDevice` during runtime 290 initialization. Not all fields are valid for all device types. 291 292 See `tf.config.get_logical_device_configuration` and 293 `tf.config.set_logical_device_configuration` for usage examples. 294 295 Fields: 296 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual 297 device. Currently only supported for GPUs. 298 experimental_priority: (optional) Priority to assign to a virtual device. 299 Lower values have higher priorities and 0 is the default. 300 Within a physical GPU, the GPU scheduler will prioritize ops on virtual 301 devices with higher priority. Currently only supported for Nvidia GPUs. 302 """ 303 304 def __new__(cls, memory_limit=None, experimental_priority=None): 305 return super(LogicalDeviceConfiguration, 306 cls).__new__(cls, memory_limit, experimental_priority) 307 308 309@tf_export("config.PhysicalDevice") 310class PhysicalDevice( 311 collections.namedtuple("PhysicalDevice", ["name", "device_type"])): 312 """Abstraction for a locally visible physical device. 313 314 TensorFlow can utilize various devices such as the CPU or multiple GPUs 315 for computation. Before initializing a local device for use, the user can 316 customize certain properties of the device such as it's visibility or memory 317 configuration. 318 319 Once a visible `tf.config.PhysicalDevice` is initialized one or more 320 `tf.config.LogicalDevice` objects are created. Use 321 `tf.config.set_visible_devices` to configure the visibility of a physical 322 device and `tf.config.set_logical_device_configuration` to configure multiple 323 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is 324 useful when separation between models is needed or to simulate a multi-device 325 environment. 326 327 Fields: 328 name: Unique identifier for device. 329 device_type: String declaring the type of device such as "CPU" or "GPU". 330 """ 331 pass 332 333 334class _AtomicCounter(object): 335 """A simple atomic counter.""" 336 337 __slots__ = ["_value", "_lock"] 338 339 def __init__(self): 340 self._value = 0 341 self._lock = threading.Lock() 342 343 def increment_and_get(self): 344 with self._lock: 345 self._value += 1 346 return self._value 347 348 349_context_id_counter = _AtomicCounter() 350 351 352class _TensorCacheDeleter(object): 353 """Deletes tensor caches for a given context.""" 354 355 __slots__ = ["_context_id"] 356 357 def __init__(self, context_id): 358 self._context_id = context_id 359 360 def __del__(self): 361 if _tensor_caches_map is None: 362 return 363 if self._context_id in _tensor_caches_map: 364 del _tensor_caches_map[self._context_id] 365 366 367# TODO(agarwal): rename to EagerContext / EagerRuntime ? 368# TODO(agarwal): consider keeping the corresponding Graph here. 369class Context(object): 370 """Environment in which eager operations execute.""" 371 372 # TODO(agarwal): create and link in some documentation for `execution_mode`. 373 # pylint: disable=redefined-outer-name 374 def __init__(self, 375 config=None, 376 device_policy=None, 377 execution_mode=None, 378 server_def=None): 379 """Creates a new Context. 380 381 Args: 382 config: (Optional.) A `ConfigProto` protocol buffer with configuration 383 options for the Context. Note that a lot of these options may be 384 currently unimplemented or irrelevant when eager execution is enabled. 385 device_policy: (Optional.) What policy to use when trying to run an 386 operation on a device with inputs which are not on that device. When set 387 to None, an appropriate value will be picked automatically. The value 388 picked may change between TensorFlow releases. Defaults to 389 DEVICE_PLACEMENT_SILENT. 390 Valid values: 391 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 392 correct. 393 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right 394 device but raises a warning. 395 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide 396 performance problems. 397 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 398 raising errors on the other ones. 399 execution_mode: (Optional.) Policy controlling how operations dispatched 400 are actually executed. When set to None, an appropriate value will be 401 picked automatically. The value picked may change between TensorFlow 402 releases. 403 Valid values: 404 - SYNC: executes each operation synchronously. 405 - ASYNC: executes each operation asynchronously. These operations may 406 return "non-ready" handles. 407 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution 408 on remote devices. GrpcServers need to be started by creating an 409 identical server_def to this, and setting the appropriate task_indexes, 410 so that the servers can communicate. It will then be possible to execute 411 operations on remote devices. 412 413 Raises: 414 ValueError: If execution_mode is not valid. 415 """ 416 # This _id is used only to index the tensor caches. 417 # TODO(iga): Remove this when tensor caches are moved to C++. 418 self._id = _context_id_counter.increment_and_get() 419 self._tensor_cache_deleter = _TensorCacheDeleter(self._id) 420 _tensor_caches_map[self._id] = _TensorCaches() 421 422 self._config = config 423 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( 424 self, 425 is_eager=lambda: default_execution_mode == EAGER_MODE, 426 device_spec=_starting_device_spec) 427 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 428 self._context_handle = None 429 self._context_devices = None 430 self._seed = None 431 self._initialize_lock = threading.Lock() 432 self._initialized = False 433 if device_policy is None: 434 device_policy = DEVICE_PLACEMENT_SILENT 435 self._device_policy = device_policy 436 self._mirroring_policy = None 437 if execution_mode not in (None, SYNC, ASYNC): 438 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" % 439 execution_mode) 440 if execution_mode is None: 441 execution_mode = SYNC 442 self._default_is_async = execution_mode == ASYNC 443 self._use_tfrt = is_tfrt_enabled() 444 self._use_tfrt_distributed_runtime = None 445 self._run_eager_op_as_function = run_eager_op_as_function_enabled() 446 self._server_def = server_def 447 self._collective_ops_server_def = None 448 self._collective_leader = None 449 self._collective_scoped_allocator_enabled_ops = None 450 self._collective_use_nccl_communication = None 451 self._collective_device_filters = None 452 self._coordination_service = None 453 454 self._device_lock = threading.Lock() 455 self._physical_devices = None 456 self._physical_device_to_index = None 457 self._visible_device_list = [] 458 self._memory_growth_map = None 459 self._virtual_device_map = {} 460 461 # Values set after construction 462 self._optimizer_jit = None 463 self._intra_op_parallelism_threads = None 464 self._inter_op_parallelism_threads = None 465 self._soft_device_placement = None 466 self._log_device_placement = None 467 self._enable_mlir_graph_optimization = None 468 self._optimizer_experimental_options = {} 469 470 _python_eager_context_create_counter.get_cell().increase_by(1) 471 472 # pylint: enable=redefined-outer-name 473 474 def _set_global_seed(self, seed): 475 """Set a global eager mode seed for random ops.""" 476 self._seed = seed 477 # `random.Random(seed)` needs `seed` to be hashable, while values of type 478 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them 479 # to int. 480 try: 481 hash(seed) 482 except TypeError: 483 seed = int(np.array(seed)) 484 self._rng = random.Random(seed) 485 # Also clear the kernel cache, to reset any existing seeds 486 if self._context_handle is not None: 487 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 488 489 def _internal_operation_seed(self): 490 """Returns a fake operation seed. 491 492 In eager mode, user shouldn't set or depend on operation seed. 493 Here, we generate a random seed based on global seed to make 494 operation's randomness different and depend on the global seed. 495 496 Returns: 497 A fake operation seed based on global seed. 498 """ 499 return self._rng.randint(0, _MAXINT32) 500 501 def _initialize_logical_devices(self): 502 """Helper to initialize devices.""" 503 # Store list of devices 504 logical_devices = [] 505 context_devices = [] 506 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) 507 try: 508 self._num_gpus = 0 509 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): 510 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) 511 context_devices.append(pydev.canonical_name(dev_name)) 512 spec = pydev.DeviceSpec.from_string(dev_name) 513 # If the job is localhost, we assume that the cluster has not yet been 514 # configured and thus clear the job, replica & task. 515 if spec.job == "localhost": 516 spec = spec.replace(job=None, replica=None, task=None) 517 logical_devices.append( 518 LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) 519 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) 520 if dev_type == "GPU": 521 self._num_gpus += 1 522 523 finally: 524 self._logical_devices = logical_devices 525 self._context_devices = context_devices 526 pywrap_tfe.TF_DeleteDeviceList(device_list) 527 528 def ensure_initialized(self): 529 """Initialize handle and devices if not already done so.""" 530 if self._initialized: 531 return 532 with self._initialize_lock: 533 if self._initialized: 534 return 535 assert self._context_devices is None 536 opts = pywrap_tfe.TFE_NewContextOptions() 537 try: 538 config_str = self.config.SerializeToString() 539 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) 540 if self._device_policy is not None: 541 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( 542 opts, self._device_policy) 543 if self._mirroring_policy is not None: 544 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( 545 opts, self._mirroring_policy) 546 if self._default_is_async == ASYNC: 547 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) 548 if self._use_tfrt is not None: 549 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) 550 # pylint: disable=g-backslash-continuation 551 if self._use_tfrt is not None and \ 552 self._use_tfrt_distributed_runtime is not None: 553 pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime( 554 opts, self._use_tfrt_distributed_runtime) 555 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction( 556 opts, self._run_eager_op_as_function) 557 context_handle = pywrap_tfe.TFE_NewContext(opts) 558 finally: 559 pywrap_tfe.TFE_DeleteContextOptions(opts) 560 assert not (self._server_def and self._collective_ops_server_def), ( 561 "Cannot enable remote execution as well as collective ops at the " 562 "moment. If this is important to you, please file an issue.") 563 if self._server_def is not None: 564 server_def_str = self._server_def.SerializeToString() 565 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, 566 server_def_str) 567 elif self._collective_ops_server_def is not None: 568 server_def_str = self._collective_ops_server_def.SerializeToString() 569 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) 570 571 self._context_handle = context_handle 572 self._initialize_logical_devices() 573 self._initialized = True 574 575 def _clear_caches(self): 576 self.ones_rank_cache().flush() 577 self.zeros_cache().flush() 578 pywrap_tfe.TFE_ClearScalarCache() 579 580 def get_server_def(self): 581 return self._server_def 582 583 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 584 """Allow setting a server_def on the context. 585 586 When a server def is replaced, it effectively clears a bunch of caches 587 within the context. If you attempt to use a tensor object that was pointing 588 to a tensor on the remote device, it will raise an error. 589 590 Args: 591 server_def: A tensorflow::ServerDef proto. Enables execution on remote 592 devices. 593 keep_alive_secs: Num. seconds after which the remote end will hang up. As 594 long as the client is still alive, the server state for the context will 595 be kept alive. If the client is killed (or there is some failure), the 596 server will clean up its context keep_alive_secs after the final RPC it 597 receives. 598 599 Raises: 600 ValueError: if server_def is None. 601 """ 602 if not server_def: 603 raise ValueError("server_def is None.") 604 605 self._server_def = server_def 606 607 if self._context_handle: 608 server_def_str = server_def.SerializeToString() 609 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, 610 server_def_str) 611 self._initialize_logical_devices() 612 613 # Clear all the caches in case there are remote tensors in them. 614 self._clear_caches() 615 616 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 617 """Update a server_def on the context. 618 619 Args: 620 server_def: A tensorflow::ServerDef proto. Enables execution on remote 621 devices. 622 keep_alive_secs: Num. seconds after which the remote end will hang up. As 623 long as the client is still alive, the server state for the context will 624 be kept alive. If the client is killed (or there is some failure), the 625 server will clean up its context keep_alive_secs after the final RPC it 626 receives. 627 628 Raises: 629 ValueError: if server_def is None. 630 """ 631 if not server_def: 632 raise ValueError("server_def is None.") 633 634 self._server_def = server_def 635 636 if self._context_handle: 637 server_def_str = server_def.SerializeToString() 638 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, 639 keep_alive_secs, server_def_str) 640 self._initialize_logical_devices() 641 642 self._clear_caches() 643 644 def check_alive(self, worker_name): 645 """Checks whether a remote worker is alive or not. 646 647 Args: 648 worker_name: a string representing the remote worker. It must be a fully 649 specified name like "/job:worker/replica:0/task:0". 650 651 Returns: 652 a boolean indicating whether the remote worker is alive or not. 653 654 Raises: 655 ValueError: if context is not initialized. 656 """ 657 # TODO(yuefengz): support checking multiple workers. 658 if self._context_handle: 659 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) 660 else: 661 raise ValueError("Context is not initialized.") 662 663 def sync_executors(self): 664 """Sync both local executors and the ones on remote workers. 665 666 In async execution mode, local function calls can return before the 667 corresponding remote op/function execution requests are completed. Calling 668 this method creates a synchronization barrier for remote executors. It only 669 returns when all remote pending nodes are finished, potentially with errors 670 if any remote executors are in error state. 671 672 Raises: 673 ValueError: if context is not initialized. 674 """ 675 if self._context_handle: 676 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle) 677 else: 678 raise ValueError("Context is not initialized.") 679 680 def clear_executor_errors(self): 681 """Clear errors in both local executors and remote workers. 682 683 After receiving errors from remote workers, additional requests on the fly 684 could further taint the status on the remote workers due to the async nature 685 of remote execution. Calling this method block on waiting for all pending 686 nodes in remote executors to finish and clear their error statuses. 687 688 Raises: 689 ValueError: if context is not initialized. 690 """ 691 if self._context_handle: 692 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle) 693 else: 694 raise ValueError("Context is not initialized.") 695 696 def enable_coordination_service(self, service_type): 697 if self._context_handle: 698 logging.warn("Configuring coordination service type may not be effective " 699 "because the context is already initialized.") 700 self._coordination_service = service_type 701 702 @property 703 def coordination_service(self): 704 return self._coordination_service 705 706 def set_config_key_value(self, key, value): 707 ensure_initialized() 708 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value) 709 710 def get_config_key_value(self, key): 711 ensure_initialized() 712 with c_api_util.tf_buffer() as buffer_: 713 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, buffer_) 714 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 715 return value 716 717 def delete_config_key_value(self, key): 718 ensure_initialized() 719 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key) 720 721 def report_error_to_cluster(self, error_code, error_message): 722 """Report error to other members in a multi-client cluster. 723 724 Args: 725 error_code: a `tf.errors` error code. 726 error_message: a string. The error message. 727 """ 728 if self._context_handle: 729 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code, 730 error_message) 731 else: 732 raise ValueError("Context is not initialized.") 733 734 def clear_kernel_cache(self): 735 """Clear kernel cache and reset all stateful kernels.""" 736 if self._context_handle is not None: 737 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 738 739 def enable_collective_ops(self, server_def): 740 """Enable distributed collective ops with an appropriate server_def. 741 742 Args: 743 server_def: A tensorflow::ServerDef proto. Enables execution on remote 744 devices. 745 746 Raises: 747 ValueError: if server_def is None. 748 RuntimeError: if this method is not called at program startup. 749 """ 750 if not server_def: 751 raise ValueError("server_def is None.") 752 753 self._collective_ops_server_def = server_def 754 755 # TODO(b/129298253): Allow creating datasets/tensors before enabling 756 # collective ops. 757 if self._context_handle is not None: 758 logging.warning("Enabling collective ops after program startup may cause " 759 "error when accessing previously created tensors.") 760 with self._initialize_lock: 761 assert self._initialized 762 server_def_str = self._collective_ops_server_def.SerializeToString() 763 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) 764 self._initialize_logical_devices() 765 self._clear_caches() 766 767 def configure_collective_ops( 768 self, 769 collective_leader="", 770 scoped_allocator_enabled_ops=("CollectiveReduce",), 771 use_nccl_communication=False, 772 device_filters=None): 773 """Configure collective ops. 774 775 Collective group leader is necessary for collective ops to run, other 776 configurations are mainly for the purpose of performance. 777 778 Args: 779 collective_leader: a device string for collective leader, e.g. 780 "/job:worker/replica:0/task:0"; empty string means local execution of 781 collective ops. 782 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped 783 allocator to run with. 784 use_nccl_communication: whether to use nccl communication for collective 785 ops. 786 device_filters: a tuple or a list of device strings. If set, corresponding 787 task can only see the devices filtered by these device filters. 788 789 Raises: 790 RuntimeError: if this method is not called at program startup. 791 """ 792 if self._collective_leader is not None: 793 if (self._collective_leader != collective_leader or 794 self._collective_scoped_allocator_enabled_ops != 795 scoped_allocator_enabled_ops or 796 self._collective_use_nccl_communication != use_nccl_communication or 797 self._collective_device_filters != device_filters): 798 raise ValueError("Collective ops are already configured.") 799 else: 800 return 801 802 if self._context_handle is not None: 803 raise RuntimeError("Collective ops must be configured at program startup") 804 805 self._collective_leader = collective_leader 806 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops 807 self._collective_use_nccl_communication = use_nccl_communication 808 self._collective_device_filters = device_filters 809 810 def abort_collective_ops(self, code, message): 811 """Abort the collective ops. 812 813 This is intended to be used when a peer failure is detected, which allows 814 the user to handle the case instead of hanging. This aborts all on-going 815 collectives. After all subsequent collectives error immediately, and you 816 need to reset_context() to use collectives again. 817 818 Args: 819 code: a `tf.errors` error code. 820 message: a string. The error message. 821 """ 822 self.ensure_initialized() 823 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) 824 825 def check_collective_ops_peer_health(self, task, timeout_in_ms): 826 """Check collective peer health. 827 828 This probes each task to see if they're still alive. Note that restarted 829 tasks are considered a different one, and they're considered not healthy. 830 831 This should only be used in multi client multi worker training. 832 833 Args: 834 task: a task string, must be in the format of /job:xxx/replica:0/task:N. 835 timeout_in_ms: an integer, the timeout. If zero, there's no timeout. 836 837 Raises: 838 tf.errors.UnavailableError: when a peer is down. 839 tf.errors.FailedPreconditionError: when a peer is a different one from the 840 one this task has talked to, e.g. the peer has restarted. 841 tf.errors.InvalidArgumentError: when the task string is invalid. 842 """ 843 self.ensure_initialized() 844 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, 845 timeout_in_ms) 846 847 @property 848 def _handle(self): 849 if self._context_handle is None: 850 raise AssertionError("Context must be initialized first.") 851 852 return self._context_handle 853 854 @property 855 def _devices(self): 856 if self._context_devices is None: 857 raise AssertionError("Context must be initialized first.") 858 859 return self._context_devices 860 861 def __str__(self): 862 if self._context_handle is None: 863 return "Eager TensorFlow Context. Devices currently uninitialized." 864 else: 865 devices = self._devices 866 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 867 for i, d in enumerate(devices): 868 lines.append(" Device %d: %s" % (i, d)) 869 return "\n".join(lines) 870 871 @tf_contextlib.contextmanager 872 def _mode(self, mode): 873 """A context manager to allow setting the mode to EAGER/GRAPH.""" 874 ctx = self._thread_local_data 875 old_is_eager = ctx.is_eager 876 ctx.is_eager = mode == EAGER_MODE 877 if mode == EAGER_MODE: 878 # Entering graph mode does not provide us with sufficient information to 879 # record a context switch; graph-based context switches are only logged 880 # when a graph is registered as the default graph. 881 self.context_switches.push(False, eager_mode, None) 882 try: 883 yield 884 finally: 885 ctx.is_eager = old_is_eager 886 if mode == EAGER_MODE: 887 self.context_switches.pop() 888 889 def executing_eagerly(self): 890 """Returns True if current thread has eager executing enabled.""" 891 return self._thread_local_data.is_eager 892 893 def ones_rank_cache(self): 894 """Per-device cache for scalars.""" 895 return _tensor_caches_map[self._id].ones_rank_cache 896 897 def zeros_cache(self): 898 """Per-device cache for scalars.""" 899 return _tensor_caches_map[self._id].zeros_cache 900 901 @property 902 def scope_name(self): 903 """Returns scope name for the current thread.""" 904 return self._thread_local_data.scope_name 905 906 @scope_name.setter 907 def scope_name(self, s): 908 """Sets scope name for the current thread.""" 909 self._thread_local_data.scope_name = s 910 911 @property 912 def device_name(self): 913 """Returns the device name for the current thread.""" 914 return self._thread_local_data.device_name 915 916 @property 917 def device_spec(self): 918 """Returns the device spec for the current thread.""" 919 return self._thread_local_data.device_spec 920 921 def _set_device(self, device_name, device_spec): 922 self._thread_local_data.device_name = device_name 923 self._thread_local_data.device_spec = device_spec 924 925 def device(self, name): 926 """Context-manager to force placement of operations and Tensors on a device. 927 928 Args: 929 name: Name of the device or None to get default placement. 930 931 Returns: 932 Context manager that forces device placement. 933 934 Raises: 935 ValueError: If name is not a string or is an invalid device name. 936 RuntimeError: If device scopes are not properly nested. 937 """ 938 if isinstance(name, LogicalDevice): 939 name = name.name 940 elif pydev.is_device_spec(name): 941 name = name.to_string() 942 return _EagerDeviceContext(self, name) 943 944 def devices(self): 945 """List of the names of devices available to execute operations.""" 946 return self._devices 947 948 def host_address_space(self): 949 self.ensure_initialized() 950 with c_api_util.tf_buffer() as buffer_: 951 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_) 952 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 953 return address_space 954 955 # TODO(fishx): remove this property. 956 @property 957 def execution_mode(self): 958 """Gets execution mode for current thread.""" 959 return ASYNC if self.is_async() else SYNC 960 961 @execution_mode.setter 962 def execution_mode(self, mode): 963 """Sets execution mode for current thread.""" 964 if mode not in (None, SYNC, ASYNC): 965 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" % 966 mode) 967 968 if mode is None: 969 mode = SYNC 970 971 enable_async = (mode == ASYNC) 972 if self.is_async() != enable_async: 973 # Only set the execution mode if the context has already been initialized 974 if self._context_handle is not None: 975 self.executor.wait() 976 executor_new = executor.new_executor(enable_async) 977 self._thread_local_data.executor = executor_new 978 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, 979 executor_new.handle()) 980 else: 981 self._default_is_async = enable_async 982 983 def is_async(self): 984 if self._context_handle is not None: 985 return self.executor.is_async() 986 else: 987 return self._default_is_async 988 989 @property 990 def executor(self): 991 self.ensure_initialized() 992 return executor.Executor( 993 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) 994 995 @executor.setter 996 def executor(self, e): 997 self.ensure_initialized() 998 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) 999 1000 @property 1001 def config(self): 1002 """Return the ConfigProto with all runtime deltas applied.""" 1003 # Ensure physical devices have been discovered and config has been imported 1004 self._initialize_physical_devices() 1005 1006 config = config_pb2.ConfigProto() 1007 if self._config is not None: 1008 config.CopyFrom(self._config) 1009 1010 if self._optimizer_jit is not None: 1011 config.graph_options.optimizer_options.global_jit_level = ( 1012 config_pb2.OptimizerOptions.ON_1 1013 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) 1014 if self._intra_op_parallelism_threads is not None: 1015 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads 1016 if self._inter_op_parallelism_threads is not None: 1017 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads 1018 1019 if self._soft_device_placement is not None: 1020 config.allow_soft_placement = self._soft_device_placement 1021 else: 1022 config.allow_soft_placement = self.executing_eagerly() 1023 1024 if self._log_device_placement is not None: 1025 config.log_device_placement = self._log_device_placement 1026 1027 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() 1028 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled 1029 if (is_mlir_bridge_enabled == 1030 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): 1031 config.experimental.enable_mlir_bridge = True 1032 1033 if self._enable_mlir_graph_optimization is not None: 1034 config.experimental.enable_mlir_graph_optimization = ( 1035 self._enable_mlir_graph_optimization) 1036 1037 def rewriter_toggle(option): 1038 toggle = self._optimizer_experimental_options.get(option, None) 1039 if toggle is None: 1040 return 1041 1042 setattr(config.graph_options.rewrite_options, option, 1043 (rewriter_config_pb2.RewriterConfig.ON 1044 if toggle else rewriter_config_pb2.RewriterConfig.OFF)) 1045 1046 def rewriter_bool(option): 1047 toggle = self._optimizer_experimental_options.get(option, None) 1048 if toggle is None: 1049 return 1050 1051 setattr(config.graph_options.rewrite_options, option, toggle) 1052 1053 rewriter_toggle("layout_optimizer") 1054 rewriter_toggle("constant_folding") 1055 rewriter_toggle("shape_optimization") 1056 rewriter_toggle("remapping") 1057 rewriter_toggle("arithmetic_optimization") 1058 rewriter_toggle("dependency_optimization") 1059 rewriter_toggle("loop_optimization") 1060 rewriter_toggle("function_optimization") 1061 rewriter_toggle("debug_stripper") 1062 rewriter_bool("disable_model_pruning") 1063 rewriter_toggle("scoped_allocator_optimization") 1064 rewriter_toggle("pin_to_host_optimization") 1065 rewriter_toggle("implementation_selector") 1066 rewriter_toggle("auto_mixed_precision") 1067 rewriter_toggle("use_plugin_optimizers") 1068 rewriter_bool("disable_meta_optimizer") 1069 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None) 1070 if nodes is not None: 1071 config.graph_options.rewrite_options.min_graph_nodes = nodes 1072 1073 # Compute device counts 1074 config.device_count["CPU"] = 0 1075 config.device_count["GPU"] = 0 1076 for dev in self._physical_devices: 1077 if dev not in self._visible_device_list: 1078 continue 1079 1080 virtual_devices = self._virtual_device_map.get(dev) 1081 if virtual_devices is None: 1082 config.device_count[dev.device_type] += 1 1083 else: 1084 config.device_count[dev.device_type] += len(virtual_devices) 1085 1086 # Configure gpu_options 1087 gpu_options = self._compute_gpu_options() 1088 config.gpu_options.MergeFrom(gpu_options) 1089 1090 # Configure collective ops 1091 if self._collective_leader: 1092 config.experimental.collective_group_leader = self._collective_leader 1093 if self._collective_scoped_allocator_enabled_ops: 1094 rewrite_options = config.graph_options.rewrite_options 1095 rewrite_options.scoped_allocator_optimization = ( 1096 rewriter_config_pb2.RewriterConfig.ON) 1097 del rewrite_options.scoped_allocator_opts.enable_op[:] 1098 for op in self._collective_scoped_allocator_enabled_ops: 1099 rewrite_options.scoped_allocator_opts.enable_op.append(op) 1100 if self._collective_use_nccl_communication: 1101 config.experimental.collective_nccl = True 1102 if self._collective_device_filters: 1103 del config.device_filters[:] 1104 for f in self._collective_device_filters: 1105 config.device_filters.append(f) 1106 1107 # Configure coordination service 1108 if self._coordination_service: 1109 config.experimental.coordination_service = self._coordination_service 1110 1111 return config 1112 1113 def _compute_gpu_options(self): 1114 """Build the GPUOptions proto.""" 1115 visible_device_list = [] 1116 virtual_devices = [] 1117 gpu_index = -1 1118 memory_growths = set() 1119 for dev in self.list_physical_devices("GPU"): 1120 gpu_index += 1 1121 1122 if dev not in self._visible_device_list: 1123 continue 1124 1125 growth = self._memory_growth_map[dev] 1126 memory_growths.add(growth) 1127 visible_device_list.append(str(gpu_index)) 1128 1129 if self._virtual_device_map: 1130 vdevs = self._virtual_device_map.get(dev, []) 1131 device_limits = [] 1132 priority = [] 1133 for virt_dev in vdevs: 1134 device_limits.append(virt_dev.memory_limit) 1135 if virt_dev.experimental_priority is not None: 1136 priority.append(virt_dev.experimental_priority) 1137 # If priority is specified, it must be specified for all virtual 1138 # devices. 1139 if priority and len(device_limits) != len(priority): 1140 raise ValueError("priority must be specified for all virtual devices") 1141 1142 virtual_devices.append( 1143 config_pb2.GPUOptions.Experimental.VirtualDevices( 1144 memory_limit_mb=device_limits, priority=priority)) 1145 1146 # Only compute growth if virtual devices have not been configured and we 1147 # have GPUs 1148 if not virtual_devices and memory_growths: 1149 if len(memory_growths) > 1: 1150 raise ValueError("Memory growth cannot differ between GPU devices") 1151 allow_growth = memory_growths.pop() 1152 else: 1153 allow_growth = None 1154 1155 return config_pb2.GPUOptions( 1156 allow_growth=allow_growth, 1157 visible_device_list=",".join(visible_device_list), 1158 experimental=config_pb2.GPUOptions.Experimental( 1159 virtual_devices=virtual_devices)) 1160 1161 @property 1162 def function_call_options(self): 1163 """Returns function call options for current thread. 1164 1165 Note that the returned object is still referenced by the eager context. 1166 1167 Returns: the FunctionCallOptions for current thread. 1168 """ 1169 if self._thread_local_data.function_call_options is None: 1170 config = self.config 1171 1172 # Default to soft placement for functions unless specified 1173 if self._soft_device_placement is None: 1174 config.allow_soft_placement = True 1175 self._thread_local_data.function_call_options = FunctionCallOptions( 1176 config_proto=config) 1177 1178 return self._thread_local_data.function_call_options 1179 1180 @function_call_options.setter 1181 def function_call_options(self, options): 1182 """Returns function call options for current thread.""" 1183 self._thread_local_data.function_call_options = options 1184 1185 def num_gpus(self): 1186 """The number of GPUs available to execute operations.""" 1187 self.ensure_initialized() 1188 return self._num_gpus 1189 1190 def add_function(self, fn): 1191 """Add a function definition to the context. 1192 1193 Once added, the function (identified by its name) can be executed like any 1194 other operation. 1195 1196 Args: 1197 fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 1198 """ 1199 self.ensure_initialized() 1200 pywrap_tfe.TFE_ContextAddFunction(self._handle, fn) 1201 1202 def add_function_def(self, fdef): 1203 """Add a function definition to the context. 1204 1205 Once added, the function (identified by its name) can be executed like any 1206 other operation. 1207 1208 Args: 1209 fdef: A FunctionDef protocol buffer message. 1210 """ 1211 self.ensure_initialized() 1212 fdef_string = fdef.SerializeToString() 1213 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, 1214 len(fdef_string)) 1215 1216 def get_function_def(self, name): 1217 """Get a function definition from the context. 1218 1219 Args: 1220 name: function signature name. 1221 1222 Returns: 1223 The requested FunctionDef. 1224 1225 Raises: 1226 tf.errors.NotFoundError: if name is not the name of a registered function. 1227 """ 1228 with c_api_util.tf_buffer() as buffer_: 1229 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_) 1230 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1231 function_def = function_pb2.FunctionDef() 1232 function_def.ParseFromString(proto_data) 1233 1234 return function_def 1235 1236 def register_custom_device(self, device_capsule, device_name, 1237 device_info_capsule): 1238 """Calls TFE_RegisterCustomDevice. See the non-member function.""" 1239 self.ensure_initialized() 1240 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, 1241 device_name, device_info_capsule) 1242 1243 def pack_eager_tensors(self, tensors): 1244 """Pack multiple `EagerTensor`s of the same dtype and shape. 1245 1246 Args: 1247 tensors: a list of EagerTensors to pack. 1248 1249 Returns: 1250 A packed EagerTensor. 1251 """ 1252 self.ensure_initialized() 1253 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) 1254 1255 def list_function_names(self): 1256 """Get a list of names of registered functions. 1257 1258 Returns: 1259 A set of names of all registered functions for the context. 1260 """ 1261 self.ensure_initialized() 1262 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle)) 1263 1264 def remove_function(self, name): 1265 """Remove a function from the context. 1266 1267 Once removed, the function cannot be executed anymore. 1268 1269 Args: 1270 name: function signature name. 1271 """ 1272 self.ensure_initialized() 1273 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) 1274 1275 def has_function(self, name): 1276 """Check if a function `name` is registered.""" 1277 self.ensure_initialized() 1278 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) 1279 1280 def add_op_callback(self, callback): 1281 """Add a post-op callback to the context. 1282 1283 A post-op callback is invoked immediately after an eager operation or 1284 function has finished execution or after a op has been added to a graph, 1285 providing access to the op's type, name input and output tensors. Multiple 1286 op callbacks can be added, in which case the callbacks will be invoked in 1287 the order in which they are added. 1288 1289 Args: 1290 callback: a callable of the signature `f(op_type, inputs, attrs, outputs, 1291 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for 1292 details on the function signature and its semantics. 1293 """ 1294 if callback not in self._thread_local_data.op_callbacks: 1295 self._thread_local_data.op_callbacks.append(callback) 1296 1297 def remove_op_callback(self, callback): 1298 """Remove an already-registered op callback. 1299 1300 Args: 1301 callback: The op callback to be removed. 1302 1303 Raises: 1304 KeyError: If `callback` is not already registered. 1305 """ 1306 if callback not in self._thread_local_data.op_callbacks: 1307 raise KeyError("The specified op callback has not been registered, " 1308 "and hence cannot be removed.") 1309 del self._thread_local_data.op_callbacks[ 1310 self._thread_local_data.op_callbacks.index(callback)] 1311 1312 @property 1313 def op_callbacks(self): 1314 return self._thread_local_data.op_callbacks 1315 1316 @property 1317 def invoking_op_callbacks(self): 1318 return self._thread_local_data.invoking_op_callbacks 1319 1320 @invoking_op_callbacks.setter 1321 def invoking_op_callbacks(self, value): 1322 self._thread_local_data.invoking_op_callbacks = value 1323 1324 def _initialize_physical_devices(self, reinitialize=False): 1325 """Gets local devices visible to the system. 1326 1327 Args: 1328 reinitialize: If True, reinitializes self._physical_devices so that 1329 dynamic registered devices will also be visible to the python front-end. 1330 """ 1331 # We lazy initialize self._physical_devices since we do not want to do this 1332 # the constructor since the backend may not be initialized yet. 1333 with self._device_lock: 1334 if not reinitialize and self._physical_devices is not None: 1335 return 1336 1337 devs = pywrap_tfe.TF_ListPhysicalDevices() 1338 self._physical_devices = [ 1339 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 1340 for d in devs 1341 ] 1342 self._physical_device_to_index = { 1343 p: i for i, p in enumerate(self._physical_devices) 1344 } 1345 1346 self._visible_device_list = list(self._physical_devices) 1347 self._memory_growth_map = { 1348 d: None for d in self._physical_devices if d.device_type == "GPU" 1349 } 1350 1351 # Import device settings that may have been passed into the constructor 1352 self._import_config() 1353 1354 def reinitialize_physical_devices(self): 1355 """Gets local devices visible to the system.""" 1356 # Reinitialize the physical device list after registering 1357 # the pluggable device. 1358 self._initialize_physical_devices(True) 1359 1360 def list_physical_devices(self, device_type=None): 1361 """List local devices visible to the system. 1362 1363 This API allows a client to query the devices before they have been 1364 initialized by the eager runtime. Additionally a user can filter by device 1365 type, to get only CPUs or GPUs. 1366 1367 Args: 1368 device_type: Optional device type to limit results to 1369 1370 Returns: 1371 List of PhysicalDevice objects. 1372 """ 1373 self._initialize_physical_devices() 1374 1375 if device_type is None: 1376 return list(self._physical_devices) 1377 1378 return [d for d in self._physical_devices if d.device_type == device_type] 1379 1380 def get_device_details(self, device): # pylint: disable=redefined-outer-name 1381 """Returns details about a physical devices. 1382 1383 Args: 1384 device: A `tf.config.PhysicalDevice` returned by 1385 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. 1386 1387 Returns: 1388 A dict with string keys. 1389 """ 1390 if not isinstance(device, PhysicalDevice): 1391 raise ValueError("device must be a tf.config.PhysicalDevice, but got: " 1392 "%s" % (device,)) 1393 if (self._physical_device_to_index is None or 1394 device not in self._physical_device_to_index): 1395 raise ValueError("The PhysicalDevice must be one obtained from " 1396 "calling `tf.config.list_physical_devices`, but got: " 1397 "%s" % (device,)) 1398 index = self._physical_device_to_index[device] 1399 details = pywrap_tfe.TF_GetDeviceDetails(index) 1400 1401 # Change compute_capability from a string to a tuple 1402 if "compute_capability" in details: 1403 try: 1404 major, minor = details["compute_capability"].split(".") 1405 details["compute_capability"] = (int(major), int(minor)) 1406 except ValueError: 1407 raise RuntimeError("Device returned compute capability an in invalid " 1408 "format: %s" % details["compute_capability"]) 1409 return details 1410 1411 def _import_config(self): 1412 """Import config if passed in during construction. 1413 1414 If Context was created with a ConfigProto such as when calling 1415 tf.compat.v1.enable_eager_execution(), then we need to pull out the 1416 various pieces we might be replacing and import then into our internal 1417 class representation. 1418 """ 1419 if self._config is None: 1420 return 1421 1422 num_cpus = self._config.device_count.get("CPU", 1) 1423 if num_cpus != 1: 1424 cpus = [d for d in self._physical_devices if d.device_type == "CPU"] 1425 if num_cpus == 0: 1426 self.set_visible_devices([], "CPU") 1427 elif num_cpus > 1: 1428 self.set_logical_device_configuration( 1429 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) 1430 1431 # Parse GPU options 1432 gpus = [d for d in self._physical_devices if d.device_type == "GPU"] 1433 1434 # If there are no GPUs detected, simply ignore all the GPU options passed in 1435 # rather than doing any validation checks. 1436 if not gpus: 1437 return 1438 1439 gpu_count = self._config.device_count.get("GPU", None) 1440 1441 visible_gpus = [] 1442 # TODO(gjn): Handle importing existing virtual GPU configuration 1443 visible_indices = self._config.gpu_options.visible_device_list 1444 if visible_indices: 1445 for index in visible_indices.split(","): 1446 if int(index) >= len(gpus): 1447 raise ValueError("Invalid visible device index: %s" % index) 1448 visible_gpus.append(gpus[int(index)]) 1449 else: 1450 visible_gpus = gpus 1451 1452 if gpu_count is not None: 1453 visible_gpus = visible_gpus[:gpu_count] 1454 1455 self.set_visible_devices(visible_gpus, "GPU") 1456 1457 def list_logical_devices(self, device_type=None): 1458 """Return logical devices.""" 1459 self.ensure_initialized() 1460 if device_type is None: 1461 return list(self._logical_devices) 1462 1463 return [d for d in self._logical_devices if d.device_type == device_type] 1464 1465 def get_visible_devices(self, device_type=None): 1466 """Get the list of visible devices.""" 1467 self._initialize_physical_devices() 1468 1469 if device_type is None: 1470 return list(self._visible_device_list) 1471 1472 return [ 1473 d for d in self._visible_device_list if d.device_type == device_type 1474 ] 1475 1476 def set_visible_devices(self, devices, device_type=None): 1477 """Set the list of visible devices.""" 1478 self._initialize_physical_devices() 1479 1480 if not isinstance(devices, list): 1481 devices = [devices] 1482 1483 for d in devices: 1484 if d not in self._physical_devices: 1485 raise ValueError("Unrecognized device: %s" % repr(d)) 1486 if device_type is not None and d.device_type != device_type: 1487 raise ValueError("Unrecognized device: %s" % repr(d)) 1488 1489 visible_device_list = [] 1490 if device_type is not None: 1491 visible_device_list = [ 1492 d for d in self._visible_device_list if d.device_type != device_type 1493 ] 1494 1495 visible_device_list += devices 1496 1497 if self._visible_device_list == visible_device_list: 1498 return 1499 1500 if self._context_handle is not None: 1501 raise RuntimeError( 1502 "Visible devices cannot be modified after being initialized") 1503 1504 self._visible_device_list = visible_device_list 1505 1506 def get_memory_info(self, dev): 1507 """Returns a dict of memory info for the device.""" 1508 self._initialize_physical_devices() 1509 self.ensure_initialized() 1510 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev) 1511 1512 def reset_memory_stats(self, dev): 1513 """Resets the tracked memory stats for the device.""" 1514 self._initialize_physical_devices() 1515 self.ensure_initialized() 1516 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev) 1517 1518 def get_memory_growth(self, dev): 1519 """Get if memory growth is enabled for a PhysicalDevice.""" 1520 self._initialize_physical_devices() 1521 1522 if dev not in self._physical_devices: 1523 raise ValueError("Unrecognized device: %s" % repr(dev)) 1524 1525 return self._memory_growth_map[dev] 1526 1527 def set_memory_growth(self, dev, enable): 1528 """Set if memory growth should be enabled for a PhysicalDevice.""" 1529 self._initialize_physical_devices() 1530 1531 if dev not in self._physical_devices: 1532 raise ValueError("Unrecognized device: %s" % repr(dev)) 1533 1534 if dev in self._virtual_device_map: 1535 raise ValueError( 1536 "Cannot set memory growth on device when virtual devices configured") 1537 1538 if dev.device_type != "GPU": 1539 raise ValueError("Cannot set memory growth on non-GPU devices") 1540 1541 if self._memory_growth_map.get(dev) == enable: 1542 return 1543 1544 if self._context_handle is not None: 1545 raise RuntimeError( 1546 "Physical devices cannot be modified after being initialized") 1547 1548 self._memory_growth_map[dev] = enable 1549 1550 def get_logical_device_configuration(self, dev): 1551 """Get the virtual device configuration for a PhysicalDevice.""" 1552 self._initialize_physical_devices() 1553 1554 if dev not in self._physical_devices: 1555 raise ValueError("Unrecognized device: %s" % repr(dev)) 1556 1557 return self._virtual_device_map.get(dev) 1558 1559 def set_logical_device_configuration(self, dev, virtual_devices): 1560 """Set the virtual device configuration for a PhysicalDevice.""" 1561 self._initialize_physical_devices() 1562 1563 if dev not in self._physical_devices: 1564 raise ValueError("Unrecognized device: %s" % repr(dev)) 1565 1566 if dev.device_type == "CPU": 1567 for vdev in virtual_devices: 1568 if vdev.memory_limit is not None: 1569 raise ValueError("Setting memory limit on CPU virtual devices is " 1570 "currently not supported") 1571 if vdev.experimental_priority is not None: 1572 raise ValueError("Setting experimental_priority on CPU virtual " 1573 " devices is currently not supported") 1574 elif dev.device_type == "GPU": 1575 for vdev in virtual_devices: 1576 if vdev.memory_limit is None: 1577 raise ValueError( 1578 "Setting memory limit is required for GPU virtual devices") 1579 else: 1580 raise ValueError("Virtual devices are not supported for %s" % 1581 dev.device_type) 1582 1583 if self._virtual_device_map.get(dev) == virtual_devices: 1584 return 1585 1586 if self._context_handle is not None: 1587 raise RuntimeError( 1588 "Virtual devices cannot be modified after being initialized") 1589 1590 self._virtual_device_map[dev] = virtual_devices 1591 1592 def set_logical_cpu_devices(self, num_cpus, prefix=""): 1593 """Set virtual CPU devices in context. 1594 1595 If virtual CPU devices are already configured at context initialization 1596 by tf.config.set_logical_device_configuration(), this method should not be 1597 called. 1598 1599 Args: 1600 num_cpus: Number of virtual CPUs. 1601 prefix: Device name prefix. 1602 1603 Raises: 1604 RuntimeError: If virtual CPUs are already configured at context 1605 initialization. 1606 """ 1607 self.ensure_initialized() 1608 # Error out if there are already multiple logical CPU in the context. 1609 if len(self.list_logical_devices("CPU")) > 1: 1610 raise RuntimeError("Virtual CPUs already set, cannot modify again.") 1611 1612 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix) 1613 self._initialize_logical_devices() 1614 1615 def get_compiler_ir(self, device_name, function_name, args, stage="hlo"): 1616 return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name, 1617 stage, device_name, args) 1618 1619 @deprecated( 1620 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) 1621 def enable_xla_devices(self): 1622 """Enables XLA:CPU and XLA:GPU devices registration.""" 1623 pywrap_tfe.TF_EnableXlaDevices() 1624 1625 @property 1626 def enable_mlir_bridge(self): 1627 return pywrap_tfe.TF_IsMlirBridgeEnabled() 1628 1629 @property 1630 def enable_mlir_graph_optimization(self): 1631 return self._enable_mlir_graph_optimization 1632 1633 @enable_mlir_bridge.setter 1634 def enable_mlir_bridge(self, enabled): 1635 pywrap_tfe.TF_EnableMlirBridge(enabled) 1636 self._thread_local_data.function_call_options = None 1637 1638 @enable_mlir_graph_optimization.setter 1639 def enable_mlir_graph_optimization(self, enabled): 1640 self._enable_mlir_graph_optimization = enabled 1641 self._thread_local_data.function_call_options = None 1642 1643 @property 1644 def optimizer_jit(self): 1645 level = self.config.graph_options.optimizer_options.global_jit_level 1646 return (level == config_pb2.OptimizerOptions.ON_1 or 1647 level == config_pb2.OptimizerOptions.ON_2) 1648 1649 @optimizer_jit.setter 1650 def optimizer_jit(self, enabled): 1651 self._optimizer_jit = enabled 1652 1653 self._thread_local_data.function_call_options = None 1654 1655 def get_optimizer_experimental_options(self): 1656 """Get experimental options for the optimizer. 1657 1658 Returns: 1659 Dictionary of current option values 1660 """ 1661 rewrite_options = self.config.graph_options.rewrite_options 1662 options = {} 1663 1664 def rewriter_toggle(option): 1665 attr = getattr(rewrite_options, option) 1666 if attr != 0: 1667 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) 1668 1669 def rewriter_bool(option): 1670 options[option] = getattr(rewrite_options, option) 1671 1672 rewriter_toggle("layout_optimizer") 1673 rewriter_toggle("constant_folding") 1674 rewriter_toggle("shape_optimization") 1675 rewriter_toggle("remapping") 1676 rewriter_toggle("arithmetic_optimization") 1677 rewriter_toggle("dependency_optimization") 1678 rewriter_toggle("loop_optimization") 1679 rewriter_toggle("function_optimization") 1680 rewriter_toggle("debug_stripper") 1681 rewriter_bool("disable_model_pruning") 1682 rewriter_toggle("scoped_allocator_optimization") 1683 rewriter_toggle("pin_to_host_optimization") 1684 rewriter_toggle("implementation_selector") 1685 rewriter_toggle("auto_mixed_precision") 1686 rewriter_toggle("use_plugin_optimizers") 1687 rewriter_bool("disable_meta_optimizer") 1688 1689 if rewrite_options.min_graph_nodes != 0: 1690 options["min_graph_nodes"] = rewrite_options.min_graph_nodes 1691 1692 return options 1693 1694 def set_optimizer_experimental_options(self, options): 1695 """Set experimental options for the optimizer. 1696 1697 Args: 1698 options: Dictionary of options to modify 1699 """ 1700 self._optimizer_experimental_options.update(options) 1701 1702 self._thread_local_data.function_call_options = None 1703 1704 @property 1705 def intra_op_parallelism_threads(self): 1706 return self.config.intra_op_parallelism_threads 1707 1708 @intra_op_parallelism_threads.setter 1709 def intra_op_parallelism_threads(self, num_threads): 1710 if self._intra_op_parallelism_threads == num_threads: 1711 return 1712 1713 if self._context_handle is not None: 1714 raise RuntimeError( 1715 "Intra op parallelism cannot be modified after initialization.") 1716 1717 self._intra_op_parallelism_threads = num_threads 1718 1719 @property 1720 def inter_op_parallelism_threads(self): 1721 return self.config.inter_op_parallelism_threads 1722 1723 @inter_op_parallelism_threads.setter 1724 def inter_op_parallelism_threads(self, num_threads): 1725 if self._inter_op_parallelism_threads == num_threads: 1726 return 1727 1728 if self._context_handle is not None: 1729 raise RuntimeError( 1730 "Inter op parallelism cannot be modified after initialization.") 1731 1732 self._inter_op_parallelism_threads = num_threads 1733 1734 @property 1735 def soft_device_placement(self): 1736 return self.config.allow_soft_placement 1737 1738 @soft_device_placement.setter 1739 def soft_device_placement(self, enable): 1740 if self._context_handle is not None: 1741 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) 1742 1743 self._soft_device_placement = enable 1744 self._thread_local_data.function_call_options = None 1745 1746 @property 1747 def log_device_placement(self): 1748 return self.config.log_device_placement 1749 1750 @log_device_placement.setter 1751 def log_device_placement(self, enable): 1752 if self._context_handle is not None: 1753 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) 1754 1755 self._log_device_placement = enable 1756 self._thread_local_data.function_call_options = None 1757 1758 @property 1759 def device_policy(self): 1760 # Only get the policy from the context if it has already been initialized 1761 if self._context_handle is not None: 1762 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) 1763 1764 return self._device_policy 1765 1766 @device_policy.setter 1767 def device_policy(self, policy): 1768 if policy is None: 1769 policy = DEVICE_PLACEMENT_SILENT 1770 1771 if self._device_policy != policy: 1772 self._device_policy = policy 1773 1774 # Only set the policy if the context has already been initialized 1775 if self._context_handle is not None: 1776 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( 1777 self._handle, self._device_policy) 1778 1779 @property 1780 def use_tfrt(self): 1781 return self._use_tfrt 1782 1783 @use_tfrt.setter 1784 def use_tfrt(self, tfrt): 1785 """Sets whether to use TFRT.""" 1786 if not isinstance(tfrt, bool): 1787 raise ValueError("Expecting a boolean but got %s" % type(tfrt)) 1788 1789 if self._use_tfrt != tfrt: 1790 if self._initialized: 1791 raise ValueError("use_tfrt should be set before being initialized.") 1792 self._use_tfrt = tfrt 1793 1794 @property 1795 def use_tfrt_distributed_runtime(self): 1796 return self._use_tfrt_distributed_runtime 1797 1798 @use_tfrt_distributed_runtime.setter 1799 def use_tfrt_distributed_runtime(self, enable): 1800 """Sets whether to use TFRT distributed runtime. 1801 1802 This is only effective when use_tfrt is also true. Note that currently TFRT 1803 distributed runtime is not function complete and this config is for testing 1804 only. 1805 Args: 1806 enable: A boolean to set whether to use TFRT distributed runtime. 1807 """ 1808 if not isinstance(enable, bool): 1809 raise ValueError("Expecting a boolean but got %s" % type(enable)) 1810 1811 if self._use_tfrt_distributed_runtime != enable: 1812 if self._initialized: 1813 raise ValueError("use_tfrt should be set before being initialized.") 1814 self._use_tfrt_distributed_runtime = enable 1815 1816 def enable_run_metadata(self): 1817 """Enables tracing of op execution via RunMetadata. 1818 1819 To retrieve the accumulated metadata call context.export_run_metadata() 1820 and to stop tracing call context.disable_run_metadata(). 1821 """ 1822 self.ensure_initialized() 1823 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) 1824 1825 def disable_run_metadata(self): 1826 """Disables tracing of op execution via RunMetadata.""" 1827 if not self._context_handle: 1828 return 1829 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) 1830 1831 def enable_graph_collection(self): 1832 """Enables graph collection of executed functions. 1833 1834 To retrieve the accumulated graphs call context.export_run_metadata() 1835 and to stop collecting graphs call context.disable_graph_collection(). 1836 """ 1837 self.ensure_initialized() 1838 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) 1839 1840 def disable_graph_collection(self): 1841 """Disables graph collection of executed functions.""" 1842 if not self._context_handle: 1843 return 1844 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) 1845 1846 def export_run_metadata(self): 1847 """Returns a RunMetadata proto with accumulated information. 1848 1849 The returned protocol buffer contains information since the most recent call 1850 to either enable_run_metadata or export_run_metadata. 1851 1852 Returns: 1853 A RunMetadata protocol buffer. Or None if not enabled. 1854 """ 1855 if not self._context_handle: 1856 return None 1857 with c_api_util.tf_buffer() as buffer_: 1858 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) 1859 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1860 run_metadata = config_pb2.RunMetadata() 1861 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1862 return run_metadata 1863 1864 @property 1865 def context_switches(self): 1866 """Returns a stack of context switches.""" 1867 return self._context_switches 1868 1869 1870class _EagerDeviceContext(object): 1871 """Context-manager forcing placement of ops and Tensors on a device.""" 1872 1873 __slots__ = ["_device_name", "_ctx", "_stack"] 1874 1875 def __init__(self, ctx, device_name): 1876 self._device_name = device_name 1877 self._ctx = ctx 1878 self._stack = [] 1879 1880 # TODO(b/189233748): Consolidate the device string parsing logic with 1881 # tensorflow/core/util/device_name_utils.cc. 1882 def __enter__(self): 1883 ctx = self._ctx 1884 old_device_name = ctx.device_name 1885 old_device_spec = ctx.device_spec 1886 new_device_name = self._device_name 1887 cache_key = (old_device_name, new_device_name) 1888 try: 1889 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 1890 except TypeError: 1891 # Error while trying to compute the cache key. 1892 raise ValueError("Expecting a string device name. Got %s(%s)" % 1893 (type(new_device_name), new_device_name)) 1894 except KeyError: 1895 # Handle a cache miss. 1896 if new_device_name is not None: 1897 if not isinstance(new_device_name, six.string_types): 1898 raise ValueError("Expecting a string device name. Got %s(%s)" % 1899 (type(new_device_name), new_device_name)) 1900 device_spec = pydev.DeviceSpec.from_string(new_device_name) 1901 if old_device_name: 1902 new_device_spec = copy.copy(old_device_spec) 1903 else: 1904 ctx.ensure_initialized() 1905 new_device_spec = pydev.DeviceSpec.from_string( 1906 ctx._context_devices[0]) # pylint: disable=protected-access 1907 new_device_spec = new_device_spec.make_merged_spec(device_spec) 1908 else: 1909 new_device_spec = pydev.DeviceSpec.from_string("") 1910 new_device_name = new_device_spec.to_string() 1911 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 1912 1913 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access 1914 self._stack.append((old_device_name, old_device_spec, new_device_spec)) 1915 1916 def __exit__(self, *ex_info): 1917 ctx = self._ctx 1918 old_device_name, old_device_spec, new_device_spec = self._stack[-1] 1919 if ctx.device_spec is not new_device_spec: 1920 raise RuntimeError("Exiting device scope without proper scope nesting") 1921 del self._stack[-1] 1922 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access 1923 1924 1925# Do not set directly. Use _set_context. 1926_context = None 1927_context_lock = threading.Lock() 1928 1929 1930def _set_context_locked(ctx): 1931 global _context 1932 pywrap_tfe.TFE_Py_SetEagerContext(ctx) 1933 _context = ctx 1934 1935 1936def _set_context(ctx): 1937 with _context_lock: 1938 _set_context_locked(ctx) 1939 1940 1941def _create_context(): 1942 with _context_lock: 1943 if _context is None: 1944 ctx = Context() 1945 _set_context_locked(ctx) 1946 1947 1948def _reset_context(): 1949 """Clears and re-initializes the singleton context. 1950 1951 Should only be used for testing. 1952 """ 1953 global _context 1954 global _device_parsing_cache 1955 1956 # Garbage collect and clear scalar cache to avoid Tensor from current context 1957 # polluting next context. 1958 gc.collect() 1959 pywrap_tfe.TFE_ClearScalarCache() 1960 with _context_lock: 1961 if _context is not None: 1962 _context._clear_caches() 1963 _context = None 1964 _create_context() 1965 _device_parsing_cache = {} 1966 1967 1968def context(): 1969 """Returns a singleton context object.""" 1970 if _context is None: 1971 _create_context() 1972 return _context 1973 1974 1975def context_safe(): 1976 """Returns current context (or None if one hasn't been initialized).""" 1977 return _context 1978 1979 1980def ensure_initialized(): 1981 """Initialize the context.""" 1982 context().ensure_initialized() 1983 1984 1985def initialize_logical_devices(): 1986 """Initialize the virtual devices.""" 1987 context()._initialize_logical_devices() # pylint: disable=protected-access 1988 1989 1990def set_global_seed(seed): 1991 """Sets the eager mode seed.""" 1992 context()._set_global_seed(seed) # pylint: disable=protected-access 1993 1994 1995def global_seed(): 1996 """Returns the eager mode seed.""" 1997 return context()._seed # pylint: disable=protected-access 1998 1999 2000def internal_operation_seed(): 2001 """Returns the operation seed generated based on global seed.""" 2002 return context()._internal_operation_seed() # pylint: disable=protected-access 2003 2004 2005@tf_export("executing_eagerly", v1=[]) 2006def executing_eagerly(): 2007 """Checks whether the current thread has eager execution enabled. 2008 2009 Eager execution is enabled by default and this API returns `True` 2010 in most of cases. However, this API might return `False` in the following use 2011 cases. 2012 2013 * Executing inside `tf.function`, unless under `tf.init_scope` or 2014 `tf.config.run_functions_eagerly(True)` is previously called. 2015 * Executing inside a transformation function for `tf.dataset`. 2016 * `tf.compat.v1.disable_eager_execution()` is called. 2017 2018 General case: 2019 2020 >>> print(tf.executing_eagerly()) 2021 True 2022 2023 Inside `tf.function`: 2024 2025 >>> @tf.function 2026 ... def fn(): 2027 ... with tf.init_scope(): 2028 ... print(tf.executing_eagerly()) 2029 ... print(tf.executing_eagerly()) 2030 >>> fn() 2031 True 2032 False 2033 2034 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called: 2035 2036 >>> tf.config.run_functions_eagerly(True) 2037 >>> @tf.function 2038 ... def fn(): 2039 ... with tf.init_scope(): 2040 ... print(tf.executing_eagerly()) 2041 ... print(tf.executing_eagerly()) 2042 >>> fn() 2043 True 2044 True 2045 >>> tf.config.run_functions_eagerly(False) 2046 2047 Inside a transformation function for `tf.dataset`: 2048 2049 >>> def data_fn(x): 2050 ... print(tf.executing_eagerly()) 2051 ... return x 2052 >>> dataset = tf.data.Dataset.range(100) 2053 >>> dataset = dataset.map(data_fn) 2054 False 2055 2056 Returns: 2057 `True` if the current thread has eager execution enabled. 2058 """ 2059 ctx = context_safe() 2060 if ctx is None: 2061 return default_execution_mode == EAGER_MODE 2062 2063 return ctx.executing_eagerly() 2064 2065 2066@tf_export(v1=["executing_eagerly"]) 2067def executing_eagerly_v1(): 2068 """Checks whether the current thread has eager execution enabled. 2069 2070 Eager execution is typically enabled via 2071 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the 2072 context of a Python function via tf.contrib.eager.py_func. 2073 2074 When eager execution is enabled, returns `True` in most cases. However, 2075 this API might return `False` in the following use cases. 2076 2077 * Executing inside `tf.function`, unless under `tf.init_scope` or 2078 `tf.config.run_functions_eagerly(True)` is previously called. 2079 * Executing inside a transformation function for `tf.dataset`. 2080 * `tf.compat.v1.disable_eager_execution()` is called. 2081 2082 >>> tf.compat.v1.enable_eager_execution() 2083 2084 General case: 2085 2086 >>> print(tf.executing_eagerly()) 2087 True 2088 2089 Inside `tf.function`: 2090 2091 >>> @tf.function 2092 ... def fn(): 2093 ... with tf.init_scope(): 2094 ... print(tf.executing_eagerly()) 2095 ... print(tf.executing_eagerly()) 2096 >>> fn() 2097 True 2098 False 2099 2100 Inside `tf.function` 2101 after `tf.config.run_functions_eagerly(True)` is called: 2102 2103 >>> tf.config.run_functions_eagerly(True) 2104 >>> @tf.function 2105 ... def fn(): 2106 ... with tf.init_scope(): 2107 ... print(tf.executing_eagerly()) 2108 ... print(tf.executing_eagerly()) 2109 >>> fn() 2110 True 2111 True 2112 >>> tf.config.run_functions_eagerly(False) 2113 2114 Inside a transformation function for `tf.dataset`: 2115 2116 >>> def data_fn(x): 2117 ... print(tf.executing_eagerly()) 2118 ... return x 2119 >>> dataset = tf.data.Dataset.range(100) 2120 >>> dataset = dataset.map(data_fn) 2121 False 2122 2123 Returns: 2124 `True` if the current thread has eager execution enabled. 2125 """ 2126 return executing_eagerly() 2127 2128 2129def in_eager_mode(): 2130 """Use executing_eagerly() instead. This function will be removed.""" 2131 return executing_eagerly() 2132 2133 2134def shared_name(name=None): 2135 """Returns the anonymous shared name GUID if no shared name is specified. 2136 2137 In eager mode we need to use a unique shared name to avoid spurious sharing 2138 issues. The runtime generates a unique name on our behalf when the reserved 2139 GUID is used as a shared name. 2140 2141 Args: 2142 name: Optional shared name 2143 2144 Returns: 2145 Eager compatible shared name. 2146 """ 2147 if name or not executing_eagerly(): 2148 return name 2149 2150 # Ensure a unique name when eager execution is enabled to avoid spurious 2151 # sharing issues. 2152 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 2153 2154 2155def graph_mode(): 2156 """Context-manager to disable eager execution for the current thread.""" 2157 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 2158 2159 2160# Used by b/167638505 for keras backend API and Lambda layer. 2161@tf_export("__internal__.eager_context.eager_mode", v1=[]) 2162def eager_mode(): 2163 """Context-manager to enable eager execution for the current thread.""" 2164 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 2165 2166 2167def scope_name(): 2168 """Name of the current scope.""" 2169 return context().scope_name 2170 2171 2172def device(name): 2173 """Context-manager to force placement of operations and Tensors on a device. 2174 2175 Example: 2176 ```python 2177 with tf.device('gpu:0'): 2178 with tf.device('cpu:0'): 2179 shape = tf.constant([], dtype=tf.int32) 2180 x = tf.random.truncated_normal(shape, tf.float32) 2181 ``` 2182 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 2183 operation runs on GPU 0. 2184 2185 Args: 2186 name: Name of the device (see context().devices()), or None to perform 2187 automatic placement. 2188 2189 Returns: 2190 Context manager for setting the device. 2191 """ 2192 ensure_initialized() 2193 return context().device(name) 2194 2195 2196# Expose some properties of Context as internally public APIs (b/160348781). 2197@tf_export("__internal__.eager_context.get_config", v1=[]) 2198def get_config(): 2199 """Get the ConfigProto of Context. 2200 2201 Returns: 2202 The ConfigProto of Context. 2203 """ 2204 return context().config 2205 2206 2207@tf_export("__internal__.eager_context.get_device_name", v1=[]) 2208def get_device_name(): 2209 """Get the device name for the current thread. 2210 2211 Returns: 2212 The device name for the current thread. 2213 """ 2214 return context().device_name 2215 2216 2217@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) 2218def set_soft_device_placement(enabled): 2219 """Set if soft device placements should be allowed. 2220 2221 Args: 2222 enabled: Whether to enable soft device placement. 2223 """ 2224 context().soft_device_placement = enabled 2225 2226 2227@tf_export("__internal__.eager_context.get_executor", v1=[]) 2228def get_executor(): 2229 """Get the Executor of the current thread. 2230 2231 Returns: 2232 The Executor of the current thread. 2233 """ 2234 return context().executor 2235 2236 2237@tf_export("debugging.get_log_device_placement") 2238def get_log_device_placement(): 2239 """Get if device placements are logged. 2240 2241 Returns: 2242 If device placements are logged. 2243 """ 2244 return context().log_device_placement 2245 2246 2247@tf_export("debugging.set_log_device_placement") 2248def set_log_device_placement(enabled): 2249 """Turns logging for device placement decisions on or off. 2250 2251 Operations execute on a particular device, producing and consuming tensors on 2252 that device. This may change the performance of the operation or require 2253 TensorFlow to copy data to or from an accelerator, so knowing where operations 2254 execute is useful for debugging performance issues. 2255 2256 For more advanced profiling, use the [TensorFlow 2257 profiler](https://www.tensorflow.org/guide/profiler). 2258 2259 Device placement for operations is typically controlled by a `tf.device` 2260 scope, but there are exceptions, for example operations on a `tf.Variable` 2261 which follow the initial placement of the variable. Turning off soft device 2262 placement (with `tf.config.set_soft_device_placement`) provides more explicit 2263 control. 2264 2265 >>> tf.debugging.set_log_device_placement(True) 2266 >>> tf.ones([]) 2267 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0 2268 >>> with tf.device("CPU"): 2269 ... tf.ones([]) 2270 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0 2271 >>> tf.debugging.set_log_device_placement(False) 2272 2273 Turning on `tf.debugging.set_log_device_placement` also logs the placement of 2274 ops inside `tf.function` when the function is called. 2275 2276 Args: 2277 enabled: Whether to enabled device placement logging. 2278 """ 2279 context().log_device_placement = enabled 2280 2281 2282@tf_contextlib.contextmanager 2283def device_policy(policy): 2284 """Context manager for setting device placement policy for current thread.""" 2285 ctx = context() 2286 old_policy = ctx.device_policy 2287 try: 2288 ctx.device_policy = policy 2289 yield 2290 finally: 2291 ctx.device_policy = old_policy 2292 2293 2294def set_execution_mode(mode): 2295 """Sets execution mode for the current thread.""" 2296 context().execution_mode = mode 2297 2298 2299# TODO(fishx): remove this method. 2300@tf_contextlib.contextmanager 2301def execution_mode(mode): 2302 """Context manager for setting execution mode for current thread.""" 2303 if mode is None: 2304 yield 2305 else: 2306 ctx = context() 2307 executor_new = executor.new_executor(mode == ASYNC) 2308 executor_old = ctx.executor 2309 try: 2310 executor_old.wait() 2311 ctx.executor = executor_new 2312 yield 2313 finally: 2314 ctx.executor = executor_old 2315 executor_new.wait() 2316 2317 2318@tf_contextlib.contextmanager 2319def executor_scope(e): 2320 """Context manager for changing executor for current thread. 2321 2322 Args: 2323 e: A Executor to execute eager ops under this scope. Setting it to None will 2324 switch back to use the default executor for the context. 2325 2326 Yields: 2327 Context manager for setting the executor for current thread. 2328 """ 2329 ctx = context() 2330 executor_old = ctx.executor 2331 try: 2332 ctx.executor = e 2333 yield 2334 finally: 2335 ctx.executor = executor_old 2336 2337 2338@tf_export("experimental.function_executor_type") 2339@tf_contextlib.contextmanager 2340def function_executor_type(executor_type): 2341 """Context manager for setting the executor of eager defined functions. 2342 2343 Eager defined functions are functions decorated by tf.contrib.eager.defun. 2344 2345 Args: 2346 executor_type: a string for the name of the executor to be used to execute 2347 functions defined by tf.contrib.eager.defun. 2348 2349 Yields: 2350 Context manager for setting the executor of eager defined functions. 2351 """ 2352 current_options = context().function_call_options 2353 old_options = copy.copy(current_options) 2354 try: 2355 current_options.executor_type = executor_type 2356 yield 2357 finally: 2358 context().function_call_options = old_options 2359 2360 2361def is_async(): 2362 """Returns true if current thread is in async mode.""" 2363 return context().is_async() 2364 2365 2366def num_gpus(): 2367 """Get the number of available GPU devices. 2368 2369 Returns: 2370 The number of available GPU devices. 2371 """ 2372 return context().num_gpus() 2373 2374 2375def enable_run_metadata(): 2376 """Enables tracing of op execution via RunMetadata. 2377 2378 To retrieve the accumulated metadata call context.export_run_metadata() 2379 and to stop tracing call context.disable_run_metadata(). 2380 """ 2381 context().enable_run_metadata() 2382 2383 2384def disable_run_metadata(): 2385 """Disables tracing of op execution via RunMetadata.""" 2386 context().disable_run_metadata() 2387 2388 2389def enable_graph_collection(): 2390 """Enables graph collection of executed functions. 2391 2392 To retrieve the accumulated graphs call context.export_run_metadata() 2393 and to stop collecting graphs call context.disable_graph_collection(). 2394 """ 2395 context().enable_graph_collection() 2396 2397 2398def disable_graph_collection(): 2399 """Disables graph collection of executed functions.""" 2400 context().disable_graph_collection() 2401 2402 2403def export_run_metadata(): 2404 """Returns a RunMetadata proto with accumulated information. 2405 2406 The returned protocol buffer contains information since the most recent call 2407 to either enable_run_metadata or export_run_metadata. 2408 2409 Returns: 2410 A RunMetadata protocol buffer. 2411 """ 2412 return context().export_run_metadata() 2413 2414 2415@contextlib.contextmanager 2416def collect_graphs(optimized=True): 2417 """Collects a flat list of pre- or post-optimization graphs. 2418 2419 The collected graphs include device placements, which can be useful for 2420 testing. 2421 2422 Usage: 2423 2424 ``` 2425 @def_function.function 2426 def f(x): 2427 return x + constant_op.constant(1.) 2428 2429 with context.collect_graphs() as graphs: 2430 with ops.device("CPU:0"): 2431 f(constant_op.constant(1.)) 2432 2433 graph, = graphs # `graph` contains a single GraphDef for inspection 2434 ``` 2435 2436 Args: 2437 optimized: whether to collect optimized graphs or non-optimized graphs 2438 2439 Yields: 2440 A list of GraphDefs, populated when the context manager exits. 2441 """ 2442 ctx = context() 2443 ctx.enable_graph_collection() 2444 try: 2445 graphs = [] 2446 yield graphs 2447 metadata = ctx.export_run_metadata() 2448 finally: 2449 ctx.disable_graph_collection() 2450 for graph in metadata.function_graphs: 2451 if optimized: 2452 graphs.append(graph.post_optimization_graph) 2453 else: 2454 graphs.append(graph.pre_optimization_graph) 2455 2456 2457def get_server_def(): 2458 return context().get_server_def() 2459 2460 2461def set_server_def(server_def): 2462 context().set_server_def(server_def) 2463 2464 2465def update_server_def(server_def): 2466 context().update_server_def(server_def) 2467 2468 2469def check_alive(worker_name): 2470 return context().check_alive(worker_name) 2471 2472 2473@tf_export("experimental.async_scope") 2474@tf_contextlib.contextmanager 2475def async_scope(): 2476 """Context manager for grouping async operations. 2477 2478 Ops/function calls inside the scope can return before finishing the actual 2479 execution. When exiting the async scope, a synchronization barrier will be 2480 automatically added to ensure the completion of all async op and function 2481 execution, potentially raising exceptions if async execution results in 2482 an error state. 2483 2484 Users may write the following code to asynchronously invoke `train_step_fn` 2485 and log the `loss` metric for every `num_steps` steps in a training loop. 2486 `train_step_fn` internally consumes data using `iterator.get_next()`, and may 2487 throw OutOfRangeError when running out of data. In the case: 2488 2489 ``` 2490 try: 2491 with tf.experimental.async_scope(): 2492 for _ in range(num_steps): 2493 # Step function updates the metric `loss` internally 2494 train_step_fn() 2495 except tf.errors.OutOfRangeError: 2496 tf.experimental.async_clear_error() 2497 logging.info('loss = %s', loss.numpy()) 2498 ``` 2499 2500 Yields: 2501 Context manager for grouping async operations. 2502 """ 2503 # TODO(haoyuzhang): replace env var once we have a config method to turn on 2504 # and off async streaming RPC 2505 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" 2506 old_policy = os.environ.get(remote_async_env_var) 2507 try: 2508 os.environ[remote_async_env_var] = str(True) 2509 yield 2510 # Note: sync local and remote executors iff the async block does not raise 2511 # an exception. Triggering sync after an exception may lead to derived 2512 # runtime errors and unexpected exception types. 2513 context().sync_executors() 2514 finally: 2515 if old_policy is None: 2516 del os.environ[remote_async_env_var] 2517 else: 2518 os.environ[remote_async_env_var] = old_policy 2519 2520 2521def async_wait(): 2522 """Sync all async operations and raise any errors during execution. 2523 2524 In async execution mode, an op/function call can return before finishing the 2525 actual execution. Calling this method creates a synchronization barrier for 2526 all async op and function execution. It only returns when all pending nodes 2527 are finished, potentially raising exceptions if async execution results in 2528 an error state. It is a no-op if the context is not initialized. 2529 """ 2530 if context()._context_handle is not None: # pylint: disable=protected-access 2531 context().sync_executors() 2532 2533 2534@tf_export("experimental.async_clear_error") 2535def async_clear_error(): 2536 """Clear pending operations and error statuses in async execution. 2537 2538 In async execution mode, an error in op/function execution can lead to errors 2539 in subsequent ops/functions that are scheduled but not yet executed. Calling 2540 this method clears all pending operations and reset the async execution state. 2541 2542 Example: 2543 2544 ``` 2545 while True: 2546 try: 2547 # Step function updates the metric `loss` internally 2548 train_step_fn() 2549 except tf.errors.OutOfRangeError: 2550 tf.experimental.async_clear_error() 2551 break 2552 logging.info('loss = %s', loss.numpy()) 2553 ``` 2554 """ 2555 context().clear_executor_errors() 2556 2557 2558def add_function(fdef): 2559 """Add a function definition to the context.""" 2560 context().add_function(fdef) 2561 2562 2563def remove_function(name): 2564 """Remove a function from the context.""" 2565 context().remove_function(name) 2566 2567 2568def get_function_def(name): 2569 return context().get_function_def(name) 2570 2571 2572def register_custom_device(device_capsule, device_name, device_info_capsule): 2573 """Calls TFE_RegisterCustomDevice to register a custom device with Python. 2574 2575 Enables using C extensions specifying a custom device from Python. See the 2576 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 2577 details. 2578 2579 Note that custom devices are not currently supported inside `tf.function`s. 2580 2581 Args: 2582 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' 2583 containing a pointer to a TFE_CustomDevice struct. The capsule retains 2584 ownership of the memory. 2585 device_name: A string indicating the name to register the custom device 2586 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may 2587 subsequently be passed to `with tf.device(...):`. 2588 device_info_capsule: A PyCapsule with the name set to 2589 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific 2590 struct with the initial state of the custom device (the void* device_info 2591 argument to TFE_RegisterCustomDevice). This method takes ownership of the 2592 memory and clears the capsule destructor. 2593 """ 2594 context().register_custom_device(device_capsule, device_name, 2595 device_info_capsule) 2596 2597 2598# Not every user creates a Context via context.context() 2599# (for example, enable_eager_execution in python/framework/ops.py), 2600# but they do all import this file. Note that IS_IN_GRAPH_MODE and 2601# in_graph_mode are both parameterless functions. 2602def _tmp_in_graph_mode(): 2603 if context_safe() is None: 2604 # Context not yet initialized. Assume graph mode following the 2605 # default implementation in `is_in_graph_mode`. 2606 return True 2607 return not executing_eagerly() 2608 2609 2610is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode 2611