1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""State management for eager execution.""" 16 17import collections 18import contextlib 19import copy 20import gc 21import os 22import random 23import threading 24 25from absl import logging 26import numpy as np 27 28from tensorflow.core.framework import function_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.core.protobuf import coordination_config_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python import pywrap_tfe 33from tensorflow.python import tf2 34from tensorflow.python.client import pywrap_tf_session 35from tensorflow.python.eager import executor 36from tensorflow.python.eager import monitoring 37from tensorflow.python.framework import c_api_util 38from tensorflow.python.framework import device as pydev 39from tensorflow.python.framework import tfrt_utils 40from tensorflow.python.util import compat 41from tensorflow.python.util import is_in_graph_mode 42from tensorflow.python.util import tf_contextlib 43from tensorflow.python.util.deprecation import deprecated 44from tensorflow.python.util.tf_export import tf_export 45 46GRAPH_MODE = 0 47EAGER_MODE = 1 48 49default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 50 51# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 52# new_device_spec). 53# Note that we do not protect this with a lock and instead rely on python's GIL 54# and the idempotent nature of writes to provide thread safety. 55_device_parsing_cache = {} 56_starting_device_spec = pydev.DeviceSpec.from_string("") 57 58_MAXINT32 = 2**31 - 1 59 60DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT 61DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN 62DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT 63DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 64 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 65 66SYNC = 0 67ASYNC = 1 68 69_KEEP_ALIVE_SECS = 600 70 71_python_eager_context_create_counter = monitoring.Counter( 72 "/tensorflow/api/python/eager_context_create_counter", 73 "Counter for number of eager contexts created in Python.") 74 75# Re-exporting through context. 76is_tfrt_enabled = tfrt_utils.enabled 77 78# This flag and the associated environment var are transient and will eventually 79# be removed, once this experiment is enabled by default. 80_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1" 81 82 83def run_eager_op_as_function_enabled(): 84 return True 85 86 87# This method should only be called after the context has beein initialized. 88def enable_jit_compile_rewrite(): 89 """Run jit_compile functions through rewrite pass. 90 91 This runs jit_compile functions through all of the multidevice function 92 rewrite passes. 93 """ 94 global _JIT_COMPILE_REWRITE_ENABLED 95 _JIT_COMPILE_REWRITE_ENABLED = True 96 if context_safe() is not None: 97 context_safe().jit_compile_rewrite = True 98 99 100# This method should only be called after the context has been initialized. 101def disable_jit_compile_rewrite(): 102 global _JIT_COMPILE_REWRITE_ENABLED 103 _JIT_COMPILE_REWRITE_ENABLED = False 104 if context_safe() is not None: 105 context_safe().jit_compile_rewrite = False 106 107 108def jit_compile_rewrite_enabled(): 109 if context_safe() is not None: 110 return context_safe().jit_compile_rewrite 111 return _JIT_COMPILE_REWRITE_ENABLED 112 113 114# Expose it as internally public APIs for Keras use cases in b/171080602. 115tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) 116 117 118class _EagerTensorCache(object): 119 """Simple cache which evicts items based on length in a FIFO manner.""" 120 121 __slots__ = ["_data", "_max_items", "_max_tensor_size"] 122 123 def __init__(self, max_items=256, max_tensor_size=10000): 124 self._data = collections.OrderedDict() 125 self._max_items = max_items 126 self._max_tensor_size = max_tensor_size 127 128 def put(self, key, value): 129 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 130 return 131 132 self._data[key] = value 133 134 if len(self._data) > self._max_items: 135 self._data.popitem(last=False) 136 137 def get(self, key): 138 return self._data.get(key, None) 139 140 def flush(self): 141 self._data.clear() 142 143 144class FunctionCallOptions: 145 """Options applied at call sites of eager functions. 146 147 Eager functions are functions decorated with tf.contrib.eager.defun. 148 """ 149 150 __slots__ = ["_config_proto_serialized", "_executor_type"] 151 152 def __init__(self, executor_type=None, config_proto=None): 153 """Constructor. 154 155 Args: 156 executor_type: (optional) name of the executor to be used to execute the 157 eager function. If None or an empty string, the default Tensorflow 158 executor will be used. 159 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized 160 string of that proto. The config used by Grappler when optimizing the 161 function graph. Each concrete function is optimized the first time is 162 called. Changing config_proto after the first call has no effect. If 163 config_proto is None, an empty RewriterConfig will be used. 164 """ 165 self.config_proto_serialized = config_proto 166 self.executor_type = executor_type 167 168 @property 169 def executor_type(self): 170 return self._executor_type 171 172 @executor_type.setter 173 def executor_type(self, executor_type): 174 self._executor_type = executor_type 175 176 @property 177 def config_proto_serialized(self): 178 return self._config_proto_serialized 179 180 @config_proto_serialized.setter 181 def config_proto_serialized(self, config): 182 if isinstance(config, config_pb2.ConfigProto): 183 self._config_proto_serialized = config.SerializeToString( 184 deterministic=True) 185 elif isinstance(config, str): 186 self._config_proto_serialized = config 187 elif config is None: 188 self._config_proto_serialized = ( 189 config_pb2.ConfigProto().SerializeToString()) 190 else: 191 raise ValueError("the rewriter config must be either a " 192 "config_pb2.ConfigProto, or a serialized string of that " 193 "proto or None. got: {}".format(type(config))) 194 195 196# Map from context_id (an int) to _TensorCaches. 197# Dicts are thread safe in CPython. 198# TODO(iga): Remove this once TensorCaches are moved to C++. 199_tensor_caches_map = {} 200 201 202class _TensorCaches(threading.local): 203 """Thread local tensor caches.""" 204 205 __slots__ = ["_ones_rank_cache", "_zeros_cache"] 206 207 def __init__(self): 208 super().__init__() 209 self._ones_rank_cache = None 210 self._zeros_cache = None 211 212 @property 213 def ones_rank_cache(self): 214 if not self._ones_rank_cache: 215 self._ones_rank_cache = _EagerTensorCache() 216 return self._ones_rank_cache 217 218 @property 219 def zeros_cache(self): 220 if not self._zeros_cache: 221 self._zeros_cache = _EagerTensorCache() 222 return self._zeros_cache 223 224 225ContextSwitch = collections.namedtuple( 226 "ContextSwitch", 227 ["is_building_function", "enter_context_fn", "device_stack"]) 228 229 230# `_ContextSwitchStack` is a `threading.local` to match the semantics of 231# ``DefaultGraphStack`, which is also a `threading.local`. 232class _ContextSwitchStack(threading.local): 233 """A thread-local stack of context switches.""" 234 235 def __init__(self, eager): 236 super().__init__() 237 self.stack = [] 238 if eager: 239 # Initialize the stack with a pointer to enter the eager context; this 240 # ensures that the fact that eager execution was enabled is propagated 241 # across threads, since (1) `enable_eager_execution` modifies a 242 # process-level flag (`default_execution_mode`) and (2) `__init__` is 243 # called each time a threading.local object is used in a separate thread. 244 self.push( 245 is_building_function=False, 246 enter_context_fn=eager_mode, 247 device_stack=None) 248 249 def push(self, is_building_function, enter_context_fn, device_stack): 250 """Push metadata about a context switch onto the stack. 251 252 A context switch can take any one of the two forms: installing a graph as 253 the default graph, or entering the eager context. For each context switch, 254 we record whether or not the entered context is building a function. 255 256 Args: 257 is_building_function: (bool.) Whether the context is building a function. 258 enter_context_fn: (function.) A callable that executes the context switch. 259 For example, `graph.as_default` or `eager_mode`. 260 device_stack: If applicable, the device function stack for this graph. 261 When breaking out of graphs in init_scope, the innermost nonempty device 262 stack is used. Eager contexts put `None` here and the value is never 263 used. 264 """ 265 266 self.stack.append( 267 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 268 269 def pop(self): 270 """Pop the stack.""" 271 272 self.stack.pop() 273 274 275@tf_export("config.LogicalDevice") 276class LogicalDevice( 277 collections.namedtuple("LogicalDevice", ["name", "device_type"])): 278 """Abstraction for a logical device initialized by the runtime. 279 280 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a 281 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors 282 and operations can be placed on a specific logical device by calling 283 `tf.device` with a specified `tf.config.LogicalDevice`. 284 285 Fields: 286 name: The fully qualified name of the device. Can be used for Op or function 287 placement. 288 device_type: String declaring the type of device such as "CPU" or "GPU". 289 """ 290 pass 291 292 293@tf_export("config.LogicalDeviceConfiguration", 294 "config.experimental.VirtualDeviceConfiguration") 295class LogicalDeviceConfiguration( 296 collections.namedtuple("LogicalDeviceConfiguration", [ 297 "memory_limit", "experimental_priority", "experimental_device_ordinal" 298 ])): 299 """Configuration class for a logical devices. 300 301 The class specifies the parameters to configure a `tf.config.PhysicalDevice` 302 as it is initialized to a `tf.config.LogicalDevice` during runtime 303 initialization. Not all fields are valid for all device types. 304 305 See `tf.config.get_logical_device_configuration` and 306 `tf.config.set_logical_device_configuration` for usage examples. 307 308 Fields: 309 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual 310 device. Currently only supported for GPUs. 311 experimental_priority: (optional) Priority to assign to a virtual device. 312 Lower values have higher priorities and 0 is the default. 313 Within a physical GPU, the GPU scheduler will prioritize ops on virtual 314 devices with higher priority. Currently only supported for Nvidia GPUs. 315 experimental_device_ordinal: (optional) Ordinal number to order the virtual 316 device. 317 LogicalDevice with lower ordinal number will receive a lower device id. 318 Physical device id and location in the list is used to break ties. 319 Currently only supported for Nvidia GPUs. 320 """ 321 322 def __new__(cls, 323 memory_limit=None, 324 experimental_priority=None, 325 experimental_device_ordinal=0): 326 return super().__new__(cls, memory_limit, experimental_priority, 327 experimental_device_ordinal) 328 329 330@tf_export("config.PhysicalDevice") 331class PhysicalDevice( 332 collections.namedtuple("PhysicalDevice", ["name", "device_type"])): 333 """Abstraction for a locally visible physical device. 334 335 TensorFlow can utilize various devices such as the CPU or multiple GPUs 336 for computation. Before initializing a local device for use, the user can 337 customize certain properties of the device such as it's visibility or memory 338 configuration. 339 340 Once a visible `tf.config.PhysicalDevice` is initialized one or more 341 `tf.config.LogicalDevice` objects are created. Use 342 `tf.config.set_visible_devices` to configure the visibility of a physical 343 device and `tf.config.set_logical_device_configuration` to configure multiple 344 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is 345 useful when separation between models is needed or to simulate a multi-device 346 environment. 347 348 Fields: 349 name: Unique identifier for device. 350 device_type: String declaring the type of device such as "CPU" or "GPU". 351 """ 352 pass 353 354 355class _AtomicCounter(object): 356 """A simple atomic counter.""" 357 358 __slots__ = ["_value", "_lock"] 359 360 def __init__(self): 361 self._value = 0 362 self._lock = threading.Lock() 363 364 def increment_and_get(self): 365 with self._lock: 366 self._value += 1 367 return self._value 368 369 370_context_id_counter = _AtomicCounter() 371 372 373class _TensorCacheDeleter(object): 374 """Deletes tensor caches for a given context.""" 375 376 __slots__ = ["_context_id"] 377 378 def __init__(self, context_id): 379 self._context_id = context_id 380 381 def __del__(self): 382 if _tensor_caches_map is None: 383 return 384 if self._context_id in _tensor_caches_map: 385 del _tensor_caches_map[self._context_id] 386 387 388# TODO(agarwal): rename to EagerContext / EagerRuntime ? 389# TODO(agarwal): consider keeping the corresponding Graph here. 390class Context: 391 """Environment in which eager operations execute.""" 392 393 # TODO(agarwal): create and link in some documentation for `execution_mode`. 394 # pylint: disable=redefined-outer-name 395 def __init__(self, 396 config=None, 397 device_policy=None, 398 execution_mode=None, 399 server_def=None): 400 """Creates a new Context. 401 402 Args: 403 config: (Optional.) A `ConfigProto` protocol buffer with configuration 404 options for the Context. Note that a lot of these options may be 405 currently unimplemented or irrelevant when eager execution is enabled. 406 device_policy: (Optional.) What policy to use when trying to run an 407 operation on a device with inputs which are not on that device. When set 408 to None, an appropriate value will be picked automatically. The value 409 picked may change between TensorFlow releases. Defaults to 410 DEVICE_PLACEMENT_SILENT. 411 Valid values: 412 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 413 correct. 414 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right 415 device but raises a warning. 416 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide 417 performance problems. 418 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 419 raising errors on the other ones. 420 execution_mode: (Optional.) Policy controlling how operations dispatched 421 are actually executed. When set to None, an appropriate value will be 422 picked automatically. The value picked may change between TensorFlow 423 releases. 424 Valid values: 425 - SYNC: executes each operation synchronously. 426 - ASYNC: executes each operation asynchronously. These operations may 427 return "non-ready" handles. 428 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution 429 on remote devices. GrpcServers need to be started by creating an 430 identical server_def to this, and setting the appropriate task_indexes, 431 so that the servers can communicate. It will then be possible to execute 432 operations on remote devices. 433 434 Raises: 435 ValueError: If execution_mode is not valid. 436 """ 437 # This _id is used only to index the tensor caches. 438 # TODO(iga): Remove this when tensor caches are moved to C++. 439 self._id = _context_id_counter.increment_and_get() 440 self._tensor_cache_deleter = _TensorCacheDeleter(self._id) 441 _tensor_caches_map[self._id] = _TensorCaches() 442 443 self._config = config 444 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( 445 self, 446 is_eager=lambda: default_execution_mode == EAGER_MODE, 447 device_spec=_starting_device_spec) 448 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 449 self._context_handle = None 450 self._context_devices = None 451 self._seed = None 452 self._initialize_lock = threading.Lock() 453 self._initialized = False 454 if device_policy is None: 455 device_policy = DEVICE_PLACEMENT_SILENT 456 self._device_policy = device_policy 457 self._mirroring_policy = None 458 if execution_mode not in (None, SYNC, ASYNC): 459 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" % 460 execution_mode) 461 if execution_mode is None: 462 execution_mode = SYNC 463 self._default_is_async = execution_mode == ASYNC 464 self._use_tfrt = is_tfrt_enabled() 465 self._use_tfrt_distributed_runtime = None 466 self._jit_compile_rewrite = jit_compile_rewrite_enabled() 467 self._server_def = server_def 468 self._collective_ops_server_def = None 469 self._collective_leader = None 470 self._collective_scoped_allocator_enabled_ops = None 471 self._collective_use_nccl_communication = None 472 self._collective_device_filters = None 473 self._coordination_service_config = None 474 475 self._device_lock = threading.Lock() 476 self._physical_devices = None 477 self._physical_device_to_index = None 478 self._pluggable_devices = None 479 self._visible_device_list = [] 480 self._memory_growth_map = None 481 self._virtual_device_map = {} 482 483 # Values set after construction 484 self._optimizer_jit = None 485 self._intra_op_parallelism_threads = None 486 self._inter_op_parallelism_threads = None 487 self._soft_device_placement = None 488 self._log_device_placement = None 489 self._operation_timeout_in_ms = None 490 self._enable_mlir_graph_optimization = None 491 self._optimizer_experimental_options = {} 492 493 _python_eager_context_create_counter.get_cell().increase_by(1) 494 495 self._is_global_context = False 496 497 # pylint: enable=redefined-outer-name 498 499 def _set_global_seed(self, seed): 500 """Set a global eager mode seed for random ops.""" 501 self._seed = seed 502 # `random.Random(seed)` needs `seed` to be hashable, while values of type 503 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them 504 # to int. 505 try: 506 hash(seed) 507 except TypeError: 508 seed = int(np.array(seed)) 509 self._rng = random.Random(seed) 510 # Also clear the kernel cache, to reset any existing seeds 511 if self._context_handle is not None: 512 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 513 514 def _internal_operation_seed(self): 515 """Returns a fake operation seed. 516 517 In eager mode, user shouldn't set or depend on operation seed. 518 Here, we generate a random seed based on global seed to make 519 operation's randomness different and depend on the global seed. 520 521 Returns: 522 A fake operation seed based on global seed. 523 """ 524 return self._rng.randint(0, _MAXINT32) 525 526 def _initialize_logical_devices(self): 527 """Helper to initialize devices.""" 528 # Store list of devices 529 logical_devices = [] 530 context_devices = [] 531 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) 532 try: 533 self._num_gpus = 0 534 current_job, current_task = None, None 535 server_def = self._server_def or self._collective_ops_server_def 536 if server_def is not None: 537 current_job, current_task = server_def.job_name, server_def.task_index 538 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): 539 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) 540 context_devices.append(pydev.canonical_name(dev_name)) 541 spec = pydev.DeviceSpec.from_string(dev_name) 542 # If the job is localhost, we assume that the cluster has not yet been 543 # configured and thus clear the job, replica & task. 544 if spec.job == "localhost": 545 spec = spec.replace(job=None, replica=None, task=None) 546 logical_devices.append( 547 LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) 548 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) 549 if (dev_type == "GPU" and spec.job == current_job and 550 spec.task == current_task): 551 self._num_gpus += 1 552 553 finally: 554 self._logical_devices = logical_devices 555 self._context_devices = context_devices 556 pywrap_tfe.TF_DeleteDeviceList(device_list) 557 558 def ensure_initialized(self): 559 """Initialize handle and devices if not already done so.""" 560 if self._initialized: 561 return 562 with self._initialize_lock: 563 if self._initialized: 564 return 565 assert self._context_devices is None 566 opts = pywrap_tfe.TFE_NewContextOptions() 567 try: 568 config_str = self.config.SerializeToString() 569 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) 570 if self._device_policy is not None: 571 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( 572 opts, self._device_policy) 573 if self._mirroring_policy is not None: 574 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( 575 opts, self._mirroring_policy) 576 if self._default_is_async == ASYNC: 577 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) 578 if self._use_tfrt is not None: 579 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) 580 # pylint: disable=g-backslash-continuation 581 if self._use_tfrt is not None and \ 582 self._use_tfrt_distributed_runtime is not None: 583 pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime( 584 opts, self._use_tfrt_distributed_runtime) 585 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True) 586 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite( 587 opts, self._jit_compile_rewrite) 588 context_handle = pywrap_tfe.TFE_NewContext(opts) 589 finally: 590 pywrap_tfe.TFE_DeleteContextOptions(opts) 591 assert not (self._server_def and self._collective_ops_server_def), ( 592 "Cannot enable remote execution as well as collective ops at the " 593 "moment. If this is important to you, please file an issue.") 594 if self._server_def is not None: 595 server_def_str = self._server_def.SerializeToString() 596 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, 597 server_def_str) 598 elif self._collective_ops_server_def is not None: 599 server_def_str = self._collective_ops_server_def.SerializeToString() 600 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) 601 602 self._context_handle = context_handle 603 self._initialize_logical_devices() 604 self._initialized = True 605 606 if self._is_global_context: 607 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 608 609 def mark_as_global_context(self): 610 # If the context was already initialized, publish it. Otherwise wait with 611 # publication until it's initialized. 612 if self._initialized: 613 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 614 self._is_global_context = True 615 616 def _clear_caches(self): 617 self.ones_rank_cache().flush() 618 self.zeros_cache().flush() 619 pywrap_tfe.TFE_ClearScalarCache() 620 621 def get_server_def(self): 622 return self._server_def 623 624 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 625 """Allow setting a server_def on the context. 626 627 When a server def is replaced, it effectively clears a bunch of caches 628 within the context. If you attempt to use a tensor object that was pointing 629 to a tensor on the remote device, it will raise an error. 630 631 Args: 632 server_def: A tensorflow::ServerDef proto. Enables execution on remote 633 devices. 634 keep_alive_secs: Num. seconds after which the remote end will hang up. As 635 long as the client is still alive, the server state for the context will 636 be kept alive. If the client is killed (or there is some failure), the 637 server will clean up its context keep_alive_secs after the final RPC it 638 receives. 639 640 Raises: 641 ValueError: if server_def is None. 642 """ 643 if not server_def: 644 raise ValueError("server_def is None.") 645 646 self._server_def = server_def 647 648 if self._context_handle: 649 server_def_str = server_def.SerializeToString() 650 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, 651 server_def_str) 652 self._initialize_logical_devices() 653 654 # Clear all the caches in case there are remote tensors in them. 655 self._clear_caches() 656 657 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 658 """Update a server_def on the context. 659 660 Args: 661 server_def: A tensorflow::ServerDef proto. Enables execution on remote 662 devices. 663 keep_alive_secs: Num. seconds after which the remote end will hang up. As 664 long as the client is still alive, the server state for the context will 665 be kept alive. If the client is killed (or there is some failure), the 666 server will clean up its context keep_alive_secs after the final RPC it 667 receives. 668 669 Raises: 670 ValueError: if server_def is None. 671 """ 672 if not server_def: 673 raise ValueError("server_def is None.") 674 675 self._server_def = server_def 676 677 if self._context_handle: 678 server_def_str = server_def.SerializeToString() 679 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, 680 keep_alive_secs, server_def_str) 681 self._initialize_logical_devices() 682 683 self._clear_caches() 684 685 def check_alive(self, worker_name): 686 """Checks whether a remote worker is alive or not. 687 688 Args: 689 worker_name: a string representing the remote worker. It must be a fully 690 specified name like "/job:worker/replica:0/task:0". 691 692 Returns: 693 a boolean indicating whether the remote worker is alive or not. 694 695 Raises: 696 ValueError: if context is not initialized. 697 """ 698 # TODO(yuefengz): support checking multiple workers. 699 if self._context_handle: 700 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) 701 else: 702 raise ValueError("Context is not initialized.") 703 704 def sync_executors(self): 705 """Sync both local executors and the ones on remote workers. 706 707 In async execution mode, local function calls can return before the 708 corresponding remote op/function execution requests are completed. Calling 709 this method creates a synchronization barrier for remote executors. It only 710 returns when all remote pending nodes are finished, potentially with errors 711 if any remote executors are in error state. 712 713 Raises: 714 ValueError: if context is not initialized. 715 """ 716 if self._context_handle: 717 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle) 718 else: 719 raise ValueError("Context is not initialized.") 720 721 def clear_executor_errors(self): 722 """Clear errors in both local executors and remote workers. 723 724 After receiving errors from remote workers, additional requests on the fly 725 could further taint the status on the remote workers due to the async nature 726 of remote execution. Calling this method block on waiting for all pending 727 nodes in remote executors to finish and clear their error statuses. 728 729 Raises: 730 ValueError: if context is not initialized. 731 """ 732 if self._context_handle: 733 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle) 734 else: 735 raise ValueError("Context is not initialized.") 736 737 def configure_coordination_service(self, 738 service_type, 739 service_leader="", 740 enable_health_check=True, 741 cluster_register_timeout_in_ms=0, 742 heartbeat_timeout_in_ms=0, 743 coordinated_jobs=None): 744 """Enable distributed coordination service with specified configs.""" 745 if self._context_handle: 746 logging.warning("Configuring coordination service type may not be " 747 "effective because the context is already initialized.") 748 config = coordination_config_pb2.CoordinationServiceConfig() 749 config.service_type = service_type 750 if service_leader: 751 config.service_leader = pydev.canonical_name(service_leader) 752 config.enable_health_check = enable_health_check 753 config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms 754 config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms 755 if coordinated_jobs is not None: 756 if isinstance(coordinated_jobs, list): 757 config.coordinated_jobs.extend(coordinated_jobs) 758 else: 759 raise ValueError("`coordinated_jobs` must be a list of job names or " 760 "None, but got: %s" % (coordinated_jobs,)) 761 self._coordination_service_config = config 762 763 @property 764 def coordination_service(self): 765 return self._coordination_service_config 766 767 def set_config_key_value(self, key, value): 768 ensure_initialized() 769 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value) 770 771 def get_config_key_value(self, key): 772 ensure_initialized() 773 with c_api_util.tf_buffer() as buffer_: 774 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, buffer_) 775 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 776 return value 777 778 def delete_config_key_value(self, key): 779 ensure_initialized() 780 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key) 781 782 def report_error_to_cluster(self, error_code, error_message): 783 """Report error to other members in a multi-client cluster. 784 785 Args: 786 error_code: a `tf.errors` error code. 787 error_message: a string. The error message. 788 """ 789 if self._context_handle: 790 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code, 791 error_message) 792 else: 793 raise ValueError("Context is not initialized.") 794 795 def clear_kernel_cache(self): 796 """Clear kernel cache and reset all stateful kernels.""" 797 if self._context_handle is not None: 798 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 799 800 def enable_collective_ops(self, server_def): 801 """Enable distributed collective ops with an appropriate server_def. 802 803 Args: 804 server_def: A tensorflow::ServerDef proto. Enables execution on remote 805 devices. 806 807 Raises: 808 ValueError: if server_def is None. 809 RuntimeError: if this method is not called at program startup. 810 """ 811 if not server_def: 812 raise ValueError("server_def is None.") 813 814 self._collective_ops_server_def = server_def 815 816 # TODO(b/129298253): Allow creating datasets/tensors before enabling 817 # collective ops. 818 if self._context_handle is not None: 819 logging.warning("Enabling collective ops after program startup may cause " 820 "error when accessing previously created tensors.") 821 with self._initialize_lock: 822 assert self._initialized 823 server_def_str = self._collective_ops_server_def.SerializeToString() 824 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) 825 self._initialize_logical_devices() 826 self._clear_caches() 827 828 def configure_collective_ops( 829 self, 830 collective_leader="", 831 scoped_allocator_enabled_ops=("CollectiveReduce",), 832 use_nccl_communication=False, 833 device_filters=None): 834 """Configure collective ops. 835 836 Collective group leader is necessary for collective ops to run, other 837 configurations are mainly for the purpose of performance. 838 839 Args: 840 collective_leader: a device string for collective leader, e.g. 841 "/job:worker/replica:0/task:0"; empty string means local execution of 842 collective ops. 843 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped 844 allocator to run with. 845 use_nccl_communication: whether to use nccl communication for collective 846 ops. 847 device_filters: a tuple or a list of device strings. If set, corresponding 848 task can only see the devices filtered by these device filters. 849 850 Raises: 851 RuntimeError: if this method is not called at program startup. 852 """ 853 if self._collective_leader is not None: 854 if (self._collective_leader != collective_leader or 855 self._collective_scoped_allocator_enabled_ops != 856 scoped_allocator_enabled_ops or 857 self._collective_use_nccl_communication != use_nccl_communication or 858 self._collective_device_filters != device_filters): 859 raise ValueError("Collective ops are already configured.") 860 else: 861 return 862 863 if self._context_handle is not None: 864 raise RuntimeError("Collective ops must be configured at program startup") 865 866 self._collective_leader = collective_leader 867 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops 868 self._collective_use_nccl_communication = use_nccl_communication 869 self._collective_device_filters = device_filters 870 871 def abort_collective_ops(self, code, message): 872 """Abort the collective ops. 873 874 This is intended to be used when a peer failure is detected, which allows 875 the user to handle the case instead of hanging. This aborts all on-going 876 collectives. After all subsequent collectives error immediately, and you 877 need to reset_context() to use collectives again. 878 879 Args: 880 code: a `tf.errors` error code. 881 message: a string. The error message. 882 """ 883 self.ensure_initialized() 884 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) 885 886 def check_collective_ops_peer_health(self, task, timeout_in_ms): 887 """Check collective peer health. 888 889 This probes each task to see if they're still alive. Note that restarted 890 tasks are considered a different one, and they're considered not healthy. 891 892 This should only be used in multi client multi worker training. 893 894 Args: 895 task: a task string, must be in the format of /job:xxx/replica:0/task:N. 896 timeout_in_ms: an integer, the timeout. If zero, there's no timeout. 897 898 Raises: 899 tf.errors.UnavailableError: when a peer is down. 900 tf.errors.FailedPreconditionError: when a peer is a different one from the 901 one this task has talked to, e.g. the peer has restarted. 902 tf.errors.InvalidArgumentError: when the task string is invalid. 903 """ 904 self.ensure_initialized() 905 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, 906 timeout_in_ms) 907 908 @property 909 def _handle(self): 910 if self._context_handle is None: 911 raise AssertionError("Context must be initialized first.") 912 913 return self._context_handle 914 915 @property 916 def _devices(self): 917 if self._context_devices is None: 918 raise AssertionError("Context must be initialized first.") 919 920 return self._context_devices 921 922 def __str__(self): 923 if self._context_handle is None: 924 return "Eager TensorFlow Context. Devices currently uninitialized." 925 else: 926 devices = self._devices 927 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 928 for i, d in enumerate(devices): 929 lines.append(" Device %d: %s" % (i, d)) 930 return "\n".join(lines) 931 932 @tf_contextlib.contextmanager 933 def _mode(self, mode): 934 """A context manager to allow setting the mode to EAGER/GRAPH.""" 935 ctx = self._thread_local_data 936 old_is_eager = ctx.is_eager 937 ctx.is_eager = mode == EAGER_MODE 938 if mode == EAGER_MODE: 939 # Entering graph mode does not provide us with sufficient information to 940 # record a context switch; graph-based context switches are only logged 941 # when a graph is registered as the default graph. 942 self.context_switches.push(False, eager_mode, None) 943 try: 944 yield 945 finally: 946 ctx.is_eager = old_is_eager 947 if mode == EAGER_MODE: 948 self.context_switches.pop() 949 950 def executing_eagerly(self): 951 """Returns True if current thread has eager executing enabled.""" 952 return self._thread_local_data.is_eager 953 954 def ones_rank_cache(self): 955 """Per-device cache for scalars.""" 956 return _tensor_caches_map[self._id].ones_rank_cache 957 958 def zeros_cache(self): 959 """Per-device cache for scalars.""" 960 return _tensor_caches_map[self._id].zeros_cache 961 962 @property 963 def scope_name(self): 964 """Returns scope name for the current thread.""" 965 return self._thread_local_data.scope_name 966 967 @scope_name.setter 968 def scope_name(self, s): 969 """Sets scope name for the current thread.""" 970 self._thread_local_data.scope_name = s 971 972 @property 973 def device_name(self): 974 """Returns the device name for the current thread.""" 975 return self._thread_local_data.device_name 976 977 @property 978 def device_spec(self): 979 """Returns the device spec for the current thread.""" 980 return self._thread_local_data.device_spec 981 982 def _set_device(self, device_name, device_spec): 983 self._thread_local_data.device_name = device_name 984 self._thread_local_data.device_spec = device_spec 985 986 def device(self, name): 987 """Context-manager to force placement of operations and Tensors on a device. 988 989 Args: 990 name: Name of the device or None to get default placement. 991 992 Returns: 993 Context manager that forces device placement. 994 995 Raises: 996 ValueError: If name is not a string or is an invalid device name. 997 RuntimeError: If device scopes are not properly nested. 998 """ 999 if isinstance(name, LogicalDevice): 1000 name = name.name 1001 elif pydev.is_device_spec(name): 1002 name = name.to_string() 1003 return _EagerDeviceContext(self, name) 1004 1005 def devices(self): 1006 """List of the names of devices available to execute operations.""" 1007 return self._devices 1008 1009 def host_address_space(self): 1010 self.ensure_initialized() 1011 with c_api_util.tf_buffer() as buffer_: 1012 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_) 1013 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 1014 return address_space 1015 1016 # TODO(fishx): remove this property. 1017 @property 1018 def execution_mode(self): 1019 """Gets execution mode for current thread.""" 1020 return ASYNC if self.is_async() else SYNC 1021 1022 @execution_mode.setter 1023 def execution_mode(self, mode): 1024 """Sets execution mode for current thread.""" 1025 if mode not in (None, SYNC, ASYNC): 1026 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" % 1027 mode) 1028 1029 if mode is None: 1030 mode = SYNC 1031 1032 enable_async = (mode == ASYNC) 1033 if self.is_async() != enable_async: 1034 # Only set the execution mode if the context has already been initialized 1035 if self._context_handle is not None: 1036 self.executor.wait() 1037 executor_new = executor.new_executor(enable_async) 1038 self._thread_local_data.executor = executor_new 1039 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, 1040 executor_new.handle()) 1041 else: 1042 self._default_is_async = enable_async 1043 1044 def is_async(self): 1045 if self._context_handle is not None: 1046 return self.executor.is_async() 1047 else: 1048 return self._default_is_async 1049 1050 @property 1051 def executor(self): 1052 self.ensure_initialized() 1053 return executor.Executor( 1054 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) 1055 1056 @executor.setter 1057 def executor(self, e): 1058 self.ensure_initialized() 1059 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) 1060 1061 @property 1062 def config(self): 1063 """Return the ConfigProto with all runtime deltas applied.""" 1064 # Ensure physical devices have been discovered and config has been imported 1065 self._initialize_physical_devices() 1066 1067 config = config_pb2.ConfigProto() 1068 if self._config is not None: 1069 config.CopyFrom(self._config) 1070 1071 if self._optimizer_jit is not None: 1072 config.graph_options.optimizer_options.global_jit_level = ( 1073 config_pb2.OptimizerOptions.ON_1 1074 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) 1075 if self._intra_op_parallelism_threads is not None: 1076 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads 1077 if self._inter_op_parallelism_threads is not None: 1078 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads 1079 1080 if self._soft_device_placement is not None: 1081 config.allow_soft_placement = self._soft_device_placement 1082 else: 1083 config.allow_soft_placement = self.executing_eagerly() 1084 1085 if self._log_device_placement is not None: 1086 config.log_device_placement = self._log_device_placement 1087 1088 if self._operation_timeout_in_ms is not None: 1089 config.operation_timeout_in_ms = self._operation_timeout_in_ms 1090 1091 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() 1092 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled 1093 if (is_mlir_bridge_enabled == 1094 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): 1095 config.experimental.enable_mlir_bridge = True 1096 1097 if self._enable_mlir_graph_optimization is not None: 1098 config.experimental.enable_mlir_graph_optimization = ( 1099 self._enable_mlir_graph_optimization) 1100 1101 def rewriter_toggle(option): 1102 toggle = self._optimizer_experimental_options.get(option, None) 1103 if toggle is None: 1104 return 1105 1106 setattr(config.graph_options.rewrite_options, option, 1107 (rewriter_config_pb2.RewriterConfig.ON 1108 if toggle else rewriter_config_pb2.RewriterConfig.OFF)) 1109 1110 def rewriter_bool(option): 1111 toggle = self._optimizer_experimental_options.get(option, None) 1112 if toggle is None: 1113 return 1114 1115 setattr(config.graph_options.rewrite_options, option, toggle) 1116 1117 rewriter_toggle("layout_optimizer") 1118 rewriter_toggle("constant_folding") 1119 rewriter_toggle("shape_optimization") 1120 rewriter_toggle("remapping") 1121 rewriter_toggle("arithmetic_optimization") 1122 rewriter_toggle("dependency_optimization") 1123 rewriter_toggle("loop_optimization") 1124 rewriter_toggle("function_optimization") 1125 rewriter_toggle("debug_stripper") 1126 rewriter_bool("disable_model_pruning") 1127 rewriter_toggle("scoped_allocator_optimization") 1128 rewriter_toggle("pin_to_host_optimization") 1129 rewriter_toggle("implementation_selector") 1130 rewriter_toggle("auto_mixed_precision") 1131 rewriter_toggle("use_plugin_optimizers") 1132 rewriter_bool("disable_meta_optimizer") 1133 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 1134 rewriter_toggle("auto_mixed_precision_mkl") 1135 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None) 1136 if nodes is not None: 1137 config.graph_options.rewrite_options.min_graph_nodes = nodes 1138 1139 # Compute device counts 1140 config.device_count["CPU"] = 0 1141 config.device_count["GPU"] = 0 1142 for dev in self._physical_devices: 1143 if dev not in self._visible_device_list: 1144 continue 1145 1146 virtual_devices = self._virtual_device_map.get(dev) 1147 if virtual_devices is None: 1148 config.device_count[dev.device_type] += 1 1149 else: 1150 config.device_count[dev.device_type] += len(virtual_devices) 1151 1152 # Configure gpu_options 1153 gpu_options = self._compute_gpu_options() 1154 config.gpu_options.MergeFrom(gpu_options) 1155 1156 # Configure collective ops 1157 if self._collective_leader: 1158 config.experimental.collective_group_leader = self._collective_leader 1159 if self._collective_scoped_allocator_enabled_ops: 1160 rewrite_options = config.graph_options.rewrite_options 1161 rewrite_options.scoped_allocator_optimization = ( 1162 rewriter_config_pb2.RewriterConfig.ON) 1163 del rewrite_options.scoped_allocator_opts.enable_op[:] 1164 for op in self._collective_scoped_allocator_enabled_ops: 1165 rewrite_options.scoped_allocator_opts.enable_op.append(op) 1166 if self._collective_use_nccl_communication: 1167 config.experimental.collective_nccl = True 1168 if self._collective_device_filters: 1169 del config.device_filters[:] 1170 for f in self._collective_device_filters: 1171 config.device_filters.append(f) 1172 1173 # Configure coordination service 1174 if self._coordination_service_config: 1175 config.experimental.coordination_config.CopyFrom( 1176 self._coordination_service_config) 1177 1178 return config 1179 1180 def _compute_gpu_options(self): 1181 """Build the GPUOptions proto.""" 1182 visible_device_list = [] 1183 virtual_devices = [] 1184 gpu_index = -1 1185 memory_growths = set() 1186 gpu_devices = self.list_physical_devices("GPU") 1187 pluggable_devices = self._pluggable_devices 1188 compatible_devices = gpu_devices 1189 for dev in pluggable_devices: 1190 if dev not in gpu_devices: 1191 compatible_devices.append(dev) 1192 for dev in compatible_devices: 1193 gpu_index += 1 1194 1195 if dev not in self._visible_device_list: 1196 continue 1197 1198 growth = self._memory_growth_map[dev] 1199 memory_growths.add(growth) 1200 visible_device_list.append(str(gpu_index)) 1201 1202 if self._virtual_device_map: 1203 vdevs = self._virtual_device_map.get(dev, []) 1204 device_ordinals = [] 1205 device_limits = [] 1206 priority = [] 1207 for virt_dev in vdevs: 1208 device_ordinals.append(virt_dev.experimental_device_ordinal) 1209 device_limits.append(virt_dev.memory_limit) 1210 if virt_dev.experimental_priority is not None: 1211 priority.append(virt_dev.experimental_priority) 1212 # If priority is specified, it must be specified for all virtual 1213 # devices. 1214 if priority and len(device_limits) != len(priority): 1215 raise ValueError("priority must be specified for all virtual devices") 1216 1217 virtual_devices.append( 1218 config_pb2.GPUOptions.Experimental.VirtualDevices( 1219 memory_limit_mb=device_limits, 1220 priority=priority, 1221 device_ordinal=device_ordinals)) 1222 1223 # Only compute growth if virtual devices have not been configured and we 1224 # have GPUs 1225 if not virtual_devices and memory_growths: 1226 if len(memory_growths) > 1: 1227 raise ValueError("Memory growth cannot differ between GPU devices") 1228 allow_growth = memory_growths.pop() 1229 else: 1230 allow_growth = None 1231 1232 return config_pb2.GPUOptions( 1233 allow_growth=allow_growth, 1234 visible_device_list=",".join(visible_device_list), 1235 experimental=config_pb2.GPUOptions.Experimental( 1236 virtual_devices=virtual_devices)) 1237 1238 @property 1239 def function_call_options(self): 1240 """Returns function call options for current thread. 1241 1242 Note that the returned object is still referenced by the eager context. 1243 1244 Returns: the FunctionCallOptions for current thread. 1245 """ 1246 if self._thread_local_data.function_call_options is None: 1247 config = self.config 1248 1249 # Default to soft placement for functions unless specified 1250 if self._soft_device_placement is None: 1251 config.allow_soft_placement = True 1252 self._thread_local_data.function_call_options = FunctionCallOptions( 1253 config_proto=config) 1254 1255 return self._thread_local_data.function_call_options 1256 1257 @function_call_options.setter 1258 def function_call_options(self, options): 1259 """Returns function call options for current thread.""" 1260 self._thread_local_data.function_call_options = options 1261 1262 def num_gpus(self): 1263 """The number of GPUs available to execute operations.""" 1264 self.ensure_initialized() 1265 return self._num_gpus 1266 1267 def add_function(self, fn): 1268 """Add a function definition to the context. 1269 1270 Once added, the function (identified by its name) can be executed like any 1271 other operation. 1272 1273 Args: 1274 fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 1275 """ 1276 self.ensure_initialized() 1277 pywrap_tfe.TFE_ContextAddFunction(self._handle, fn) 1278 1279 def add_function_def(self, fdef): 1280 """Add a function definition to the context. 1281 1282 Once added, the function (identified by its name) can be executed like any 1283 other operation. 1284 1285 Args: 1286 fdef: A FunctionDef protocol buffer message. 1287 """ 1288 self.ensure_initialized() 1289 fdef_string = fdef.SerializeToString() 1290 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, 1291 len(fdef_string)) 1292 1293 def get_function_def(self, name): 1294 """Get a function definition from the context. 1295 1296 Args: 1297 name: function signature name. 1298 1299 Returns: 1300 The requested FunctionDef. 1301 1302 Raises: 1303 tf.errors.NotFoundError: if name is not the name of a registered function. 1304 """ 1305 with c_api_util.tf_buffer() as buffer_: 1306 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_) 1307 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1308 function_def = function_pb2.FunctionDef() 1309 function_def.ParseFromString(proto_data) 1310 1311 return function_def 1312 1313 def register_custom_device(self, device_capsule, device_name, 1314 device_info_capsule): 1315 """Calls TFE_RegisterCustomDevice. See the non-member function.""" 1316 self.ensure_initialized() 1317 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, 1318 device_name, device_info_capsule) 1319 1320 def pack_eager_tensors(self, tensors): 1321 """Pack multiple `EagerTensor`s of the same dtype and shape. 1322 1323 Args: 1324 tensors: a list of EagerTensors to pack. 1325 1326 Returns: 1327 A packed EagerTensor. 1328 """ 1329 self.ensure_initialized() 1330 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) 1331 1332 def list_function_names(self): 1333 """Get a list of names of registered functions. 1334 1335 Returns: 1336 A set of names of all registered functions for the context. 1337 """ 1338 self.ensure_initialized() 1339 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle)) 1340 1341 def remove_function(self, name): 1342 """Remove a function from the context. 1343 1344 Once removed, the function cannot be executed anymore. 1345 1346 Args: 1347 name: function signature name. 1348 """ 1349 self.ensure_initialized() 1350 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) 1351 1352 def has_function(self, name): 1353 """Check if a function `name` is registered.""" 1354 self.ensure_initialized() 1355 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) 1356 1357 def add_op_callback(self, callback): 1358 """Add a post-op callback to the context. 1359 1360 A post-op callback is invoked immediately after an eager operation or 1361 function has finished execution or after a op has been added to a graph, 1362 providing access to the op's type, name input and output tensors. Multiple 1363 op callbacks can be added, in which case the callbacks will be invoked in 1364 the order in which they are added. 1365 1366 Args: 1367 callback: a callable of the signature `f(op_type, inputs, attrs, outputs, 1368 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for 1369 details on the function signature and its semantics. 1370 """ 1371 if callback not in self._thread_local_data.op_callbacks: 1372 self._thread_local_data.op_callbacks.append(callback) 1373 1374 def remove_op_callback(self, callback): 1375 """Remove an already-registered op callback. 1376 1377 Args: 1378 callback: The op callback to be removed. 1379 1380 Raises: 1381 KeyError: If `callback` is not already registered. 1382 """ 1383 if callback not in self._thread_local_data.op_callbacks: 1384 raise KeyError("The specified op callback has not been registered, " 1385 "and hence cannot be removed.") 1386 del self._thread_local_data.op_callbacks[ 1387 self._thread_local_data.op_callbacks.index(callback)] 1388 1389 @property 1390 def op_callbacks(self): 1391 return self._thread_local_data.op_callbacks 1392 1393 @property 1394 def invoking_op_callbacks(self): 1395 return self._thread_local_data.invoking_op_callbacks 1396 1397 @invoking_op_callbacks.setter 1398 def invoking_op_callbacks(self, value): 1399 self._thread_local_data.invoking_op_callbacks = value 1400 1401 def _initialize_physical_devices(self, reinitialize=False): 1402 """Gets local devices visible to the system. 1403 1404 Args: 1405 reinitialize: If True, reinitializes self._physical_devices so that 1406 dynamic registered devices will also be visible to the python front-end. 1407 """ 1408 # We lazy initialize self._physical_devices since we do not want to do this 1409 # the constructor since the backend may not be initialized yet. 1410 with self._device_lock: 1411 if not reinitialize and self._physical_devices is not None: 1412 return 1413 1414 devs = pywrap_tfe.TF_ListPhysicalDevices() 1415 self._physical_devices = [ 1416 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 1417 for d in devs 1418 ] 1419 self._physical_device_to_index = { 1420 p: i for i, p in enumerate(self._physical_devices) 1421 } 1422 # We maintain a separate list just so we can check whether the device in 1423 # _physical_devices is a PluggableDevice. 1424 pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices() 1425 self._pluggable_devices = [ 1426 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 1427 for d in pluggable_devs 1428 ] 1429 1430 self._visible_device_list = list(self._physical_devices) 1431 self._memory_growth_map = { 1432 d: None 1433 for d in self._physical_devices 1434 if d.device_type == "GPU" or d in self._pluggable_devices 1435 } 1436 1437 # Import device settings that may have been passed into the constructor 1438 self._import_config() 1439 1440 def reinitialize_physical_devices(self): 1441 """Gets local devices visible to the system.""" 1442 # Reinitialize the physical device list after registering 1443 # the pluggable device. 1444 self._initialize_physical_devices(True) 1445 1446 def list_physical_devices(self, device_type=None): 1447 """List local devices visible to the system. 1448 1449 This API allows a client to query the devices before they have been 1450 initialized by the eager runtime. Additionally a user can filter by device 1451 type, to get only CPUs or GPUs. 1452 1453 Args: 1454 device_type: Optional device type to limit results to 1455 1456 Returns: 1457 List of PhysicalDevice objects. 1458 """ 1459 self._initialize_physical_devices() 1460 1461 if device_type is None: 1462 return list(self._physical_devices) 1463 1464 return [d for d in self._physical_devices if d.device_type == device_type] 1465 1466 def get_device_details(self, device): # pylint: disable=redefined-outer-name 1467 """Returns details about a physical devices. 1468 1469 Args: 1470 device: A `tf.config.PhysicalDevice` returned by 1471 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. 1472 1473 Returns: 1474 A dict with string keys. 1475 """ 1476 if not isinstance(device, PhysicalDevice): 1477 raise ValueError("device must be a tf.config.PhysicalDevice, but got: " 1478 "%s" % (device,)) 1479 if (self._physical_device_to_index is None or 1480 device not in self._physical_device_to_index): 1481 raise ValueError("The PhysicalDevice must be one obtained from " 1482 "calling `tf.config.list_physical_devices`, but got: " 1483 "%s" % (device,)) 1484 index = self._physical_device_to_index[device] 1485 details = pywrap_tfe.TF_GetDeviceDetails(index) 1486 1487 # Change compute_capability from a string to a tuple 1488 if "compute_capability" in details: 1489 try: 1490 major, minor = details["compute_capability"].split(".") 1491 details["compute_capability"] = (int(major), int(minor)) 1492 except ValueError: 1493 raise RuntimeError("Device returned compute capability an in invalid " 1494 "format: %s" % details["compute_capability"]) 1495 return details 1496 1497 def _import_config(self): 1498 """Import config if passed in during construction. 1499 1500 If Context was created with a ConfigProto such as when calling 1501 tf.compat.v1.enable_eager_execution(), then we need to pull out the 1502 various pieces we might be replacing and import then into our internal 1503 class representation. 1504 """ 1505 if self._config is None: 1506 return 1507 1508 num_cpus = self._config.device_count.get("CPU", 1) 1509 if num_cpus != 1: 1510 cpus = [d for d in self._physical_devices if d.device_type == "CPU"] 1511 if num_cpus == 0: 1512 self.set_visible_devices([], "CPU") 1513 elif num_cpus > 1: 1514 self.set_logical_device_configuration( 1515 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) 1516 1517 # Parse GPU options 1518 gpus = [d for d in self._physical_devices if d.device_type == "GPU"] 1519 1520 # If there are no GPUs detected, simply ignore all the GPU options passed in 1521 # rather than doing any validation checks. 1522 if not gpus: 1523 return 1524 1525 gpu_count = self._config.device_count.get("GPU", None) 1526 1527 visible_gpus = [] 1528 # TODO(gjn): Handle importing existing virtual GPU configuration 1529 visible_indices = self._config.gpu_options.visible_device_list 1530 if visible_indices: 1531 for index in visible_indices.split(","): 1532 if int(index) >= len(gpus): 1533 raise ValueError("Invalid visible device index: %s" % index) 1534 visible_gpus.append(gpus[int(index)]) 1535 else: 1536 visible_gpus = gpus 1537 1538 if gpu_count is not None: 1539 visible_gpus = visible_gpus[:gpu_count] 1540 1541 self.set_visible_devices(visible_gpus, "GPU") 1542 1543 def list_logical_devices(self, device_type=None): 1544 """Return logical devices.""" 1545 self.ensure_initialized() 1546 if device_type is None: 1547 return list(self._logical_devices) 1548 1549 return [d for d in self._logical_devices if d.device_type == device_type] 1550 1551 def get_visible_devices(self, device_type=None): 1552 """Get the list of visible devices.""" 1553 self._initialize_physical_devices() 1554 1555 if device_type is None: 1556 return list(self._visible_device_list) 1557 1558 return [ 1559 d for d in self._visible_device_list if d.device_type == device_type 1560 ] 1561 1562 def set_visible_devices(self, devices, device_type=None): 1563 """Set the list of visible devices.""" 1564 self._initialize_physical_devices() 1565 1566 if not isinstance(devices, list): 1567 devices = [devices] 1568 1569 for d in devices: 1570 if d not in self._physical_devices: 1571 raise ValueError("Unrecognized device: %s" % repr(d)) 1572 if device_type is not None and d.device_type != device_type: 1573 raise ValueError("Unrecognized device: %s" % repr(d)) 1574 1575 visible_device_list = [] 1576 if device_type is not None: 1577 visible_device_list = [ 1578 d for d in self._visible_device_list if d.device_type != device_type 1579 ] 1580 1581 visible_device_list += devices 1582 1583 if self._visible_device_list == visible_device_list: 1584 return 1585 1586 if self._context_handle is not None: 1587 raise RuntimeError( 1588 "Visible devices cannot be modified after being initialized") 1589 1590 self._visible_device_list = visible_device_list 1591 1592 def get_memory_info(self, dev): 1593 """Returns a dict of memory info for the device.""" 1594 self._initialize_physical_devices() 1595 self.ensure_initialized() 1596 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev) 1597 1598 def reset_memory_stats(self, dev): 1599 """Resets the tracked memory stats for the device.""" 1600 self._initialize_physical_devices() 1601 self.ensure_initialized() 1602 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev) 1603 1604 def get_memory_growth(self, dev): 1605 """Get if memory growth is enabled for a PhysicalDevice.""" 1606 self._initialize_physical_devices() 1607 1608 if dev not in self._physical_devices: 1609 raise ValueError("Unrecognized device: %s" % repr(dev)) 1610 1611 return self._memory_growth_map[dev] 1612 1613 def set_memory_growth(self, dev, enable): 1614 """Set if memory growth should be enabled for a PhysicalDevice.""" 1615 self._initialize_physical_devices() 1616 1617 if dev not in self._physical_devices: 1618 raise ValueError("Unrecognized device: %s" % repr(dev)) 1619 1620 if dev in self._virtual_device_map: 1621 raise ValueError( 1622 "Cannot set memory growth on device when virtual devices configured") 1623 1624 if dev.device_type != "GPU" and dev not in self._pluggable_devices: 1625 raise ValueError( 1626 "Cannot set memory growth on non-GPU and non-Pluggable devices") 1627 1628 if self._memory_growth_map.get(dev) == enable: 1629 return 1630 1631 if self._context_handle is not None: 1632 raise RuntimeError( 1633 "Physical devices cannot be modified after being initialized") 1634 1635 self._memory_growth_map[dev] = enable 1636 1637 def get_logical_device_configuration(self, dev): 1638 """Get the virtual device configuration for a PhysicalDevice.""" 1639 self._initialize_physical_devices() 1640 1641 if dev not in self._physical_devices: 1642 raise ValueError("Unrecognized device: %s" % repr(dev)) 1643 1644 return self._virtual_device_map.get(dev) 1645 1646 def set_logical_device_configuration(self, dev, virtual_devices): 1647 """Set the virtual device configuration for a PhysicalDevice.""" 1648 self._initialize_physical_devices() 1649 1650 if dev not in self._physical_devices: 1651 raise ValueError("Unrecognized device: %s" % repr(dev)) 1652 1653 if dev.device_type == "CPU": 1654 for vdev in virtual_devices: 1655 if vdev.memory_limit is not None: 1656 raise ValueError("Setting memory limit on CPU virtual devices is " 1657 "currently not supported") 1658 if vdev.experimental_priority is not None: 1659 raise ValueError("Setting experimental_priority on CPU virtual " 1660 " devices is currently not supported") 1661 if vdev.experimental_device_ordinal != 0: 1662 raise ValueError("Setting experimental_device_ordinal on CPU virtual " 1663 " devices is currently not supported") 1664 elif dev.device_type == "GPU": 1665 for vdev in virtual_devices: 1666 if vdev.memory_limit is None: 1667 raise ValueError( 1668 "Setting memory limit is required for GPU virtual devices") 1669 else: 1670 raise ValueError("Virtual devices are not supported for %s" % 1671 dev.device_type) 1672 1673 if self._virtual_device_map.get(dev) == virtual_devices: 1674 return 1675 1676 if self._context_handle is not None: 1677 raise RuntimeError( 1678 "Virtual devices cannot be modified after being initialized") 1679 1680 self._virtual_device_map[dev] = virtual_devices 1681 1682 def set_logical_cpu_devices(self, num_cpus, prefix=""): 1683 """Set virtual CPU devices in context. 1684 1685 If virtual CPU devices are already configured at context initialization 1686 by tf.config.set_logical_device_configuration(), this method should not be 1687 called. 1688 1689 Args: 1690 num_cpus: Number of virtual CPUs. 1691 prefix: Device name prefix. 1692 1693 Raises: 1694 RuntimeError: If virtual CPUs are already configured at context 1695 initialization. 1696 """ 1697 server_def = self._server_def or self._collective_ops_server_def 1698 local_prefix = ["/device"] 1699 if server_def is not None: 1700 local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name, 1701 server_def.task_index)) 1702 logical_local_devices = [d for d in self.list_logical_devices("CPU") if 1703 d.name.startswith(tuple(local_prefix))] 1704 self.ensure_initialized() 1705 # Error out if there are already multiple logical CPU in the context. 1706 if len(logical_local_devices) > 1: 1707 raise RuntimeError("Virtual CPUs already set, cannot modify again.") 1708 1709 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix) 1710 self._initialize_logical_devices() 1711 1712 def get_compiler_ir(self, device_name, function_name, args, stage="hlo"): 1713 return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name, 1714 stage, device_name, args) 1715 1716 @deprecated( 1717 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) 1718 def enable_xla_devices(self): 1719 """Enables XLA:CPU and XLA:GPU devices registration.""" 1720 pywrap_tfe.TF_EnableXlaDevices() 1721 1722 @property 1723 def enable_mlir_bridge(self): 1724 return pywrap_tfe.TF_IsMlirBridgeEnabled() 1725 1726 @property 1727 def enable_mlir_graph_optimization(self): 1728 return self._enable_mlir_graph_optimization 1729 1730 @enable_mlir_bridge.setter 1731 def enable_mlir_bridge(self, enabled): 1732 pywrap_tfe.TF_EnableMlirBridge(enabled) 1733 self._thread_local_data.function_call_options = None 1734 1735 @enable_mlir_graph_optimization.setter 1736 def enable_mlir_graph_optimization(self, enabled): 1737 self._enable_mlir_graph_optimization = enabled 1738 self._thread_local_data.function_call_options = None 1739 1740 @property 1741 def optimizer_jit(self): 1742 level = self.config.graph_options.optimizer_options.global_jit_level 1743 return (level == config_pb2.OptimizerOptions.ON_1 or 1744 level == config_pb2.OptimizerOptions.ON_2) 1745 1746 @optimizer_jit.setter 1747 def optimizer_jit(self, enabled): 1748 self._optimizer_jit = enabled 1749 1750 self._thread_local_data.function_call_options = None 1751 1752 def get_optimizer_experimental_options(self): 1753 """Get experimental options for the optimizer. 1754 1755 Returns: 1756 Dictionary of current option values 1757 """ 1758 rewrite_options = self.config.graph_options.rewrite_options 1759 options = {} 1760 1761 def rewriter_toggle(option): 1762 attr = getattr(rewrite_options, option) 1763 if attr != 0: 1764 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) 1765 1766 def rewriter_bool(option): 1767 options[option] = getattr(rewrite_options, option) 1768 1769 rewriter_toggle("layout_optimizer") 1770 rewriter_toggle("constant_folding") 1771 rewriter_toggle("shape_optimization") 1772 rewriter_toggle("remapping") 1773 rewriter_toggle("arithmetic_optimization") 1774 rewriter_toggle("dependency_optimization") 1775 rewriter_toggle("loop_optimization") 1776 rewriter_toggle("function_optimization") 1777 rewriter_toggle("debug_stripper") 1778 rewriter_bool("disable_model_pruning") 1779 rewriter_toggle("scoped_allocator_optimization") 1780 rewriter_toggle("pin_to_host_optimization") 1781 rewriter_toggle("implementation_selector") 1782 rewriter_toggle("auto_mixed_precision") 1783 rewriter_toggle("use_plugin_optimizers") 1784 rewriter_bool("disable_meta_optimizer") 1785 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 1786 rewriter_toggle("auto_mixed_precision_mkl") 1787 1788 if rewrite_options.min_graph_nodes != 0: 1789 options["min_graph_nodes"] = rewrite_options.min_graph_nodes 1790 1791 return options 1792 1793 def set_optimizer_experimental_options(self, options): 1794 """Set experimental options for the optimizer. 1795 1796 Args: 1797 options: Dictionary of options to modify 1798 """ 1799 self._optimizer_experimental_options.update(options) 1800 1801 self._thread_local_data.function_call_options = None 1802 1803 @property 1804 def intra_op_parallelism_threads(self): 1805 return self.config.intra_op_parallelism_threads 1806 1807 @intra_op_parallelism_threads.setter 1808 def intra_op_parallelism_threads(self, num_threads): 1809 if self._intra_op_parallelism_threads == num_threads: 1810 return 1811 1812 if self._context_handle is not None: 1813 raise RuntimeError( 1814 "Intra op parallelism cannot be modified after initialization.") 1815 1816 self._intra_op_parallelism_threads = num_threads 1817 1818 @property 1819 def inter_op_parallelism_threads(self): 1820 return self.config.inter_op_parallelism_threads 1821 1822 @inter_op_parallelism_threads.setter 1823 def inter_op_parallelism_threads(self, num_threads): 1824 if self._inter_op_parallelism_threads == num_threads: 1825 return 1826 1827 if self._context_handle is not None: 1828 raise RuntimeError( 1829 "Inter op parallelism cannot be modified after initialization.") 1830 1831 self._inter_op_parallelism_threads = num_threads 1832 1833 @property 1834 def soft_device_placement(self): 1835 return self.config.allow_soft_placement 1836 1837 @soft_device_placement.setter 1838 def soft_device_placement(self, enable): 1839 if self._context_handle is not None: 1840 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) 1841 1842 self._soft_device_placement = enable 1843 self._thread_local_data.function_call_options = None 1844 1845 @property 1846 def log_device_placement(self): 1847 return self.config.log_device_placement 1848 1849 @log_device_placement.setter 1850 def log_device_placement(self, enable): 1851 if self._context_handle is not None: 1852 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) 1853 1854 self._log_device_placement = enable 1855 self._thread_local_data.function_call_options = None 1856 1857 @property 1858 def jit_compile_rewrite(self): 1859 return self._jit_compile_rewrite 1860 1861 @jit_compile_rewrite.setter 1862 def jit_compile_rewrite(self, enable): 1863 if self._context_handle is not None: 1864 pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable) 1865 self._jit_compile_rewrite = enable 1866 1867 @property 1868 def device_policy(self): 1869 # Only get the policy from the context if it has already been initialized 1870 if self._context_handle is not None: 1871 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) 1872 1873 return self._device_policy 1874 1875 @device_policy.setter 1876 def device_policy(self, policy): 1877 if policy is None: 1878 policy = DEVICE_PLACEMENT_SILENT 1879 1880 if self._device_policy != policy: 1881 self._device_policy = policy 1882 1883 # Only set the policy if the context has already been initialized 1884 if self._context_handle is not None: 1885 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( 1886 self._handle, self._device_policy) 1887 1888 @property 1889 def use_tfrt(self): 1890 return self._use_tfrt 1891 1892 @use_tfrt.setter 1893 def use_tfrt(self, tfrt): 1894 """Sets whether to use TFRT.""" 1895 if not isinstance(tfrt, bool): 1896 raise ValueError("Expecting a boolean but got %s" % type(tfrt)) 1897 1898 if self._use_tfrt != tfrt: 1899 if self._initialized: 1900 raise ValueError("use_tfrt should be set before being initialized.") 1901 self._use_tfrt = tfrt 1902 1903 @property 1904 def use_tfrt_distributed_runtime(self): 1905 return self._use_tfrt_distributed_runtime 1906 1907 @use_tfrt_distributed_runtime.setter 1908 def use_tfrt_distributed_runtime(self, enable): 1909 """Sets whether to use TFRT distributed runtime. 1910 1911 This is only effective when use_tfrt is also true. Note that currently TFRT 1912 distributed runtime is not function complete and this config is for testing 1913 only. 1914 Args: 1915 enable: A boolean to set whether to use TFRT distributed runtime. 1916 """ 1917 if not isinstance(enable, bool): 1918 raise ValueError("Expecting a boolean but got %s" % type(enable)) 1919 1920 if self._use_tfrt_distributed_runtime != enable: 1921 if self._initialized: 1922 raise ValueError("use_tfrt should be set before being initialized.") 1923 self._use_tfrt_distributed_runtime = enable 1924 1925 @property 1926 def operation_timeout_in_ms(self): 1927 return self.config.operation_timeout_in_ms 1928 1929 @operation_timeout_in_ms.setter 1930 def operation_timeout_in_ms(self, timeout_in_ms): 1931 if self._operation_timeout_in_ms == timeout_in_ms: 1932 return 1933 1934 if self._context_handle is not None: 1935 raise RuntimeError( 1936 "Operation timeout cannot be modified after initialization.") 1937 1938 self._operation_timeout_in_ms = timeout_in_ms 1939 1940 def enable_run_metadata(self): 1941 """Enables tracing of op execution via RunMetadata. 1942 1943 To retrieve the accumulated metadata call context.export_run_metadata() 1944 and to stop tracing call context.disable_run_metadata(). 1945 """ 1946 self.ensure_initialized() 1947 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) 1948 1949 def disable_run_metadata(self): 1950 """Disables tracing of op execution via RunMetadata.""" 1951 if not self._context_handle: 1952 return 1953 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) 1954 1955 def enable_graph_collection(self): 1956 """Enables graph collection of executed functions. 1957 1958 To retrieve the accumulated graphs call context.export_run_metadata() 1959 and to stop collecting graphs call context.disable_graph_collection(). 1960 """ 1961 self.ensure_initialized() 1962 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) 1963 1964 def disable_graph_collection(self): 1965 """Disables graph collection of executed functions.""" 1966 if not self._context_handle: 1967 return 1968 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) 1969 1970 def export_run_metadata(self): 1971 """Returns a RunMetadata proto with accumulated information. 1972 1973 The returned protocol buffer contains information since the most recent call 1974 to either enable_run_metadata or export_run_metadata. 1975 1976 Returns: 1977 A RunMetadata protocol buffer. Or None if not enabled. 1978 """ 1979 if not self._context_handle: 1980 return None 1981 with c_api_util.tf_buffer() as buffer_: 1982 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) 1983 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 1984 run_metadata = config_pb2.RunMetadata() 1985 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1986 return run_metadata 1987 1988 @property 1989 def context_switches(self): 1990 """Returns a stack of context switches.""" 1991 return self._context_switches 1992 1993 1994class _EagerDeviceContext(object): 1995 """Context-manager forcing placement of ops and Tensors on a device.""" 1996 1997 __slots__ = ["_device_name", "_ctx", "_stack"] 1998 1999 def __init__(self, ctx, device_name): 2000 self._device_name = device_name 2001 self._ctx = ctx 2002 self._stack = [] 2003 2004 # TODO(b/189233748): Consolidate the device string parsing logic with 2005 # tensorflow/core/util/device_name_utils.cc. 2006 def __enter__(self): 2007 ctx = self._ctx 2008 old_device_name = ctx.device_name 2009 old_device_spec = ctx.device_spec 2010 new_device_name = self._device_name 2011 cache_key = (old_device_name, new_device_name) 2012 try: 2013 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 2014 except TypeError: 2015 # Error while trying to compute the cache key. 2016 raise ValueError("Expecting a string device name. Got %s(%s)" % 2017 (type(new_device_name), new_device_name)) 2018 except KeyError: 2019 # Handle a cache miss. 2020 if new_device_name is not None: 2021 if not isinstance(new_device_name, str): 2022 raise ValueError("Expecting a string device name. Got %s(%s)" % 2023 (type(new_device_name), new_device_name)) 2024 device_spec = pydev.DeviceSpec.from_string(new_device_name) 2025 if old_device_name: 2026 new_device_spec = copy.copy(old_device_spec) 2027 else: 2028 ctx.ensure_initialized() 2029 new_device_spec = pydev.DeviceSpec.from_string( 2030 ctx._context_devices[0]) # pylint: disable=protected-access 2031 new_device_spec = new_device_spec.make_merged_spec(device_spec) 2032 else: 2033 new_device_spec = pydev.DeviceSpec.from_string("") 2034 new_device_name = new_device_spec.to_string() 2035 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 2036 2037 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access 2038 self._stack.append((old_device_name, old_device_spec, new_device_spec)) 2039 2040 def __exit__(self, *ex_info): 2041 ctx = self._ctx 2042 old_device_name, old_device_spec, new_device_spec = self._stack[-1] 2043 if ctx.device_spec is not new_device_spec: 2044 raise RuntimeError("Exiting device scope without proper scope nesting") 2045 del self._stack[-1] 2046 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access 2047 2048 2049# Do not change directly. 2050_context = None 2051_context_lock = threading.Lock() 2052 2053 2054def _set_context_locked(ctx): 2055 global _context 2056 pywrap_tfe.TFE_Py_SetEagerContext(ctx) 2057 ctx.mark_as_global_context() 2058 _context = ctx 2059 2060 2061def _set_context(ctx): 2062 with _context_lock: 2063 _set_context_locked(ctx) 2064 2065 2066def _create_context(): 2067 with _context_lock: 2068 if _context is None: 2069 ctx = Context() 2070 _set_context_locked(ctx) 2071 2072 2073def _reset_context(): 2074 """Clears and re-initializes the singleton context. 2075 2076 Should only be used for testing. 2077 """ 2078 global _context 2079 global _device_parsing_cache 2080 2081 # Garbage collect and clear scalar cache to avoid Tensor from current context 2082 # polluting next context. 2083 gc.collect() 2084 pywrap_tfe.TFE_ClearScalarCache() 2085 with _context_lock: 2086 if _context is not None: 2087 _context._clear_caches() 2088 _context = None 2089 _create_context() 2090 _device_parsing_cache = {} 2091 2092 2093def _reset_jit_compiler_flags(): 2094 """Clears and re-initializes the TF JIT compiler flags. 2095 2096 Should only be used for testing. 2097 """ 2098 pywrap_tfe.TF_ResetJitCompilerFlags() 2099 2100 2101def context(): 2102 """Returns a singleton context object.""" 2103 if _context is None: 2104 _create_context() 2105 return _context 2106 2107 2108def context_safe(): 2109 """Returns current context (or None if one hasn't been initialized).""" 2110 return _context 2111 2112 2113def ensure_initialized(): 2114 """Initialize the context.""" 2115 context().ensure_initialized() 2116 2117 2118def initialize_logical_devices(): 2119 """Initialize the virtual devices.""" 2120 context()._initialize_logical_devices() # pylint: disable=protected-access 2121 2122 2123def set_global_seed(seed): 2124 """Sets the eager mode seed.""" 2125 context()._set_global_seed(seed) # pylint: disable=protected-access 2126 2127 2128def global_seed(): 2129 """Returns the eager mode seed.""" 2130 return context()._seed # pylint: disable=protected-access 2131 2132 2133def internal_operation_seed(): 2134 """Returns the operation seed generated based on global seed.""" 2135 return context()._internal_operation_seed() # pylint: disable=protected-access 2136 2137 2138@tf_export("executing_eagerly", v1=[]) 2139def executing_eagerly(): 2140 """Checks whether the current thread has eager execution enabled. 2141 2142 Eager execution is enabled by default and this API returns `True` 2143 in most of cases. However, this API might return `False` in the following use 2144 cases. 2145 2146 * Executing inside `tf.function`, unless under `tf.init_scope` or 2147 `tf.config.run_functions_eagerly(True)` is previously called. 2148 * Executing inside a transformation function for `tf.dataset`. 2149 * `tf.compat.v1.disable_eager_execution()` is called. 2150 2151 General case: 2152 2153 >>> print(tf.executing_eagerly()) 2154 True 2155 2156 Inside `tf.function`: 2157 2158 >>> @tf.function 2159 ... def fn(): 2160 ... with tf.init_scope(): 2161 ... print(tf.executing_eagerly()) 2162 ... print(tf.executing_eagerly()) 2163 >>> fn() 2164 True 2165 False 2166 2167 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called: 2168 2169 >>> tf.config.run_functions_eagerly(True) 2170 >>> @tf.function 2171 ... def fn(): 2172 ... with tf.init_scope(): 2173 ... print(tf.executing_eagerly()) 2174 ... print(tf.executing_eagerly()) 2175 >>> fn() 2176 True 2177 True 2178 >>> tf.config.run_functions_eagerly(False) 2179 2180 Inside a transformation function for `tf.dataset`: 2181 2182 >>> def data_fn(x): 2183 ... print(tf.executing_eagerly()) 2184 ... return x 2185 >>> dataset = tf.data.Dataset.range(100) 2186 >>> dataset = dataset.map(data_fn) 2187 False 2188 2189 Returns: 2190 `True` if the current thread has eager execution enabled. 2191 """ 2192 ctx = context_safe() 2193 if ctx is None: 2194 return default_execution_mode == EAGER_MODE 2195 2196 return ctx.executing_eagerly() 2197 2198 2199@tf_export(v1=["executing_eagerly"]) 2200def executing_eagerly_v1(): 2201 """Checks whether the current thread has eager execution enabled. 2202 2203 Eager execution is typically enabled via 2204 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the 2205 context of a Python function via tf.contrib.eager.py_func. 2206 2207 When eager execution is enabled, returns `True` in most cases. However, 2208 this API might return `False` in the following use cases. 2209 2210 * Executing inside `tf.function`, unless under `tf.init_scope` or 2211 `tf.config.run_functions_eagerly(True)` is previously called. 2212 * Executing inside a transformation function for `tf.dataset`. 2213 * `tf.compat.v1.disable_eager_execution()` is called. 2214 2215 >>> tf.compat.v1.enable_eager_execution() 2216 2217 General case: 2218 2219 >>> print(tf.executing_eagerly()) 2220 True 2221 2222 Inside `tf.function`: 2223 2224 >>> @tf.function 2225 ... def fn(): 2226 ... with tf.init_scope(): 2227 ... print(tf.executing_eagerly()) 2228 ... print(tf.executing_eagerly()) 2229 >>> fn() 2230 True 2231 False 2232 2233 Inside `tf.function` 2234 after `tf.config.run_functions_eagerly(True)` is called: 2235 2236 >>> tf.config.run_functions_eagerly(True) 2237 >>> @tf.function 2238 ... def fn(): 2239 ... with tf.init_scope(): 2240 ... print(tf.executing_eagerly()) 2241 ... print(tf.executing_eagerly()) 2242 >>> fn() 2243 True 2244 True 2245 >>> tf.config.run_functions_eagerly(False) 2246 2247 Inside a transformation function for `tf.dataset`: 2248 2249 >>> def data_fn(x): 2250 ... print(tf.executing_eagerly()) 2251 ... return x 2252 >>> dataset = tf.data.Dataset.range(100) 2253 >>> dataset = dataset.map(data_fn) 2254 False 2255 2256 Returns: 2257 `True` if the current thread has eager execution enabled. 2258 """ 2259 return executing_eagerly() 2260 2261 2262def in_eager_mode(): 2263 """Use executing_eagerly() instead. This function will be removed.""" 2264 return executing_eagerly() 2265 2266 2267def anonymous_name(): 2268 """Returns the anonymous shared name. 2269 2270 In eager mode we create anonymous resources to avoid spurious sharing issues. 2271 The runtime generates a unique name on our behalf when the reserved 2272 anonymous shared name is used as a shared name. 2273 2274 Returns: 2275 The anonymous shared name. 2276 """ 2277 2278 # The magic value is defined as 2279 # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++. 2280 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 2281 2282 2283def graph_mode(): 2284 """Context-manager to disable eager execution for the current thread.""" 2285 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 2286 2287 2288# Used by b/167638505 for keras backend API and Lambda layer. 2289@tf_export("__internal__.eager_context.eager_mode", v1=[]) 2290def eager_mode(): 2291 """Context-manager to enable eager execution for the current thread.""" 2292 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 2293 2294 2295def scope_name(): 2296 """Name of the current scope.""" 2297 return context().scope_name 2298 2299 2300def device(name): 2301 """Context-manager to force placement of operations and Tensors on a device. 2302 2303 Example: 2304 ```python 2305 with tf.device('gpu:0'): 2306 with tf.device('cpu:0'): 2307 shape = tf.constant([], dtype=tf.int32) 2308 x = tf.random.truncated_normal(shape, tf.float32) 2309 ``` 2310 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 2311 operation runs on GPU 0. 2312 2313 Args: 2314 name: Name of the device (see context().devices()), or None to perform 2315 automatic placement. 2316 2317 Returns: 2318 Context manager for setting the device. 2319 """ 2320 ensure_initialized() 2321 return context().device(name) 2322 2323 2324# Expose some properties of Context as internally public APIs (b/160348781). 2325@tf_export("__internal__.eager_context.get_config", v1=[]) 2326def get_config(): 2327 """Get the ConfigProto of Context. 2328 2329 Returns: 2330 The ConfigProto of Context. 2331 """ 2332 return context().config 2333 2334 2335@tf_export("__internal__.eager_context.get_device_name", v1=[]) 2336def get_device_name(): 2337 """Get the device name for the current thread. 2338 2339 Returns: 2340 The device name for the current thread. 2341 """ 2342 return context().device_name 2343 2344 2345@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) 2346def set_soft_device_placement(enabled): 2347 """Set if soft device placements should be allowed. 2348 2349 Args: 2350 enabled: Whether to enable soft device placement. 2351 """ 2352 context().soft_device_placement = enabled 2353 2354 2355@tf_export("__internal__.eager_context.get_executor", v1=[]) 2356def get_executor(): 2357 """Get the Executor of the current thread. 2358 2359 Returns: 2360 The Executor of the current thread. 2361 """ 2362 return context().executor 2363 2364 2365@tf_export("debugging.get_log_device_placement") 2366def get_log_device_placement(): 2367 """Get if device placements are logged. 2368 2369 Returns: 2370 If device placements are logged. 2371 """ 2372 return context().log_device_placement 2373 2374 2375@tf_export("debugging.set_log_device_placement") 2376def set_log_device_placement(enabled): 2377 """Turns logging for device placement decisions on or off. 2378 2379 Operations execute on a particular device, producing and consuming tensors on 2380 that device. This may change the performance of the operation or require 2381 TensorFlow to copy data to or from an accelerator, so knowing where operations 2382 execute is useful for debugging performance issues. 2383 2384 For more advanced profiling, use the [TensorFlow 2385 profiler](https://www.tensorflow.org/guide/profiler). 2386 2387 Device placement for operations is typically controlled by a `tf.device` 2388 scope, but there are exceptions, for example operations on a `tf.Variable` 2389 which follow the initial placement of the variable. Turning off soft device 2390 placement (with `tf.config.set_soft_device_placement`) provides more explicit 2391 control. 2392 2393 >>> tf.debugging.set_log_device_placement(True) 2394 >>> tf.ones([]) 2395 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0 2396 >>> with tf.device("CPU"): 2397 ... tf.ones([]) 2398 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0 2399 >>> tf.debugging.set_log_device_placement(False) 2400 2401 Turning on `tf.debugging.set_log_device_placement` also logs the placement of 2402 ops inside `tf.function` when the function is called. 2403 2404 Args: 2405 enabled: Whether to enabled device placement logging. 2406 """ 2407 context().log_device_placement = enabled 2408 2409 2410@tf_contextlib.contextmanager 2411def device_policy(policy): 2412 """Context manager for setting device placement policy for current thread.""" 2413 ctx = context() 2414 old_policy = ctx.device_policy 2415 try: 2416 ctx.device_policy = policy 2417 yield 2418 finally: 2419 ctx.device_policy = old_policy 2420 2421 2422def set_execution_mode(mode): 2423 """Sets execution mode for the current thread.""" 2424 context().execution_mode = mode 2425 2426 2427# TODO(fishx): remove this method. 2428@tf_contextlib.contextmanager 2429def execution_mode(mode): 2430 """Context manager for setting execution mode for current thread.""" 2431 if mode is None: 2432 yield 2433 else: 2434 ctx = context() 2435 executor_new = executor.new_executor(mode == ASYNC) 2436 executor_old = ctx.executor 2437 try: 2438 executor_old.wait() 2439 ctx.executor = executor_new 2440 yield 2441 finally: 2442 ctx.executor = executor_old 2443 executor_new.wait() 2444 2445 2446@tf_contextlib.contextmanager 2447def executor_scope(e): 2448 """Context manager for changing executor for current thread. 2449 2450 Args: 2451 e: A Executor to execute eager ops under this scope. Setting it to None will 2452 switch back to use the default executor for the context. 2453 2454 Yields: 2455 Context manager for setting the executor for current thread. 2456 """ 2457 ctx = context() 2458 executor_old = ctx.executor 2459 try: 2460 ctx.executor = e 2461 yield 2462 finally: 2463 ctx.executor = executor_old 2464 2465 2466@tf_export("experimental.function_executor_type") 2467@tf_contextlib.contextmanager 2468def function_executor_type(executor_type): 2469 """Context manager for setting the executor of eager defined functions. 2470 2471 Eager defined functions are functions decorated by tf.contrib.eager.defun. 2472 2473 Args: 2474 executor_type: a string for the name of the executor to be used to execute 2475 functions defined by tf.contrib.eager.defun. 2476 2477 Yields: 2478 Context manager for setting the executor of eager defined functions. 2479 """ 2480 current_options = context().function_call_options 2481 old_options = copy.copy(current_options) 2482 try: 2483 current_options.executor_type = executor_type 2484 yield 2485 finally: 2486 context().function_call_options = old_options 2487 2488 2489def is_async(): 2490 """Returns true if current thread is in async mode.""" 2491 return context().is_async() 2492 2493 2494def num_gpus(): 2495 """Get the number of available GPU devices. 2496 2497 Returns: 2498 The number of available GPU devices. 2499 """ 2500 return context().num_gpus() 2501 2502 2503def enable_run_metadata(): 2504 """Enables tracing of op execution via RunMetadata. 2505 2506 To retrieve the accumulated metadata call context.export_run_metadata() 2507 and to stop tracing call context.disable_run_metadata(). 2508 """ 2509 context().enable_run_metadata() 2510 2511 2512def disable_run_metadata(): 2513 """Disables tracing of op execution via RunMetadata.""" 2514 context().disable_run_metadata() 2515 2516 2517def enable_graph_collection(): 2518 """Enables graph collection of executed functions. 2519 2520 To retrieve the accumulated graphs call context.export_run_metadata() 2521 and to stop collecting graphs call context.disable_graph_collection(). 2522 """ 2523 context().enable_graph_collection() 2524 2525 2526def disable_graph_collection(): 2527 """Disables graph collection of executed functions.""" 2528 context().disable_graph_collection() 2529 2530 2531def export_run_metadata(): 2532 """Returns a RunMetadata proto with accumulated information. 2533 2534 The returned protocol buffer contains information since the most recent call 2535 to either enable_run_metadata or export_run_metadata. 2536 2537 Returns: 2538 A RunMetadata protocol buffer. 2539 """ 2540 return context().export_run_metadata() 2541 2542 2543@contextlib.contextmanager 2544def collect_graphs(optimized=True): 2545 """Collects a flat list of pre- or post-optimization graphs. 2546 2547 The collected graphs include device placements, which can be useful for 2548 testing. 2549 2550 Usage: 2551 2552 ``` 2553 @def_function.function 2554 def f(x): 2555 return x + constant_op.constant(1.) 2556 2557 with context.collect_graphs() as graphs: 2558 with ops.device("CPU:0"): 2559 f(constant_op.constant(1.)) 2560 2561 graph, = graphs # `graph` contains a single GraphDef for inspection 2562 ``` 2563 2564 Args: 2565 optimized: whether to collect optimized graphs or non-optimized graphs 2566 2567 Yields: 2568 A list of GraphDefs, populated when the context manager exits. 2569 """ 2570 ctx = context() 2571 ctx.enable_graph_collection() 2572 try: 2573 graphs = [] 2574 yield graphs 2575 metadata = ctx.export_run_metadata() 2576 finally: 2577 ctx.disable_graph_collection() 2578 for graph in metadata.function_graphs: 2579 if optimized: 2580 graphs.append(graph.post_optimization_graph) 2581 else: 2582 graphs.append(graph.pre_optimization_graph) 2583 2584 2585def get_server_def(): 2586 return context().get_server_def() 2587 2588 2589def set_server_def(server_def): 2590 context().set_server_def(server_def) 2591 2592 2593def update_server_def(server_def): 2594 context().update_server_def(server_def) 2595 2596 2597def check_alive(worker_name): 2598 return context().check_alive(worker_name) 2599 2600 2601@tf_export("experimental.async_scope") 2602@tf_contextlib.contextmanager 2603def async_scope(): 2604 """Context manager for grouping async operations. 2605 2606 Ops/function calls inside the scope can return before finishing the actual 2607 execution. When exiting the async scope, a synchronization barrier will be 2608 automatically added to ensure the completion of all async op and function 2609 execution, potentially raising exceptions if async execution results in 2610 an error state. 2611 2612 Users may write the following code to asynchronously invoke `train_step_fn` 2613 and log the `loss` metric for every `num_steps` steps in a training loop. 2614 `train_step_fn` internally consumes data using `iterator.get_next()`, and may 2615 throw OutOfRangeError when running out of data. In the case: 2616 2617 ``` 2618 try: 2619 with tf.experimental.async_scope(): 2620 for _ in range(num_steps): 2621 # Step function updates the metric `loss` internally 2622 train_step_fn() 2623 except tf.errors.OutOfRangeError: 2624 tf.experimental.async_clear_error() 2625 logging.info('loss = %s', loss.numpy()) 2626 ``` 2627 2628 Yields: 2629 Context manager for grouping async operations. 2630 """ 2631 # TODO(haoyuzhang): replace env var once we have a config method to turn on 2632 # and off async streaming RPC 2633 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" 2634 old_policy = os.environ.get(remote_async_env_var) 2635 try: 2636 os.environ[remote_async_env_var] = str(True) 2637 yield 2638 # Note: sync local and remote executors iff the async block does not raise 2639 # an exception. Triggering sync after an exception may lead to derived 2640 # runtime errors and unexpected exception types. 2641 context().sync_executors() 2642 finally: 2643 if old_policy is None: 2644 del os.environ[remote_async_env_var] 2645 else: 2646 os.environ[remote_async_env_var] = old_policy 2647 2648 2649def async_wait(): 2650 """Sync all async operations and raise any errors during execution. 2651 2652 In async execution mode, an op/function call can return before finishing the 2653 actual execution. Calling this method creates a synchronization barrier for 2654 all async op and function execution. It only returns when all pending nodes 2655 are finished, potentially raising exceptions if async execution results in 2656 an error state. It is a no-op if the context is not initialized. 2657 """ 2658 disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY" 2659 if os.environ.get(disable_async_executor_env_var) == str(True): 2660 return 2661 if context()._context_handle is not None: # pylint: disable=protected-access 2662 context().sync_executors() 2663 2664 2665@tf_export("experimental.async_clear_error") 2666def async_clear_error(): 2667 """Clear pending operations and error statuses in async execution. 2668 2669 In async execution mode, an error in op/function execution can lead to errors 2670 in subsequent ops/functions that are scheduled but not yet executed. Calling 2671 this method clears all pending operations and reset the async execution state. 2672 2673 Example: 2674 2675 ``` 2676 while True: 2677 try: 2678 # Step function updates the metric `loss` internally 2679 train_step_fn() 2680 except tf.errors.OutOfRangeError: 2681 tf.experimental.async_clear_error() 2682 break 2683 logging.info('loss = %s', loss.numpy()) 2684 ``` 2685 """ 2686 context().clear_executor_errors() 2687 2688 2689def add_function(fdef): 2690 """Add a function definition to the context.""" 2691 context().add_function(fdef) 2692 2693 2694def remove_function(name): 2695 """Remove a function from the context.""" 2696 context().remove_function(name) 2697 2698 2699def get_function_def(name): 2700 return context().get_function_def(name) 2701 2702 2703def register_custom_device(device_capsule, device_name, device_info_capsule): 2704 """Calls TFE_RegisterCustomDevice to register a custom device with Python. 2705 2706 Enables using C extensions specifying a custom device from Python. See the 2707 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 2708 details. 2709 2710 Note that custom devices are not currently supported inside `tf.function`s. 2711 2712 Args: 2713 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' 2714 containing a pointer to a TFE_CustomDevice struct. The capsule retains 2715 ownership of the memory. 2716 device_name: A string indicating the name to register the custom device 2717 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may 2718 subsequently be passed to `with tf.device(...):`. 2719 device_info_capsule: A PyCapsule with the name set to 2720 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific 2721 struct with the initial state of the custom device (the void* device_info 2722 argument to TFE_RegisterCustomDevice). This method takes ownership of the 2723 memory and clears the capsule destructor. 2724 """ 2725 context().register_custom_device(device_capsule, device_name, 2726 device_info_capsule) 2727 2728 2729# Not every user creates a Context via context.context() 2730# (for example, enable_eager_execution in python/framework/ops.py), 2731# but they do all import this file. Note that IS_IN_GRAPH_MODE and 2732# in_graph_mode are both parameterless functions. 2733def _tmp_in_graph_mode(): 2734 if context_safe() is None: 2735 # Context not yet initialized. Assume graph mode following the 2736 # default implementation in `is_in_graph_mode`. 2737 return True 2738 return not executing_eagerly() 2739 2740 2741is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode 2742