• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import gc
25import os
26import random
27import threading
28
29from absl import logging
30import numpy as np
31import six
32
33from tensorflow.core.framework import function_pb2
34from tensorflow.core.protobuf import config_pb2
35from tensorflow.core.protobuf import rewriter_config_pb2
36from tensorflow.python import pywrap_tfe
37from tensorflow.python import tf2
38from tensorflow.python.client import pywrap_tf_session
39from tensorflow.python.eager import executor
40from tensorflow.python.eager import monitoring
41from tensorflow.python.framework import c_api_util
42from tensorflow.python.framework import device as pydev
43from tensorflow.python.framework import tfrt_utils
44from tensorflow.python.util import compat
45from tensorflow.python.util import is_in_graph_mode
46from tensorflow.python.util import tf_contextlib
47from tensorflow.python.util.deprecation import deprecated
48from tensorflow.python.util.tf_export import tf_export
49
50GRAPH_MODE = 0
51EAGER_MODE = 1
52
53default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
54
55# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
56# new_device_spec).
57# Note that we do not protect this with a lock and instead rely on python's GIL
58# and the idempotent nature of writes to provide thread safety.
59_device_parsing_cache = {}
60_starting_device_spec = pydev.DeviceSpec.from_string("")
61
62_MAXINT32 = 2**31 - 1
63
64DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
65DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
66DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
67DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
68    pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
69
70SYNC = 0
71ASYNC = 1
72
73_KEEP_ALIVE_SECS = 600
74
75_python_eager_context_create_counter = monitoring.Counter(
76    "/tensorflow/api/python/eager_context_create_counter",
77    "Counter for number of eager contexts created in Python.")
78
79# Re-exporting through context.
80is_tfrt_enabled = tfrt_utils.enabled
81
82_RUN_EAGER_OP_AS_FUNCTION_ENABLED = False
83
84
85def enable_run_eager_op_as_function():
86  """Execute elementary eager ops (non-function) wrapped in a call op.
87
88  This should be functionally equivalent to running the eager op's kernel
89  directly (the default) but reduces the number of codepaths for executing
90  TF2 programs in the runtime, thereby improving consistency (in terms of
91  optimizations and rewrites for instance) and maintainability.
92  """
93  # Must be called before context is actually built.
94  global _RUN_EAGER_OP_AS_FUNCTION_ENABLED
95  _RUN_EAGER_OP_AS_FUNCTION_ENABLED = True
96
97
98def run_eager_op_as_function_enabled():
99  return _RUN_EAGER_OP_AS_FUNCTION_ENABLED
100
101
102# Expose it as internally public APIs for Keras use cases in b/171080602.
103tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
104
105
106class _EagerTensorCache(object):
107  """Simple cache which evicts items based on length in a FIFO manner."""
108
109  __slots__ = ["_data", "_max_items", "_max_tensor_size"]
110
111  def __init__(self, max_items=256, max_tensor_size=10000):
112    self._data = collections.OrderedDict()
113    self._max_items = max_items
114    self._max_tensor_size = max_tensor_size
115
116  def put(self, key, value):
117    if value._num_elements() > self._max_tensor_size:  # pylint: disable=protected-access
118      return
119
120    self._data[key] = value
121
122    if len(self._data) > self._max_items:
123      self._data.popitem(last=False)
124
125  def get(self, key):
126    return self._data.get(key, None)
127
128  def flush(self):
129    self._data.clear()
130
131
132class FunctionCallOptions(object):
133  """Options applied at call sites of eager functions.
134
135  Eager functions are functions decorated with tf.contrib.eager.defun.
136  """
137
138  __slots__ = ["_config_proto_serialized", "_executor_type"]
139
140  def __init__(self, executor_type=None, config_proto=None):
141    """Constructor.
142
143    Args:
144      executor_type: (optional) name of the executor to be used to execute the
145        eager function. If None or an empty string, the default Tensorflow
146        executor will be used.
147      config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized
148        string of that proto. The config used by Grappler when optimizing the
149        function graph. Each concrete function is optimized the first time is
150        called. Changing config_proto after the first call has no effect. If
151        config_proto is None, an empty RewriterConfig will be used.
152    """
153    self.config_proto_serialized = config_proto
154    self.executor_type = executor_type
155
156  @property
157  def executor_type(self):
158    return self._executor_type
159
160  @executor_type.setter
161  def executor_type(self, executor_type):
162    self._executor_type = executor_type
163
164  @property
165  def config_proto_serialized(self):
166    return self._config_proto_serialized
167
168  @config_proto_serialized.setter
169  def config_proto_serialized(self, config):
170    if isinstance(config, config_pb2.ConfigProto):
171      self._config_proto_serialized = config.SerializeToString(
172          deterministic=True)
173    elif isinstance(config, str):
174      self._config_proto_serialized = config
175    elif config is None:
176      self._config_proto_serialized = (
177          config_pb2.ConfigProto().SerializeToString())
178    else:
179      raise ValueError("the rewriter config must be either a "
180                       "config_pb2.ConfigProto, or a serialized string of that "
181                       "proto or None. got: {}".format(type(config)))
182
183
184# Map from context_id (an int) to _TensorCaches.
185# Dicts are thread safe in CPython.
186# TODO(iga): Remove this once TensorCaches are moved to C++.
187_tensor_caches_map = {}
188
189
190class _TensorCaches(threading.local):
191  """Thread local tensor caches."""
192
193  __slots__ = ["_ones_rank_cache", "_zeros_cache"]
194
195  def __init__(self):
196    super(_TensorCaches, self).__init__()
197    self._ones_rank_cache = None
198    self._zeros_cache = None
199
200  @property
201  def ones_rank_cache(self):
202    if not self._ones_rank_cache:
203      self._ones_rank_cache = _EagerTensorCache()
204    return self._ones_rank_cache
205
206  @property
207  def zeros_cache(self):
208    if not self._zeros_cache:
209      self._zeros_cache = _EagerTensorCache()
210    return self._zeros_cache
211
212
213ContextSwitch = collections.namedtuple(
214    "ContextSwitch",
215    ["is_building_function", "enter_context_fn", "device_stack"])
216
217
218# `_ContextSwitchStack` is a `threading.local` to match the semantics of
219# ``DefaultGraphStack`, which is also a `threading.local`.
220class _ContextSwitchStack(threading.local):
221  """A thread-local stack of context switches."""
222
223  def __init__(self, eager):
224    super(_ContextSwitchStack, self).__init__()
225    self.stack = []
226    if eager:
227      # Initialize the stack with a pointer to enter the eager context; this
228      # ensures that the fact that eager execution was enabled is propagated
229      # across threads, since (1) `enable_eager_execution` modifies a
230      # process-level flag (`default_execution_mode`) and (2) `__init__` is
231      # called each time a threading.local object is used in a separate thread.
232      self.push(
233          is_building_function=False,
234          enter_context_fn=eager_mode,
235          device_stack=None)
236
237  def push(self, is_building_function, enter_context_fn, device_stack):
238    """Push metadata about a context switch onto the stack.
239
240    A context switch can take any one of the two forms: installing a graph as
241    the default graph, or entering the eager context. For each context switch,
242    we record whether or not the entered context is building a function.
243
244    Args:
245      is_building_function: (bool.) Whether the context is building a function.
246      enter_context_fn: (function.) A callable that executes the context switch.
247        For example, `graph.as_default` or `eager_mode`.
248      device_stack: If applicable, the device function stack for this graph.
249        When breaking out of graphs in init_scope, the innermost nonempty device
250        stack is used. Eager contexts put `None` here and the value is never
251        used.
252    """
253
254    self.stack.append(
255        ContextSwitch(is_building_function, enter_context_fn, device_stack))
256
257  def pop(self):
258    """Pop the stack."""
259
260    self.stack.pop()
261
262
263@tf_export("config.LogicalDevice")
264class LogicalDevice(
265    collections.namedtuple("LogicalDevice", ["name", "device_type"])):
266  """Abstraction for a logical device initialized by the runtime.
267
268  A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
269  `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
270  and operations can be placed on a specific logical device by calling
271  `tf.device` with a specified `tf.config.LogicalDevice`.
272
273  Fields:
274    name: The fully qualified name of the device. Can be used for Op or function
275      placement.
276    device_type: String declaring the type of device such as "CPU" or "GPU".
277  """
278  pass
279
280
281@tf_export("config.LogicalDeviceConfiguration",
282           "config.experimental.VirtualDeviceConfiguration")
283class LogicalDeviceConfiguration(
284    collections.namedtuple("LogicalDeviceConfiguration",
285                           ["memory_limit", "experimental_priority"])):
286  """Configuration class for a logical devices.
287
288  The class specifies the parameters to configure a `tf.config.PhysicalDevice`
289  as it is initialized to a `tf.config.LogicalDevice` during runtime
290  initialization. Not all fields are valid for all device types.
291
292  See `tf.config.get_logical_device_configuration` and
293  `tf.config.set_logical_device_configuration` for usage examples.
294
295  Fields:
296    memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
297      device. Currently only supported for GPUs.
298    experimental_priority: (optional) Priority to assign to a virtual device.
299      Lower values have higher priorities and 0 is the default.
300      Within a physical GPU, the GPU scheduler will prioritize ops on virtual
301      devices with higher priority. Currently only supported for Nvidia GPUs.
302  """
303
304  def __new__(cls, memory_limit=None, experimental_priority=None):
305    return super(LogicalDeviceConfiguration,
306                 cls).__new__(cls, memory_limit, experimental_priority)
307
308
309@tf_export("config.PhysicalDevice")
310class PhysicalDevice(
311    collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
312  """Abstraction for a locally visible physical device.
313
314  TensorFlow can utilize various devices such as the CPU or multiple GPUs
315  for computation. Before initializing a local device for use, the user can
316  customize certain properties of the device such as it's visibility or memory
317  configuration.
318
319  Once a visible `tf.config.PhysicalDevice` is initialized one or more
320  `tf.config.LogicalDevice` objects are created. Use
321  `tf.config.set_visible_devices` to configure the visibility of a physical
322  device and `tf.config.set_logical_device_configuration` to configure multiple
323  `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
324  useful when separation between models is needed or to simulate a multi-device
325  environment.
326
327  Fields:
328    name: Unique identifier for device.
329    device_type: String declaring the type of device such as "CPU" or "GPU".
330  """
331  pass
332
333
334class _AtomicCounter(object):
335  """A simple atomic counter."""
336
337  __slots__ = ["_value", "_lock"]
338
339  def __init__(self):
340    self._value = 0
341    self._lock = threading.Lock()
342
343  def increment_and_get(self):
344    with self._lock:
345      self._value += 1
346      return self._value
347
348
349_context_id_counter = _AtomicCounter()
350
351
352class _TensorCacheDeleter(object):
353  """Deletes tensor caches for a given context."""
354
355  __slots__ = ["_context_id"]
356
357  def __init__(self, context_id):
358    self._context_id = context_id
359
360  def __del__(self):
361    if _tensor_caches_map is None:
362      return
363    if self._context_id in _tensor_caches_map:
364      del _tensor_caches_map[self._context_id]
365
366
367# TODO(agarwal): rename to EagerContext / EagerRuntime ?
368# TODO(agarwal): consider keeping the corresponding Graph here.
369class Context(object):
370  """Environment in which eager operations execute."""
371
372  # TODO(agarwal): create and link in some documentation for `execution_mode`.
373  # pylint: disable=redefined-outer-name
374  def __init__(self,
375               config=None,
376               device_policy=None,
377               execution_mode=None,
378               server_def=None):
379    """Creates a new Context.
380
381    Args:
382      config: (Optional.) A `ConfigProto` protocol buffer with configuration
383        options for the Context. Note that a lot of these options may be
384        currently unimplemented or irrelevant when eager execution is enabled.
385      device_policy: (Optional.) What policy to use when trying to run an
386        operation on a device with inputs which are not on that device. When set
387        to None, an appropriate value will be picked automatically. The value
388        picked may change between TensorFlow releases.  Defaults to
389        DEVICE_PLACEMENT_SILENT.
390        Valid values:
391        - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
392          correct.
393        - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
394          device but raises a warning.
395        - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
396          performance problems.
397        - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
398          raising errors on the other ones.
399      execution_mode: (Optional.) Policy controlling how operations dispatched
400        are actually executed. When set to None, an appropriate value will be
401        picked automatically. The value picked may change between TensorFlow
402        releases.
403        Valid values:
404        - SYNC: executes each operation synchronously.
405        - ASYNC: executes each operation asynchronously. These operations may
406          return "non-ready" handles.
407      server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
408        on remote devices. GrpcServers need to be started by creating an
409        identical server_def to this, and setting the appropriate task_indexes,
410        so that the servers can communicate. It will then be possible to execute
411        operations on remote devices.
412
413    Raises:
414     ValueError: If execution_mode is not valid.
415    """
416    # This _id is used only to index the tensor caches.
417    # TODO(iga): Remove this when tensor caches are moved to C++.
418    self._id = _context_id_counter.increment_and_get()
419    self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
420    _tensor_caches_map[self._id] = _TensorCaches()
421
422    self._config = config
423    self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
424        self,
425        is_eager=lambda: default_execution_mode == EAGER_MODE,
426        device_spec=_starting_device_spec)
427    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
428    self._context_handle = None
429    self._context_devices = None
430    self._seed = None
431    self._initialize_lock = threading.Lock()
432    self._initialized = False
433    if device_policy is None:
434      device_policy = DEVICE_PLACEMENT_SILENT
435    self._device_policy = device_policy
436    self._mirroring_policy = None
437    if execution_mode not in (None, SYNC, ASYNC):
438      raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
439                       execution_mode)
440    if execution_mode is None:
441      execution_mode = SYNC
442    self._default_is_async = execution_mode == ASYNC
443    self._use_tfrt = is_tfrt_enabled()
444    self._use_tfrt_distributed_runtime = None
445    self._run_eager_op_as_function = run_eager_op_as_function_enabled()
446    self._server_def = server_def
447    self._collective_ops_server_def = None
448    self._collective_leader = None
449    self._collective_scoped_allocator_enabled_ops = None
450    self._collective_use_nccl_communication = None
451    self._collective_device_filters = None
452    self._coordination_service = None
453
454    self._device_lock = threading.Lock()
455    self._physical_devices = None
456    self._physical_device_to_index = None
457    self._visible_device_list = []
458    self._memory_growth_map = None
459    self._virtual_device_map = {}
460
461    # Values set after construction
462    self._optimizer_jit = None
463    self._intra_op_parallelism_threads = None
464    self._inter_op_parallelism_threads = None
465    self._soft_device_placement = None
466    self._log_device_placement = None
467    self._enable_mlir_graph_optimization = None
468    self._optimizer_experimental_options = {}
469
470    _python_eager_context_create_counter.get_cell().increase_by(1)
471
472  # pylint: enable=redefined-outer-name
473
474  def _set_global_seed(self, seed):
475    """Set a global eager mode seed for random ops."""
476    self._seed = seed
477    # `random.Random(seed)` needs `seed` to be hashable, while values of type
478    # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
479    # to int.
480    try:
481      hash(seed)
482    except TypeError:
483      seed = int(np.array(seed))
484    self._rng = random.Random(seed)
485    # Also clear the kernel cache, to reset any existing seeds
486    if self._context_handle is not None:
487      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
488
489  def _internal_operation_seed(self):
490    """Returns a fake operation seed.
491
492      In eager mode, user shouldn't set or depend on operation seed.
493      Here, we generate a random seed based on global seed to make
494      operation's randomness different and depend on the global seed.
495
496    Returns:
497      A fake operation seed based on global seed.
498    """
499    return self._rng.randint(0, _MAXINT32)
500
501  def _initialize_logical_devices(self):
502    """Helper to initialize devices."""
503    # Store list of devices
504    logical_devices = []
505    context_devices = []
506    device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
507    try:
508      self._num_gpus = 0
509      for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
510        dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
511        context_devices.append(pydev.canonical_name(dev_name))
512        spec = pydev.DeviceSpec.from_string(dev_name)
513        # If the job is localhost, we assume that the cluster has not yet been
514        # configured and thus clear the job, replica & task.
515        if spec.job == "localhost":
516          spec = spec.replace(job=None, replica=None, task=None)
517        logical_devices.append(
518            LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
519        dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
520        if dev_type == "GPU":
521          self._num_gpus += 1
522
523    finally:
524      self._logical_devices = logical_devices
525      self._context_devices = context_devices
526      pywrap_tfe.TF_DeleteDeviceList(device_list)
527
528  def ensure_initialized(self):
529    """Initialize handle and devices if not already done so."""
530    if self._initialized:
531      return
532    with self._initialize_lock:
533      if self._initialized:
534        return
535      assert self._context_devices is None
536      opts = pywrap_tfe.TFE_NewContextOptions()
537      try:
538        config_str = self.config.SerializeToString()
539        pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
540        if self._device_policy is not None:
541          pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
542              opts, self._device_policy)
543        if self._mirroring_policy is not None:
544          pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
545              opts, self._mirroring_policy)
546        if self._default_is_async == ASYNC:
547          pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
548        if self._use_tfrt is not None:
549          pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
550        # pylint: disable=g-backslash-continuation
551        if self._use_tfrt is not None and \
552            self._use_tfrt_distributed_runtime is not None:
553          pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime(
554              opts, self._use_tfrt_distributed_runtime)
555        pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(
556            opts, self._run_eager_op_as_function)
557        context_handle = pywrap_tfe.TFE_NewContext(opts)
558      finally:
559        pywrap_tfe.TFE_DeleteContextOptions(opts)
560      assert not (self._server_def and self._collective_ops_server_def), (
561          "Cannot enable remote execution as well as collective ops at the "
562          "moment. If this is important to you, please file an issue.")
563      if self._server_def is not None:
564        server_def_str = self._server_def.SerializeToString()
565        pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
566                                           server_def_str)
567      elif self._collective_ops_server_def is not None:
568        server_def_str = self._collective_ops_server_def.SerializeToString()
569        pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
570
571      self._context_handle = context_handle
572      self._initialize_logical_devices()
573      self._initialized = True
574
575  def _clear_caches(self):
576    self.ones_rank_cache().flush()
577    self.zeros_cache().flush()
578    pywrap_tfe.TFE_ClearScalarCache()
579
580  def get_server_def(self):
581    return self._server_def
582
583  def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
584    """Allow setting a server_def on the context.
585
586    When a server def is replaced, it effectively clears a bunch of caches
587    within the context. If you attempt to use a tensor object that was pointing
588    to a tensor on the remote device, it will raise an error.
589
590    Args:
591      server_def: A tensorflow::ServerDef proto. Enables execution on remote
592        devices.
593      keep_alive_secs: Num. seconds after which the remote end will hang up. As
594        long as the client is still alive, the server state for the context will
595        be kept alive. If the client is killed (or there is some failure), the
596        server will clean up its context keep_alive_secs after the final RPC it
597        receives.
598
599    Raises:
600      ValueError: if server_def is None.
601    """
602    if not server_def:
603      raise ValueError("server_def is None.")
604
605    self._server_def = server_def
606
607    if self._context_handle:
608      server_def_str = server_def.SerializeToString()
609      pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
610                                         server_def_str)
611      self._initialize_logical_devices()
612
613    # Clear all the caches in case there are remote tensors in them.
614    self._clear_caches()
615
616  def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
617    """Update a server_def on the context.
618
619    Args:
620      server_def: A tensorflow::ServerDef proto. Enables execution on remote
621        devices.
622      keep_alive_secs: Num. seconds after which the remote end will hang up. As
623        long as the client is still alive, the server state for the context will
624        be kept alive. If the client is killed (or there is some failure), the
625        server will clean up its context keep_alive_secs after the final RPC it
626        receives.
627
628    Raises:
629      ValueError: if server_def is None.
630    """
631    if not server_def:
632      raise ValueError("server_def is None.")
633
634    self._server_def = server_def
635
636    if self._context_handle:
637      server_def_str = server_def.SerializeToString()
638      pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
639                                            keep_alive_secs, server_def_str)
640      self._initialize_logical_devices()
641
642    self._clear_caches()
643
644  def check_alive(self, worker_name):
645    """Checks whether a remote worker is alive or not.
646
647    Args:
648      worker_name: a string representing the remote worker. It must be a fully
649      specified name like "/job:worker/replica:0/task:0".
650
651    Returns:
652      a boolean indicating whether the remote worker is alive or not.
653
654    Raises:
655      ValueError: if context is not initialized.
656    """
657    # TODO(yuefengz): support checking multiple workers.
658    if self._context_handle:
659      return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
660    else:
661      raise ValueError("Context is not initialized.")
662
663  def sync_executors(self):
664    """Sync both local executors and the ones on remote workers.
665
666    In async execution mode, local function calls can return before the
667    corresponding remote op/function execution requests are completed. Calling
668    this method creates a synchronization barrier for remote executors. It only
669    returns when all remote pending nodes are finished, potentially with errors
670    if any remote executors are in error state.
671
672    Raises:
673      ValueError: if context is not initialized.
674    """
675    if self._context_handle:
676      pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
677    else:
678      raise ValueError("Context is not initialized.")
679
680  def clear_executor_errors(self):
681    """Clear errors in both local executors and remote workers.
682
683    After receiving errors from remote workers, additional requests on the fly
684    could further taint the status on the remote workers due to the async nature
685    of remote execution. Calling this method block on waiting for all pending
686    nodes in remote executors to finish and clear their error statuses.
687
688    Raises:
689      ValueError: if context is not initialized.
690    """
691    if self._context_handle:
692      pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
693    else:
694      raise ValueError("Context is not initialized.")
695
696  def enable_coordination_service(self, service_type):
697    if self._context_handle:
698      logging.warn("Configuring coordination service type may not be effective "
699                   "because the context is already initialized.")
700    self._coordination_service = service_type
701
702  @property
703  def coordination_service(self):
704    return self._coordination_service
705
706  def set_config_key_value(self, key, value):
707    ensure_initialized()
708    pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value)
709
710  def get_config_key_value(self, key):
711    ensure_initialized()
712    with c_api_util.tf_buffer() as buffer_:
713      pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, buffer_)
714      value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
715    return value
716
717  def delete_config_key_value(self, key):
718    ensure_initialized()
719    pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key)
720
721  def report_error_to_cluster(self, error_code, error_message):
722    """Report error to other members in a multi-client cluster.
723
724    Args:
725      error_code: a `tf.errors` error code.
726      error_message: a string. The error message.
727    """
728    if self._context_handle:
729      pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code,
730                                          error_message)
731    else:
732      raise ValueError("Context is not initialized.")
733
734  def clear_kernel_cache(self):
735    """Clear kernel cache and reset all stateful kernels."""
736    if self._context_handle is not None:
737      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
738
739  def enable_collective_ops(self, server_def):
740    """Enable distributed collective ops with an appropriate server_def.
741
742    Args:
743      server_def: A tensorflow::ServerDef proto. Enables execution on remote
744        devices.
745
746    Raises:
747      ValueError: if server_def is None.
748      RuntimeError: if this method is not called at program startup.
749    """
750    if not server_def:
751      raise ValueError("server_def is None.")
752
753    self._collective_ops_server_def = server_def
754
755    # TODO(b/129298253): Allow creating datasets/tensors before enabling
756    # collective ops.
757    if self._context_handle is not None:
758      logging.warning("Enabling collective ops after program startup may cause "
759                      "error when accessing previously created tensors.")
760      with self._initialize_lock:
761        assert self._initialized
762        server_def_str = self._collective_ops_server_def.SerializeToString()
763        pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
764        self._initialize_logical_devices()
765        self._clear_caches()
766
767  def configure_collective_ops(
768      self,
769      collective_leader="",
770      scoped_allocator_enabled_ops=("CollectiveReduce",),
771      use_nccl_communication=False,
772      device_filters=None):
773    """Configure collective ops.
774
775      Collective group leader is necessary for collective ops to run, other
776      configurations are mainly for the purpose of performance.
777
778    Args:
779      collective_leader: a device string for collective leader, e.g.
780        "/job:worker/replica:0/task:0"; empty string means local execution of
781          collective ops.
782      scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
783        allocator to run with.
784      use_nccl_communication: whether to use nccl communication for collective
785        ops.
786      device_filters: a tuple or a list of device strings. If set, corresponding
787        task can only see the devices filtered by these device filters.
788
789    Raises:
790      RuntimeError: if this method is not called at program startup.
791    """
792    if self._collective_leader is not None:
793      if (self._collective_leader != collective_leader or
794          self._collective_scoped_allocator_enabled_ops !=
795          scoped_allocator_enabled_ops or
796          self._collective_use_nccl_communication != use_nccl_communication or
797          self._collective_device_filters != device_filters):
798        raise ValueError("Collective ops are already configured.")
799      else:
800        return
801
802    if self._context_handle is not None:
803      raise RuntimeError("Collective ops must be configured at program startup")
804
805    self._collective_leader = collective_leader
806    self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
807    self._collective_use_nccl_communication = use_nccl_communication
808    self._collective_device_filters = device_filters
809
810  def abort_collective_ops(self, code, message):
811    """Abort the collective ops.
812
813    This is intended to be used when a peer failure is detected, which allows
814    the user to handle the case instead of hanging. This aborts all on-going
815    collectives. After all subsequent collectives error immediately, and you
816    need to reset_context() to use collectives again.
817
818    Args:
819      code: a `tf.errors` error code.
820      message: a string. The error message.
821    """
822    self.ensure_initialized()
823    pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
824
825  def check_collective_ops_peer_health(self, task, timeout_in_ms):
826    """Check collective peer health.
827
828    This probes each task to see if they're still alive. Note that restarted
829    tasks are considered a different one, and they're considered not healthy.
830
831    This should only be used in multi client multi worker training.
832
833    Args:
834      task: a task string, must be in the format of /job:xxx/replica:0/task:N.
835      timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
836
837    Raises:
838      tf.errors.UnavailableError: when a peer is down.
839      tf.errors.FailedPreconditionError: when a peer is a different one from the
840        one this task has talked to, e.g. the peer has restarted.
841      tf.errors.InvalidArgumentError: when the task string is invalid.
842    """
843    self.ensure_initialized()
844    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
845                                                timeout_in_ms)
846
847  @property
848  def _handle(self):
849    if self._context_handle is None:
850      raise AssertionError("Context must be initialized first.")
851
852    return self._context_handle
853
854  @property
855  def _devices(self):
856    if self._context_devices is None:
857      raise AssertionError("Context must be initialized first.")
858
859    return self._context_devices
860
861  def __str__(self):
862    if self._context_handle is None:
863      return "Eager TensorFlow Context. Devices currently uninitialized."
864    else:
865      devices = self._devices
866      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
867      for i, d in enumerate(devices):
868        lines.append("   Device %d: %s" % (i, d))
869      return "\n".join(lines)
870
871  @tf_contextlib.contextmanager
872  def _mode(self, mode):
873    """A context manager to allow setting the mode to EAGER/GRAPH."""
874    ctx = self._thread_local_data
875    old_is_eager = ctx.is_eager
876    ctx.is_eager = mode == EAGER_MODE
877    if mode == EAGER_MODE:
878      # Entering graph mode does not provide us with sufficient information to
879      # record a context switch; graph-based context switches are only logged
880      # when a graph is registered as the default graph.
881      self.context_switches.push(False, eager_mode, None)
882    try:
883      yield
884    finally:
885      ctx.is_eager = old_is_eager
886      if mode == EAGER_MODE:
887        self.context_switches.pop()
888
889  def executing_eagerly(self):
890    """Returns True if current thread has eager executing enabled."""
891    return self._thread_local_data.is_eager
892
893  def ones_rank_cache(self):
894    """Per-device cache for scalars."""
895    return _tensor_caches_map[self._id].ones_rank_cache
896
897  def zeros_cache(self):
898    """Per-device cache for scalars."""
899    return _tensor_caches_map[self._id].zeros_cache
900
901  @property
902  def scope_name(self):
903    """Returns scope name for the current thread."""
904    return self._thread_local_data.scope_name
905
906  @scope_name.setter
907  def scope_name(self, s):
908    """Sets scope name for the current thread."""
909    self._thread_local_data.scope_name = s
910
911  @property
912  def device_name(self):
913    """Returns the device name for the current thread."""
914    return self._thread_local_data.device_name
915
916  @property
917  def device_spec(self):
918    """Returns the device spec for the current thread."""
919    return self._thread_local_data.device_spec
920
921  def _set_device(self, device_name, device_spec):
922    self._thread_local_data.device_name = device_name
923    self._thread_local_data.device_spec = device_spec
924
925  def device(self, name):
926    """Context-manager to force placement of operations and Tensors on a device.
927
928    Args:
929      name: Name of the device or None to get default placement.
930
931    Returns:
932      Context manager that forces device placement.
933
934    Raises:
935      ValueError: If name is not a string or is an invalid device name.
936      RuntimeError: If device scopes are not properly nested.
937    """
938    if isinstance(name, LogicalDevice):
939      name = name.name
940    elif pydev.is_device_spec(name):
941      name = name.to_string()
942    return _EagerDeviceContext(self, name)
943
944  def devices(self):
945    """List of the names of devices available to execute operations."""
946    return self._devices
947
948  def host_address_space(self):
949    self.ensure_initialized()
950    with c_api_util.tf_buffer() as buffer_:
951      pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
952      address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
953    return address_space
954
955  # TODO(fishx): remove this property.
956  @property
957  def execution_mode(self):
958    """Gets execution mode for current thread."""
959    return ASYNC if self.is_async() else SYNC
960
961  @execution_mode.setter
962  def execution_mode(self, mode):
963    """Sets execution mode for current thread."""
964    if mode not in (None, SYNC, ASYNC):
965      raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" %
966                       mode)
967
968    if mode is None:
969      mode = SYNC
970
971    enable_async = (mode == ASYNC)
972    if self.is_async() != enable_async:
973      # Only set the execution mode if the context has already been initialized
974      if self._context_handle is not None:
975        self.executor.wait()
976        executor_new = executor.new_executor(enable_async)
977        self._thread_local_data.executor = executor_new
978        pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
979                                                   executor_new.handle())
980      else:
981        self._default_is_async = enable_async
982
983  def is_async(self):
984    if self._context_handle is not None:
985      return self.executor.is_async()
986    else:
987      return self._default_is_async
988
989  @property
990  def executor(self):
991    self.ensure_initialized()
992    return executor.Executor(
993        pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
994
995  @executor.setter
996  def executor(self, e):
997    self.ensure_initialized()
998    pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
999
1000  @property
1001  def config(self):
1002    """Return the ConfigProto with all runtime deltas applied."""
1003    # Ensure physical devices have been discovered and config has been imported
1004    self._initialize_physical_devices()
1005
1006    config = config_pb2.ConfigProto()
1007    if self._config is not None:
1008      config.CopyFrom(self._config)
1009
1010    if self._optimizer_jit is not None:
1011      config.graph_options.optimizer_options.global_jit_level = (
1012          config_pb2.OptimizerOptions.ON_1
1013          if self._optimizer_jit else config_pb2.OptimizerOptions.OFF)
1014    if self._intra_op_parallelism_threads is not None:
1015      config.intra_op_parallelism_threads = self._intra_op_parallelism_threads
1016    if self._inter_op_parallelism_threads is not None:
1017      config.inter_op_parallelism_threads = self._inter_op_parallelism_threads
1018
1019    if self._soft_device_placement is not None:
1020      config.allow_soft_placement = self._soft_device_placement
1021    else:
1022      config.allow_soft_placement = self.executing_eagerly()
1023
1024    if self._log_device_placement is not None:
1025      config.log_device_placement = self._log_device_placement
1026
1027    is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
1028    config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
1029    if (is_mlir_bridge_enabled ==
1030        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
1031      config.experimental.enable_mlir_bridge = True
1032
1033    if self._enable_mlir_graph_optimization is not None:
1034      config.experimental.enable_mlir_graph_optimization = (
1035          self._enable_mlir_graph_optimization)
1036
1037    def rewriter_toggle(option):
1038      toggle = self._optimizer_experimental_options.get(option, None)
1039      if toggle is None:
1040        return
1041
1042      setattr(config.graph_options.rewrite_options, option,
1043              (rewriter_config_pb2.RewriterConfig.ON
1044               if toggle else rewriter_config_pb2.RewriterConfig.OFF))
1045
1046    def rewriter_bool(option):
1047      toggle = self._optimizer_experimental_options.get(option, None)
1048      if toggle is None:
1049        return
1050
1051      setattr(config.graph_options.rewrite_options, option, toggle)
1052
1053    rewriter_toggle("layout_optimizer")
1054    rewriter_toggle("constant_folding")
1055    rewriter_toggle("shape_optimization")
1056    rewriter_toggle("remapping")
1057    rewriter_toggle("arithmetic_optimization")
1058    rewriter_toggle("dependency_optimization")
1059    rewriter_toggle("loop_optimization")
1060    rewriter_toggle("function_optimization")
1061    rewriter_toggle("debug_stripper")
1062    rewriter_bool("disable_model_pruning")
1063    rewriter_toggle("scoped_allocator_optimization")
1064    rewriter_toggle("pin_to_host_optimization")
1065    rewriter_toggle("implementation_selector")
1066    rewriter_toggle("auto_mixed_precision")
1067    rewriter_toggle("use_plugin_optimizers")
1068    rewriter_bool("disable_meta_optimizer")
1069    nodes = self._optimizer_experimental_options.get("min_graph_nodes", None)
1070    if nodes is not None:
1071      config.graph_options.rewrite_options.min_graph_nodes = nodes
1072
1073    # Compute device counts
1074    config.device_count["CPU"] = 0
1075    config.device_count["GPU"] = 0
1076    for dev in self._physical_devices:
1077      if dev not in self._visible_device_list:
1078        continue
1079
1080      virtual_devices = self._virtual_device_map.get(dev)
1081      if virtual_devices is None:
1082        config.device_count[dev.device_type] += 1
1083      else:
1084        config.device_count[dev.device_type] += len(virtual_devices)
1085
1086    # Configure gpu_options
1087    gpu_options = self._compute_gpu_options()
1088    config.gpu_options.MergeFrom(gpu_options)
1089
1090    # Configure collective ops
1091    if self._collective_leader:
1092      config.experimental.collective_group_leader = self._collective_leader
1093    if self._collective_scoped_allocator_enabled_ops:
1094      rewrite_options = config.graph_options.rewrite_options
1095      rewrite_options.scoped_allocator_optimization = (
1096          rewriter_config_pb2.RewriterConfig.ON)
1097      del rewrite_options.scoped_allocator_opts.enable_op[:]
1098      for op in self._collective_scoped_allocator_enabled_ops:
1099        rewrite_options.scoped_allocator_opts.enable_op.append(op)
1100    if self._collective_use_nccl_communication:
1101      config.experimental.collective_nccl = True
1102    if self._collective_device_filters:
1103      del config.device_filters[:]
1104      for f in self._collective_device_filters:
1105        config.device_filters.append(f)
1106
1107    # Configure coordination service
1108    if self._coordination_service:
1109      config.experimental.coordination_service = self._coordination_service
1110
1111    return config
1112
1113  def _compute_gpu_options(self):
1114    """Build the GPUOptions proto."""
1115    visible_device_list = []
1116    virtual_devices = []
1117    gpu_index = -1
1118    memory_growths = set()
1119    for dev in self.list_physical_devices("GPU"):
1120      gpu_index += 1
1121
1122      if dev not in self._visible_device_list:
1123        continue
1124
1125      growth = self._memory_growth_map[dev]
1126      memory_growths.add(growth)
1127      visible_device_list.append(str(gpu_index))
1128
1129      if self._virtual_device_map:
1130        vdevs = self._virtual_device_map.get(dev, [])
1131        device_limits = []
1132        priority = []
1133        for virt_dev in vdevs:
1134          device_limits.append(virt_dev.memory_limit)
1135          if virt_dev.experimental_priority is not None:
1136            priority.append(virt_dev.experimental_priority)
1137        # If priority is specified, it must be specified for all virtual
1138        # devices.
1139        if priority and len(device_limits) != len(priority):
1140          raise ValueError("priority must be specified for all virtual devices")
1141
1142        virtual_devices.append(
1143            config_pb2.GPUOptions.Experimental.VirtualDevices(
1144                memory_limit_mb=device_limits, priority=priority))
1145
1146    # Only compute growth if virtual devices have not been configured and we
1147    # have GPUs
1148    if not virtual_devices and memory_growths:
1149      if len(memory_growths) > 1:
1150        raise ValueError("Memory growth cannot differ between GPU devices")
1151      allow_growth = memory_growths.pop()
1152    else:
1153      allow_growth = None
1154
1155    return config_pb2.GPUOptions(
1156        allow_growth=allow_growth,
1157        visible_device_list=",".join(visible_device_list),
1158        experimental=config_pb2.GPUOptions.Experimental(
1159            virtual_devices=virtual_devices))
1160
1161  @property
1162  def function_call_options(self):
1163    """Returns function call options for current thread.
1164
1165    Note that the returned object is still referenced by the eager context.
1166
1167    Returns: the FunctionCallOptions for current thread.
1168    """
1169    if self._thread_local_data.function_call_options is None:
1170      config = self.config
1171
1172      # Default to soft placement for functions unless specified
1173      if self._soft_device_placement is None:
1174        config.allow_soft_placement = True
1175      self._thread_local_data.function_call_options = FunctionCallOptions(
1176          config_proto=config)
1177
1178    return self._thread_local_data.function_call_options
1179
1180  @function_call_options.setter
1181  def function_call_options(self, options):
1182    """Returns function call options for current thread."""
1183    self._thread_local_data.function_call_options = options
1184
1185  def num_gpus(self):
1186    """The number of GPUs available to execute operations."""
1187    self.ensure_initialized()
1188    return self._num_gpus
1189
1190  def add_function(self, fn):
1191    """Add a function definition to the context.
1192
1193    Once added, the function (identified by its name) can be executed like any
1194    other operation.
1195
1196    Args:
1197      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
1198    """
1199    self.ensure_initialized()
1200    pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
1201
1202  def add_function_def(self, fdef):
1203    """Add a function definition to the context.
1204
1205    Once added, the function (identified by its name) can be executed like any
1206    other operation.
1207
1208    Args:
1209      fdef: A FunctionDef protocol buffer message.
1210    """
1211    self.ensure_initialized()
1212    fdef_string = fdef.SerializeToString()
1213    pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
1214                                         len(fdef_string))
1215
1216  def get_function_def(self, name):
1217    """Get a function definition from the context.
1218
1219    Args:
1220      name: function signature name.
1221
1222    Returns:
1223      The requested FunctionDef.
1224
1225    Raises:
1226      tf.errors.NotFoundError: if name is not the name of a registered function.
1227    """
1228    with c_api_util.tf_buffer() as buffer_:
1229      pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
1230      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1231    function_def = function_pb2.FunctionDef()
1232    function_def.ParseFromString(proto_data)
1233
1234    return function_def
1235
1236  def register_custom_device(self, device_capsule, device_name,
1237                             device_info_capsule):
1238    """Calls TFE_RegisterCustomDevice. See the non-member function."""
1239    self.ensure_initialized()
1240    pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
1241                                           device_name, device_info_capsule)
1242
1243  def pack_eager_tensors(self, tensors):
1244    """Pack multiple `EagerTensor`s of the same dtype and shape.
1245
1246    Args:
1247      tensors: a list of EagerTensors to pack.
1248
1249    Returns:
1250      A packed EagerTensor.
1251    """
1252    self.ensure_initialized()
1253    return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
1254
1255  def list_function_names(self):
1256    """Get a list of names of registered functions.
1257
1258    Returns:
1259      A set of names of all registered functions for the context.
1260    """
1261    self.ensure_initialized()
1262    return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
1263
1264  def remove_function(self, name):
1265    """Remove a function from the context.
1266
1267    Once removed, the function cannot be executed anymore.
1268
1269    Args:
1270      name: function signature name.
1271    """
1272    self.ensure_initialized()
1273    pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
1274
1275  def has_function(self, name):
1276    """Check if a function `name` is registered."""
1277    self.ensure_initialized()
1278    return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
1279
1280  def add_op_callback(self, callback):
1281    """Add a post-op callback to the context.
1282
1283    A post-op callback is invoked immediately after an eager operation or
1284    function has finished execution or after a op has been added to a graph,
1285    providing access to the op's type, name input and output tensors. Multiple
1286    op callbacks can be added, in which case the callbacks will be invoked in
1287    the order in which they are added.
1288
1289    Args:
1290      callback: a callable of the signature `f(op_type, inputs, attrs, outputs,
1291        op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for
1292        details on the function signature and its semantics.
1293    """
1294    if callback not in self._thread_local_data.op_callbacks:
1295      self._thread_local_data.op_callbacks.append(callback)
1296
1297  def remove_op_callback(self, callback):
1298    """Remove an already-registered op callback.
1299
1300    Args:
1301      callback: The op callback to be removed.
1302
1303    Raises:
1304      KeyError: If `callback` is not already registered.
1305    """
1306    if callback not in self._thread_local_data.op_callbacks:
1307      raise KeyError("The specified op callback has not been registered, "
1308                     "and hence cannot be removed.")
1309    del self._thread_local_data.op_callbacks[
1310        self._thread_local_data.op_callbacks.index(callback)]
1311
1312  @property
1313  def op_callbacks(self):
1314    return self._thread_local_data.op_callbacks
1315
1316  @property
1317  def invoking_op_callbacks(self):
1318    return self._thread_local_data.invoking_op_callbacks
1319
1320  @invoking_op_callbacks.setter
1321  def invoking_op_callbacks(self, value):
1322    self._thread_local_data.invoking_op_callbacks = value
1323
1324  def _initialize_physical_devices(self, reinitialize=False):
1325    """Gets local devices visible to the system.
1326
1327    Args:
1328      reinitialize: If True, reinitializes self._physical_devices  so that
1329        dynamic registered devices will also be visible to the python front-end.
1330    """
1331    # We lazy initialize self._physical_devices since we do not want to do this
1332    # the constructor since the backend may not be initialized yet.
1333    with self._device_lock:
1334      if not reinitialize and self._physical_devices is not None:
1335        return
1336
1337      devs = pywrap_tfe.TF_ListPhysicalDevices()
1338      self._physical_devices = [
1339          PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1340          for d in devs
1341      ]
1342      self._physical_device_to_index = {
1343          p: i for i, p in enumerate(self._physical_devices)
1344      }
1345
1346      self._visible_device_list = list(self._physical_devices)
1347      self._memory_growth_map = {
1348          d: None for d in self._physical_devices if d.device_type == "GPU"
1349      }
1350
1351    # Import device settings that may have been passed into the constructor
1352    self._import_config()
1353
1354  def reinitialize_physical_devices(self):
1355    """Gets local devices visible to the system."""
1356    # Reinitialize the physical device list after registering
1357    # the pluggable device.
1358    self._initialize_physical_devices(True)
1359
1360  def list_physical_devices(self, device_type=None):
1361    """List local devices visible to the system.
1362
1363    This API allows a client to query the devices before they have been
1364    initialized by the eager runtime. Additionally a user can filter by device
1365    type, to get only CPUs or GPUs.
1366
1367    Args:
1368      device_type: Optional device type to limit results to
1369
1370    Returns:
1371      List of PhysicalDevice objects.
1372    """
1373    self._initialize_physical_devices()
1374
1375    if device_type is None:
1376      return list(self._physical_devices)
1377
1378    return [d for d in self._physical_devices if d.device_type == device_type]
1379
1380  def get_device_details(self, device):  # pylint: disable=redefined-outer-name
1381    """Returns details about a physical devices.
1382
1383    Args:
1384      device: A `tf.config.PhysicalDevice` returned by
1385        `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
1386
1387    Returns:
1388      A dict with string keys.
1389    """
1390    if not isinstance(device, PhysicalDevice):
1391      raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
1392                       "%s" % (device,))
1393    if (self._physical_device_to_index is None or
1394        device not in self._physical_device_to_index):
1395      raise ValueError("The PhysicalDevice must be one obtained from "
1396                       "calling `tf.config.list_physical_devices`, but got: "
1397                       "%s" % (device,))
1398    index = self._physical_device_to_index[device]
1399    details = pywrap_tfe.TF_GetDeviceDetails(index)
1400
1401    # Change compute_capability from a string to a tuple
1402    if "compute_capability" in details:
1403      try:
1404        major, minor = details["compute_capability"].split(".")
1405        details["compute_capability"] = (int(major), int(minor))
1406      except ValueError:
1407        raise RuntimeError("Device returned compute capability an in invalid "
1408                           "format: %s" % details["compute_capability"])
1409    return details
1410
1411  def _import_config(self):
1412    """Import config if passed in during construction.
1413
1414    If Context was created with a ConfigProto such as when calling
1415    tf.compat.v1.enable_eager_execution(), then we need to pull out the
1416    various pieces we might be replacing and import then into our internal
1417    class representation.
1418    """
1419    if self._config is None:
1420      return
1421
1422    num_cpus = self._config.device_count.get("CPU", 1)
1423    if num_cpus != 1:
1424      cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
1425      if num_cpus == 0:
1426        self.set_visible_devices([], "CPU")
1427      elif num_cpus > 1:
1428        self.set_logical_device_configuration(
1429            cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)])
1430
1431    # Parse GPU options
1432    gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
1433
1434    # If there are no GPUs detected, simply ignore all the GPU options passed in
1435    # rather than doing any validation checks.
1436    if not gpus:
1437      return
1438
1439    gpu_count = self._config.device_count.get("GPU", None)
1440
1441    visible_gpus = []
1442    # TODO(gjn): Handle importing existing virtual GPU configuration
1443    visible_indices = self._config.gpu_options.visible_device_list
1444    if visible_indices:
1445      for index in visible_indices.split(","):
1446        if int(index) >= len(gpus):
1447          raise ValueError("Invalid visible device index: %s" % index)
1448        visible_gpus.append(gpus[int(index)])
1449    else:
1450      visible_gpus = gpus
1451
1452    if gpu_count is not None:
1453      visible_gpus = visible_gpus[:gpu_count]
1454
1455    self.set_visible_devices(visible_gpus, "GPU")
1456
1457  def list_logical_devices(self, device_type=None):
1458    """Return logical devices."""
1459    self.ensure_initialized()
1460    if device_type is None:
1461      return list(self._logical_devices)
1462
1463    return [d for d in self._logical_devices if d.device_type == device_type]
1464
1465  def get_visible_devices(self, device_type=None):
1466    """Get the list of visible devices."""
1467    self._initialize_physical_devices()
1468
1469    if device_type is None:
1470      return list(self._visible_device_list)
1471
1472    return [
1473        d for d in self._visible_device_list if d.device_type == device_type
1474    ]
1475
1476  def set_visible_devices(self, devices, device_type=None):
1477    """Set the list of visible devices."""
1478    self._initialize_physical_devices()
1479
1480    if not isinstance(devices, list):
1481      devices = [devices]
1482
1483    for d in devices:
1484      if d not in self._physical_devices:
1485        raise ValueError("Unrecognized device: %s" % repr(d))
1486      if device_type is not None and d.device_type != device_type:
1487        raise ValueError("Unrecognized device: %s" % repr(d))
1488
1489    visible_device_list = []
1490    if device_type is not None:
1491      visible_device_list = [
1492          d for d in self._visible_device_list if d.device_type != device_type
1493      ]
1494
1495    visible_device_list += devices
1496
1497    if self._visible_device_list == visible_device_list:
1498      return
1499
1500    if self._context_handle is not None:
1501      raise RuntimeError(
1502          "Visible devices cannot be modified after being initialized")
1503
1504    self._visible_device_list = visible_device_list
1505
1506  def get_memory_info(self, dev):
1507    """Returns a dict of memory info for the device."""
1508    self._initialize_physical_devices()
1509    self.ensure_initialized()
1510    return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
1511
1512  def reset_memory_stats(self, dev):
1513    """Resets the tracked memory stats for the device."""
1514    self._initialize_physical_devices()
1515    self.ensure_initialized()
1516    pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev)
1517
1518  def get_memory_growth(self, dev):
1519    """Get if memory growth is enabled for a PhysicalDevice."""
1520    self._initialize_physical_devices()
1521
1522    if dev not in self._physical_devices:
1523      raise ValueError("Unrecognized device: %s" % repr(dev))
1524
1525    return self._memory_growth_map[dev]
1526
1527  def set_memory_growth(self, dev, enable):
1528    """Set if memory growth should be enabled for a PhysicalDevice."""
1529    self._initialize_physical_devices()
1530
1531    if dev not in self._physical_devices:
1532      raise ValueError("Unrecognized device: %s" % repr(dev))
1533
1534    if dev in self._virtual_device_map:
1535      raise ValueError(
1536          "Cannot set memory growth on device when virtual devices configured")
1537
1538    if dev.device_type != "GPU":
1539      raise ValueError("Cannot set memory growth on non-GPU devices")
1540
1541    if self._memory_growth_map.get(dev) == enable:
1542      return
1543
1544    if self._context_handle is not None:
1545      raise RuntimeError(
1546          "Physical devices cannot be modified after being initialized")
1547
1548    self._memory_growth_map[dev] = enable
1549
1550  def get_logical_device_configuration(self, dev):
1551    """Get the virtual device configuration for a PhysicalDevice."""
1552    self._initialize_physical_devices()
1553
1554    if dev not in self._physical_devices:
1555      raise ValueError("Unrecognized device: %s" % repr(dev))
1556
1557    return self._virtual_device_map.get(dev)
1558
1559  def set_logical_device_configuration(self, dev, virtual_devices):
1560    """Set the virtual device configuration for a PhysicalDevice."""
1561    self._initialize_physical_devices()
1562
1563    if dev not in self._physical_devices:
1564      raise ValueError("Unrecognized device: %s" % repr(dev))
1565
1566    if dev.device_type == "CPU":
1567      for vdev in virtual_devices:
1568        if vdev.memory_limit is not None:
1569          raise ValueError("Setting memory limit on CPU virtual devices is "
1570                           "currently not supported")
1571        if vdev.experimental_priority is not None:
1572          raise ValueError("Setting experimental_priority on CPU virtual "
1573                           " devices is currently not supported")
1574    elif dev.device_type == "GPU":
1575      for vdev in virtual_devices:
1576        if vdev.memory_limit is None:
1577          raise ValueError(
1578              "Setting memory limit is required for GPU virtual devices")
1579    else:
1580      raise ValueError("Virtual devices are not supported for %s" %
1581                       dev.device_type)
1582
1583    if self._virtual_device_map.get(dev) == virtual_devices:
1584      return
1585
1586    if self._context_handle is not None:
1587      raise RuntimeError(
1588          "Virtual devices cannot be modified after being initialized")
1589
1590    self._virtual_device_map[dev] = virtual_devices
1591
1592  def set_logical_cpu_devices(self, num_cpus, prefix=""):
1593    """Set virtual CPU devices in context.
1594
1595    If virtual CPU devices are already configured at context initialization
1596    by tf.config.set_logical_device_configuration(), this method should not be
1597    called.
1598
1599    Args:
1600      num_cpus: Number of virtual CPUs.
1601      prefix: Device name prefix.
1602
1603    Raises:
1604     RuntimeError: If virtual CPUs are already configured at context
1605     initialization.
1606    """
1607    self.ensure_initialized()
1608    # Error out if there are already multiple logical CPU in the context.
1609    if len(self.list_logical_devices("CPU")) > 1:
1610      raise RuntimeError("Virtual CPUs already set, cannot modify again.")
1611
1612    pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix)
1613    self._initialize_logical_devices()
1614
1615  def get_compiler_ir(self, device_name, function_name, args, stage="hlo"):
1616    return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
1617                                       stage, device_name, args)
1618
1619  @deprecated(
1620      None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
1621  def enable_xla_devices(self):
1622    """Enables XLA:CPU and XLA:GPU devices registration."""
1623    pywrap_tfe.TF_EnableXlaDevices()
1624
1625  @property
1626  def enable_mlir_bridge(self):
1627    return pywrap_tfe.TF_IsMlirBridgeEnabled()
1628
1629  @property
1630  def enable_mlir_graph_optimization(self):
1631    return self._enable_mlir_graph_optimization
1632
1633  @enable_mlir_bridge.setter
1634  def enable_mlir_bridge(self, enabled):
1635    pywrap_tfe.TF_EnableMlirBridge(enabled)
1636    self._thread_local_data.function_call_options = None
1637
1638  @enable_mlir_graph_optimization.setter
1639  def enable_mlir_graph_optimization(self, enabled):
1640    self._enable_mlir_graph_optimization = enabled
1641    self._thread_local_data.function_call_options = None
1642
1643  @property
1644  def optimizer_jit(self):
1645    level = self.config.graph_options.optimizer_options.global_jit_level
1646    return (level == config_pb2.OptimizerOptions.ON_1 or
1647            level == config_pb2.OptimizerOptions.ON_2)
1648
1649  @optimizer_jit.setter
1650  def optimizer_jit(self, enabled):
1651    self._optimizer_jit = enabled
1652
1653    self._thread_local_data.function_call_options = None
1654
1655  def get_optimizer_experimental_options(self):
1656    """Get experimental options for the optimizer.
1657
1658    Returns:
1659      Dictionary of current option values
1660    """
1661    rewrite_options = self.config.graph_options.rewrite_options
1662    options = {}
1663
1664    def rewriter_toggle(option):
1665      attr = getattr(rewrite_options, option)
1666      if attr != 0:
1667        options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
1668
1669    def rewriter_bool(option):
1670      options[option] = getattr(rewrite_options, option)
1671
1672    rewriter_toggle("layout_optimizer")
1673    rewriter_toggle("constant_folding")
1674    rewriter_toggle("shape_optimization")
1675    rewriter_toggle("remapping")
1676    rewriter_toggle("arithmetic_optimization")
1677    rewriter_toggle("dependency_optimization")
1678    rewriter_toggle("loop_optimization")
1679    rewriter_toggle("function_optimization")
1680    rewriter_toggle("debug_stripper")
1681    rewriter_bool("disable_model_pruning")
1682    rewriter_toggle("scoped_allocator_optimization")
1683    rewriter_toggle("pin_to_host_optimization")
1684    rewriter_toggle("implementation_selector")
1685    rewriter_toggle("auto_mixed_precision")
1686    rewriter_toggle("use_plugin_optimizers")
1687    rewriter_bool("disable_meta_optimizer")
1688
1689    if rewrite_options.min_graph_nodes != 0:
1690      options["min_graph_nodes"] = rewrite_options.min_graph_nodes
1691
1692    return options
1693
1694  def set_optimizer_experimental_options(self, options):
1695    """Set experimental options for the optimizer.
1696
1697    Args:
1698      options: Dictionary of options to modify
1699    """
1700    self._optimizer_experimental_options.update(options)
1701
1702    self._thread_local_data.function_call_options = None
1703
1704  @property
1705  def intra_op_parallelism_threads(self):
1706    return self.config.intra_op_parallelism_threads
1707
1708  @intra_op_parallelism_threads.setter
1709  def intra_op_parallelism_threads(self, num_threads):
1710    if self._intra_op_parallelism_threads == num_threads:
1711      return
1712
1713    if self._context_handle is not None:
1714      raise RuntimeError(
1715          "Intra op parallelism cannot be modified after initialization.")
1716
1717    self._intra_op_parallelism_threads = num_threads
1718
1719  @property
1720  def inter_op_parallelism_threads(self):
1721    return self.config.inter_op_parallelism_threads
1722
1723  @inter_op_parallelism_threads.setter
1724  def inter_op_parallelism_threads(self, num_threads):
1725    if self._inter_op_parallelism_threads == num_threads:
1726      return
1727
1728    if self._context_handle is not None:
1729      raise RuntimeError(
1730          "Inter op parallelism cannot be modified after initialization.")
1731
1732    self._inter_op_parallelism_threads = num_threads
1733
1734  @property
1735  def soft_device_placement(self):
1736    return self.config.allow_soft_placement
1737
1738  @soft_device_placement.setter
1739  def soft_device_placement(self, enable):
1740    if self._context_handle is not None:
1741      pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
1742
1743    self._soft_device_placement = enable
1744    self._thread_local_data.function_call_options = None
1745
1746  @property
1747  def log_device_placement(self):
1748    return self.config.log_device_placement
1749
1750  @log_device_placement.setter
1751  def log_device_placement(self, enable):
1752    if self._context_handle is not None:
1753      pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
1754
1755    self._log_device_placement = enable
1756    self._thread_local_data.function_call_options = None
1757
1758  @property
1759  def device_policy(self):
1760    # Only get the policy from the context if it has already been initialized
1761    if self._context_handle is not None:
1762      return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
1763
1764    return self._device_policy
1765
1766  @device_policy.setter
1767  def device_policy(self, policy):
1768    if policy is None:
1769      policy = DEVICE_PLACEMENT_SILENT
1770
1771    if self._device_policy != policy:
1772      self._device_policy = policy
1773
1774      # Only set the policy if the context has already been initialized
1775      if self._context_handle is not None:
1776        pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
1777            self._handle, self._device_policy)
1778
1779  @property
1780  def use_tfrt(self):
1781    return self._use_tfrt
1782
1783  @use_tfrt.setter
1784  def use_tfrt(self, tfrt):
1785    """Sets whether to use TFRT."""
1786    if not isinstance(tfrt, bool):
1787      raise ValueError("Expecting a boolean but got %s" % type(tfrt))
1788
1789    if self._use_tfrt != tfrt:
1790      if self._initialized:
1791        raise ValueError("use_tfrt should be set before being initialized.")
1792      self._use_tfrt = tfrt
1793
1794  @property
1795  def use_tfrt_distributed_runtime(self):
1796    return self._use_tfrt_distributed_runtime
1797
1798  @use_tfrt_distributed_runtime.setter
1799  def use_tfrt_distributed_runtime(self, enable):
1800    """Sets whether to use TFRT distributed runtime.
1801
1802    This is only effective when use_tfrt is also true. Note that currently TFRT
1803    distributed runtime is not function complete and this config is for testing
1804    only.
1805    Args:
1806      enable: A boolean to set whether to use TFRT distributed runtime.
1807    """
1808    if not isinstance(enable, bool):
1809      raise ValueError("Expecting a boolean but got %s" % type(enable))
1810
1811    if self._use_tfrt_distributed_runtime != enable:
1812      if self._initialized:
1813        raise ValueError("use_tfrt should be set before being initialized.")
1814      self._use_tfrt_distributed_runtime = enable
1815
1816  def enable_run_metadata(self):
1817    """Enables tracing of op execution via RunMetadata.
1818
1819    To retrieve the accumulated metadata call context.export_run_metadata()
1820    and to stop tracing call context.disable_run_metadata().
1821    """
1822    self.ensure_initialized()
1823    pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
1824
1825  def disable_run_metadata(self):
1826    """Disables tracing of op execution via RunMetadata."""
1827    if not self._context_handle:
1828      return
1829    pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
1830
1831  def enable_graph_collection(self):
1832    """Enables graph collection of executed functions.
1833
1834    To retrieve the accumulated graphs call context.export_run_metadata()
1835    and to stop collecting graphs call context.disable_graph_collection().
1836    """
1837    self.ensure_initialized()
1838    pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
1839
1840  def disable_graph_collection(self):
1841    """Disables graph collection of executed functions."""
1842    if not self._context_handle:
1843      return
1844    pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
1845
1846  def export_run_metadata(self):
1847    """Returns a RunMetadata proto with accumulated information.
1848
1849    The returned protocol buffer contains information since the most recent call
1850    to either enable_run_metadata or export_run_metadata.
1851
1852    Returns:
1853      A RunMetadata protocol buffer. Or None if not enabled.
1854    """
1855    if not self._context_handle:
1856      return None
1857    with c_api_util.tf_buffer() as buffer_:
1858      pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
1859      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1860    run_metadata = config_pb2.RunMetadata()
1861    run_metadata.ParseFromString(compat.as_bytes(proto_data))
1862    return run_metadata
1863
1864  @property
1865  def context_switches(self):
1866    """Returns a stack of context switches."""
1867    return self._context_switches
1868
1869
1870class _EagerDeviceContext(object):
1871  """Context-manager forcing placement of ops and Tensors on a device."""
1872
1873  __slots__ = ["_device_name", "_ctx", "_stack"]
1874
1875  def __init__(self, ctx, device_name):
1876    self._device_name = device_name
1877    self._ctx = ctx
1878    self._stack = []
1879
1880  # TODO(b/189233748): Consolidate the device string parsing logic with
1881  # tensorflow/core/util/device_name_utils.cc.
1882  def __enter__(self):
1883    ctx = self._ctx
1884    old_device_name = ctx.device_name
1885    old_device_spec = ctx.device_spec
1886    new_device_name = self._device_name
1887    cache_key = (old_device_name, new_device_name)
1888    try:
1889      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
1890    except TypeError:
1891      # Error while trying to compute the cache key.
1892      raise ValueError("Expecting a string device name. Got %s(%s)" %
1893                       (type(new_device_name), new_device_name))
1894    except KeyError:
1895      # Handle a cache miss.
1896      if new_device_name is not None:
1897        if not isinstance(new_device_name, six.string_types):
1898          raise ValueError("Expecting a string device name. Got %s(%s)" %
1899                           (type(new_device_name), new_device_name))
1900        device_spec = pydev.DeviceSpec.from_string(new_device_name)
1901        if old_device_name:
1902          new_device_spec = copy.copy(old_device_spec)
1903        else:
1904          ctx.ensure_initialized()
1905          new_device_spec = pydev.DeviceSpec.from_string(
1906              ctx._context_devices[0])  # pylint: disable=protected-access
1907        new_device_spec = new_device_spec.make_merged_spec(device_spec)
1908      else:
1909        new_device_spec = pydev.DeviceSpec.from_string("")
1910      new_device_name = new_device_spec.to_string()
1911      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
1912
1913    ctx._set_device(new_device_name, new_device_spec)  # pylint: disable=protected-access
1914    self._stack.append((old_device_name, old_device_spec, new_device_spec))
1915
1916  def __exit__(self, *ex_info):
1917    ctx = self._ctx
1918    old_device_name, old_device_spec, new_device_spec = self._stack[-1]
1919    if ctx.device_spec is not new_device_spec:
1920      raise RuntimeError("Exiting device scope without proper scope nesting")
1921    del self._stack[-1]
1922    ctx._set_device(old_device_name, old_device_spec)  # pylint: disable=protected-access
1923
1924
1925# Do not set directly. Use _set_context.
1926_context = None
1927_context_lock = threading.Lock()
1928
1929
1930def _set_context_locked(ctx):
1931  global _context
1932  pywrap_tfe.TFE_Py_SetEagerContext(ctx)
1933  _context = ctx
1934
1935
1936def _set_context(ctx):
1937  with _context_lock:
1938    _set_context_locked(ctx)
1939
1940
1941def _create_context():
1942  with _context_lock:
1943    if _context is None:
1944      ctx = Context()
1945      _set_context_locked(ctx)
1946
1947
1948def _reset_context():
1949  """Clears and re-initializes the singleton context.
1950
1951  Should only be used for testing.
1952  """
1953  global _context
1954  global _device_parsing_cache
1955
1956  # Garbage collect and clear scalar cache to avoid Tensor from current context
1957  # polluting next context.
1958  gc.collect()
1959  pywrap_tfe.TFE_ClearScalarCache()
1960  with _context_lock:
1961    if _context is not None:
1962      _context._clear_caches()
1963      _context = None
1964  _create_context()
1965  _device_parsing_cache = {}
1966
1967
1968def context():
1969  """Returns a singleton context object."""
1970  if _context is None:
1971    _create_context()
1972  return _context
1973
1974
1975def context_safe():
1976  """Returns current context (or None if one hasn't been initialized)."""
1977  return _context
1978
1979
1980def ensure_initialized():
1981  """Initialize the context."""
1982  context().ensure_initialized()
1983
1984
1985def initialize_logical_devices():
1986  """Initialize the virtual devices."""
1987  context()._initialize_logical_devices()  # pylint: disable=protected-access
1988
1989
1990def set_global_seed(seed):
1991  """Sets the eager mode seed."""
1992  context()._set_global_seed(seed)  # pylint: disable=protected-access
1993
1994
1995def global_seed():
1996  """Returns the eager mode seed."""
1997  return context()._seed  # pylint: disable=protected-access
1998
1999
2000def internal_operation_seed():
2001  """Returns the operation seed generated based on global seed."""
2002  return context()._internal_operation_seed()  # pylint: disable=protected-access
2003
2004
2005@tf_export("executing_eagerly", v1=[])
2006def executing_eagerly():
2007  """Checks whether the current thread has eager execution enabled.
2008
2009  Eager execution is enabled by default and this API returns `True`
2010  in most of cases. However, this API might return `False` in the following use
2011  cases.
2012
2013  *  Executing inside `tf.function`, unless under `tf.init_scope` or
2014     `tf.config.run_functions_eagerly(True)` is previously called.
2015  *  Executing inside a transformation function for `tf.dataset`.
2016  *  `tf.compat.v1.disable_eager_execution()` is called.
2017
2018  General case:
2019
2020  >>> print(tf.executing_eagerly())
2021  True
2022
2023  Inside `tf.function`:
2024
2025  >>> @tf.function
2026  ... def fn():
2027  ...   with tf.init_scope():
2028  ...     print(tf.executing_eagerly())
2029  ...   print(tf.executing_eagerly())
2030  >>> fn()
2031  True
2032  False
2033
2034  Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called:
2035
2036  >>> tf.config.run_functions_eagerly(True)
2037  >>> @tf.function
2038  ... def fn():
2039  ...   with tf.init_scope():
2040  ...     print(tf.executing_eagerly())
2041  ...   print(tf.executing_eagerly())
2042  >>> fn()
2043  True
2044  True
2045  >>> tf.config.run_functions_eagerly(False)
2046
2047  Inside a transformation function for `tf.dataset`:
2048
2049  >>> def data_fn(x):
2050  ...   print(tf.executing_eagerly())
2051  ...   return x
2052  >>> dataset = tf.data.Dataset.range(100)
2053  >>> dataset = dataset.map(data_fn)
2054  False
2055
2056  Returns:
2057    `True` if the current thread has eager execution enabled.
2058  """
2059  ctx = context_safe()
2060  if ctx is None:
2061    return default_execution_mode == EAGER_MODE
2062
2063  return ctx.executing_eagerly()
2064
2065
2066@tf_export(v1=["executing_eagerly"])
2067def executing_eagerly_v1():
2068  """Checks whether the current thread has eager execution enabled.
2069
2070  Eager execution is typically enabled via
2071  `tf.compat.v1.enable_eager_execution`, but may also be enabled within the
2072  context of a Python function via tf.contrib.eager.py_func.
2073
2074  When eager execution is enabled, returns `True` in most cases. However,
2075  this API might return `False` in the following use cases.
2076
2077  *  Executing inside `tf.function`, unless under `tf.init_scope` or
2078     `tf.config.run_functions_eagerly(True)` is previously called.
2079  *  Executing inside a transformation function for `tf.dataset`.
2080  *  `tf.compat.v1.disable_eager_execution()` is called.
2081
2082  >>> tf.compat.v1.enable_eager_execution()
2083
2084  General case:
2085
2086  >>> print(tf.executing_eagerly())
2087  True
2088
2089  Inside `tf.function`:
2090
2091  >>> @tf.function
2092  ... def fn():
2093  ...   with tf.init_scope():
2094  ...     print(tf.executing_eagerly())
2095  ...   print(tf.executing_eagerly())
2096  >>> fn()
2097  True
2098  False
2099
2100  Inside `tf.function`
2101  after  `tf.config.run_functions_eagerly(True)` is called:
2102
2103  >>> tf.config.run_functions_eagerly(True)
2104  >>> @tf.function
2105  ... def fn():
2106  ...   with tf.init_scope():
2107  ...     print(tf.executing_eagerly())
2108  ...   print(tf.executing_eagerly())
2109  >>> fn()
2110  True
2111  True
2112  >>> tf.config.run_functions_eagerly(False)
2113
2114  Inside a transformation function for `tf.dataset`:
2115
2116  >>> def data_fn(x):
2117  ...   print(tf.executing_eagerly())
2118  ...   return x
2119  >>> dataset = tf.data.Dataset.range(100)
2120  >>> dataset = dataset.map(data_fn)
2121  False
2122
2123  Returns:
2124    `True` if the current thread has eager execution enabled.
2125  """
2126  return executing_eagerly()
2127
2128
2129def in_eager_mode():
2130  """Use executing_eagerly() instead. This function will be removed."""
2131  return executing_eagerly()
2132
2133
2134def shared_name(name=None):
2135  """Returns the anonymous shared name GUID if no shared name is specified.
2136
2137  In eager mode we need to use a unique shared name to avoid spurious sharing
2138  issues. The runtime generates a unique name on our behalf when the reserved
2139  GUID is used as a shared name.
2140
2141  Args:
2142    name: Optional shared name
2143
2144  Returns:
2145    Eager compatible shared name.
2146  """
2147  if name or not executing_eagerly():
2148    return name
2149
2150  # Ensure a unique name when eager execution is enabled to avoid spurious
2151  # sharing issues.
2152  return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
2153
2154
2155def graph_mode():
2156  """Context-manager to disable eager execution for the current thread."""
2157  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
2158
2159
2160# Used by b/167638505 for keras backend API and Lambda layer.
2161@tf_export("__internal__.eager_context.eager_mode", v1=[])
2162def eager_mode():
2163  """Context-manager to enable eager execution for the current thread."""
2164  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
2165
2166
2167def scope_name():
2168  """Name of the current scope."""
2169  return context().scope_name
2170
2171
2172def device(name):
2173  """Context-manager to force placement of operations and Tensors on a device.
2174
2175  Example:
2176  ```python
2177  with tf.device('gpu:0'):
2178    with tf.device('cpu:0'):
2179      shape = tf.constant([], dtype=tf.int32)
2180    x = tf.random.truncated_normal(shape, tf.float32)
2181  ```
2182  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
2183  operation runs on GPU 0.
2184
2185  Args:
2186    name: Name of the device (see context().devices()), or None to perform
2187      automatic placement.
2188
2189  Returns:
2190    Context manager for setting the device.
2191  """
2192  ensure_initialized()
2193  return context().device(name)
2194
2195
2196# Expose some properties of Context as internally public APIs (b/160348781).
2197@tf_export("__internal__.eager_context.get_config", v1=[])
2198def get_config():
2199  """Get the ConfigProto of Context.
2200
2201  Returns:
2202    The ConfigProto of Context.
2203  """
2204  return context().config
2205
2206
2207@tf_export("__internal__.eager_context.get_device_name", v1=[])
2208def get_device_name():
2209  """Get the device name for the current thread.
2210
2211  Returns:
2212    The device name for the current thread.
2213  """
2214  return context().device_name
2215
2216
2217@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
2218def set_soft_device_placement(enabled):
2219  """Set if soft device placements should be allowed.
2220
2221  Args:
2222    enabled: Whether to enable soft device placement.
2223  """
2224  context().soft_device_placement = enabled
2225
2226
2227@tf_export("__internal__.eager_context.get_executor", v1=[])
2228def get_executor():
2229  """Get the Executor of the current thread.
2230
2231  Returns:
2232    The Executor of the current thread.
2233  """
2234  return context().executor
2235
2236
2237@tf_export("debugging.get_log_device_placement")
2238def get_log_device_placement():
2239  """Get if device placements are logged.
2240
2241  Returns:
2242    If device placements are logged.
2243  """
2244  return context().log_device_placement
2245
2246
2247@tf_export("debugging.set_log_device_placement")
2248def set_log_device_placement(enabled):
2249  """Turns logging for device placement decisions on or off.
2250
2251  Operations execute on a particular device, producing and consuming tensors on
2252  that device. This may change the performance of the operation or require
2253  TensorFlow to copy data to or from an accelerator, so knowing where operations
2254  execute is useful for debugging performance issues.
2255
2256  For more advanced profiling, use the [TensorFlow
2257  profiler](https://www.tensorflow.org/guide/profiler).
2258
2259  Device placement for operations is typically controlled by a `tf.device`
2260  scope, but there are exceptions, for example operations on a `tf.Variable`
2261  which follow the initial placement of the variable. Turning off soft device
2262  placement (with `tf.config.set_soft_device_placement`) provides more explicit
2263  control.
2264
2265  >>> tf.debugging.set_log_device_placement(True)
2266  >>> tf.ones([])
2267  >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
2268  >>> with tf.device("CPU"):
2269  ...  tf.ones([])
2270  >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0
2271  >>> tf.debugging.set_log_device_placement(False)
2272
2273  Turning on `tf.debugging.set_log_device_placement` also logs the placement of
2274  ops inside `tf.function` when the function is called.
2275
2276  Args:
2277    enabled: Whether to enabled device placement logging.
2278  """
2279  context().log_device_placement = enabled
2280
2281
2282@tf_contextlib.contextmanager
2283def device_policy(policy):
2284  """Context manager for setting device placement policy for current thread."""
2285  ctx = context()
2286  old_policy = ctx.device_policy
2287  try:
2288    ctx.device_policy = policy
2289    yield
2290  finally:
2291    ctx.device_policy = old_policy
2292
2293
2294def set_execution_mode(mode):
2295  """Sets execution mode for the current thread."""
2296  context().execution_mode = mode
2297
2298
2299# TODO(fishx): remove this method.
2300@tf_contextlib.contextmanager
2301def execution_mode(mode):
2302  """Context manager for setting execution mode for current thread."""
2303  if mode is None:
2304    yield
2305  else:
2306    ctx = context()
2307    executor_new = executor.new_executor(mode == ASYNC)
2308    executor_old = ctx.executor
2309    try:
2310      executor_old.wait()
2311      ctx.executor = executor_new
2312      yield
2313    finally:
2314      ctx.executor = executor_old
2315      executor_new.wait()
2316
2317
2318@tf_contextlib.contextmanager
2319def executor_scope(e):
2320  """Context manager for changing executor for current thread.
2321
2322  Args:
2323    e: A Executor to execute eager ops under this scope. Setting it to None will
2324      switch back to use the default executor for the context.
2325
2326  Yields:
2327    Context manager for setting the executor for current thread.
2328  """
2329  ctx = context()
2330  executor_old = ctx.executor
2331  try:
2332    ctx.executor = e
2333    yield
2334  finally:
2335    ctx.executor = executor_old
2336
2337
2338@tf_export("experimental.function_executor_type")
2339@tf_contextlib.contextmanager
2340def function_executor_type(executor_type):
2341  """Context manager for setting the executor of eager defined functions.
2342
2343  Eager defined functions are functions decorated by tf.contrib.eager.defun.
2344
2345  Args:
2346    executor_type: a string for the name of the executor to be used to execute
2347      functions defined by tf.contrib.eager.defun.
2348
2349  Yields:
2350    Context manager for setting the executor of eager defined functions.
2351  """
2352  current_options = context().function_call_options
2353  old_options = copy.copy(current_options)
2354  try:
2355    current_options.executor_type = executor_type
2356    yield
2357  finally:
2358    context().function_call_options = old_options
2359
2360
2361def is_async():
2362  """Returns true if current thread is in async mode."""
2363  return context().is_async()
2364
2365
2366def num_gpus():
2367  """Get the number of available GPU devices.
2368
2369  Returns:
2370    The number of available GPU devices.
2371  """
2372  return context().num_gpus()
2373
2374
2375def enable_run_metadata():
2376  """Enables tracing of op execution via RunMetadata.
2377
2378  To retrieve the accumulated metadata call context.export_run_metadata()
2379  and to stop tracing call context.disable_run_metadata().
2380  """
2381  context().enable_run_metadata()
2382
2383
2384def disable_run_metadata():
2385  """Disables tracing of op execution via RunMetadata."""
2386  context().disable_run_metadata()
2387
2388
2389def enable_graph_collection():
2390  """Enables graph collection of executed functions.
2391
2392  To retrieve the accumulated graphs call context.export_run_metadata()
2393  and to stop collecting graphs call context.disable_graph_collection().
2394  """
2395  context().enable_graph_collection()
2396
2397
2398def disable_graph_collection():
2399  """Disables graph collection of executed functions."""
2400  context().disable_graph_collection()
2401
2402
2403def export_run_metadata():
2404  """Returns a RunMetadata proto with accumulated information.
2405
2406  The returned protocol buffer contains information since the most recent call
2407  to either enable_run_metadata or export_run_metadata.
2408
2409  Returns:
2410    A RunMetadata protocol buffer.
2411  """
2412  return context().export_run_metadata()
2413
2414
2415@contextlib.contextmanager
2416def collect_graphs(optimized=True):
2417  """Collects a flat list of pre- or post-optimization graphs.
2418
2419  The collected graphs include device placements, which can be useful for
2420  testing.
2421
2422  Usage:
2423
2424  ```
2425  @def_function.function
2426  def f(x):
2427    return x + constant_op.constant(1.)
2428
2429  with context.collect_graphs() as graphs:
2430    with ops.device("CPU:0"):
2431      f(constant_op.constant(1.))
2432
2433  graph, = graphs  # `graph` contains a single GraphDef for inspection
2434  ```
2435
2436  Args:
2437    optimized: whether to collect optimized graphs or non-optimized graphs
2438
2439  Yields:
2440    A list of GraphDefs, populated when the context manager exits.
2441  """
2442  ctx = context()
2443  ctx.enable_graph_collection()
2444  try:
2445    graphs = []
2446    yield graphs
2447    metadata = ctx.export_run_metadata()
2448  finally:
2449    ctx.disable_graph_collection()
2450  for graph in metadata.function_graphs:
2451    if optimized:
2452      graphs.append(graph.post_optimization_graph)
2453    else:
2454      graphs.append(graph.pre_optimization_graph)
2455
2456
2457def get_server_def():
2458  return context().get_server_def()
2459
2460
2461def set_server_def(server_def):
2462  context().set_server_def(server_def)
2463
2464
2465def update_server_def(server_def):
2466  context().update_server_def(server_def)
2467
2468
2469def check_alive(worker_name):
2470  return context().check_alive(worker_name)
2471
2472
2473@tf_export("experimental.async_scope")
2474@tf_contextlib.contextmanager
2475def async_scope():
2476  """Context manager for grouping async operations.
2477
2478  Ops/function calls inside the scope can return before finishing the actual
2479  execution. When exiting the async scope, a synchronization barrier will be
2480  automatically added to ensure the completion of all async op and function
2481  execution, potentially raising exceptions if async execution results in
2482  an error state.
2483
2484  Users may write the following code to asynchronously invoke `train_step_fn`
2485  and log the `loss` metric for every `num_steps` steps in a training loop.
2486  `train_step_fn` internally consumes data using `iterator.get_next()`, and may
2487  throw OutOfRangeError when running out of data. In the case:
2488
2489  ```
2490  try:
2491    with tf.experimental.async_scope():
2492      for _ in range(num_steps):
2493        # Step function updates the metric `loss` internally
2494        train_step_fn()
2495  except tf.errors.OutOfRangeError:
2496    tf.experimental.async_clear_error()
2497  logging.info('loss = %s', loss.numpy())
2498  ```
2499
2500  Yields:
2501    Context manager for grouping async operations.
2502  """
2503  # TODO(haoyuzhang): replace env var once we have a config method to turn on
2504  # and off async streaming RPC
2505  remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
2506  old_policy = os.environ.get(remote_async_env_var)
2507  try:
2508    os.environ[remote_async_env_var] = str(True)
2509    yield
2510    # Note: sync local and remote executors iff the async block does not raise
2511    # an exception. Triggering sync after an exception may lead to derived
2512    # runtime errors and unexpected exception types.
2513    context().sync_executors()
2514  finally:
2515    if old_policy is None:
2516      del os.environ[remote_async_env_var]
2517    else:
2518      os.environ[remote_async_env_var] = old_policy
2519
2520
2521def async_wait():
2522  """Sync all async operations and raise any errors during execution.
2523
2524  In async execution mode, an op/function call can return before finishing the
2525  actual execution. Calling this method creates a synchronization barrier for
2526  all async op and function execution. It only returns when all pending nodes
2527  are finished, potentially raising exceptions if async execution results in
2528  an error state. It is a no-op if the context is not initialized.
2529  """
2530  if context()._context_handle is not None:  # pylint: disable=protected-access
2531    context().sync_executors()
2532
2533
2534@tf_export("experimental.async_clear_error")
2535def async_clear_error():
2536  """Clear pending operations and error statuses in async execution.
2537
2538  In async execution mode, an error in op/function execution can lead to errors
2539  in subsequent ops/functions that are scheduled but not yet executed. Calling
2540  this method clears all pending operations and reset the async execution state.
2541
2542  Example:
2543
2544  ```
2545  while True:
2546    try:
2547      # Step function updates the metric `loss` internally
2548      train_step_fn()
2549    except tf.errors.OutOfRangeError:
2550      tf.experimental.async_clear_error()
2551      break
2552  logging.info('loss = %s', loss.numpy())
2553  ```
2554  """
2555  context().clear_executor_errors()
2556
2557
2558def add_function(fdef):
2559  """Add a function definition to the context."""
2560  context().add_function(fdef)
2561
2562
2563def remove_function(name):
2564  """Remove a function from the context."""
2565  context().remove_function(name)
2566
2567
2568def get_function_def(name):
2569  return context().get_function_def(name)
2570
2571
2572def register_custom_device(device_capsule, device_name, device_info_capsule):
2573  """Calls TFE_RegisterCustomDevice to register a custom device with Python.
2574
2575  Enables using C extensions specifying a custom device from Python. See the
2576  experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2577  details.
2578
2579  Note that custom devices are not currently supported inside `tf.function`s.
2580
2581  Args:
2582    device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
2583      containing a pointer to a TFE_CustomDevice struct. The capsule retains
2584      ownership of the memory.
2585    device_name: A string indicating the name to register the custom device
2586      under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
2587      subsequently be passed to `with tf.device(...):`.
2588    device_info_capsule: A PyCapsule with the name set to
2589      'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
2590      struct with the initial state of the custom device (the void* device_info
2591      argument to TFE_RegisterCustomDevice). This method takes ownership of the
2592      memory and clears the capsule destructor.
2593  """
2594  context().register_custom_device(device_capsule, device_name,
2595                                   device_info_capsule)
2596
2597
2598# Not every user creates a Context via context.context()
2599# (for example, enable_eager_execution in python/framework/ops.py),
2600# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
2601# in_graph_mode are both parameterless functions.
2602def _tmp_in_graph_mode():
2603  if context_safe() is None:
2604    # Context not yet initialized. Assume graph mode following the
2605    # default implementation in `is_in_graph_mode`.
2606    return True
2607  return not executing_eagerly()
2608
2609
2610is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode
2611