1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""State management for eager execution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import copy 24import random 25import threading 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python import pywrap_tensorflow 29from tensorflow.python import tf2 30from tensorflow.python.framework import c_api_util 31from tensorflow.python.framework import device as pydev 32from tensorflow.python.util import compat 33from tensorflow.python.util import is_in_graph_mode 34from tensorflow.python.util import tf_contextlib 35from tensorflow.python.util.tf_export import tf_export 36 37GRAPH_MODE = 0 38EAGER_MODE = 1 39 40default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 41 42# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 43# new_device_spec). 44# Note that we do not protect this with a lock and instead rely on python's GIL 45# and the idempotent nature of writes to provide thread safety. 46_device_parsing_cache = {} 47_starting_device_spec = pydev.DeviceSpec.from_string("") 48 49_MAXINT32 = 2**31 - 1 50 51DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT 52DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN 53DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT 54DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 55 pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 56SYNC = 0 57ASYNC = 1 58 59 60class _EagerTensorCache(object): 61 """Simple cache which evicts items based on length in a FIFO manner.""" 62 63 def __init__(self, max_items=256, max_tensor_size=10000): 64 self._data = collections.OrderedDict() 65 self._max_items = max_items 66 self._max_tensor_size = max_tensor_size 67 68 def put(self, key, value): 69 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 70 return 71 72 self._data[key] = value 73 74 if len(self._data) > self._max_items: 75 self._data.popitem(last=False) 76 77 def get(self, key): 78 return self._data.get(key, None) 79 80 def flush(self): 81 self._data = {} 82 83 84class FunctionCallOptions(object): 85 """Options applied at call sites of eager functions. 86 87 Eager functions are functions decorated with tf.contrib.eager.defun. 88 """ 89 90 def __init__(self, executor_type=None, config_proto=None): 91 """Constructor. 92 93 Args: 94 executor_type: (optional) name of the executor to be used to execute the 95 eager function. If None or an empty string, the default Tensorflow 96 executor will be used. 97 config_proto: (optional) a `config_pb2.ConfigProto` proto or 98 a serialized string of that proto. 99 The config used by Grappler when optimizing the function graph. 100 Each concrete function is optimized the first time is called. Changing 101 config_proto after the first call has no effect. 102 If config_proto is None, an empty RewriterConfig will be used. 103 """ 104 self.config_proto_serialized = config_proto 105 self.executor_type = executor_type 106 107 @property 108 def executor_type(self): 109 return self._executor_type 110 111 @executor_type.setter 112 def executor_type(self, executor_type): 113 self._executor_type = executor_type 114 115 @property 116 def config_proto_serialized(self): 117 return self._config_proto_serialized 118 119 @config_proto_serialized.setter 120 def config_proto_serialized(self, config): 121 if isinstance(config, config_pb2.ConfigProto): 122 self._config_proto_serialized = config.SerializeToString() 123 elif isinstance(config, str): 124 self._config_proto_serialized = config 125 elif config is None: 126 self._config_proto_serialized = ( 127 config_pb2.ConfigProto().SerializeToString()) 128 else: 129 raise ValueError("the rewriter config must be either a " 130 "config_pb2.ConfigProto, or a serialized string of that " 131 "proto or None. got: {}".format(type(config))) 132 133 134class _ThreadLocalData(threading.local): 135 """Thread local storage for the eager context.""" 136 137 def __init__(self): 138 super(_ThreadLocalData, self).__init__() 139 self.device_spec = _starting_device_spec 140 self.device_name = "" 141 self.mode = default_execution_mode 142 self.is_eager = default_execution_mode == EAGER_MODE 143 self.scope_name = "" 144 self.summary_writer = None 145 self.summary_recording = None 146 self.summary_recording_distribution_strategy = True 147 self.summary_step = None 148 self.scalar_cache = {} 149 self._ones_rank_cache = None 150 self._zeros_cache = None 151 self.execution_mode = SYNC 152 self.function_call_options = None 153 154 @property 155 def ones_rank_cache(self): 156 if not self._ones_rank_cache: 157 self._ones_rank_cache = _EagerTensorCache() 158 return self._ones_rank_cache 159 160 @property 161 def zeros_cache(self): 162 if not self._zeros_cache: 163 self._zeros_cache = _EagerTensorCache() 164 return self._zeros_cache 165 166 167ContextSwitch = collections.namedtuple( 168 "ContextSwitch", ["is_building_function", "enter_context_fn", 169 "device_stack"]) 170 171 172# `_ContextSwitchStack` is a `threading.local` to match the semantics of 173# ``DefaultGraphStack`, which is also a `threading.local`. 174class _ContextSwitchStack(threading.local): 175 """A thread-local stack of context switches.""" 176 177 def __init__(self, eager): 178 super(_ContextSwitchStack, self).__init__() 179 self.stack = [] 180 if eager: 181 # Initialize the stack with a pointer to enter the eager context; this 182 # ensures that the fact that eager execution was enabled is propagated 183 # across threads, since (1) `enable_eager_execution` modifies a 184 # process-level flag (`default_execution_mode`) and (2) `__init__` is 185 # called each time a threading.local object is used in a separate thread. 186 self.push(is_building_function=False, enter_context_fn=eager_mode, 187 device_stack=None) 188 189 def push(self, is_building_function, enter_context_fn, device_stack): 190 """Push metadata about a context switch onto the stack. 191 192 A context switch can take any one of the two forms: installing a graph as 193 the default graph, or entering the eager context. For each context switch, 194 we record whether or not the entered context is building a function. 195 196 Args: 197 is_building_function: (bool.) Whether the context is building a function. 198 enter_context_fn: (function.) A callable that executes the context switch. 199 For example, `graph.as_default` or `eager_mode`. 200 device_stack: If applicable, the device function stack for this 201 graph. When breaking out of graphs in init_scope, the innermost nonempty 202 device stack is used. Eager contexts put `None` here and the value is 203 never used. 204 """ 205 206 self.stack.append( 207 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 208 209 def pop(self): 210 """Pop the stack.""" 211 212 self.stack.pop() 213 214 215# TODO(agarwal): rename to EagerContext / EagerRuntime ? 216# TODO(agarwal): consider keeping the corresponding Graph here. 217class Context(object): 218 """Environment in which eager operations execute.""" 219 220 # TODO(agarwal): create and link in some documentation for `execution_mode`. 221 # pylint: disable=redefined-outer-name 222 def __init__(self, 223 config=None, 224 device_policy=None, 225 execution_mode=None, 226 server_def=None): 227 """Creates a new Context. 228 229 Args: 230 config: (Optional.) A `ConfigProto` protocol buffer with configuration 231 options for the Context. Note that a lot of these options may be 232 currently unimplemented or irrelevant when eager execution is enabled. 233 device_policy: (Optional.) What policy to use when trying to run an 234 operation on a device with inputs which are not on that device. 235 When set to None, an appropriate value will be picked automatically. 236 The value picked may change between TensorFlow releases. 237 238 Defaults to DEVICE_PLACEMENT_SILENT. 239 Valid values: 240 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is 241 not correct. 242 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the 243 right device but raises a warning. 244 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might 245 hide performance problems. 246 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 247 raising errors on the other ones. 248 execution_mode: (Optional.) Policy controlling how operations dispatched 249 are actually executed. When set to None, an appropriate value will be 250 picked automatically. The value picked may change between TensorFlow 251 releases. 252 Valid values: 253 - SYNC: executes each operation synchronously. 254 - ASYNC: executes each operation asynchronously. These 255 operations may return "non-ready" handles. 256 server_def: (Optional.) A tensorflow::ServerDef proto. 257 Enables execution on remote devices. GrpcServers need to be started by 258 creating an identical server_def to this, and setting the appropriate 259 task_indexes, so that the servers can communicate. It will then be 260 possible to execute operations on remote devices. 261 262 Raises: 263 ValueError: If execution_mode is not valid. 264 """ 265 if config is None: 266 config = config_pb2.ConfigProto( 267 allow_soft_placement=True, 268 log_device_placement=False, 269 ) 270 self._config = config 271 self._thread_local_data = _ThreadLocalData() 272 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 273 self._context_handle = None 274 self._context_devices = None 275 self._post_execution_callbacks = [] 276 self._seed = None 277 self._initialize_lock = threading.Lock() 278 if device_policy is None: 279 device_policy = DEVICE_PLACEMENT_SILENT 280 self._device_policy = device_policy 281 if execution_mode not in (None, SYNC, ASYNC): 282 raise ValueError( 283 "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) 284 if execution_mode is None: 285 execution_mode = SYNC 286 self._execution_mode = execution_mode 287 self._server_def = server_def 288 self._collective_ops_server_def = None 289 290 # pylint: enable=redefined-outer-name 291 292 def _set_global_seed(self, seed): 293 """Set a global eager mode seed for random ops.""" 294 self._seed = seed 295 self._rng = random.Random(self._seed) 296 # Also clear the kernel cache, to reset any existing seeds 297 if self._context_handle is not None: 298 pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle) 299 300 def _internal_operation_seed(self): 301 """Returns a fake operation seed. 302 303 In eager mode, user shouldn't set or depend on operation seed. 304 Here, we generate a random seed based on global seed to make 305 operation's randomness different and depend on the global seed. 306 307 Returns: 308 A fake operation seed based on global seed. 309 """ 310 return self._rng.randint(0, _MAXINT32) 311 312 def _initialize_devices(self): 313 """Helper to initialize devices.""" 314 # Store list of devices 315 self._context_devices = [] 316 device_list = pywrap_tensorflow.TFE_ContextListDevices( 317 self._context_handle) 318 try: 319 self._num_gpus = 0 320 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): 321 dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) 322 self._context_devices.append(pydev.canonical_name(dev_name)) 323 dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) 324 if dev_type == "GPU": 325 self._num_gpus += 1 326 327 finally: 328 pywrap_tensorflow.TF_DeleteDeviceList(device_list) 329 330 def _initialize_handle_and_devices(self): 331 """Initialize handle and devices.""" 332 with self._initialize_lock: 333 if self._context_handle is not None: 334 return 335 assert self._context_devices is None 336 opts = pywrap_tensorflow.TFE_NewContextOptions() 337 try: 338 if self._config is not None: 339 config_str = self._config.SerializeToString() 340 pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str) 341 if self._device_policy is not None: 342 pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( 343 opts, self._device_policy) 344 if self._execution_mode == ASYNC: 345 pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) 346 self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) 347 finally: 348 pywrap_tensorflow.TFE_DeleteContextOptions(opts) 349 assert not (self._server_def and self._collective_ops_server_def), ( 350 "Cannot enable remote execution as well as collective ops at the " 351 "moment. If this is important to you, please file an issue.") 352 if self._server_def is not None: 353 server_def_str = self._server_def.SerializeToString() 354 pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600, 355 server_def_str) 356 elif self._collective_ops_server_def is not None: 357 server_def_str = self._collective_ops_server_def.SerializeToString() 358 pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle, 359 server_def_str) 360 361 self._initialize_devices() 362 363 def _clear_caches(self): 364 self.scalar_cache().clear() 365 self.ones_rank_cache().flush() 366 self.zeros_cache().flush() 367 368 def set_server_def(self, server_def, keep_alive_secs=600): 369 """Allow setting a server_def on the context. 370 371 When a server def is replaced, it effectively clears a bunch of caches 372 within the context. If you attempt to use a tensor object that was pointing 373 to a tensor on the remote device, it will raise an error. 374 375 Args: 376 server_def: A tensorflow::ServerDef proto. 377 Enables execution on remote devices. 378 keep_alive_secs: Num. seconds after which the remote end will hang up. 379 As long as the client is still alive, the server state for the context 380 will be kept alive. If the client is killed (or there is some failure), 381 the server will clean up its context keep_alive_secs after the final RPC 382 it receives. 383 384 Raises: 385 ValueError: if server_def is None. 386 """ 387 if not server_def: 388 raise ValueError("server_def is None.") 389 if not self._context_handle: 390 self._server_def = server_def 391 else: 392 server_def_str = server_def.SerializeToString() 393 pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 394 keep_alive_secs, server_def_str) 395 396 # Clear all the caches in case there are remote tensors in them. 397 self._clear_caches() 398 399 self._initialize_devices() 400 401 def enable_collective_ops(self, server_def): 402 """Enable collective ops with an appropriate server_def. 403 404 If previously enabled, this cannot be re-enabled. 405 406 Args: 407 server_def: A tensorflow::ServerDef proto. Enables execution on remote 408 devices. 409 410 Raises: 411 ValueError: if server_def is None. 412 """ 413 if not server_def: 414 raise ValueError("server_def is None.") 415 if not self._context_handle: 416 self._collective_ops_server_def = server_def 417 else: 418 server_def_str = server_def.SerializeToString() 419 pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle, 420 server_def_str) 421 422 self._clear_caches() 423 self._initialize_devices() 424 425 @property 426 def _handle(self): 427 ctx = self._context_handle 428 if ctx is None: 429 self._initialize_handle_and_devices() 430 return self._context_handle 431 else: 432 return ctx 433 434 @property 435 def _devices(self): 436 devices = self._context_devices 437 if devices is None: 438 self._initialize_handle_and_devices() 439 return self._context_devices 440 else: 441 return devices 442 443 def __str__(self): 444 if self._context_handle is None: 445 return "Eager TensorFlow Context. Devices currently uninitialized." 446 else: 447 devices = self._devices 448 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 449 for i, d in enumerate(devices): 450 lines.append(" Device %d: %s" % (i, d)) 451 return "\n".join(lines) 452 453 @tf_contextlib.contextmanager 454 def _mode(self, mode): 455 """A context manager to allow setting the mode to EAGER/GRAPH.""" 456 ctx = self._thread_local_data 457 old_mode = ctx.mode 458 old_is_eager = ctx.is_eager 459 ctx.mode = mode 460 ctx.is_eager = mode == EAGER_MODE 461 if mode == EAGER_MODE: 462 # Entering graph mode does not provide us with sufficient information to 463 # record a context switch; graph-based context switches are only logged 464 # when a graph is registered as the default graph. 465 self.context_switches.push(False, eager_mode, None) 466 try: 467 yield 468 finally: 469 ctx.is_eager = old_is_eager 470 ctx.mode = old_mode 471 if mode == EAGER_MODE: 472 self.context_switches.pop() 473 474 def executing_eagerly(self): 475 """Returns True if current thread has eager executing enabled.""" 476 return self._thread_local_data.is_eager 477 478 def scalar_cache(self): 479 """Per-device cache for scalars.""" 480 return self._thread_local_data.scalar_cache 481 482 def ones_rank_cache(self): 483 """Per-device cache for scalars.""" 484 return self._thread_local_data.ones_rank_cache 485 486 def zeros_cache(self): 487 """Per-device cache for scalars.""" 488 return self._thread_local_data.zeros_cache 489 490 @property 491 def scope_name(self): 492 """Returns scope name for the current thread.""" 493 return self._thread_local_data.scope_name 494 495 @scope_name.setter 496 def scope_name(self, s): 497 """Sets scope name for the current thread.""" 498 self._thread_local_data.scope_name = s 499 500 @property 501 def summary_writer(self): 502 """Returns default summary writer for the current thread.""" 503 return self._thread_local_data.summary_writer 504 505 @summary_writer.setter 506 def summary_writer(self, writer): 507 """Sets default summary writer for the current thread.""" 508 self._thread_local_data.summary_writer = writer 509 510 @property 511 def summary_recording(self): 512 """Returns summary recording condition.""" 513 return self._thread_local_data.summary_recording 514 515 @summary_recording.setter 516 def summary_recording(self, condition): 517 """Sets summary recording condition.""" 518 self._thread_local_data.summary_recording = condition 519 520 @property 521 def summary_recording_distribution_strategy(self): 522 """Returns summary recording condition for distribution strategy.""" 523 return self._thread_local_data.summary_recording_distribution_strategy 524 525 @summary_recording_distribution_strategy.setter 526 def summary_recording_distribution_strategy(self, condition): 527 """Sets summary recording condition for distribution strategy.""" 528 self._thread_local_data.summary_recording_distribution_strategy = condition 529 530 @property 531 def summary_step(self): 532 """Returns summary step variable.""" 533 return self._thread_local_data.summary_step 534 535 @summary_step.setter 536 def summary_step(self, step): 537 """Sets summary step variable.""" 538 self._thread_local_data.summary_step = step 539 540 @property 541 def device_name(self): 542 """Returns the device name for the current thread.""" 543 return self._thread_local_data.device_name 544 545 @property 546 def device_spec(self): 547 """Returns the device spec for the current thread.""" 548 return self._thread_local_data.device_spec 549 550 @tf_contextlib.contextmanager 551 def device(self, name): 552 """Context-manager to force placement of operations and Tensors on a device. 553 554 Args: 555 name: Name of the device or None to get default placement. 556 557 Yields: 558 Nothing. 559 560 Raises: 561 ValueError: If name is not a string or is an invalid device name. 562 """ 563 eager_context = self._thread_local_data 564 old_device_name = eager_context.device_name 565 old_device_spec = eager_context.device_spec 566 cache_key = (old_device_name, name) 567 try: 568 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 569 except TypeError: 570 # Error while trying to compute the cache key. 571 raise ValueError("Expecting a string device name. Got %s(%s)" % 572 (type(name), name)) 573 except KeyError: 574 # Handle a cache miss. 575 if name is not None: 576 if not isinstance(name, str): 577 raise ValueError("Expecting a string device name. Got %s(%s)" % 578 (type(name), name)) 579 device_spec = pydev.DeviceSpec.from_string(name) 580 if old_device_name: 581 new_device_spec = copy.copy(old_device_spec) 582 else: 583 self._initialize_handle_and_devices() 584 new_device_spec = pydev.DeviceSpec.from_string( 585 self._context_devices[0]) 586 new_device_spec.merge_from(device_spec) 587 else: 588 new_device_spec = pydev.DeviceSpec.from_string("") 589 new_device_name = new_device_spec.to_string() 590 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 591 592 try: 593 eager_context.device_name = new_device_name 594 eager_context.device_spec = new_device_spec 595 yield 596 finally: 597 eager_context.device_name = old_device_name 598 eager_context.device_spec = old_device_spec 599 600 def devices(self): 601 """List of the names of devices available to execute operations.""" 602 return self._devices 603 604 @property 605 def execution_mode(self): 606 """Gets execution mode for current thread.""" 607 # Only get the execution mode from the context if it has already been 608 # initialized 609 if self._context_handle is None: 610 return self._execution_mode 611 612 mode = self._thread_local_data.execution_mode 613 if mode is None: 614 mode = self._execution_mode 615 return mode 616 617 @execution_mode.setter 618 def execution_mode(self, mode): 619 """Sets execution mode for current thread.""" 620 if mode not in (None, SYNC, ASYNC): 621 raise ValueError( 622 "Execution mode should be None/SYNC/ASYNC. Got %s" % mode) 623 if mode is None: 624 mode = SYNC 625 626 if self._thread_local_data.execution_mode != mode: 627 self._thread_local_data.execution_mode = mode 628 629 # Only set the execution mode if the context has already been initialized 630 if self._context_handle is not None: 631 pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._context_handle, 632 mode == ASYNC) 633 else: 634 self._execution_mode = mode 635 636 @property 637 def function_call_options(self): 638 """Returns function call options for current thread. 639 640 Note that the returned object is still referenced by the eager context. 641 642 Returns: the FunctionCallOptions for current thread. 643 """ 644 if self._thread_local_data.function_call_options is None: 645 base_config = config_pb2.ConfigProto() 646 base_config.CopyFrom(self._config) 647 self._thread_local_data.function_call_options = FunctionCallOptions( 648 config_proto=base_config) 649 650 return self._thread_local_data.function_call_options 651 652 @function_call_options.setter 653 def function_call_options(self, options): 654 """Returns function call options for current thread.""" 655 self._thread_local_data.function_call_options = options 656 657 def async_wait(self): 658 """Waits for ops dispatched in ASYNC mode to finish.""" 659 pywrap_tensorflow.TFE_ContextAsyncWait(self._handle) 660 661 def async_clear_error(self): 662 """Clears errors raised during ASYNC execution.""" 663 pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle) 664 665 def num_gpus(self): 666 """The number of GPUs available to execute operations.""" 667 self._initialize_handle_and_devices() 668 return self._num_gpus 669 670 def add_function(self, fn): 671 """Add a function definition to the context. 672 673 Once added, the function (identified by its name) can be executed like any 674 other operation. 675 676 Args: 677 fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 678 """ 679 pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn) 680 681 def add_function_def(self, fdef): 682 """Add a function definition to the context. 683 684 Once added, the function (identified by its name) can be executed like any 685 other operation. 686 687 Args: 688 fdef: A FunctionDef protocol buffer message. 689 """ 690 fdef_string = fdef.SerializeToString() 691 pywrap_tensorflow.TFE_ContextAddFunctionDef( 692 self._handle, fdef_string, len(fdef_string)) 693 694 def has_function(self, name): 695 """Check if a function `name` is registered.""" 696 return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name)) 697 698 def add_post_execution_callback(self, callback): 699 """Add a post-execution callback to the context. 700 701 A post-execution callback is invoked immediately after an eager operation or 702 function has finished execution, providing access to the op's type, name 703 input and output tensors. Multiple execution callbacks can be added, in 704 which case the callbacks will be invoked in the order in which they are 705 added. 706 707 Args: 708 callback: a callable of the signature 709 `f(op_type, op_name, attrs, inputs, outputs)`. 710 `op_type` is the type of the operation that was just executed (e.g., 711 `MatMul`). 712 `op_name` is the name of the operation that has was just executed. This 713 name is set by the client who created the operation and can be `None` if 714 it is unset. 715 `attrs` contains the attributes of the operation as a `tuple` of 716 alternating attribute names and attribute values. 717 `inputs` is the `list` of input `Tensor`(s) to the op. 718 `outputs` is the `list` of output `Tensor`(s) from the op. 719 Return value(s) from the callback are ignored. 720 """ 721 # TODO(cais): (b/64674139) Allow access to function-internal operations. 722 self._post_execution_callbacks.append(callback) 723 724 def clear_post_execution_callbacks(self): 725 """Clear all post-execution callbacks added to the context.""" 726 del self._post_execution_callbacks[:] 727 728 @property 729 def post_execution_callbacks(self): 730 """Get the list of post-execution callbacks added to the context.""" 731 return self._post_execution_callbacks 732 733 @property 734 def gpu_per_process_memory_fraction(self): 735 return self._config.gpu_options.per_process_gpu_memory_fraction 736 737 @gpu_per_process_memory_fraction.setter 738 def gpu_per_process_memory_fraction(self, fraction): 739 if self._context_handle is not None: 740 raise RuntimeError( 741 "GPU options must be set at program startup") 742 743 self._config.gpu_options.per_process_gpu_memory_fraction = fraction 744 745 @property 746 def gpu_per_process_memory_growth(self): 747 return self._config.gpu_options.allow_growth 748 749 @gpu_per_process_memory_growth.setter 750 def gpu_per_process_memory_growth(self, enabled): 751 if self._context_handle is not None: 752 raise RuntimeError( 753 "GPU options must be set at program startup") 754 755 self._config.gpu_options.allow_growth = enabled 756 757 @property 758 def intra_op_parallelism_threads(self): 759 return self._config.intra_op_parallelism_threads 760 761 @intra_op_parallelism_threads.setter 762 def intra_op_parallelism_threads(self, num_threads): 763 if self._context_handle is not None: 764 raise RuntimeError( 765 "Intra op parallelism must be set at program startup") 766 767 self._config.intra_op_parallelism_threads = num_threads 768 769 @property 770 def inter_op_parallelism_threads(self): 771 return self._config.inter_op_parallelism_threads 772 773 @inter_op_parallelism_threads.setter 774 def inter_op_parallelism_threads(self, num_threads): 775 if self._context_handle is not None: 776 raise RuntimeError( 777 "Inter op parallelism must be set at program startup") 778 779 self._config.inter_op_parallelism_threads = num_threads 780 781 @property 782 def soft_device_placement(self): 783 return self._config.allow_soft_placement 784 785 @soft_device_placement.setter 786 def soft_device_placement(self, enabled): 787 self._config.allow_soft_placement = enabled 788 789 self._thread_local_data.function_call_options = None 790 791 @property 792 def log_device_placement(self): 793 return self._config.log_device_placement 794 795 @log_device_placement.setter 796 def log_device_placement(self, enabled): 797 if self._context_handle is not None: 798 raise RuntimeError( 799 "Device placement logging must be set at program startup") 800 801 self._config.log_device_placement = enabled 802 803 @property 804 def device_policy(self): 805 # Only get the policy from the context if it has already been initialized 806 if self._context_handle is not None: 807 return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle) 808 809 return self._device_policy 810 811 @device_policy.setter 812 def device_policy(self, policy): 813 if policy is None: 814 policy = DEVICE_PLACEMENT_SILENT 815 816 if self._device_policy != policy: 817 self._device_policy = policy 818 819 # Only set the policy if the context has already been initialized 820 if self._context_handle is not None: 821 pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy( 822 self._handle, self._device_policy) 823 824 def enable_run_metadata(self): 825 """Enables tracing of op execution via RunMetadata. 826 827 To retrieve the accumulated metadata call context.export_run_metadata() 828 and to stop tracing call context.disable_run_metadata(). 829 """ 830 pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle) 831 832 def disable_run_metadata(self): 833 """Disables tracing of op execution via RunMetadata.""" 834 if not self._context_handle: 835 return 836 pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) 837 838 def enable_graph_collection(self): 839 """Enables graph collection of executed functions. 840 841 To retrieve the accumulated graphs call context.export_run_metadata() 842 and to stop collecting graphs call context.disable_graph_collection(). 843 """ 844 pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle) 845 846 def disable_graph_collection(self): 847 """Disables graph collections of executed functions.""" 848 if not self._context_handle: 849 return 850 pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle) 851 852 def export_run_metadata(self): 853 """Returns a RunMetadata proto with accumulated information. 854 855 The returned protocol buffer contains information since the most recent call 856 to either enable_run_metadata or export_run_metadata. 857 858 Returns: 859 A RunMetadata protocol buffer. Or None if not enabled. 860 """ 861 if not self._context_handle: 862 return None 863 with c_api_util.tf_buffer() as buffer_: 864 pywrap_tensorflow.TFE_ContextExportRunMetadata( 865 self._context_handle, buffer_) 866 proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) 867 run_metadata = config_pb2.RunMetadata() 868 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 869 return run_metadata 870 871 @property 872 def context_switches(self): 873 """Returns a stack of context switches.""" 874 return self._context_switches 875 876 def start_step(self): 877 pywrap_tensorflow.TFE_ContextStartStep(self._handle) 878 879 def end_step(self): 880 pywrap_tensorflow.TFE_ContextEndStep(self._handle) 881 882_context = None 883_context_lock = threading.Lock() 884 885 886def _initialize_context(): 887 global _context 888 with _context_lock: 889 if _context is None: 890 _context = Context() 891 892 893def context(): 894 """Returns a singleton context object.""" 895 if _context is None: 896 _initialize_context() 897 return _context 898 899 900def context_safe(): 901 """Returns current context (or None if one hasn't been initialized).""" 902 return _context 903 904 905def set_global_seed(seed): 906 """Sets the eager mode seed.""" 907 context()._set_global_seed(seed) # pylint: disable=protected-access 908 909 910def global_seed(): 911 """Returns the eager mode seed.""" 912 return context()._seed # pylint: disable=protected-access 913 914 915def internal_operation_seed(): 916 """Returns the operation seed generated based on global seed.""" 917 return context()._internal_operation_seed() # pylint: disable=protected-access 918 919 920@tf_export("executing_eagerly") 921def executing_eagerly(): 922 """Returns True if the current thread has eager execution enabled. 923 924 Eager execution is typically enabled via `tf.enable_eager_execution`, 925 but may also be enabled within the context of a Python function via 926 tf.contrib.eager.py_func. 927 """ 928 return context().executing_eagerly() 929 930 931def in_eager_mode(): 932 """Use executing_eagerly() instead. This function will be removed.""" 933 return executing_eagerly() 934 935 936def shared_name(name=None): 937 """Returns the anonymous shared name GUID if no shared name is specified. 938 939 In eager mode we need to use a unique shared name to avoid spurious sharing 940 issues. The runtime generates a unique name on our behalf when the reserved 941 GUID is used as a shared name. 942 943 Args: 944 name: Optional shared name 945 946 Returns: 947 Eager compatible shared name. 948 """ 949 if name or not executing_eagerly(): 950 return name 951 952 # Ensure a unique name when eager execution is enabled to avoid spurious 953 # sharing issues. 954 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 955 956 957def graph_mode(): 958 """Context-manager to disable eager execution for the current thread.""" 959 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 960 961 962def eager_mode(): 963 """Context-manager to enable eager execution for the current thread.""" 964 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 965 966 967# TODO(agarwal): get rid of this and use ops.name_scope instead. 968@contextlib.contextmanager 969def namescope(name): 970 """ContextManager for creating hierarchical name scopes.""" 971 ctx = context() 972 old_name = ctx.scope_name 973 ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name 974 try: 975 yield 976 finally: 977 ctx.scope_name = old_name 978 979 980def scope_name(): 981 """Name of the current scope.""" 982 return context().scope_name 983 984 985def device(name): 986 """Context-manager to force placement of operations and Tensors on a device. 987 988 Example: 989 ```python 990 with tfe.device('gpu:0'): 991 with tfe.device('cpu:0'): 992 shape = tf.constant([], dtype=tf.int32) 993 x = tf.truncated_normal(shape, tf.float32) 994 ``` 995 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 996 operation runs on GPU 0. 997 998 Args: 999 name: Name of the device (see context().devices()), or None to 1000 perform automatic placement. 1001 1002 Returns: 1003 Context manager for setting the device. 1004 """ 1005 return context().device(name) 1006 1007 1008@tf_export("config.experimental_list_devices") 1009def list_devices(): 1010 """List the names of the available devices. 1011 1012 Returns: 1013 Names of the available devices, as a `list`. 1014 """ 1015 return context().devices() 1016 1017 1018@tf_export("debugging.get_log_device_placement") 1019def get_log_device_placement(): 1020 """Get if device placements are logged. 1021 1022 Returns: 1023 If device placements are logged. 1024 """ 1025 return context().log_device_placement 1026 1027 1028@tf_export("debugging.set_log_device_placement") 1029def set_log_device_placement(enabled): 1030 """Set if device placements should be logged. 1031 1032 Args: 1033 enabled: Whether to enabled device placement logging. 1034 """ 1035 context().log_device_placement = enabled 1036 1037 1038@tf_contextlib.contextmanager 1039def device_policy(policy): 1040 """Context manager for setting device placement policy for current thread.""" 1041 ctx = context() 1042 old_policy = ctx.device_policy 1043 try: 1044 ctx.device_policy = policy 1045 yield 1046 finally: 1047 ctx.device_policy = old_policy 1048 1049 1050def set_execution_mode(mode): 1051 """Sets execution mode for the current thread.""" 1052 context().execution_mode = mode 1053 1054 1055@tf_contextlib.contextmanager 1056def execution_mode(mode): 1057 """Context manager for setting execution mode for current thread.""" 1058 ctx = context() 1059 old_mode = ctx.execution_mode 1060 try: 1061 ctx.execution_mode = mode 1062 yield 1063 finally: 1064 ctx.execution_mode = old_mode 1065 1066 1067@tf_export("experimental.function_executor_type") 1068@tf_contextlib.contextmanager 1069def function_executor_type(executor_type): 1070 """Context manager for setting the executor of eager defined functions. 1071 1072 Eager defined functions are functions decorated by tf.contrib.eager.defun. 1073 1074 Args: 1075 executor_type: a string for the name of the executor to be used to execute 1076 functions defined by tf.contrib.eager.defun. 1077 1078 Yields: 1079 Context manager for setting the executor of eager defined functions. 1080 """ 1081 current_options = context().function_call_options 1082 old_options = copy.copy(current_options) 1083 try: 1084 current_options.executor_type = executor_type 1085 yield 1086 finally: 1087 context().function_call_options = old_options 1088 1089 1090def async_wait(): 1091 """Waits for ops dispatched in ASYNC mode to finish.""" 1092 return context().async_wait() 1093 1094 1095def async_clear_error(): 1096 """Clears errors raised during ASYNC execution mode.""" 1097 return context().async_clear_error() 1098 1099 1100def num_gpus(): 1101 """Get the number of available GPU devices. 1102 1103 Returns: 1104 The number of available GPU devices. 1105 """ 1106 return context().num_gpus() 1107 1108 1109def enable_run_metadata(): 1110 """Enables tracing of op execution via RunMetadata. 1111 1112 To retrieve the accumulated metadata call context.export_run_metadata() 1113 and to stop tracing call context.disable_run_metadata(). 1114 """ 1115 context().enable_run_metadata() 1116 1117 1118def disable_run_metadata(): 1119 """Disables tracing of op execution via RunMetadata.""" 1120 context().disable_run_metadata() 1121 1122 1123def enable_graph_collection(): 1124 """Enables tracing of op execution via RunMetadata. 1125 1126 To retrieve the accumulated metadata call context.export_run_metadata() 1127 and to stop tracing call context.disable_run_metadata(). 1128 """ 1129 context().enable_graph_collection() 1130 1131 1132def disable_graph_collection(): 1133 """Disables tracing of op execution via RunMetadata.""" 1134 context().disable_graph_collection() 1135 1136 1137def export_run_metadata(): 1138 """Returns a RunMetadata proto with accumulated information. 1139 1140 The returned protocol buffer contains information since the most recent call 1141 to either enable_run_metadata or export_run_metadata. 1142 1143 Returns: 1144 A RunMetadata protocol buffer. 1145 """ 1146 return context().export_run_metadata() 1147 1148 1149def set_server_def(server_def): 1150 context().set_server_def(server_def) 1151 1152 1153def add_function(fdef): 1154 """Add a function definition to the context.""" 1155 context().add_function(fdef) 1156 1157 1158# Not every user creates a Context via context.context() 1159# (for example, enable_eager_execution in python/framework/ops.py), 1160# but they do all import this file. Note that IS_IN_GRAPH_MODE and 1161# in_graph_mode are both parameterless functions. 1162def _tmp_in_graph_mode(): 1163 if context_safe() is None: 1164 # Context not yet initialized. Assume graph mode following the 1165 # default implementation in `is_in_graph_mode`. 1166 return True 1167 return not executing_eagerly() 1168 1169 1170is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode 1171