1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""State management for eager execution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import os 25import random 26import threading 27 28from absl import logging 29import numpy as np 30import six 31 32from tensorflow.core.framework import function_pb2 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.core.protobuf import rewriter_config_pb2 35from tensorflow.python import pywrap_tfe 36from tensorflow.python import tf2 37from tensorflow.python.client import pywrap_tf_session 38from tensorflow.python.eager import executor 39from tensorflow.python.eager import monitoring 40from tensorflow.python.framework import c_api_util 41from tensorflow.python.framework import device as pydev 42from tensorflow.python.framework import tfrt_utils 43from tensorflow.python.util import compat 44from tensorflow.python.util import is_in_graph_mode 45from tensorflow.python.util import tf_contextlib 46from tensorflow.python.util.deprecation import deprecated 47from tensorflow.python.util.tf_export import tf_export 48 49GRAPH_MODE = 0 50EAGER_MODE = 1 51 52default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 53 54# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 55# new_device_spec). 56# Note that we do not protect this with a lock and instead rely on python's GIL 57# and the idempotent nature of writes to provide thread safety. 58_device_parsing_cache = {} 59_starting_device_spec = pydev.DeviceSpec.from_string("") 60 61_MAXINT32 = 2**31 - 1 62 63DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT 64DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN 65DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT 66DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 67 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 68 69SYNC = 0 70ASYNC = 1 71 72_KEEP_ALIVE_SECS = 600 73 74_python_eager_context_create_counter = monitoring.Counter( 75 "/tensorflow/api/python/eager_context_create_counter", 76 "Counter for number of eager contexts created in Python.") 77 78# Re-exporting through context. 79is_tfrt_enabled = tfrt_utils.enabled 80 81# Expose it as internally public APIs for Keras use cases in b/171080602. 82tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) 83 84 85class _EagerTensorCache(object): 86 """Simple cache which evicts items based on length in a FIFO manner.""" 87 88 __slots__ = ["_data", "_max_items", "_max_tensor_size"] 89 90 def __init__(self, max_items=256, max_tensor_size=10000): 91 self._data = collections.OrderedDict() 92 self._max_items = max_items 93 self._max_tensor_size = max_tensor_size 94 95 def put(self, key, value): 96 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 97 return 98 99 self._data[key] = value 100 101 if len(self._data) > self._max_items: 102 self._data.popitem(last=False) 103 104 def get(self, key): 105 return self._data.get(key, None) 106 107 def flush(self): 108 self._data.clear() 109 110 111class FunctionCallOptions(object): 112 """Options applied at call sites of eager functions. 113 114 Eager functions are functions decorated with tf.contrib.eager.defun. 115 """ 116 117 __slots__ = ["_config_proto_serialized", "_executor_type"] 118 119 def __init__(self, executor_type=None, config_proto=None): 120 """Constructor. 121 122 Args: 123 executor_type: (optional) name of the executor to be used to execute the 124 eager function. If None or an empty string, the default Tensorflow 125 executor will be used. 126 config_proto: (optional) a `config_pb2.ConfigProto` proto or 127 a serialized string of that proto. 128 The config used by Grappler when optimizing the function graph. 129 Each concrete function is optimized the first time is called. Changing 130 config_proto after the first call has no effect. 131 If config_proto is None, an empty RewriterConfig will be used. 132 """ 133 self.config_proto_serialized = config_proto 134 self.executor_type = executor_type 135 136 @property 137 def executor_type(self): 138 return self._executor_type 139 140 @executor_type.setter 141 def executor_type(self, executor_type): 142 self._executor_type = executor_type 143 144 @property 145 def config_proto_serialized(self): 146 return self._config_proto_serialized 147 148 @config_proto_serialized.setter 149 def config_proto_serialized(self, config): 150 if isinstance(config, config_pb2.ConfigProto): 151 self._config_proto_serialized = config.SerializeToString( 152 deterministic=True) 153 elif isinstance(config, str): 154 self._config_proto_serialized = config 155 elif config is None: 156 self._config_proto_serialized = ( 157 config_pb2.ConfigProto().SerializeToString()) 158 else: 159 raise ValueError("the rewriter config must be either a " 160 "config_pb2.ConfigProto, or a serialized string of that " 161 "proto or None. got: {}".format(type(config))) 162 163 164# Map from context_id (an int) to _TensorCaches. 165# Dicts are thread safe in CPython. 166# TODO(iga): Remove this once TensorCaches are moved to C++. 167_tensor_caches_map = {} 168 169 170class _TensorCaches(threading.local): 171 """Thread local tensor caches.""" 172 173 __slots__ = ["_ones_rank_cache", "_zeros_cache"] 174 175 def __init__(self): 176 super(_TensorCaches, self).__init__() 177 self._ones_rank_cache = None 178 self._zeros_cache = None 179 180 @property 181 def ones_rank_cache(self): 182 if not self._ones_rank_cache: 183 self._ones_rank_cache = _EagerTensorCache() 184 return self._ones_rank_cache 185 186 @property 187 def zeros_cache(self): 188 if not self._zeros_cache: 189 self._zeros_cache = _EagerTensorCache() 190 return self._zeros_cache 191 192 193ContextSwitch = collections.namedtuple( 194 "ContextSwitch", ["is_building_function", "enter_context_fn", 195 "device_stack"]) 196 197 198# `_ContextSwitchStack` is a `threading.local` to match the semantics of 199# ``DefaultGraphStack`, which is also a `threading.local`. 200class _ContextSwitchStack(threading.local): 201 """A thread-local stack of context switches.""" 202 203 def __init__(self, eager): 204 super(_ContextSwitchStack, self).__init__() 205 self.stack = [] 206 if eager: 207 # Initialize the stack with a pointer to enter the eager context; this 208 # ensures that the fact that eager execution was enabled is propagated 209 # across threads, since (1) `enable_eager_execution` modifies a 210 # process-level flag (`default_execution_mode`) and (2) `__init__` is 211 # called each time a threading.local object is used in a separate thread. 212 self.push(is_building_function=False, enter_context_fn=eager_mode, 213 device_stack=None) 214 215 def push(self, is_building_function, enter_context_fn, device_stack): 216 """Push metadata about a context switch onto the stack. 217 218 A context switch can take any one of the two forms: installing a graph as 219 the default graph, or entering the eager context. For each context switch, 220 we record whether or not the entered context is building a function. 221 222 Args: 223 is_building_function: (bool.) Whether the context is building a function. 224 enter_context_fn: (function.) A callable that executes the context switch. 225 For example, `graph.as_default` or `eager_mode`. 226 device_stack: If applicable, the device function stack for this 227 graph. When breaking out of graphs in init_scope, the innermost nonempty 228 device stack is used. Eager contexts put `None` here and the value is 229 never used. 230 """ 231 232 self.stack.append( 233 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 234 235 def pop(self): 236 """Pop the stack.""" 237 238 self.stack.pop() 239 240 241@tf_export("config.LogicalDevice") 242class LogicalDevice( 243 collections.namedtuple("LogicalDevice", ["name", "device_type"])): 244 """Abstraction for a logical device initialized by the runtime. 245 246 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a 247 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors 248 and operations can be placed on a specific logical device by calling 249 `tf.device` with a specified `tf.config.LogicalDevice`. 250 251 Fields: 252 name: The fully qualified name of the device. Can be used for Op or function 253 placement. 254 device_type: String declaring the type of device such as "CPU" or "GPU". 255 """ 256 pass 257 258 259@tf_export("config.LogicalDeviceConfiguration", 260 "config.experimental.VirtualDeviceConfiguration") 261class LogicalDeviceConfiguration( 262 collections.namedtuple("LogicalDeviceConfiguration", 263 ["memory_limit", "experimental_priority"])): 264 """Configuration class for a logical devices. 265 266 The class specifies the parameters to configure a `tf.config.PhysicalDevice` 267 as it is initialized to a `tf.config.LogicalDevice` during runtime 268 initialization. Not all fields are valid for all device types. 269 270 See `tf.config.get_logical_device_configuration` and 271 `tf.config.set_logical_device_configuration` for usage examples. 272 273 Fields: 274 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual 275 device. Currently only supported for GPUs. 276 experimental_priority: (optional) Priority to assign to a virtual device. 277 Lower values have higher priorities and 0 is the default. 278 Within a physical GPU, the GPU scheduler will prioritize ops on virtual 279 devices with higher priority. Currently only supported for Nvidia GPUs. 280 """ 281 282 def __new__(cls, memory_limit=None, experimental_priority=None): 283 return super(LogicalDeviceConfiguration, 284 cls).__new__(cls, memory_limit, experimental_priority) 285 286 287@tf_export("config.PhysicalDevice") 288class PhysicalDevice( 289 collections.namedtuple("PhysicalDevice", ["name", "device_type"])): 290 """Abstraction for a locally visible physical device. 291 292 TensorFlow can utilize various devices such as the CPU or multiple GPUs 293 for computation. Before initializing a local device for use, the user can 294 customize certain properties of the device such as it's visibility or memory 295 configuration. 296 297 Once a visible `tf.config.PhysicalDevice` is initialized one or more 298 `tf.config.LogicalDevice` objects are created. Use 299 `tf.config.set_visible_devices` to configure the visibility of a physical 300 device and `tf.config.set_logical_device_configuration` to configure multiple 301 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is 302 useful when separation between models is needed or to simulate a multi-device 303 environment. 304 305 Fields: 306 name: Unique identifier for device. 307 device_type: String declaring the type of device such as "CPU" or "GPU". 308 """ 309 pass 310 311 312class _AtomicCounter(object): 313 """A simple atomic counter.""" 314 315 __slots__ = ["_value", "_lock"] 316 317 def __init__(self): 318 self._value = 0 319 self._lock = threading.Lock() 320 321 def increment_and_get(self): 322 with self._lock: 323 self._value += 1 324 return self._value 325 326 327_context_id_counter = _AtomicCounter() 328 329 330class _TensorCacheDeleter(object): 331 """Deletes tensor caches for a given context.""" 332 333 __slots__ = ["_context_id"] 334 335 def __init__(self, context_id): 336 self._context_id = context_id 337 338 def __del__(self): 339 if _tensor_caches_map is None: 340 return 341 if self._context_id in _tensor_caches_map: 342 del _tensor_caches_map[self._context_id] 343 344 345# TODO(agarwal): rename to EagerContext / EagerRuntime ? 346# TODO(agarwal): consider keeping the corresponding Graph here. 347class Context(object): 348 """Environment in which eager operations execute.""" 349 350 # TODO(agarwal): create and link in some documentation for `execution_mode`. 351 # pylint: disable=redefined-outer-name 352 def __init__(self, 353 config=None, 354 device_policy=None, 355 execution_mode=None, 356 server_def=None): 357 """Creates a new Context. 358 359 Args: 360 config: (Optional.) A `ConfigProto` protocol buffer with configuration 361 options for the Context. Note that a lot of these options may be 362 currently unimplemented or irrelevant when eager execution is enabled. 363 device_policy: (Optional.) What policy to use when trying to run an 364 operation on a device with inputs which are not on that device. 365 When set to None, an appropriate value will be picked automatically. 366 The value picked may change between TensorFlow releases. 367 368 Defaults to DEVICE_PLACEMENT_SILENT. 369 Valid values: 370 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is 371 not correct. 372 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the 373 right device but raises a warning. 374 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might 375 hide performance problems. 376 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 377 raising errors on the other ones. 378 execution_mode: (Optional.) Policy controlling how operations dispatched 379 are actually executed. When set to None, an appropriate value will be 380 picked automatically. The value picked may change between TensorFlow 381 releases. 382 Valid values: 383 - SYNC: executes each operation synchronously. 384 - ASYNC: executes each operation asynchronously. These 385 operations may return "non-ready" handles. 386 server_def: (Optional.) A tensorflow::ServerDef proto. 387 Enables execution on remote devices. GrpcServers need to be started by 388 creating an identical server_def to this, and setting the appropriate 389 task_indexes, so that the servers can communicate. It will then be 390 possible to execute operations on remote devices. 391 392 Raises: 393 ValueError: If execution_mode is not valid. 394 """ 395 # This _id is used only to index the tensor caches. 396 # TODO(iga): Remove this when tensor caches are moved to C++. 397 self._id = _context_id_counter.increment_and_get() 398 self._tensor_cache_deleter = _TensorCacheDeleter(self._id) 399 _tensor_caches_map[self._id] = _TensorCaches() 400 401 self._config = config 402 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( 403 self, 404 is_eager=lambda: default_execution_mode == EAGER_MODE, 405 device_spec=_starting_device_spec) 406 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 407 self._context_handle = None 408 self._context_devices = None 409 self._seed = None 410 self._initialize_lock = threading.Lock() 411 self._initialized = False 412 if device_policy is None: 413 device_policy = DEVICE_PLACEMENT_SILENT 414 self._device_policy = device_policy 415 self._mirroring_policy = None 416 if execution_mode not in (None, SYNC, ASYNC): 417 raise ValueError( 418 "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) 419 if execution_mode is None: 420 execution_mode = SYNC 421 self._default_is_async = execution_mode == ASYNC 422 self._use_tfrt = is_tfrt_enabled() 423 self._server_def = server_def 424 self._collective_ops_server_def = None 425 self._collective_leader = None 426 self._collective_scoped_allocator_enabled_ops = None 427 self._collective_use_nccl_communication = None 428 self._collective_device_filters = None 429 430 self._device_lock = threading.Lock() 431 self._physical_devices = None 432 self._physical_device_to_index = None 433 self._visible_device_list = [] 434 self._memory_growth_map = None 435 self._virtual_device_map = {} 436 437 # Values set after construction 438 self._optimizer_jit = None 439 self._intra_op_parallelism_threads = None 440 self._inter_op_parallelism_threads = None 441 self._soft_device_placement = None 442 self._log_device_placement = None 443 self._enable_mlir_graph_optimization = None 444 self._optimizer_experimental_options = {} 445 446 _python_eager_context_create_counter.get_cell().increase_by(1) 447 # pylint: enable=redefined-outer-name 448 449 def _set_global_seed(self, seed): 450 """Set a global eager mode seed for random ops.""" 451 self._seed = seed 452 # `random.Random(seed)` needs `seed` to be hashable, while values of type 453 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them 454 # to int. 455 try: 456 hash(seed) 457 except TypeError: 458 seed = int(np.array(seed)) 459 self._rng = random.Random(seed) 460 # Also clear the kernel cache, to reset any existing seeds 461 if self._context_handle is not None: 462 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 463 464 def _internal_operation_seed(self): 465 """Returns a fake operation seed. 466 467 In eager mode, user shouldn't set or depend on operation seed. 468 Here, we generate a random seed based on global seed to make 469 operation's randomness different and depend on the global seed. 470 471 Returns: 472 A fake operation seed based on global seed. 473 """ 474 return self._rng.randint(0, _MAXINT32) 475 476 def _initialize_logical_devices(self): 477 """Helper to initialize devices.""" 478 # Store list of devices 479 logical_devices = [] 480 context_devices = [] 481 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) 482 try: 483 self._num_gpus = 0 484 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): 485 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) 486 context_devices.append(pydev.canonical_name(dev_name)) 487 spec = pydev.DeviceSpec.from_string(dev_name) 488 # If the job is localhost, we assume that the cluster has not yet been 489 # configured and thus clear the job, replica & task. 490 if spec.job == "localhost": 491 spec = spec.replace(job=None, replica=None, task=None) 492 logical_devices.append( 493 LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) 494 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) 495 if dev_type == "GPU": 496 self._num_gpus += 1 497 498 finally: 499 self._logical_devices = logical_devices 500 self._context_devices = context_devices 501 pywrap_tfe.TF_DeleteDeviceList(device_list) 502 503 def ensure_initialized(self): 504 """Initialize handle and devices if not already done so.""" 505 if self._initialized: 506 return 507 with self._initialize_lock: 508 if self._initialized: 509 return 510 assert self._context_devices is None 511 opts = pywrap_tfe.TFE_NewContextOptions() 512 try: 513 config_str = self.config.SerializeToString() 514 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) 515 if self._device_policy is not None: 516 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( 517 opts, self._device_policy) 518 if self._mirroring_policy is not None: 519 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( 520 opts, self._mirroring_policy) 521 if self._default_is_async == ASYNC: 522 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) 523 if self._use_tfrt is not None: 524 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) 525 context_handle = pywrap_tfe.TFE_NewContext(opts) 526 finally: 527 pywrap_tfe.TFE_DeleteContextOptions(opts) 528 assert not (self._server_def and self._collective_ops_server_def), ( 529 "Cannot enable remote execution as well as collective ops at the " 530 "moment. If this is important to you, please file an issue.") 531 if self._server_def is not None: 532 server_def_str = self._server_def.SerializeToString() 533 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, 534 server_def_str) 535 elif self._collective_ops_server_def is not None: 536 server_def_str = self._collective_ops_server_def.SerializeToString() 537 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) 538 539 self._context_handle = context_handle 540 self._initialize_logical_devices() 541 self._initialized = True 542 543 def _clear_caches(self): 544 self.ones_rank_cache().flush() 545 self.zeros_cache().flush() 546 pywrap_tfe.TFE_ClearScalarCache() 547 548 def get_server_def(self): 549 return self._server_def 550 551 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 552 """Allow setting a server_def on the context. 553 554 When a server def is replaced, it effectively clears a bunch of caches 555 within the context. If you attempt to use a tensor object that was pointing 556 to a tensor on the remote device, it will raise an error. 557 558 Args: 559 server_def: A tensorflow::ServerDef proto. 560 Enables execution on remote devices. 561 keep_alive_secs: Num. seconds after which the remote end will hang up. 562 As long as the client is still alive, the server state for the context 563 will be kept alive. If the client is killed (or there is some failure), 564 the server will clean up its context keep_alive_secs after the final RPC 565 it receives. 566 567 Raises: 568 ValueError: if server_def is None. 569 """ 570 if not server_def: 571 raise ValueError("server_def is None.") 572 573 self._server_def = server_def 574 575 if self._context_handle: 576 server_def_str = server_def.SerializeToString() 577 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, 578 server_def_str) 579 self._initialize_logical_devices() 580 581 # Clear all the caches in case there are remote tensors in them. 582 self._clear_caches() 583 584 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 585 """Update a server_def on the context. 586 587 Args: 588 server_def: A tensorflow::ServerDef proto. Enables execution on remote 589 devices. 590 keep_alive_secs: Num. seconds after which the remote end will hang up. As 591 long as the client is still alive, the server state for the context will 592 be kept alive. If the client is killed (or there is some failure), the 593 server will clean up its context keep_alive_secs after the final RPC it 594 receives. 595 596 Raises: 597 ValueError: if server_def is None. 598 """ 599 if not server_def: 600 raise ValueError("server_def is None.") 601 602 self._server_def = server_def 603 604 if self._context_handle: 605 server_def_str = server_def.SerializeToString() 606 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, 607 keep_alive_secs, server_def_str) 608 self._initialize_logical_devices() 609 610 self._clear_caches() 611 612 def check_alive(self, worker_name): 613 """Checks whether a remote worker is alive or not. 614 615 Args: 616 worker_name: a string representing the remote worker. It must be a fully 617 specified name like "/job:worker/replica:0/task:0". 618 619 Returns: 620 a boolean indicating whether the remote worker is alive or not. 621 622 Raises: 623 ValueError: if context is not initialized. 624 """ 625 # TODO(yuefengz): support checking multiple workers. 626 if self._context_handle: 627 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) 628 else: 629 raise ValueError("Context is not initialized.") 630 631 def sync_executors(self): 632 """Sync both local executors and the ones on remote workers. 633 634 In async execution mode, local function calls can return before the 635 corresponding remote op/function execution requests are completed. Calling 636 this method creates a synchronization barrier for remote executors. It only 637 returns when all remote pending nodes are finished, potentially with errors 638 if any remote executors are in error state. 639 640 Raises: 641 ValueError: if context is not initialized. 642 """ 643 if self._context_handle: 644 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle) 645 else: 646 raise ValueError("Context is not initialized.") 647 648 def clear_executor_errors(self): 649 """Clear errors in both local executors and remote workers. 650 651 After receiving errors from remote workers, additional requests on the fly 652 could further taint the status on the remote workers due to the async nature 653 of remote execution. Calling this method block on waiting for all pending 654 nodes in remote executors to finish and clear their error statuses. 655 656 Raises: 657 ValueError: if context is not initialized. 658 """ 659 if self._context_handle: 660 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle) 661 else: 662 raise ValueError("Context is not initialized.") 663 664 def clear_kernel_cache(self): 665 """Clear kernel cache and reset all stateful kernels. 666 667 Raises: 668 ValueError: if context is not initialized. 669 """ 670 if self._context_handle is not None: 671 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 672 else: 673 raise ValueError("Context is not initialized.") 674 675 def enable_collective_ops(self, server_def): 676 """Enable distributed collective ops with an appropriate server_def. 677 678 Args: 679 server_def: A tensorflow::ServerDef proto. Enables execution on remote 680 devices. 681 682 Raises: 683 ValueError: if server_def is None. 684 RuntimeError: if this method is not called at program startup. 685 """ 686 if not server_def: 687 raise ValueError("server_def is None.") 688 689 self._collective_ops_server_def = server_def 690 691 # TODO(b/129298253): Allow creating datasets/tensors before enabling 692 # collective ops. 693 if self._context_handle is not None: 694 logging.warning("Enabling collective ops after program startup may cause " 695 "error when accessing previously created tensors.") 696 with self._initialize_lock: 697 assert self._initialized 698 server_def_str = self._collective_ops_server_def.SerializeToString() 699 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) 700 self._initialize_logical_devices() 701 self._clear_caches() 702 703 def configure_collective_ops( 704 self, 705 collective_leader="", 706 scoped_allocator_enabled_ops=("CollectiveReduce",), 707 use_nccl_communication=False, 708 device_filters=None): 709 """Configure collective ops. 710 711 Collective group leader is necessary for collective ops to run, other 712 configurations are mainly for the purpose of performance. 713 714 Args: 715 collective_leader: a device string for collective leader, e.g. 716 "/job:worker/replica:0/task:0"; empty string means local execution of 717 collective ops. 718 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped 719 allocator to run with. 720 use_nccl_communication: whether to use nccl communication for collective 721 ops. 722 device_filters: a tuple or a list of device strings. If set, corresponding 723 task can only see the devices filtered by these device filters. 724 725 Raises: 726 RuntimeError: if this method is not called at program startup. 727 """ 728 if self._collective_leader is not None: 729 if (self._collective_leader != collective_leader or 730 self._collective_scoped_allocator_enabled_ops != 731 scoped_allocator_enabled_ops or 732 self._collective_use_nccl_communication != use_nccl_communication or 733 self._collective_device_filters != device_filters): 734 raise ValueError("Collective ops are already configured.") 735 else: 736 return 737 738 if self._context_handle is not None: 739 raise RuntimeError("Collective ops must be configured at program startup") 740 741 self._collective_leader = collective_leader 742 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops 743 self._collective_use_nccl_communication = use_nccl_communication 744 self._collective_device_filters = device_filters 745 746 def abort_collective_ops(self, code, message): 747 """Abort the collective ops. 748 749 This is intended to be used when a peer failure is detected, which allows 750 the user to handle the case instead of hanging. This aborts all on-going 751 collectives. After all subsequent collectives error immediately, and you 752 need to reset_context() to use collectives again. 753 754 Args: 755 code: a `tf.errors` error code. 756 message: a string. The error message. 757 """ 758 self.ensure_initialized() 759 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) 760 761 def check_collective_ops_peer_health(self, task, timeout_in_ms): 762 """Check collective peer health. 763 764 This probes each task to see if they're still alive. Note that restarted 765 tasks are considered a different one, and they're considered not healthy. 766 767 This should only be used in multi client multi worker training. 768 769 Args: 770 task: a task string, must be in the format of /job:xxx/replica:0/task:N. 771 timeout_in_ms: an integer, the timeout. If zero, there's no timeout. 772 773 Raises: 774 tf.errors.UnavailableError: when a peer is down. 775 tf.errors.FailedPreconditionError: when a peer is a different one from the 776 one this task has talked to, e.g. the peer has restarted. 777 tf.errors.InvalidArgumentError: when the task string is invalid. 778 """ 779 self.ensure_initialized() 780 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, 781 timeout_in_ms) 782 783 @property 784 def _handle(self): 785 if self._context_handle is None: 786 raise AssertionError("Context must be initialized first.") 787 788 return self._context_handle 789 790 @property 791 def _devices(self): 792 if self._context_devices is None: 793 raise AssertionError("Context must be initialized first.") 794 795 return self._context_devices 796 797 def __str__(self): 798 if self._context_handle is None: 799 return "Eager TensorFlow Context. Devices currently uninitialized." 800 else: 801 devices = self._devices 802 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 803 for i, d in enumerate(devices): 804 lines.append(" Device %d: %s" % (i, d)) 805 return "\n".join(lines) 806 807 @tf_contextlib.contextmanager 808 def _mode(self, mode): 809 """A context manager to allow setting the mode to EAGER/GRAPH.""" 810 ctx = self._thread_local_data 811 old_is_eager = ctx.is_eager 812 ctx.is_eager = mode == EAGER_MODE 813 if mode == EAGER_MODE: 814 # Entering graph mode does not provide us with sufficient information to 815 # record a context switch; graph-based context switches are only logged 816 # when a graph is registered as the default graph. 817 self.context_switches.push(False, eager_mode, None) 818 try: 819 yield 820 finally: 821 ctx.is_eager = old_is_eager 822 if mode == EAGER_MODE: 823 self.context_switches.pop() 824 825 def executing_eagerly(self): 826 """Returns True if current thread has eager executing enabled.""" 827 return self._thread_local_data.is_eager 828 829 def ones_rank_cache(self): 830 """Per-device cache for scalars.""" 831 return _tensor_caches_map[self._id].ones_rank_cache 832 833 def zeros_cache(self): 834 """Per-device cache for scalars.""" 835 return _tensor_caches_map[self._id].zeros_cache 836 837 @property 838 def scope_name(self): 839 """Returns scope name for the current thread.""" 840 return self._thread_local_data.scope_name 841 842 @scope_name.setter 843 def scope_name(self, s): 844 """Sets scope name for the current thread.""" 845 self._thread_local_data.scope_name = s 846 847 @property 848 def device_name(self): 849 """Returns the device name for the current thread.""" 850 return self._thread_local_data.device_name 851 852 @property 853 def device_spec(self): 854 """Returns the device spec for the current thread.""" 855 return self._thread_local_data.device_spec 856 857 def _set_device(self, device_name, device_spec): 858 self._thread_local_data.device_name = device_name 859 self._thread_local_data.device_spec = device_spec 860 861 def device(self, name): 862 """Context-manager to force placement of operations and Tensors on a device. 863 864 Args: 865 name: Name of the device or None to get default placement. 866 867 Returns: 868 Context manager that forces device placement. 869 870 Raises: 871 ValueError: If name is not a string or is an invalid device name. 872 RuntimeError: If device scopes are not properly nested. 873 """ 874 if isinstance(name, LogicalDevice): 875 name = name.name 876 elif pydev.is_device_spec(name): 877 name = name.to_string() 878 return _EagerDeviceContext(self, name) 879 880 def devices(self): 881 """List of the names of devices available to execute operations.""" 882 return self._devices 883 884 def host_address_space(self): 885 self.ensure_initialized() 886 with c_api_util.tf_buffer() as buffer_: 887 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_) 888 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 889 return address_space 890 891 # TODO(fishx): remove this property. 892 @property 893 def execution_mode(self): 894 """Gets execution mode for current thread.""" 895 return ASYNC if self.is_async() else SYNC 896 897 @execution_mode.setter 898 def execution_mode(self, mode): 899 """Sets execution mode for current thread.""" 900 if mode not in (None, SYNC, ASYNC): 901 raise ValueError( 902 "Execution mode should be None/SYNC/ASYNC. Got %s" % mode) 903 904 if mode is None: 905 mode = SYNC 906 907 enable_async = (mode == ASYNC) 908 if self.is_async() != enable_async: 909 # Only set the execution mode if the context has already been initialized 910 if self._context_handle is not None: 911 self.executor.wait() 912 executor_new = executor.new_executor(enable_async) 913 self._thread_local_data.executor = executor_new 914 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, 915 executor_new.handle()) 916 else: 917 self._default_is_async = enable_async 918 919 def is_async(self): 920 if self._context_handle is not None: 921 return self.executor.is_async() 922 else: 923 return self._default_is_async 924 925 @property 926 def executor(self): 927 self.ensure_initialized() 928 return executor.Executor( 929 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) 930 931 @executor.setter 932 def executor(self, e): 933 self.ensure_initialized() 934 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) 935 936 @property 937 def config(self): 938 """Return the ConfigProto with all runtime deltas applied.""" 939 # Ensure physical devices have been discovered and config has been imported 940 self._initialize_physical_devices() 941 942 config = config_pb2.ConfigProto() 943 if self._config is not None: 944 config.CopyFrom(self._config) 945 946 if self._optimizer_jit is not None: 947 config.graph_options.optimizer_options.global_jit_level = ( 948 config_pb2.OptimizerOptions.ON_1 949 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) 950 if self._intra_op_parallelism_threads is not None: 951 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads 952 if self._inter_op_parallelism_threads is not None: 953 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads 954 955 if self._soft_device_placement is not None: 956 config.allow_soft_placement = self._soft_device_placement 957 else: 958 config.allow_soft_placement = self.executing_eagerly() 959 960 if self._log_device_placement is not None: 961 config.log_device_placement = self._log_device_placement 962 963 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() 964 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled 965 if (is_mlir_bridge_enabled == 966 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): 967 config.experimental.enable_mlir_bridge = True 968 969 if self._enable_mlir_graph_optimization is not None: 970 config.experimental.enable_mlir_graph_optimization = ( 971 self._enable_mlir_graph_optimization) 972 973 def rewriter_toggle(option): 974 toggle = self._optimizer_experimental_options.get(option, None) 975 if toggle is None: 976 return 977 978 setattr(config.graph_options.rewrite_options, 979 option, 980 (rewriter_config_pb2.RewriterConfig.ON 981 if toggle else rewriter_config_pb2.RewriterConfig.OFF)) 982 983 def rewriter_bool(option): 984 toggle = self._optimizer_experimental_options.get(option, None) 985 if toggle is None: 986 return 987 988 setattr(config.graph_options.rewrite_options, 989 option, 990 toggle) 991 992 rewriter_toggle("layout_optimizer") 993 rewriter_toggle("constant_folding") 994 rewriter_toggle("shape_optimization") 995 rewriter_toggle("remapping") 996 rewriter_toggle("arithmetic_optimization") 997 rewriter_toggle("dependency_optimization") 998 rewriter_toggle("loop_optimization") 999 rewriter_toggle("function_optimization") 1000 rewriter_toggle("debug_stripper") 1001 rewriter_bool("disable_model_pruning") 1002 rewriter_toggle("scoped_allocator_optimization") 1003 rewriter_toggle("pin_to_host_optimization") 1004 rewriter_toggle("implementation_selector") 1005 rewriter_toggle("auto_mixed_precision") 1006 rewriter_bool("disable_meta_optimizer") 1007 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None) 1008 if nodes is not None: 1009 config.graph_options.rewrite_options.min_graph_nodes = nodes 1010 1011 # Compute device counts 1012 config.device_count["CPU"] = 0 1013 config.device_count["GPU"] = 0 1014 for dev in self._physical_devices: 1015 if dev not in self._visible_device_list: 1016 continue 1017 1018 virtual_devices = self._virtual_device_map.get(dev) 1019 if virtual_devices is None: 1020 config.device_count[dev.device_type] += 1 1021 else: 1022 config.device_count[dev.device_type] += len(virtual_devices) 1023 1024 # Configure gpu_options 1025 gpu_options = self._compute_gpu_options() 1026 config.gpu_options.MergeFrom(gpu_options) 1027 1028 # Configure collective ops 1029 if self._collective_leader: 1030 config.experimental.collective_group_leader = self._collective_leader 1031 if self._collective_scoped_allocator_enabled_ops: 1032 rewrite_options = config.graph_options.rewrite_options 1033 rewrite_options.scoped_allocator_optimization = ( 1034 rewriter_config_pb2.RewriterConfig.ON) 1035 del rewrite_options.scoped_allocator_opts.enable_op[:] 1036 for op in self._collective_scoped_allocator_enabled_ops: 1037 rewrite_options.scoped_allocator_opts.enable_op.append(op) 1038 if self._collective_use_nccl_communication: 1039 config.experimental.collective_nccl = True 1040 if self._collective_device_filters: 1041 del config.device_filters[:] 1042 for f in self._collective_device_filters: 1043 config.device_filters.append(f) 1044 1045 return config 1046 1047 def _compute_gpu_options(self): 1048 """Build the GPUOptions proto.""" 1049 visible_device_list = [] 1050 virtual_devices = [] 1051 gpu_index = -1 1052 memory_growths = set() 1053 for dev in self.list_physical_devices("GPU"): 1054 gpu_index += 1 1055 1056 if dev not in self._visible_device_list: 1057 continue 1058 1059 growth = self._memory_growth_map[dev] 1060 memory_growths.add(growth) 1061 visible_device_list.append(str(gpu_index)) 1062 1063 if self._virtual_device_map: 1064 vdevs = self._virtual_device_map.get(dev, []) 1065 device_limits = [] 1066 priority = [] 1067 for virt_dev in vdevs: 1068 device_limits.append(virt_dev.memory_limit) 1069 if virt_dev.experimental_priority is not None: 1070 priority.append(virt_dev.experimental_priority) 1071 # If priority is specified, it must be specified for all virtual 1072 # devices. 1073 if priority and len(device_limits) != len(priority): 1074 raise ValueError("priority must be specified for all virtual devices") 1075 1076 virtual_devices.append( 1077 config_pb2.GPUOptions.Experimental.VirtualDevices( 1078 memory_limit_mb=device_limits, priority=priority)) 1079 1080 # Only compute growth if virtual devices have not been configured and we 1081 # have GPUs 1082 if not virtual_devices and memory_growths: 1083 if len(memory_growths) > 1: 1084 raise ValueError("Memory growth cannot differ between GPU devices") 1085 allow_growth = memory_growths.pop() 1086 else: 1087 allow_growth = None 1088 1089 return config_pb2.GPUOptions( 1090 allow_growth=allow_growth, 1091 visible_device_list=",".join(visible_device_list), 1092 experimental=config_pb2.GPUOptions.Experimental( 1093 virtual_devices=virtual_devices)) 1094 1095 @property 1096 def function_call_options(self): 1097 """Returns function call options for current thread. 1098 1099 Note that the returned object is still referenced by the eager context. 1100 1101 Returns: the FunctionCallOptions for current thread. 1102 """ 1103 if self._thread_local_data.function_call_options is None: 1104 config = self.config 1105 1106 # Default to soft placement for functions unless specified 1107 if self._soft_device_placement is None: 1108 config.allow_soft_placement = True 1109 self._thread_local_data.function_call_options = FunctionCallOptions( 1110 config_proto=config) 1111 1112 return self._thread_local_data.function_call_options 1113 1114 @function_call_options.setter 1115 def function_call_options(self, options): 1116 """Returns function call options for current thread.""" 1117 self._thread_local_data.function_call_options = options 1118 1119 def num_gpus(self): 1120 """The number of GPUs available to execute operations.""" 1121 self.ensure_initialized() 1122 return self._num_gpus 1123 1124 def add_function(self, fn): 1125 """Add a function definition to the context. 1126 1127 Once added, the function (identified by its name) can be executed like any 1128 other operation. 1129 1130 Args: 1131 fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 1132 """ 1133 self.ensure_initialized() 1134 pywrap_tfe.TFE_ContextAddFunction(self._handle, fn) 1135 1136 def add_function_def(self, fdef): 1137 """Add a function definition to the context. 1138 1139 Once added, the function (identified by its name) can be executed like any 1140 other operation. 1141 1142 Args: 1143 fdef: A FunctionDef protocol buffer message. 1144 """ 1145 self.ensure_initialized() 1146 fdef_string = fdef.SerializeToString() 1147 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, 1148 len(fdef_string)) 1149 1150 def get_function_def(self, name): 1151 """Get a function definition from the context. 1152 1153 Args: 1154 name: function signature name. 1155 1156 Returns: 1157 The requested FunctionDef. 1158 1159 Raises: 1160 tf.errors.NotFoundError: if name is not the name of a registered function. 1161 """ 1162 with c_api_util.tf_buffer() as buffer_: 1163 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_) 1164 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1165 function_def = function_pb2.FunctionDef() 1166 function_def.ParseFromString(proto_data) 1167 1168 return function_def 1169 1170 def register_custom_device(self, device_capsule, device_name, 1171 device_info_capsule): 1172 """Calls TFE_RegisterCustomDevice. See the non-member function.""" 1173 self.ensure_initialized() 1174 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, 1175 device_name, device_info_capsule) 1176 1177 def pack_eager_tensors(self, tensors): 1178 """Pack multiple `EagerTensor`s of the same dtype and shape. 1179 1180 Args: 1181 tensors: a list of EagerTensors to pack. 1182 1183 Returns: 1184 A packed EagerTensor. 1185 """ 1186 self.ensure_initialized() 1187 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) 1188 1189 def list_function_names(self): 1190 """Get a list of names of registered functions. 1191 1192 Returns: 1193 A set of names of all registered functions for the context. 1194 """ 1195 self.ensure_initialized() 1196 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle)) 1197 1198 def remove_function(self, name): 1199 """Remove a function from the context. 1200 1201 Once removed, the function cannot be executed anymore. 1202 1203 Args: 1204 name: function signature name. 1205 """ 1206 self.ensure_initialized() 1207 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) 1208 1209 def has_function(self, name): 1210 """Check if a function `name` is registered.""" 1211 self.ensure_initialized() 1212 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) 1213 1214 def add_op_callback(self, callback): 1215 """Add a post-op callback to the context. 1216 1217 A post-op callback is invoked immediately after an eager operation or 1218 function has finished execution or after a op has been added to a graph, 1219 providing access to the op's type, name input and output tensors. Multiple 1220 op callbacks can be added, in which case the callbacks will be invoked in 1221 the order in which they are added. 1222 1223 Args: 1224 callback: a callable of the signature 1225 `f(op_type, inputs, attrs, outputs, op_name=None, graph=None)`. 1226 See doc strings in `op_callbacks.py` for details on the function 1227 signature and its semantics. 1228 """ 1229 if callback not in self._thread_local_data.op_callbacks: 1230 self._thread_local_data.op_callbacks.append(callback) 1231 1232 def remove_op_callback(self, callback): 1233 """Remove an already-registered op callback. 1234 1235 Args: 1236 callback: The op callback to be removed. 1237 1238 Raises: 1239 KeyError: If `callback` is not already registered. 1240 """ 1241 if callback not in self._thread_local_data.op_callbacks: 1242 raise KeyError( 1243 "The specified op callback has not been registered, " 1244 "and hence cannot be removed.") 1245 del self._thread_local_data.op_callbacks[ 1246 self._thread_local_data.op_callbacks.index(callback)] 1247 1248 @property 1249 def op_callbacks(self): 1250 return self._thread_local_data.op_callbacks 1251 1252 @property 1253 def invoking_op_callbacks(self): 1254 return self._thread_local_data.invoking_op_callbacks 1255 1256 @invoking_op_callbacks.setter 1257 def invoking_op_callbacks(self, value): 1258 self._thread_local_data.invoking_op_callbacks = value 1259 1260 def _initialize_physical_devices(self, reinitialize=False): 1261 """Gets local devices visible to the system. 1262 1263 Args: 1264 reinitialize: If True, reinitializes self._physical_devices so that 1265 dynamic registered devices will also be visible to the python front-end. 1266 """ 1267 # We lazy initialize self._physical_devices since we do not want to do this 1268 # the constructor since the backend may not be initialized yet. 1269 with self._device_lock: 1270 if not reinitialize and self._physical_devices is not None: 1271 return 1272 1273 devs = pywrap_tfe.TF_ListPhysicalDevices() 1274 self._physical_devices = [ 1275 PhysicalDevice(name=d.decode(), 1276 device_type=d.decode().split(":")[1]) for d in devs] 1277 self._physical_device_to_index = { 1278 p: i for i, p in enumerate(self._physical_devices) 1279 } 1280 1281 self._visible_device_list = list(self._physical_devices) 1282 self._memory_growth_map = { 1283 d: None for d in self._physical_devices if d.device_type == "GPU" 1284 } 1285 1286 # Import device settings that may have been passed into the constructor 1287 self._import_config() 1288 1289 def reinitialize_physical_devices(self): 1290 """Gets local devices visible to the system.""" 1291 # Reinitialize the physical device list after registering 1292 # the pluggable device. 1293 self._initialize_physical_devices(True) 1294 1295 def list_physical_devices(self, device_type=None): 1296 """List local devices visible to the system. 1297 1298 This API allows a client to query the devices before they have been 1299 initialized by the eager runtime. Additionally a user can filter by device 1300 type, to get only CPUs or GPUs. 1301 1302 Args: 1303 device_type: Optional device type to limit results to 1304 1305 Returns: 1306 List of PhysicalDevice objects. 1307 """ 1308 self._initialize_physical_devices() 1309 1310 if device_type is None: 1311 return list(self._physical_devices) 1312 1313 return [d for d in self._physical_devices if d.device_type == device_type] 1314 1315 def get_device_details(self, device): # pylint: disable=redefined-outer-name 1316 """Returns details about a physical devices. 1317 1318 Args: 1319 device: A `tf.config.PhysicalDevice` returned by 1320 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. 1321 1322 Returns: 1323 A dict with string keys. 1324 """ 1325 if not isinstance(device, PhysicalDevice): 1326 raise ValueError("device must be a tf.config.PhysicalDevice, but got: " 1327 "%s" % (device,)) 1328 if (self._physical_device_to_index is None or 1329 device not in self._physical_device_to_index): 1330 raise ValueError("The PhysicalDevice must be one obtained from " 1331 "calling `tf.config.list_physical_devices`, but got: " 1332 "%s" % (device,)) 1333 index = self._physical_device_to_index[device] 1334 details = pywrap_tfe.TF_GetDeviceDetails(index) 1335 1336 # Change compute_capability from a string to a tuple 1337 if "compute_capability" in details: 1338 try: 1339 major, minor = details["compute_capability"].split(".") 1340 details["compute_capability"] = (int(major), int(minor)) 1341 except ValueError: 1342 raise RuntimeError("Device returned compute capability an in invalid " 1343 "format: %s" % details["compute_capability"]) 1344 return details 1345 1346 def _import_config(self): 1347 """Import config if passed in during construction. 1348 1349 If Context was created with a ConfigProto such as when calling 1350 tf.compat.v1.enable_eager_execution(), then we need to pull out the 1351 various pieces we might be replacing and import then into our internal 1352 class representation. 1353 """ 1354 if self._config is None: 1355 return 1356 1357 num_cpus = self._config.device_count.get("CPU", 1) 1358 if num_cpus != 1: 1359 cpus = [d for d in self._physical_devices if d.device_type == "CPU"] 1360 if num_cpus == 0: 1361 self.set_visible_devices([], "CPU") 1362 elif num_cpus > 1: 1363 self.set_logical_device_configuration( 1364 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) 1365 1366 # Parse GPU options 1367 gpus = [d for d in self._physical_devices if d.device_type == "GPU"] 1368 1369 # If there are no GPUs detected, simply ignore all the GPU options passed in 1370 # rather than doing any validation checks. 1371 if not gpus: 1372 return 1373 1374 gpu_count = self._config.device_count.get("GPU", None) 1375 1376 visible_gpus = [] 1377 # TODO(gjn): Handle importing existing virtual GPU configuration 1378 visible_indices = self._config.gpu_options.visible_device_list 1379 if visible_indices: 1380 for index in visible_indices.split(","): 1381 if int(index) >= len(gpus): 1382 raise ValueError("Invalid visible device index: %s" % index) 1383 visible_gpus.append(gpus[int(index)]) 1384 else: 1385 visible_gpus = gpus 1386 1387 if gpu_count is not None: 1388 visible_gpus = visible_gpus[:gpu_count] 1389 1390 self.set_visible_devices(visible_gpus, "GPU") 1391 1392 def list_logical_devices(self, device_type=None): 1393 """Return logical devices.""" 1394 self.ensure_initialized() 1395 if device_type is None: 1396 return list(self._logical_devices) 1397 1398 return [d for d in self._logical_devices if d.device_type == device_type] 1399 1400 def get_visible_devices(self, device_type=None): 1401 """Get the list of visible devices.""" 1402 self._initialize_physical_devices() 1403 1404 if device_type is None: 1405 return list(self._visible_device_list) 1406 1407 return [ 1408 d for d in self._visible_device_list if d.device_type == device_type 1409 ] 1410 1411 def set_visible_devices(self, devices, device_type=None): 1412 """Set the list of visible devices.""" 1413 self._initialize_physical_devices() 1414 1415 if not isinstance(devices, list): 1416 devices = [devices] 1417 1418 for d in devices: 1419 if d not in self._physical_devices: 1420 raise ValueError("Unrecognized device: %s" % repr(d)) 1421 if device_type is not None and d.device_type != device_type: 1422 raise ValueError("Unrecognized device: %s" % repr(d)) 1423 1424 visible_device_list = [] 1425 if device_type is not None: 1426 visible_device_list = [ 1427 d for d in self._visible_device_list if d.device_type != device_type 1428 ] 1429 1430 visible_device_list += devices 1431 1432 if self._visible_device_list == visible_device_list: 1433 return 1434 1435 if self._context_handle is not None: 1436 raise RuntimeError( 1437 "Visible devices cannot be modified after being initialized") 1438 1439 self._visible_device_list = visible_device_list 1440 1441 def get_memory_info(self, dev): 1442 """Returns a dict of memory info for the device.""" 1443 self._initialize_physical_devices() 1444 self.ensure_initialized() 1445 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev) 1446 1447 # TODO(reedwm): Remove this function 1448 def get_total_memory_usage(self, dev): 1449 """Returns total memory usage in bytes for the current device.""" 1450 return self.get_memory_info(dev)["current"] 1451 1452 def get_memory_growth(self, dev): 1453 """Get if memory growth is enabled for a PhysicalDevice.""" 1454 self._initialize_physical_devices() 1455 1456 if dev not in self._physical_devices: 1457 raise ValueError("Unrecognized device: %s" % repr(dev)) 1458 1459 return self._memory_growth_map[dev] 1460 1461 def set_memory_growth(self, dev, enable): 1462 """Set if memory growth should be enabled for a PhysicalDevice.""" 1463 self._initialize_physical_devices() 1464 1465 if dev not in self._physical_devices: 1466 raise ValueError("Unrecognized device: %s" % repr(dev)) 1467 1468 if dev in self._virtual_device_map: 1469 raise ValueError( 1470 "Cannot set memory growth on device when virtual devices configured") 1471 1472 if dev.device_type != "GPU": 1473 raise ValueError("Cannot set memory growth on non-GPU devices") 1474 1475 if self._memory_growth_map.get(dev) == enable: 1476 return 1477 1478 if self._context_handle is not None: 1479 raise RuntimeError( 1480 "Physical devices cannot be modified after being initialized") 1481 1482 self._memory_growth_map[dev] = enable 1483 1484 def get_logical_device_configuration(self, dev): 1485 """Get the virtual device configuration for a PhysicalDevice.""" 1486 self._initialize_physical_devices() 1487 1488 if dev not in self._physical_devices: 1489 raise ValueError("Unrecognized device: %s" % repr(dev)) 1490 1491 return self._virtual_device_map.get(dev) 1492 1493 def set_logical_device_configuration(self, dev, virtual_devices): 1494 """Set the virtual device configuration for a PhysicalDevice.""" 1495 self._initialize_physical_devices() 1496 1497 if dev not in self._physical_devices: 1498 raise ValueError("Unrecognized device: %s" % repr(dev)) 1499 1500 if dev.device_type == "CPU": 1501 for vdev in virtual_devices: 1502 if vdev.memory_limit is not None: 1503 raise ValueError("Setting memory limit on CPU virtual devices is " 1504 "currently not supported") 1505 if vdev.experimental_priority is not None: 1506 raise ValueError("Setting experimental_priority on CPU virtual " 1507 " devices is currently not supported") 1508 elif dev.device_type == "GPU": 1509 for vdev in virtual_devices: 1510 if vdev.memory_limit is None: 1511 raise ValueError( 1512 "Setting memory limit is required for GPU virtual devices") 1513 else: 1514 raise ValueError("Virtual devices are not supported for %s" % 1515 dev.device_type) 1516 1517 if self._virtual_device_map.get(dev) == virtual_devices: 1518 return 1519 1520 if self._context_handle is not None: 1521 raise RuntimeError( 1522 "Virtual devices cannot be modified after being initialized") 1523 1524 self._virtual_device_map[dev] = virtual_devices 1525 1526 def get_compiler_ir(self, device_name, function_name, args, stage="hlo"): 1527 return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name, 1528 stage, device_name, args) 1529 1530 @deprecated( 1531 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) 1532 def enable_xla_devices(self): 1533 """Enables XLA:CPU and XLA:GPU devices registration.""" 1534 pywrap_tfe.TF_EnableXlaDevices() 1535 1536 @property 1537 def enable_mlir_bridge(self): 1538 return pywrap_tfe.TF_IsMlirBridgeEnabled() 1539 1540 @property 1541 def enable_mlir_graph_optimization(self): 1542 return self._enable_mlir_graph_optimization 1543 1544 @enable_mlir_bridge.setter 1545 def enable_mlir_bridge(self, enabled): 1546 pywrap_tfe.TF_EnableMlirBridge(enabled) 1547 self._thread_local_data.function_call_options = None 1548 1549 @enable_mlir_graph_optimization.setter 1550 def enable_mlir_graph_optimization(self, enabled): 1551 self._enable_mlir_graph_optimization = enabled 1552 self._thread_local_data.function_call_options = None 1553 1554 @property 1555 def optimizer_jit(self): 1556 level = self.config.graph_options.optimizer_options.global_jit_level 1557 return (level == config_pb2.OptimizerOptions.ON_1 or 1558 level == config_pb2.OptimizerOptions.ON_2) 1559 1560 @optimizer_jit.setter 1561 def optimizer_jit(self, enabled): 1562 self._optimizer_jit = enabled 1563 1564 self._thread_local_data.function_call_options = None 1565 1566 def get_optimizer_experimental_options(self): 1567 """Get experimental options for the optimizer. 1568 1569 Returns: 1570 Dictionary of current option values 1571 """ 1572 rewrite_options = self.config.graph_options.rewrite_options 1573 options = {} 1574 1575 def rewriter_toggle(option): 1576 attr = getattr(rewrite_options, option) 1577 if attr != 0: 1578 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) 1579 1580 def rewriter_bool(option): 1581 options[option] = getattr(rewrite_options, option) 1582 1583 rewriter_toggle("layout_optimizer") 1584 rewriter_toggle("constant_folding") 1585 rewriter_toggle("shape_optimization") 1586 rewriter_toggle("remapping") 1587 rewriter_toggle("arithmetic_optimization") 1588 rewriter_toggle("dependency_optimization") 1589 rewriter_toggle("loop_optimization") 1590 rewriter_toggle("function_optimization") 1591 rewriter_toggle("debug_stripper") 1592 rewriter_bool("disable_model_pruning") 1593 rewriter_toggle("scoped_allocator_optimization") 1594 rewriter_toggle("pin_to_host_optimization") 1595 rewriter_toggle("implementation_selector") 1596 rewriter_toggle("auto_mixed_precision") 1597 rewriter_bool("disable_meta_optimizer") 1598 1599 if rewrite_options.min_graph_nodes != 0: 1600 options["min_graph_nodes"] = rewrite_options.min_graph_nodes 1601 1602 return options 1603 1604 def set_optimizer_experimental_options(self, options): 1605 """Set experimental options for the optimizer. 1606 1607 Args: 1608 options: Dictionary of options to modify 1609 """ 1610 self._optimizer_experimental_options.update(options) 1611 1612 self._thread_local_data.function_call_options = None 1613 1614 @property 1615 def intra_op_parallelism_threads(self): 1616 return self.config.intra_op_parallelism_threads 1617 1618 @intra_op_parallelism_threads.setter 1619 def intra_op_parallelism_threads(self, num_threads): 1620 if self._intra_op_parallelism_threads == num_threads: 1621 return 1622 1623 if self._context_handle is not None: 1624 raise RuntimeError( 1625 "Intra op parallelism cannot be modified after initialization.") 1626 1627 self._intra_op_parallelism_threads = num_threads 1628 1629 @property 1630 def inter_op_parallelism_threads(self): 1631 return self.config.inter_op_parallelism_threads 1632 1633 @inter_op_parallelism_threads.setter 1634 def inter_op_parallelism_threads(self, num_threads): 1635 if self._inter_op_parallelism_threads == num_threads: 1636 return 1637 1638 if self._context_handle is not None: 1639 raise RuntimeError( 1640 "Inter op parallelism cannot be modified after initialization.") 1641 1642 self._inter_op_parallelism_threads = num_threads 1643 1644 @property 1645 def soft_device_placement(self): 1646 return self.config.allow_soft_placement 1647 1648 @soft_device_placement.setter 1649 def soft_device_placement(self, enable): 1650 if self._context_handle is not None: 1651 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) 1652 1653 self._soft_device_placement = enable 1654 self._thread_local_data.function_call_options = None 1655 1656 @property 1657 def log_device_placement(self): 1658 return self.config.log_device_placement 1659 1660 @log_device_placement.setter 1661 def log_device_placement(self, enable): 1662 if self._context_handle is not None: 1663 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) 1664 1665 self._log_device_placement = enable 1666 self._thread_local_data.function_call_options = None 1667 1668 @property 1669 def device_policy(self): 1670 # Only get the policy from the context if it has already been initialized 1671 if self._context_handle is not None: 1672 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) 1673 1674 return self._device_policy 1675 1676 @device_policy.setter 1677 def device_policy(self, policy): 1678 if policy is None: 1679 policy = DEVICE_PLACEMENT_SILENT 1680 1681 if self._device_policy != policy: 1682 self._device_policy = policy 1683 1684 # Only set the policy if the context has already been initialized 1685 if self._context_handle is not None: 1686 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( 1687 self._handle, self._device_policy) 1688 1689 @property 1690 def use_tfrt(self): 1691 return self._use_tfrt 1692 1693 @use_tfrt.setter 1694 def use_tfrt(self, tfrt): 1695 """Sets whether to use TFRT.""" 1696 if not isinstance(tfrt, bool): 1697 raise ValueError("Expecting a boolean but got %s" % type(tfrt)) 1698 1699 if self._use_tfrt != tfrt: 1700 if self._initialized: 1701 raise ValueError("use_tfrt should be set before being initialized.") 1702 self._use_tfrt = tfrt 1703 1704 def enable_run_metadata(self): 1705 """Enables tracing of op execution via RunMetadata. 1706 1707 To retrieve the accumulated metadata call context.export_run_metadata() 1708 and to stop tracing call context.disable_run_metadata(). 1709 """ 1710 self.ensure_initialized() 1711 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) 1712 1713 def disable_run_metadata(self): 1714 """Disables tracing of op execution via RunMetadata.""" 1715 if not self._context_handle: 1716 return 1717 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) 1718 1719 def enable_graph_collection(self): 1720 """Enables graph collection of executed functions. 1721 1722 To retrieve the accumulated graphs call context.export_run_metadata() 1723 and to stop collecting graphs call context.disable_graph_collection(). 1724 """ 1725 self.ensure_initialized() 1726 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) 1727 1728 def disable_graph_collection(self): 1729 """Disables graph collection of executed functions.""" 1730 if not self._context_handle: 1731 return 1732 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) 1733 1734 def export_run_metadata(self): 1735 """Returns a RunMetadata proto with accumulated information. 1736 1737 The returned protocol buffer contains information since the most recent call 1738 to either enable_run_metadata or export_run_metadata. 1739 1740 Returns: 1741 A RunMetadata protocol buffer. Or None if not enabled. 1742 """ 1743 if not self._context_handle: 1744 return None 1745 with c_api_util.tf_buffer() as buffer_: 1746 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) 1747 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1748 run_metadata = config_pb2.RunMetadata() 1749 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1750 return run_metadata 1751 1752 @property 1753 def context_switches(self): 1754 """Returns a stack of context switches.""" 1755 return self._context_switches 1756 1757 1758class _EagerDeviceContext(object): 1759 """Context-manager forcing placement of ops and Tensors on a device.""" 1760 1761 __slots__ = ["_device_name", "_ctx", "_stack"] 1762 1763 def __init__(self, ctx, device_name): 1764 self._device_name = device_name 1765 self._ctx = ctx 1766 self._stack = [] 1767 1768 def __enter__(self): 1769 ctx = self._ctx 1770 old_device_name = ctx.device_name 1771 old_device_spec = ctx.device_spec 1772 new_device_name = self._device_name 1773 cache_key = (old_device_name, new_device_name) 1774 try: 1775 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 1776 except TypeError: 1777 # Error while trying to compute the cache key. 1778 raise ValueError("Expecting a string device name. Got %s(%s)" % 1779 (type(new_device_name), new_device_name)) 1780 except KeyError: 1781 # Handle a cache miss. 1782 if new_device_name is not None: 1783 if not isinstance(new_device_name, six.string_types): 1784 raise ValueError("Expecting a string device name. Got %s(%s)" % 1785 (type(new_device_name), new_device_name)) 1786 device_spec = pydev.DeviceSpec.from_string(new_device_name) 1787 if old_device_name: 1788 new_device_spec = copy.copy(old_device_spec) 1789 else: 1790 ctx.ensure_initialized() 1791 new_device_spec = pydev.DeviceSpec.from_string( 1792 ctx._context_devices[0]) # pylint: disable=protected-access 1793 new_device_spec = new_device_spec.make_merged_spec(device_spec) 1794 else: 1795 new_device_spec = pydev.DeviceSpec.from_string("") 1796 new_device_name = new_device_spec.to_string() 1797 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 1798 1799 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access 1800 self._stack.append((old_device_name, old_device_spec, new_device_spec)) 1801 1802 def __exit__(self, *ex_info): 1803 ctx = self._ctx 1804 old_device_name, old_device_spec, new_device_spec = self._stack[-1] 1805 if ctx.device_spec is not new_device_spec: 1806 raise RuntimeError( 1807 "Exiting device scope without proper scope nesting") 1808 del self._stack[-1] 1809 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access 1810 1811 1812# Do not set directly. Use _set_context. 1813_context = None 1814_context_lock = threading.Lock() 1815 1816 1817def _set_context_locked(ctx): 1818 global _context 1819 pywrap_tfe.TFE_Py_SetEagerContext(ctx) 1820 _context = ctx 1821 1822 1823def _set_context(ctx): 1824 with _context_lock: 1825 _set_context_locked(ctx) 1826 1827 1828def _create_context(): 1829 with _context_lock: 1830 if _context is None: 1831 ctx = Context() 1832 _set_context_locked(ctx) 1833 1834 1835def _reset_context(): 1836 """Clears and re-initializes the singleton context. 1837 1838 Should only be used for testing. 1839 """ 1840 global _context 1841 global _device_parsing_cache 1842 with _context_lock: 1843 if _context is not None: 1844 _context._clear_caches() 1845 _context = None 1846 _create_context() 1847 _device_parsing_cache = {} 1848 pywrap_tfe.TFE_ClearScalarCache() 1849 1850 1851def context(): 1852 """Returns a singleton context object.""" 1853 if _context is None: 1854 _create_context() 1855 return _context 1856 1857 1858def context_safe(): 1859 """Returns current context (or None if one hasn't been initialized).""" 1860 return _context 1861 1862 1863def ensure_initialized(): 1864 """Initialize the context.""" 1865 context().ensure_initialized() 1866 1867 1868def set_global_seed(seed): 1869 """Sets the eager mode seed.""" 1870 context()._set_global_seed(seed) # pylint: disable=protected-access 1871 1872 1873def global_seed(): 1874 """Returns the eager mode seed.""" 1875 return context()._seed # pylint: disable=protected-access 1876 1877 1878def internal_operation_seed(): 1879 """Returns the operation seed generated based on global seed.""" 1880 return context()._internal_operation_seed() # pylint: disable=protected-access 1881 1882 1883@tf_export("executing_eagerly", v1=[]) 1884def executing_eagerly(): 1885 """Checks whether the current thread has eager execution enabled. 1886 1887 Eager execution is enabled by default and this API returns `True` 1888 in most of cases. However, this API might return `False` in the following use 1889 cases. 1890 1891 * Executing inside `tf.function`, unless under `tf.init_scope` or 1892 `tf.config.run_functions_eagerly(True)` is previously called. 1893 * Executing inside a transformation function for `tf.dataset`. 1894 * `tf.compat.v1.disable_eager_execution()` is called. 1895 1896 General case: 1897 1898 >>> print(tf.executing_eagerly()) 1899 True 1900 1901 Inside `tf.function`: 1902 1903 >>> @tf.function 1904 ... def fn(): 1905 ... with tf.init_scope(): 1906 ... print(tf.executing_eagerly()) 1907 ... print(tf.executing_eagerly()) 1908 >>> fn() 1909 True 1910 False 1911 1912 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called: 1913 1914 >>> tf.config.run_functions_eagerly(True) 1915 >>> @tf.function 1916 ... def fn(): 1917 ... with tf.init_scope(): 1918 ... print(tf.executing_eagerly()) 1919 ... print(tf.executing_eagerly()) 1920 >>> fn() 1921 True 1922 True 1923 >>> tf.config.run_functions_eagerly(False) 1924 1925 Inside a transformation function for `tf.dataset`: 1926 1927 >>> def data_fn(x): 1928 ... print(tf.executing_eagerly()) 1929 ... return x 1930 >>> dataset = tf.data.Dataset.range(100) 1931 >>> dataset = dataset.map(data_fn) 1932 False 1933 1934 Returns: 1935 `True` if the current thread has eager execution enabled. 1936 """ 1937 ctx = context_safe() 1938 if ctx is None: 1939 return default_execution_mode == EAGER_MODE 1940 1941 return ctx.executing_eagerly() 1942 1943 1944@tf_export(v1=["executing_eagerly"]) 1945def executing_eagerly_v1(): 1946 """Checks whether the current thread has eager execution enabled. 1947 1948 Eager execution is typically enabled via 1949 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the 1950 context of a Python function via tf.contrib.eager.py_func. 1951 1952 When eager execution is enabled, returns `True` in most cases. However, 1953 this API might return `False` in the following use cases. 1954 1955 * Executing inside `tf.function`, unless under `tf.init_scope` or 1956 `tf.config.run_functions_eagerly(True)` is previously called. 1957 * Executing inside a transformation function for `tf.dataset`. 1958 * `tf.compat.v1.disable_eager_execution()` is called. 1959 1960 >>> tf.compat.v1.enable_eager_execution() 1961 1962 General case: 1963 1964 >>> print(tf.executing_eagerly()) 1965 True 1966 1967 Inside `tf.function`: 1968 1969 >>> @tf.function 1970 ... def fn(): 1971 ... with tf.init_scope(): 1972 ... print(tf.executing_eagerly()) 1973 ... print(tf.executing_eagerly()) 1974 >>> fn() 1975 True 1976 False 1977 1978 Inside `tf.function` 1979 after `tf.config.run_functions_eagerly(True)` is called: 1980 1981 >>> tf.config.run_functions_eagerly(True) 1982 >>> @tf.function 1983 ... def fn(): 1984 ... with tf.init_scope(): 1985 ... print(tf.executing_eagerly()) 1986 ... print(tf.executing_eagerly()) 1987 >>> fn() 1988 True 1989 True 1990 >>> tf.config.run_functions_eagerly(False) 1991 1992 Inside a transformation function for `tf.dataset`: 1993 1994 >>> def data_fn(x): 1995 ... print(tf.executing_eagerly()) 1996 ... return x 1997 >>> dataset = tf.data.Dataset.range(100) 1998 >>> dataset = dataset.map(data_fn) 1999 False 2000 2001 Returns: 2002 `True` if the current thread has eager execution enabled. 2003 """ 2004 return executing_eagerly() 2005 2006 2007def in_eager_mode(): 2008 """Use executing_eagerly() instead. This function will be removed.""" 2009 return executing_eagerly() 2010 2011 2012def shared_name(name=None): 2013 """Returns the anonymous shared name GUID if no shared name is specified. 2014 2015 In eager mode we need to use a unique shared name to avoid spurious sharing 2016 issues. The runtime generates a unique name on our behalf when the reserved 2017 GUID is used as a shared name. 2018 2019 Args: 2020 name: Optional shared name 2021 2022 Returns: 2023 Eager compatible shared name. 2024 """ 2025 if name or not executing_eagerly(): 2026 return name 2027 2028 # Ensure a unique name when eager execution is enabled to avoid spurious 2029 # sharing issues. 2030 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 2031 2032 2033def graph_mode(): 2034 """Context-manager to disable eager execution for the current thread.""" 2035 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 2036 2037 2038# Used by b/167638505 for keras backend API and Lambda layer. 2039@tf_export("__internal__.eager_context.eager_mode", v1=[]) 2040def eager_mode(): 2041 """Context-manager to enable eager execution for the current thread.""" 2042 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 2043 2044 2045def scope_name(): 2046 """Name of the current scope.""" 2047 return context().scope_name 2048 2049 2050def device(name): 2051 """Context-manager to force placement of operations and Tensors on a device. 2052 2053 Example: 2054 ```python 2055 with tf.device('gpu:0'): 2056 with tf.device('cpu:0'): 2057 shape = tf.constant([], dtype=tf.int32) 2058 x = tf.random.truncated_normal(shape, tf.float32) 2059 ``` 2060 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 2061 operation runs on GPU 0. 2062 2063 Args: 2064 name: Name of the device (see context().devices()), or None to 2065 perform automatic placement. 2066 2067 Returns: 2068 Context manager for setting the device. 2069 """ 2070 ensure_initialized() 2071 return context().device(name) 2072 2073 2074# Expose some properties of Context as internally public APIs (b/160348781). 2075@tf_export("__internal__.eager_context.get_config", v1=[]) 2076def get_config(): 2077 """Get the ConfigProto of Context. 2078 2079 Returns: 2080 The ConfigProto of Context. 2081 """ 2082 return context().config 2083 2084 2085@tf_export("__internal__.eager_context.get_device_name", v1=[]) 2086def get_device_name(): 2087 """Get the device name for the current thread. 2088 2089 Returns: 2090 The device name for the current thread. 2091 """ 2092 return context().device_name 2093 2094 2095@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) 2096def set_soft_device_placement(enabled): 2097 """Set if soft device placements should be allowed. 2098 2099 Args: 2100 enabled: Whether to enable soft device placement. 2101 """ 2102 context().soft_device_placement = enabled 2103 2104 2105@tf_export("__internal__.eager_context.get_executor", v1=[]) 2106def get_executor(): 2107 """Get the Executor of the current thread. 2108 2109 Returns: 2110 The Executor of the current thread. 2111 """ 2112 return context().executor 2113 2114 2115@tf_export("debugging.get_log_device_placement") 2116def get_log_device_placement(): 2117 """Get if device placements are logged. 2118 2119 Returns: 2120 If device placements are logged. 2121 """ 2122 return context().log_device_placement 2123 2124 2125@tf_export("debugging.set_log_device_placement") 2126def set_log_device_placement(enabled): 2127 """Set if device placements should be logged. 2128 2129 Args: 2130 enabled: Whether to enabled device placement logging. 2131 """ 2132 context().log_device_placement = enabled 2133 2134 2135@tf_contextlib.contextmanager 2136def device_policy(policy): 2137 """Context manager for setting device placement policy for current thread.""" 2138 ctx = context() 2139 old_policy = ctx.device_policy 2140 try: 2141 ctx.device_policy = policy 2142 yield 2143 finally: 2144 ctx.device_policy = old_policy 2145 2146 2147def set_execution_mode(mode): 2148 """Sets execution mode for the current thread.""" 2149 context().execution_mode = mode 2150 2151 2152# TODO(fishx): remove this method. 2153@tf_contextlib.contextmanager 2154def execution_mode(mode): 2155 """Context manager for setting execution mode for current thread.""" 2156 if mode is None: 2157 yield 2158 else: 2159 ctx = context() 2160 executor_new = executor.new_executor(mode == ASYNC) 2161 executor_old = ctx.executor 2162 try: 2163 executor_old.wait() 2164 ctx.executor = executor_new 2165 yield 2166 finally: 2167 ctx.executor = executor_old 2168 executor_new.wait() 2169 2170 2171@tf_contextlib.contextmanager 2172def executor_scope(e): 2173 """Context manager for changing executor for current thread. 2174 2175 Args: 2176 e: A Executor to execute eager ops under this scope. Setting it to None will 2177 switch back to use the default executor for the context. 2178 2179 Yields: 2180 Context manager for setting the executor for current thread. 2181 """ 2182 ctx = context() 2183 executor_old = ctx.executor 2184 try: 2185 ctx.executor = e 2186 yield 2187 finally: 2188 ctx.executor = executor_old 2189 2190 2191@tf_export("experimental.function_executor_type") 2192@tf_contextlib.contextmanager 2193def function_executor_type(executor_type): 2194 """Context manager for setting the executor of eager defined functions. 2195 2196 Eager defined functions are functions decorated by tf.contrib.eager.defun. 2197 2198 Args: 2199 executor_type: a string for the name of the executor to be used to execute 2200 functions defined by tf.contrib.eager.defun. 2201 2202 Yields: 2203 Context manager for setting the executor of eager defined functions. 2204 """ 2205 current_options = context().function_call_options 2206 old_options = copy.copy(current_options) 2207 try: 2208 current_options.executor_type = executor_type 2209 yield 2210 finally: 2211 context().function_call_options = old_options 2212 2213 2214def is_async(): 2215 """Returns true if current thread is in async mode.""" 2216 return context().is_async() 2217 2218 2219def num_gpus(): 2220 """Get the number of available GPU devices. 2221 2222 Returns: 2223 The number of available GPU devices. 2224 """ 2225 return context().num_gpus() 2226 2227 2228def enable_run_metadata(): 2229 """Enables tracing of op execution via RunMetadata. 2230 2231 To retrieve the accumulated metadata call context.export_run_metadata() 2232 and to stop tracing call context.disable_run_metadata(). 2233 """ 2234 context().enable_run_metadata() 2235 2236 2237def disable_run_metadata(): 2238 """Disables tracing of op execution via RunMetadata.""" 2239 context().disable_run_metadata() 2240 2241 2242def enable_graph_collection(): 2243 """Enables graph collection of executed functions. 2244 2245 To retrieve the accumulated graphs call context.export_run_metadata() 2246 and to stop collecting graphs call context.disable_graph_collection(). 2247 """ 2248 context().enable_graph_collection() 2249 2250 2251def disable_graph_collection(): 2252 """Disables graph collection of executed functions.""" 2253 context().disable_graph_collection() 2254 2255 2256def export_run_metadata(): 2257 """Returns a RunMetadata proto with accumulated information. 2258 2259 The returned protocol buffer contains information since the most recent call 2260 to either enable_run_metadata or export_run_metadata. 2261 2262 Returns: 2263 A RunMetadata protocol buffer. 2264 """ 2265 return context().export_run_metadata() 2266 2267 2268@contextlib.contextmanager 2269def collect_graphs(optimized=True): 2270 """Collects a flat list of pre- or post-optimization graphs. 2271 2272 The collected graphs include device placements, which can be useful for 2273 testing. 2274 2275 Usage: 2276 2277 ``` 2278 @def_function.function 2279 def f(x): 2280 return x + constant_op.constant(1.) 2281 2282 with context.collect_graphs() as graphs: 2283 with ops.device("CPU:0"): 2284 f(constant_op.constant(1.)) 2285 2286 graph, = graphs # `graph` contains a single GraphDef for inspection 2287 ``` 2288 2289 Args: 2290 optimized: whether to collect optimized graphs or non-optimized graphs 2291 Yields: 2292 A list of GraphDefs, populated when the context manager exits. 2293 """ 2294 ctx = context() 2295 ctx.enable_graph_collection() 2296 try: 2297 graphs = [] 2298 yield graphs 2299 metadata = ctx.export_run_metadata() 2300 finally: 2301 ctx.disable_graph_collection() 2302 for graph in metadata.function_graphs: 2303 if optimized: 2304 graphs.append(graph.post_optimization_graph) 2305 else: 2306 graphs.append(graph.pre_optimization_graph) 2307 2308 2309def get_server_def(): 2310 return context().get_server_def() 2311 2312 2313def set_server_def(server_def): 2314 context().set_server_def(server_def) 2315 2316 2317def update_server_def(server_def): 2318 context().update_server_def(server_def) 2319 2320 2321def check_alive(worker_name): 2322 return context().check_alive(worker_name) 2323 2324 2325@tf_export("experimental.async_scope") 2326@tf_contextlib.contextmanager 2327def async_scope(): 2328 """Context manager for grouping async operations. 2329 2330 Ops/function calls inside the scope can return before finishing the actual 2331 execution. When exiting the async scope, a synchronization barrier will be 2332 automatically added to ensure the completion of all async op and function 2333 execution, potentially raising exceptions if async execution results in 2334 an error state. 2335 2336 Users may write the following code to asynchronously invoke `train_step_fn` 2337 and log the `loss` metric for every `num_steps` steps in a training loop. 2338 `train_step_fn` internally consumes data using `iterator.get_next()`, and may 2339 throw OutOfRangeError when running out of data. In the case: 2340 2341 ``` 2342 try: 2343 with tf.experimental.async_scope(): 2344 for _ in range(num_steps): 2345 # Step function updates the metric `loss` internally 2346 train_step_fn() 2347 except tf.errors.OutOfRangeError: 2348 tf.experimental.async_clear_error() 2349 logging.info('loss = %s', loss.numpy()) 2350 ``` 2351 2352 Yields: 2353 Context manager for grouping async operations. 2354 """ 2355 # TODO(haoyuzhang): replace env var once we have a config method to turn on 2356 # and off async streaming RPC 2357 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" 2358 old_policy = os.environ.get(remote_async_env_var) 2359 try: 2360 os.environ[remote_async_env_var] = str(True) 2361 yield 2362 # Note: sync local and remote executors iff the async block does not raise 2363 # an exception. Triggering sync after an exception may lead to derived 2364 # runtime errors and unexpected exception types. 2365 context().sync_executors() 2366 finally: 2367 if old_policy is None: 2368 del os.environ[remote_async_env_var] 2369 else: 2370 os.environ[remote_async_env_var] = old_policy 2371 2372 2373def async_wait(): 2374 """Sync all async operations and raise any errors during execution. 2375 2376 In async execution mode, an op/function call can return before finishing the 2377 actual execution. Calling this method creates a synchronization barrier for 2378 all async op and function execution. It only returns when all pending nodes 2379 are finished, potentially raising exceptions if async execution results in 2380 an error state. 2381 """ 2382 context().sync_executors() 2383 2384 2385@tf_export("experimental.async_clear_error") 2386def async_clear_error(): 2387 """Clear pending operations and error statuses in async execution. 2388 2389 In async execution mode, an error in op/function execution can lead to errors 2390 in subsequent ops/functions that are scheduled but not yet executed. Calling 2391 this method clears all pending operations and reset the async execution state. 2392 2393 Example: 2394 2395 ``` 2396 while True: 2397 try: 2398 # Step function updates the metric `loss` internally 2399 train_step_fn() 2400 except tf.errors.OutOfRangeError: 2401 tf.experimental.async_clear_error() 2402 break 2403 logging.info('loss = %s', loss.numpy()) 2404 ``` 2405 """ 2406 context().clear_executor_errors() 2407 2408 2409def add_function(fdef): 2410 """Add a function definition to the context.""" 2411 context().add_function(fdef) 2412 2413 2414def remove_function(name): 2415 """Remove a function from the context.""" 2416 context().remove_function(name) 2417 2418 2419def get_function_def(name): 2420 return context().get_function_def(name) 2421 2422 2423def register_custom_device(device_capsule, device_name, device_info_capsule): 2424 """Calls TFE_RegisterCustomDevice to register a custom device with Python. 2425 2426 Enables using C extensions specifying a custom device from Python. See the 2427 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 2428 details. 2429 2430 Note that custom devices are not currently supported inside `tf.function`s. 2431 2432 Args: 2433 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' 2434 containing a pointer to a TFE_CustomDevice struct. The capsule retains 2435 ownership of the memory. 2436 device_name: A string indicating the name to register the custom device 2437 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may 2438 subsequently be passed to `with tf.device(...):`. 2439 device_info_capsule: A PyCapsule with the name set to 2440 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific 2441 struct with the initial state of the custom device (the void* device_info 2442 argument to TFE_RegisterCustomDevice). This method takes ownership of the 2443 memory and clears the capsule destructor. 2444 """ 2445 context().register_custom_device(device_capsule, device_name, 2446 device_info_capsule) 2447 2448 2449# Not every user creates a Context via context.context() 2450# (for example, enable_eager_execution in python/framework/ops.py), 2451# but they do all import this file. Note that IS_IN_GRAPH_MODE and 2452# in_graph_mode are both parameterless functions. 2453def _tmp_in_graph_mode(): 2454 if context_safe() is None: 2455 # Context not yet initialized. Assume graph mode following the 2456 # default implementation in `is_in_graph_mode`. 2457 return True 2458 return not executing_eagerly() 2459 2460 2461is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode 2462