• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
16
17import collections
18import contextlib
19import copy
20import gc
21import os
22import random
23import threading
24
25from absl import logging
26import numpy as np
27
28from tensorflow.core.framework import function_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.core.protobuf import coordination_config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python import pywrap_tfe
33from tensorflow.python import tf2
34from tensorflow.python.client import pywrap_tf_session
35from tensorflow.python.eager import executor
36from tensorflow.python.eager import monitoring
37from tensorflow.python.framework import c_api_util
38from tensorflow.python.framework import device as pydev
39from tensorflow.python.framework import tfrt_utils
40from tensorflow.python.util import compat
41from tensorflow.python.util import is_in_graph_mode
42from tensorflow.python.util import tf_contextlib
43from tensorflow.python.util.deprecation import deprecated
44from tensorflow.python.util.tf_export import tf_export
45
46GRAPH_MODE = 0
47EAGER_MODE = 1
48
49default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
50
51# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
52# new_device_spec).
53# Note that we do not protect this with a lock and instead rely on python's GIL
54# and the idempotent nature of writes to provide thread safety.
55_device_parsing_cache = {}
56_starting_device_spec = pydev.DeviceSpec.from_string("")
57
58_MAXINT32 = 2**31 - 1
59
60DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
61DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
62DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
63DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
64    pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
65
66SYNC = 0
67ASYNC = 1
68
69_KEEP_ALIVE_SECS = 600
70
71_python_eager_context_create_counter = monitoring.Counter(
72    "/tensorflow/api/python/eager_context_create_counter",
73    "Counter for number of eager contexts created in Python.")
74
75# Re-exporting through context.
76is_tfrt_enabled = tfrt_utils.enabled
77
78# This flag and the associated environment var are transient and will eventually
79# be removed, once this experiment is enabled by default.
80_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1"
81
82
83def run_eager_op_as_function_enabled():
84  return True
85
86
87# This method should only be called after the context has beein initialized.
88def enable_jit_compile_rewrite():
89  """Run jit_compile functions through rewrite pass.
90
91  This runs jit_compile functions through all of the multidevice function
92  rewrite passes.
93  """
94  global _JIT_COMPILE_REWRITE_ENABLED
95  _JIT_COMPILE_REWRITE_ENABLED = True
96  if context_safe() is not None:
97    context_safe().jit_compile_rewrite = True
98
99
100# This method should only be called after the context has been initialized.
101def disable_jit_compile_rewrite():
102  global _JIT_COMPILE_REWRITE_ENABLED
103  _JIT_COMPILE_REWRITE_ENABLED = False
104  if context_safe() is not None:
105    context_safe().jit_compile_rewrite = False
106
107
108def jit_compile_rewrite_enabled():
109  if context_safe() is not None:
110    return context_safe().jit_compile_rewrite
111  return _JIT_COMPILE_REWRITE_ENABLED
112
113
114# Expose it as internally public APIs for Keras use cases in b/171080602.
115tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
116
117
118class _EagerTensorCache(object):
119  """Simple cache which evicts items based on length in a FIFO manner."""
120
121  __slots__ = ["_data", "_max_items", "_max_tensor_size"]
122
123  def __init__(self, max_items=256, max_tensor_size=10000):
124    self._data = collections.OrderedDict()
125    self._max_items = max_items
126    self._max_tensor_size = max_tensor_size
127
128  def put(self, key, value):
129    if value._num_elements() > self._max_tensor_size:  # pylint: disable=protected-access
130      return
131
132    self._data[key] = value
133
134    if len(self._data) > self._max_items:
135      self._data.popitem(last=False)
136
137  def get(self, key):
138    return self._data.get(key, None)
139
140  def flush(self):
141    self._data.clear()
142
143
144class FunctionCallOptions:
145  """Options applied at call sites of eager functions.
146
147  Eager functions are functions decorated with tf.contrib.eager.defun.
148  """
149
150  __slots__ = ["_config_proto_serialized", "_executor_type"]
151
152  def __init__(self, executor_type=None, config_proto=None):
153    """Constructor.
154
155    Args:
156      executor_type: (optional) name of the executor to be used to execute the
157        eager function. If None or an empty string, the default Tensorflow
158        executor will be used.
159      config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized
160        string of that proto. The config used by Grappler when optimizing the
161        function graph. Each concrete function is optimized the first time is
162        called. Changing config_proto after the first call has no effect. If
163        config_proto is None, an empty RewriterConfig will be used.
164    """
165    self.config_proto_serialized = config_proto
166    self.executor_type = executor_type
167
168  @property
169  def executor_type(self):
170    return self._executor_type
171
172  @executor_type.setter
173  def executor_type(self, executor_type):
174    self._executor_type = executor_type
175
176  @property
177  def config_proto_serialized(self):
178    return self._config_proto_serialized
179
180  @config_proto_serialized.setter
181  def config_proto_serialized(self, config):
182    if isinstance(config, config_pb2.ConfigProto):
183      self._config_proto_serialized = config.SerializeToString(
184          deterministic=True)
185    elif isinstance(config, str):
186      self._config_proto_serialized = config
187    elif config is None:
188      self._config_proto_serialized = (
189          config_pb2.ConfigProto().SerializeToString())
190    else:
191      raise ValueError("the rewriter config must be either a "
192                       "config_pb2.ConfigProto, or a serialized string of that "
193                       "proto or None. got: {}".format(type(config)))
194
195
196# Map from context_id (an int) to _TensorCaches.
197# Dicts are thread safe in CPython.
198# TODO(iga): Remove this once TensorCaches are moved to C++.
199_tensor_caches_map = {}
200
201
202class _TensorCaches(threading.local):
203  """Thread local tensor caches."""
204
205  __slots__ = ["_ones_rank_cache", "_zeros_cache"]
206
207  def __init__(self):
208    super().__init__()
209    self._ones_rank_cache = None
210    self._zeros_cache = None
211
212  @property
213  def ones_rank_cache(self):
214    if not self._ones_rank_cache:
215      self._ones_rank_cache = _EagerTensorCache()
216    return self._ones_rank_cache
217
218  @property
219  def zeros_cache(self):
220    if not self._zeros_cache:
221      self._zeros_cache = _EagerTensorCache()
222    return self._zeros_cache
223
224
225ContextSwitch = collections.namedtuple(
226    "ContextSwitch",
227    ["is_building_function", "enter_context_fn", "device_stack"])
228
229
230# `_ContextSwitchStack` is a `threading.local` to match the semantics of
231# ``DefaultGraphStack`, which is also a `threading.local`.
232class _ContextSwitchStack(threading.local):
233  """A thread-local stack of context switches."""
234
235  def __init__(self, eager):
236    super().__init__()
237    self.stack = []
238    if eager:
239      # Initialize the stack with a pointer to enter the eager context; this
240      # ensures that the fact that eager execution was enabled is propagated
241      # across threads, since (1) `enable_eager_execution` modifies a
242      # process-level flag (`default_execution_mode`) and (2) `__init__` is
243      # called each time a threading.local object is used in a separate thread.
244      self.push(
245          is_building_function=False,
246          enter_context_fn=eager_mode,
247          device_stack=None)
248
249  def push(self, is_building_function, enter_context_fn, device_stack):
250    """Push metadata about a context switch onto the stack.
251
252    A context switch can take any one of the two forms: installing a graph as
253    the default graph, or entering the eager context. For each context switch,
254    we record whether or not the entered context is building a function.
255
256    Args:
257      is_building_function: (bool.) Whether the context is building a function.
258      enter_context_fn: (function.) A callable that executes the context switch.
259        For example, `graph.as_default` or `eager_mode`.
260      device_stack: If applicable, the device function stack for this graph.
261        When breaking out of graphs in init_scope, the innermost nonempty device
262        stack is used. Eager contexts put `None` here and the value is never
263        used.
264    """
265
266    self.stack.append(
267        ContextSwitch(is_building_function, enter_context_fn, device_stack))
268
269  def pop(self):
270    """Pop the stack."""
271
272    self.stack.pop()
273
274
275@tf_export("config.LogicalDevice")
276class LogicalDevice(
277    collections.namedtuple("LogicalDevice", ["name", "device_type"])):
278  """Abstraction for a logical device initialized by the runtime.
279
280  A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
281  `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
282  and operations can be placed on a specific logical device by calling
283  `tf.device` with a specified `tf.config.LogicalDevice`.
284
285  Fields:
286    name: The fully qualified name of the device. Can be used for Op or function
287      placement.
288    device_type: String declaring the type of device such as "CPU" or "GPU".
289  """
290  pass
291
292
293@tf_export("config.LogicalDeviceConfiguration",
294           "config.experimental.VirtualDeviceConfiguration")
295class LogicalDeviceConfiguration(
296    collections.namedtuple("LogicalDeviceConfiguration", [
297        "memory_limit", "experimental_priority", "experimental_device_ordinal"
298    ])):
299  """Configuration class for a logical devices.
300
301  The class specifies the parameters to configure a `tf.config.PhysicalDevice`
302  as it is initialized to a `tf.config.LogicalDevice` during runtime
303  initialization. Not all fields are valid for all device types.
304
305  See `tf.config.get_logical_device_configuration` and
306  `tf.config.set_logical_device_configuration` for usage examples.
307
308  Fields:
309    memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
310      device. Currently only supported for GPUs.
311    experimental_priority: (optional) Priority to assign to a virtual device.
312      Lower values have higher priorities and 0 is the default.
313      Within a physical GPU, the GPU scheduler will prioritize ops on virtual
314      devices with higher priority. Currently only supported for Nvidia GPUs.
315    experimental_device_ordinal: (optional) Ordinal number to order the virtual
316    device.
317      LogicalDevice with lower ordinal number will receive a lower device id.
318      Physical device id and location in the list is used to break ties.
319      Currently only supported for Nvidia GPUs.
320  """
321
322  def __new__(cls,
323              memory_limit=None,
324              experimental_priority=None,
325              experimental_device_ordinal=0):
326    return super().__new__(cls, memory_limit, experimental_priority,
327                           experimental_device_ordinal)
328
329
330@tf_export("config.PhysicalDevice")
331class PhysicalDevice(
332    collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
333  """Abstraction for a locally visible physical device.
334
335  TensorFlow can utilize various devices such as the CPU or multiple GPUs
336  for computation. Before initializing a local device for use, the user can
337  customize certain properties of the device such as it's visibility or memory
338  configuration.
339
340  Once a visible `tf.config.PhysicalDevice` is initialized one or more
341  `tf.config.LogicalDevice` objects are created. Use
342  `tf.config.set_visible_devices` to configure the visibility of a physical
343  device and `tf.config.set_logical_device_configuration` to configure multiple
344  `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
345  useful when separation between models is needed or to simulate a multi-device
346  environment.
347
348  Fields:
349    name: Unique identifier for device.
350    device_type: String declaring the type of device such as "CPU" or "GPU".
351  """
352  pass
353
354
355class _AtomicCounter(object):
356  """A simple atomic counter."""
357
358  __slots__ = ["_value", "_lock"]
359
360  def __init__(self):
361    self._value = 0
362    self._lock = threading.Lock()
363
364  def increment_and_get(self):
365    with self._lock:
366      self._value += 1
367      return self._value
368
369
370_context_id_counter = _AtomicCounter()
371
372
373class _TensorCacheDeleter(object):
374  """Deletes tensor caches for a given context."""
375
376  __slots__ = ["_context_id"]
377
378  def __init__(self, context_id):
379    self._context_id = context_id
380
381  def __del__(self):
382    if _tensor_caches_map is None:
383      return
384    if self._context_id in _tensor_caches_map:
385      del _tensor_caches_map[self._context_id]
386
387
388# TODO(agarwal): rename to EagerContext / EagerRuntime ?
389# TODO(agarwal): consider keeping the corresponding Graph here.
390class Context:
391  """Environment in which eager operations execute."""
392
393  # TODO(agarwal): create and link in some documentation for `execution_mode`.
394  # pylint: disable=redefined-outer-name
395  def __init__(self,
396               config=None,
397               device_policy=None,
398               execution_mode=None,
399               server_def=None):
400    """Creates a new Context.
401
402    Args:
403      config: (Optional.) A `ConfigProto` protocol buffer with configuration
404        options for the Context. Note that a lot of these options may be
405        currently unimplemented or irrelevant when eager execution is enabled.
406      device_policy: (Optional.) What policy to use when trying to run an
407        operation on a device with inputs which are not on that device. When set
408        to None, an appropriate value will be picked automatically. The value
409        picked may change between TensorFlow releases.  Defaults to
410        DEVICE_PLACEMENT_SILENT.
411        Valid values:
412        - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
413          correct.
414        - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
415          device but raises a warning.
416        - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
417          performance problems.
418        - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
419          raising errors on the other ones.
420      execution_mode: (Optional.) Policy controlling how operations dispatched
421        are actually executed. When set to None, an appropriate value will be
422        picked automatically. The value picked may change between TensorFlow
423        releases.
424        Valid values:
425        - SYNC: executes each operation synchronously.
426        - ASYNC: executes each operation asynchronously. These operations may
427          return "non-ready" handles.
428      server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
429        on remote devices. GrpcServers need to be started by creating an
430        identical server_def to this, and setting the appropriate task_indexes,
431        so that the servers can communicate. It will then be possible to execute
432        operations on remote devices.
433
434    Raises:
435     ValueError: If execution_mode is not valid.
436    """
437    # This _id is used only to index the tensor caches.
438    # TODO(iga): Remove this when tensor caches are moved to C++.
439    self._id = _context_id_counter.increment_and_get()
440    self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
441    _tensor_caches_map[self._id] = _TensorCaches()
442
443    self._config = config
444    self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
445        self,
446        is_eager=lambda: default_execution_mode == EAGER_MODE,
447        device_spec=_starting_device_spec)
448    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
449    self._context_handle = None
450    self._context_devices = None
451    self._seed = None
452    self._initialize_lock = threading.Lock()
453    self._initialized = False
454    if device_policy is None:
455      device_policy = DEVICE_PLACEMENT_SILENT
456    self._device_policy = device_policy
457    self._mirroring_policy = None
458    if execution_mode not in (None, SYNC, ASYNC):
459      raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
460                       execution_mode)
461    if execution_mode is None:
462      execution_mode = SYNC
463    self._default_is_async = execution_mode == ASYNC
464    self._use_tfrt = is_tfrt_enabled()
465    self._use_tfrt_distributed_runtime = None
466    self._jit_compile_rewrite = jit_compile_rewrite_enabled()
467    self._server_def = server_def
468    self._collective_ops_server_def = None
469    self._collective_leader = None
470    self._collective_scoped_allocator_enabled_ops = None
471    self._collective_use_nccl_communication = None
472    self._collective_device_filters = None
473    self._coordination_service_config = None
474
475    self._device_lock = threading.Lock()
476    self._physical_devices = None
477    self._physical_device_to_index = None
478    self._pluggable_devices = None
479    self._visible_device_list = []
480    self._memory_growth_map = None
481    self._virtual_device_map = {}
482
483    # Values set after construction
484    self._optimizer_jit = None
485    self._intra_op_parallelism_threads = None
486    self._inter_op_parallelism_threads = None
487    self._soft_device_placement = None
488    self._log_device_placement = None
489    self._operation_timeout_in_ms = None
490    self._enable_mlir_graph_optimization = None
491    self._optimizer_experimental_options = {}
492
493    _python_eager_context_create_counter.get_cell().increase_by(1)
494
495    self._is_global_context = False
496
497  # pylint: enable=redefined-outer-name
498
499  def _set_global_seed(self, seed):
500    """Set a global eager mode seed for random ops."""
501    self._seed = seed
502    # `random.Random(seed)` needs `seed` to be hashable, while values of type
503    # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
504    # to int.
505    try:
506      hash(seed)
507    except TypeError:
508      seed = int(np.array(seed))
509    self._rng = random.Random(seed)
510    # Also clear the kernel cache, to reset any existing seeds
511    if self._context_handle is not None:
512      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
513
514  def _internal_operation_seed(self):
515    """Returns a fake operation seed.
516
517      In eager mode, user shouldn't set or depend on operation seed.
518      Here, we generate a random seed based on global seed to make
519      operation's randomness different and depend on the global seed.
520
521    Returns:
522      A fake operation seed based on global seed.
523    """
524    return self._rng.randint(0, _MAXINT32)
525
526  def _initialize_logical_devices(self):
527    """Helper to initialize devices."""
528    # Store list of devices
529    logical_devices = []
530    context_devices = []
531    device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
532    try:
533      self._num_gpus = 0
534      current_job, current_task = None, None
535      server_def = self._server_def or self._collective_ops_server_def
536      if server_def is not None:
537        current_job, current_task = server_def.job_name, server_def.task_index
538      for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
539        dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
540        context_devices.append(pydev.canonical_name(dev_name))
541        spec = pydev.DeviceSpec.from_string(dev_name)
542        # If the job is localhost, we assume that the cluster has not yet been
543        # configured and thus clear the job, replica & task.
544        if spec.job == "localhost":
545          spec = spec.replace(job=None, replica=None, task=None)
546        logical_devices.append(
547            LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
548        dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
549        if (dev_type == "GPU" and spec.job == current_job and
550            spec.task == current_task):
551          self._num_gpus += 1
552
553    finally:
554      self._logical_devices = logical_devices
555      self._context_devices = context_devices
556      pywrap_tfe.TF_DeleteDeviceList(device_list)
557
558  def ensure_initialized(self):
559    """Initialize handle and devices if not already done so."""
560    if self._initialized:
561      return
562    with self._initialize_lock:
563      if self._initialized:
564        return
565      assert self._context_devices is None
566      opts = pywrap_tfe.TFE_NewContextOptions()
567      try:
568        config_str = self.config.SerializeToString()
569        pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
570        if self._device_policy is not None:
571          pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
572              opts, self._device_policy)
573        if self._mirroring_policy is not None:
574          pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
575              opts, self._mirroring_policy)
576        if self._default_is_async == ASYNC:
577          pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
578        if self._use_tfrt is not None:
579          pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
580        # pylint: disable=g-backslash-continuation
581        if self._use_tfrt is not None and \
582            self._use_tfrt_distributed_runtime is not None:
583          pywrap_tfe.TFE_ContextOptionsSetTfrtDistributedRuntime(
584              opts, self._use_tfrt_distributed_runtime)
585        pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
586        pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
587            opts, self._jit_compile_rewrite)
588        context_handle = pywrap_tfe.TFE_NewContext(opts)
589      finally:
590        pywrap_tfe.TFE_DeleteContextOptions(opts)
591      assert not (self._server_def and self._collective_ops_server_def), (
592          "Cannot enable remote execution as well as collective ops at the "
593          "moment. If this is important to you, please file an issue.")
594      if self._server_def is not None:
595        server_def_str = self._server_def.SerializeToString()
596        pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
597                                           server_def_str)
598      elif self._collective_ops_server_def is not None:
599        server_def_str = self._collective_ops_server_def.SerializeToString()
600        pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
601
602      self._context_handle = context_handle
603      self._initialize_logical_devices()
604      self._initialized = True
605
606      if self._is_global_context:
607        pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
608
609  def mark_as_global_context(self):
610    # If the context was already initialized, publish it. Otherwise wait with
611    # publication until it's initialized.
612    if self._initialized:
613      pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
614    self._is_global_context = True
615
616  def _clear_caches(self):
617    self.ones_rank_cache().flush()
618    self.zeros_cache().flush()
619    pywrap_tfe.TFE_ClearScalarCache()
620
621  def get_server_def(self):
622    return self._server_def
623
624  def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
625    """Allow setting a server_def on the context.
626
627    When a server def is replaced, it effectively clears a bunch of caches
628    within the context. If you attempt to use a tensor object that was pointing
629    to a tensor on the remote device, it will raise an error.
630
631    Args:
632      server_def: A tensorflow::ServerDef proto. Enables execution on remote
633        devices.
634      keep_alive_secs: Num. seconds after which the remote end will hang up. As
635        long as the client is still alive, the server state for the context will
636        be kept alive. If the client is killed (or there is some failure), the
637        server will clean up its context keep_alive_secs after the final RPC it
638        receives.
639
640    Raises:
641      ValueError: if server_def is None.
642    """
643    if not server_def:
644      raise ValueError("server_def is None.")
645
646    self._server_def = server_def
647
648    if self._context_handle:
649      server_def_str = server_def.SerializeToString()
650      pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
651                                         server_def_str)
652      self._initialize_logical_devices()
653
654    # Clear all the caches in case there are remote tensors in them.
655    self._clear_caches()
656
657  def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
658    """Update a server_def on the context.
659
660    Args:
661      server_def: A tensorflow::ServerDef proto. Enables execution on remote
662        devices.
663      keep_alive_secs: Num. seconds after which the remote end will hang up. As
664        long as the client is still alive, the server state for the context will
665        be kept alive. If the client is killed (or there is some failure), the
666        server will clean up its context keep_alive_secs after the final RPC it
667        receives.
668
669    Raises:
670      ValueError: if server_def is None.
671    """
672    if not server_def:
673      raise ValueError("server_def is None.")
674
675    self._server_def = server_def
676
677    if self._context_handle:
678      server_def_str = server_def.SerializeToString()
679      pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
680                                            keep_alive_secs, server_def_str)
681      self._initialize_logical_devices()
682
683    self._clear_caches()
684
685  def check_alive(self, worker_name):
686    """Checks whether a remote worker is alive or not.
687
688    Args:
689      worker_name: a string representing the remote worker. It must be a fully
690      specified name like "/job:worker/replica:0/task:0".
691
692    Returns:
693      a boolean indicating whether the remote worker is alive or not.
694
695    Raises:
696      ValueError: if context is not initialized.
697    """
698    # TODO(yuefengz): support checking multiple workers.
699    if self._context_handle:
700      return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
701    else:
702      raise ValueError("Context is not initialized.")
703
704  def sync_executors(self):
705    """Sync both local executors and the ones on remote workers.
706
707    In async execution mode, local function calls can return before the
708    corresponding remote op/function execution requests are completed. Calling
709    this method creates a synchronization barrier for remote executors. It only
710    returns when all remote pending nodes are finished, potentially with errors
711    if any remote executors are in error state.
712
713    Raises:
714      ValueError: if context is not initialized.
715    """
716    if self._context_handle:
717      pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
718    else:
719      raise ValueError("Context is not initialized.")
720
721  def clear_executor_errors(self):
722    """Clear errors in both local executors and remote workers.
723
724    After receiving errors from remote workers, additional requests on the fly
725    could further taint the status on the remote workers due to the async nature
726    of remote execution. Calling this method block on waiting for all pending
727    nodes in remote executors to finish and clear their error statuses.
728
729    Raises:
730      ValueError: if context is not initialized.
731    """
732    if self._context_handle:
733      pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
734    else:
735      raise ValueError("Context is not initialized.")
736
737  def configure_coordination_service(self,
738                                     service_type,
739                                     service_leader="",
740                                     enable_health_check=True,
741                                     cluster_register_timeout_in_ms=0,
742                                     heartbeat_timeout_in_ms=0,
743                                     coordinated_jobs=None):
744    """Enable distributed coordination service with specified configs."""
745    if self._context_handle:
746      logging.warning("Configuring coordination service type may not be "
747                      "effective because the context is already initialized.")
748    config = coordination_config_pb2.CoordinationServiceConfig()
749    config.service_type = service_type
750    if service_leader:
751      config.service_leader = pydev.canonical_name(service_leader)
752    config.enable_health_check = enable_health_check
753    config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms
754    config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms
755    if coordinated_jobs is not None:
756      if isinstance(coordinated_jobs, list):
757        config.coordinated_jobs.extend(coordinated_jobs)
758      else:
759        raise ValueError("`coordinated_jobs` must be a list of job names or "
760                         "None, but got: %s" % (coordinated_jobs,))
761    self._coordination_service_config = config
762
763  @property
764  def coordination_service(self):
765    return self._coordination_service_config
766
767  def set_config_key_value(self, key, value):
768    ensure_initialized()
769    pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value)
770
771  def get_config_key_value(self, key):
772    ensure_initialized()
773    with c_api_util.tf_buffer() as buffer_:
774      pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, buffer_)
775      value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
776    return value
777
778  def delete_config_key_value(self, key):
779    ensure_initialized()
780    pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key)
781
782  def report_error_to_cluster(self, error_code, error_message):
783    """Report error to other members in a multi-client cluster.
784
785    Args:
786      error_code: a `tf.errors` error code.
787      error_message: a string. The error message.
788    """
789    if self._context_handle:
790      pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code,
791                                          error_message)
792    else:
793      raise ValueError("Context is not initialized.")
794
795  def clear_kernel_cache(self):
796    """Clear kernel cache and reset all stateful kernels."""
797    if self._context_handle is not None:
798      pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
799
800  def enable_collective_ops(self, server_def):
801    """Enable distributed collective ops with an appropriate server_def.
802
803    Args:
804      server_def: A tensorflow::ServerDef proto. Enables execution on remote
805        devices.
806
807    Raises:
808      ValueError: if server_def is None.
809      RuntimeError: if this method is not called at program startup.
810    """
811    if not server_def:
812      raise ValueError("server_def is None.")
813
814    self._collective_ops_server_def = server_def
815
816    # TODO(b/129298253): Allow creating datasets/tensors before enabling
817    # collective ops.
818    if self._context_handle is not None:
819      logging.warning("Enabling collective ops after program startup may cause "
820                      "error when accessing previously created tensors.")
821      with self._initialize_lock:
822        assert self._initialized
823        server_def_str = self._collective_ops_server_def.SerializeToString()
824        pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
825        self._initialize_logical_devices()
826        self._clear_caches()
827
828  def configure_collective_ops(
829      self,
830      collective_leader="",
831      scoped_allocator_enabled_ops=("CollectiveReduce",),
832      use_nccl_communication=False,
833      device_filters=None):
834    """Configure collective ops.
835
836      Collective group leader is necessary for collective ops to run, other
837      configurations are mainly for the purpose of performance.
838
839    Args:
840      collective_leader: a device string for collective leader, e.g.
841        "/job:worker/replica:0/task:0"; empty string means local execution of
842          collective ops.
843      scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
844        allocator to run with.
845      use_nccl_communication: whether to use nccl communication for collective
846        ops.
847      device_filters: a tuple or a list of device strings. If set, corresponding
848        task can only see the devices filtered by these device filters.
849
850    Raises:
851      RuntimeError: if this method is not called at program startup.
852    """
853    if self._collective_leader is not None:
854      if (self._collective_leader != collective_leader or
855          self._collective_scoped_allocator_enabled_ops !=
856          scoped_allocator_enabled_ops or
857          self._collective_use_nccl_communication != use_nccl_communication or
858          self._collective_device_filters != device_filters):
859        raise ValueError("Collective ops are already configured.")
860      else:
861        return
862
863    if self._context_handle is not None:
864      raise RuntimeError("Collective ops must be configured at program startup")
865
866    self._collective_leader = collective_leader
867    self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
868    self._collective_use_nccl_communication = use_nccl_communication
869    self._collective_device_filters = device_filters
870
871  def abort_collective_ops(self, code, message):
872    """Abort the collective ops.
873
874    This is intended to be used when a peer failure is detected, which allows
875    the user to handle the case instead of hanging. This aborts all on-going
876    collectives. After all subsequent collectives error immediately, and you
877    need to reset_context() to use collectives again.
878
879    Args:
880      code: a `tf.errors` error code.
881      message: a string. The error message.
882    """
883    self.ensure_initialized()
884    pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
885
886  def check_collective_ops_peer_health(self, task, timeout_in_ms):
887    """Check collective peer health.
888
889    This probes each task to see if they're still alive. Note that restarted
890    tasks are considered a different one, and they're considered not healthy.
891
892    This should only be used in multi client multi worker training.
893
894    Args:
895      task: a task string, must be in the format of /job:xxx/replica:0/task:N.
896      timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
897
898    Raises:
899      tf.errors.UnavailableError: when a peer is down.
900      tf.errors.FailedPreconditionError: when a peer is a different one from the
901        one this task has talked to, e.g. the peer has restarted.
902      tf.errors.InvalidArgumentError: when the task string is invalid.
903    """
904    self.ensure_initialized()
905    pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
906                                                timeout_in_ms)
907
908  @property
909  def _handle(self):
910    if self._context_handle is None:
911      raise AssertionError("Context must be initialized first.")
912
913    return self._context_handle
914
915  @property
916  def _devices(self):
917    if self._context_devices is None:
918      raise AssertionError("Context must be initialized first.")
919
920    return self._context_devices
921
922  def __str__(self):
923    if self._context_handle is None:
924      return "Eager TensorFlow Context. Devices currently uninitialized."
925    else:
926      devices = self._devices
927      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
928      for i, d in enumerate(devices):
929        lines.append("   Device %d: %s" % (i, d))
930      return "\n".join(lines)
931
932  @tf_contextlib.contextmanager
933  def _mode(self, mode):
934    """A context manager to allow setting the mode to EAGER/GRAPH."""
935    ctx = self._thread_local_data
936    old_is_eager = ctx.is_eager
937    ctx.is_eager = mode == EAGER_MODE
938    if mode == EAGER_MODE:
939      # Entering graph mode does not provide us with sufficient information to
940      # record a context switch; graph-based context switches are only logged
941      # when a graph is registered as the default graph.
942      self.context_switches.push(False, eager_mode, None)
943    try:
944      yield
945    finally:
946      ctx.is_eager = old_is_eager
947      if mode == EAGER_MODE:
948        self.context_switches.pop()
949
950  def executing_eagerly(self):
951    """Returns True if current thread has eager executing enabled."""
952    return self._thread_local_data.is_eager
953
954  def ones_rank_cache(self):
955    """Per-device cache for scalars."""
956    return _tensor_caches_map[self._id].ones_rank_cache
957
958  def zeros_cache(self):
959    """Per-device cache for scalars."""
960    return _tensor_caches_map[self._id].zeros_cache
961
962  @property
963  def scope_name(self):
964    """Returns scope name for the current thread."""
965    return self._thread_local_data.scope_name
966
967  @scope_name.setter
968  def scope_name(self, s):
969    """Sets scope name for the current thread."""
970    self._thread_local_data.scope_name = s
971
972  @property
973  def device_name(self):
974    """Returns the device name for the current thread."""
975    return self._thread_local_data.device_name
976
977  @property
978  def device_spec(self):
979    """Returns the device spec for the current thread."""
980    return self._thread_local_data.device_spec
981
982  def _set_device(self, device_name, device_spec):
983    self._thread_local_data.device_name = device_name
984    self._thread_local_data.device_spec = device_spec
985
986  def device(self, name):
987    """Context-manager to force placement of operations and Tensors on a device.
988
989    Args:
990      name: Name of the device or None to get default placement.
991
992    Returns:
993      Context manager that forces device placement.
994
995    Raises:
996      ValueError: If name is not a string or is an invalid device name.
997      RuntimeError: If device scopes are not properly nested.
998    """
999    if isinstance(name, LogicalDevice):
1000      name = name.name
1001    elif pydev.is_device_spec(name):
1002      name = name.to_string()
1003    return _EagerDeviceContext(self, name)
1004
1005  def devices(self):
1006    """List of the names of devices available to execute operations."""
1007    return self._devices
1008
1009  def host_address_space(self):
1010    self.ensure_initialized()
1011    with c_api_util.tf_buffer() as buffer_:
1012      pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
1013      address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
1014    return address_space
1015
1016  # TODO(fishx): remove this property.
1017  @property
1018  def execution_mode(self):
1019    """Gets execution mode for current thread."""
1020    return ASYNC if self.is_async() else SYNC
1021
1022  @execution_mode.setter
1023  def execution_mode(self, mode):
1024    """Sets execution mode for current thread."""
1025    if mode not in (None, SYNC, ASYNC):
1026      raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" %
1027                       mode)
1028
1029    if mode is None:
1030      mode = SYNC
1031
1032    enable_async = (mode == ASYNC)
1033    if self.is_async() != enable_async:
1034      # Only set the execution mode if the context has already been initialized
1035      if self._context_handle is not None:
1036        self.executor.wait()
1037        executor_new = executor.new_executor(enable_async)
1038        self._thread_local_data.executor = executor_new
1039        pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
1040                                                   executor_new.handle())
1041      else:
1042        self._default_is_async = enable_async
1043
1044  def is_async(self):
1045    if self._context_handle is not None:
1046      return self.executor.is_async()
1047    else:
1048      return self._default_is_async
1049
1050  @property
1051  def executor(self):
1052    self.ensure_initialized()
1053    return executor.Executor(
1054        pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
1055
1056  @executor.setter
1057  def executor(self, e):
1058    self.ensure_initialized()
1059    pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
1060
1061  @property
1062  def config(self):
1063    """Return the ConfigProto with all runtime deltas applied."""
1064    # Ensure physical devices have been discovered and config has been imported
1065    self._initialize_physical_devices()
1066
1067    config = config_pb2.ConfigProto()
1068    if self._config is not None:
1069      config.CopyFrom(self._config)
1070
1071    if self._optimizer_jit is not None:
1072      config.graph_options.optimizer_options.global_jit_level = (
1073          config_pb2.OptimizerOptions.ON_1
1074          if self._optimizer_jit else config_pb2.OptimizerOptions.OFF)
1075    if self._intra_op_parallelism_threads is not None:
1076      config.intra_op_parallelism_threads = self._intra_op_parallelism_threads
1077    if self._inter_op_parallelism_threads is not None:
1078      config.inter_op_parallelism_threads = self._inter_op_parallelism_threads
1079
1080    if self._soft_device_placement is not None:
1081      config.allow_soft_placement = self._soft_device_placement
1082    else:
1083      config.allow_soft_placement = self.executing_eagerly()
1084
1085    if self._log_device_placement is not None:
1086      config.log_device_placement = self._log_device_placement
1087
1088    if self._operation_timeout_in_ms is not None:
1089      config.operation_timeout_in_ms = self._operation_timeout_in_ms
1090
1091    is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
1092    config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
1093    if (is_mlir_bridge_enabled ==
1094        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
1095      config.experimental.enable_mlir_bridge = True
1096
1097    if self._enable_mlir_graph_optimization is not None:
1098      config.experimental.enable_mlir_graph_optimization = (
1099          self._enable_mlir_graph_optimization)
1100
1101    def rewriter_toggle(option):
1102      toggle = self._optimizer_experimental_options.get(option, None)
1103      if toggle is None:
1104        return
1105
1106      setattr(config.graph_options.rewrite_options, option,
1107              (rewriter_config_pb2.RewriterConfig.ON
1108               if toggle else rewriter_config_pb2.RewriterConfig.OFF))
1109
1110    def rewriter_bool(option):
1111      toggle = self._optimizer_experimental_options.get(option, None)
1112      if toggle is None:
1113        return
1114
1115      setattr(config.graph_options.rewrite_options, option, toggle)
1116
1117    rewriter_toggle("layout_optimizer")
1118    rewriter_toggle("constant_folding")
1119    rewriter_toggle("shape_optimization")
1120    rewriter_toggle("remapping")
1121    rewriter_toggle("arithmetic_optimization")
1122    rewriter_toggle("dependency_optimization")
1123    rewriter_toggle("loop_optimization")
1124    rewriter_toggle("function_optimization")
1125    rewriter_toggle("debug_stripper")
1126    rewriter_bool("disable_model_pruning")
1127    rewriter_toggle("scoped_allocator_optimization")
1128    rewriter_toggle("pin_to_host_optimization")
1129    rewriter_toggle("implementation_selector")
1130    rewriter_toggle("auto_mixed_precision")
1131    rewriter_toggle("use_plugin_optimizers")
1132    rewriter_bool("disable_meta_optimizer")
1133    rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1134    rewriter_toggle("auto_mixed_precision_mkl")
1135    nodes = self._optimizer_experimental_options.get("min_graph_nodes", None)
1136    if nodes is not None:
1137      config.graph_options.rewrite_options.min_graph_nodes = nodes
1138
1139    # Compute device counts
1140    config.device_count["CPU"] = 0
1141    config.device_count["GPU"] = 0
1142    for dev in self._physical_devices:
1143      if dev not in self._visible_device_list:
1144        continue
1145
1146      virtual_devices = self._virtual_device_map.get(dev)
1147      if virtual_devices is None:
1148        config.device_count[dev.device_type] += 1
1149      else:
1150        config.device_count[dev.device_type] += len(virtual_devices)
1151
1152    # Configure gpu_options
1153    gpu_options = self._compute_gpu_options()
1154    config.gpu_options.MergeFrom(gpu_options)
1155
1156    # Configure collective ops
1157    if self._collective_leader:
1158      config.experimental.collective_group_leader = self._collective_leader
1159    if self._collective_scoped_allocator_enabled_ops:
1160      rewrite_options = config.graph_options.rewrite_options
1161      rewrite_options.scoped_allocator_optimization = (
1162          rewriter_config_pb2.RewriterConfig.ON)
1163      del rewrite_options.scoped_allocator_opts.enable_op[:]
1164      for op in self._collective_scoped_allocator_enabled_ops:
1165        rewrite_options.scoped_allocator_opts.enable_op.append(op)
1166    if self._collective_use_nccl_communication:
1167      config.experimental.collective_nccl = True
1168    if self._collective_device_filters:
1169      del config.device_filters[:]
1170      for f in self._collective_device_filters:
1171        config.device_filters.append(f)
1172
1173    # Configure coordination service
1174    if self._coordination_service_config:
1175      config.experimental.coordination_config.CopyFrom(
1176          self._coordination_service_config)
1177
1178    return config
1179
1180  def _compute_gpu_options(self):
1181    """Build the GPUOptions proto."""
1182    visible_device_list = []
1183    virtual_devices = []
1184    gpu_index = -1
1185    memory_growths = set()
1186    gpu_devices = self.list_physical_devices("GPU")
1187    pluggable_devices = self._pluggable_devices
1188    compatible_devices = gpu_devices
1189    for dev in pluggable_devices:
1190      if dev not in gpu_devices:
1191        compatible_devices.append(dev)
1192    for dev in compatible_devices:
1193      gpu_index += 1
1194
1195      if dev not in self._visible_device_list:
1196        continue
1197
1198      growth = self._memory_growth_map[dev]
1199      memory_growths.add(growth)
1200      visible_device_list.append(str(gpu_index))
1201
1202      if self._virtual_device_map:
1203        vdevs = self._virtual_device_map.get(dev, [])
1204        device_ordinals = []
1205        device_limits = []
1206        priority = []
1207        for virt_dev in vdevs:
1208          device_ordinals.append(virt_dev.experimental_device_ordinal)
1209          device_limits.append(virt_dev.memory_limit)
1210          if virt_dev.experimental_priority is not None:
1211            priority.append(virt_dev.experimental_priority)
1212        # If priority is specified, it must be specified for all virtual
1213        # devices.
1214        if priority and len(device_limits) != len(priority):
1215          raise ValueError("priority must be specified for all virtual devices")
1216
1217        virtual_devices.append(
1218            config_pb2.GPUOptions.Experimental.VirtualDevices(
1219                memory_limit_mb=device_limits,
1220                priority=priority,
1221                device_ordinal=device_ordinals))
1222
1223    # Only compute growth if virtual devices have not been configured and we
1224    # have GPUs
1225    if not virtual_devices and memory_growths:
1226      if len(memory_growths) > 1:
1227        raise ValueError("Memory growth cannot differ between GPU devices")
1228      allow_growth = memory_growths.pop()
1229    else:
1230      allow_growth = None
1231
1232    return config_pb2.GPUOptions(
1233        allow_growth=allow_growth,
1234        visible_device_list=",".join(visible_device_list),
1235        experimental=config_pb2.GPUOptions.Experimental(
1236            virtual_devices=virtual_devices))
1237
1238  @property
1239  def function_call_options(self):
1240    """Returns function call options for current thread.
1241
1242    Note that the returned object is still referenced by the eager context.
1243
1244    Returns: the FunctionCallOptions for current thread.
1245    """
1246    if self._thread_local_data.function_call_options is None:
1247      config = self.config
1248
1249      # Default to soft placement for functions unless specified
1250      if self._soft_device_placement is None:
1251        config.allow_soft_placement = True
1252      self._thread_local_data.function_call_options = FunctionCallOptions(
1253          config_proto=config)
1254
1255    return self._thread_local_data.function_call_options
1256
1257  @function_call_options.setter
1258  def function_call_options(self, options):
1259    """Returns function call options for current thread."""
1260    self._thread_local_data.function_call_options = options
1261
1262  def num_gpus(self):
1263    """The number of GPUs available to execute operations."""
1264    self.ensure_initialized()
1265    return self._num_gpus
1266
1267  def add_function(self, fn):
1268    """Add a function definition to the context.
1269
1270    Once added, the function (identified by its name) can be executed like any
1271    other operation.
1272
1273    Args:
1274      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
1275    """
1276    self.ensure_initialized()
1277    pywrap_tfe.TFE_ContextAddFunction(self._handle, fn)
1278
1279  def add_function_def(self, fdef):
1280    """Add a function definition to the context.
1281
1282    Once added, the function (identified by its name) can be executed like any
1283    other operation.
1284
1285    Args:
1286      fdef: A FunctionDef protocol buffer message.
1287    """
1288    self.ensure_initialized()
1289    fdef_string = fdef.SerializeToString()
1290    pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
1291                                         len(fdef_string))
1292
1293  def get_function_def(self, name):
1294    """Get a function definition from the context.
1295
1296    Args:
1297      name: function signature name.
1298
1299    Returns:
1300      The requested FunctionDef.
1301
1302    Raises:
1303      tf.errors.NotFoundError: if name is not the name of a registered function.
1304    """
1305    with c_api_util.tf_buffer() as buffer_:
1306      pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
1307      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1308    function_def = function_pb2.FunctionDef()
1309    function_def.ParseFromString(proto_data)
1310
1311    return function_def
1312
1313  def register_custom_device(self, device_capsule, device_name,
1314                             device_info_capsule):
1315    """Calls TFE_RegisterCustomDevice. See the non-member function."""
1316    self.ensure_initialized()
1317    pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
1318                                           device_name, device_info_capsule)
1319
1320  def pack_eager_tensors(self, tensors):
1321    """Pack multiple `EagerTensor`s of the same dtype and shape.
1322
1323    Args:
1324      tensors: a list of EagerTensors to pack.
1325
1326    Returns:
1327      A packed EagerTensor.
1328    """
1329    self.ensure_initialized()
1330    return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
1331
1332  def list_function_names(self):
1333    """Get a list of names of registered functions.
1334
1335    Returns:
1336      A set of names of all registered functions for the context.
1337    """
1338    self.ensure_initialized()
1339    return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
1340
1341  def remove_function(self, name):
1342    """Remove a function from the context.
1343
1344    Once removed, the function cannot be executed anymore.
1345
1346    Args:
1347      name: function signature name.
1348    """
1349    self.ensure_initialized()
1350    pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
1351
1352  def has_function(self, name):
1353    """Check if a function `name` is registered."""
1354    self.ensure_initialized()
1355    return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
1356
1357  def add_op_callback(self, callback):
1358    """Add a post-op callback to the context.
1359
1360    A post-op callback is invoked immediately after an eager operation or
1361    function has finished execution or after a op has been added to a graph,
1362    providing access to the op's type, name input and output tensors. Multiple
1363    op callbacks can be added, in which case the callbacks will be invoked in
1364    the order in which they are added.
1365
1366    Args:
1367      callback: a callable of the signature `f(op_type, inputs, attrs, outputs,
1368        op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for
1369        details on the function signature and its semantics.
1370    """
1371    if callback not in self._thread_local_data.op_callbacks:
1372      self._thread_local_data.op_callbacks.append(callback)
1373
1374  def remove_op_callback(self, callback):
1375    """Remove an already-registered op callback.
1376
1377    Args:
1378      callback: The op callback to be removed.
1379
1380    Raises:
1381      KeyError: If `callback` is not already registered.
1382    """
1383    if callback not in self._thread_local_data.op_callbacks:
1384      raise KeyError("The specified op callback has not been registered, "
1385                     "and hence cannot be removed.")
1386    del self._thread_local_data.op_callbacks[
1387        self._thread_local_data.op_callbacks.index(callback)]
1388
1389  @property
1390  def op_callbacks(self):
1391    return self._thread_local_data.op_callbacks
1392
1393  @property
1394  def invoking_op_callbacks(self):
1395    return self._thread_local_data.invoking_op_callbacks
1396
1397  @invoking_op_callbacks.setter
1398  def invoking_op_callbacks(self, value):
1399    self._thread_local_data.invoking_op_callbacks = value
1400
1401  def _initialize_physical_devices(self, reinitialize=False):
1402    """Gets local devices visible to the system.
1403
1404    Args:
1405      reinitialize: If True, reinitializes self._physical_devices  so that
1406        dynamic registered devices will also be visible to the python front-end.
1407    """
1408    # We lazy initialize self._physical_devices since we do not want to do this
1409    # the constructor since the backend may not be initialized yet.
1410    with self._device_lock:
1411      if not reinitialize and self._physical_devices is not None:
1412        return
1413
1414      devs = pywrap_tfe.TF_ListPhysicalDevices()
1415      self._physical_devices = [
1416          PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1417          for d in devs
1418      ]
1419      self._physical_device_to_index = {
1420          p: i for i, p in enumerate(self._physical_devices)
1421      }
1422      # We maintain a separate list just so we can check whether the device in
1423      # _physical_devices is a PluggableDevice.
1424      pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices()
1425      self._pluggable_devices = [
1426          PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1427          for d in pluggable_devs
1428      ]
1429
1430      self._visible_device_list = list(self._physical_devices)
1431      self._memory_growth_map = {
1432          d: None
1433          for d in self._physical_devices
1434          if d.device_type == "GPU" or d in self._pluggable_devices
1435      }
1436
1437    # Import device settings that may have been passed into the constructor
1438    self._import_config()
1439
1440  def reinitialize_physical_devices(self):
1441    """Gets local devices visible to the system."""
1442    # Reinitialize the physical device list after registering
1443    # the pluggable device.
1444    self._initialize_physical_devices(True)
1445
1446  def list_physical_devices(self, device_type=None):
1447    """List local devices visible to the system.
1448
1449    This API allows a client to query the devices before they have been
1450    initialized by the eager runtime. Additionally a user can filter by device
1451    type, to get only CPUs or GPUs.
1452
1453    Args:
1454      device_type: Optional device type to limit results to
1455
1456    Returns:
1457      List of PhysicalDevice objects.
1458    """
1459    self._initialize_physical_devices()
1460
1461    if device_type is None:
1462      return list(self._physical_devices)
1463
1464    return [d for d in self._physical_devices if d.device_type == device_type]
1465
1466  def get_device_details(self, device):  # pylint: disable=redefined-outer-name
1467    """Returns details about a physical devices.
1468
1469    Args:
1470      device: A `tf.config.PhysicalDevice` returned by
1471        `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
1472
1473    Returns:
1474      A dict with string keys.
1475    """
1476    if not isinstance(device, PhysicalDevice):
1477      raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
1478                       "%s" % (device,))
1479    if (self._physical_device_to_index is None or
1480        device not in self._physical_device_to_index):
1481      raise ValueError("The PhysicalDevice must be one obtained from "
1482                       "calling `tf.config.list_physical_devices`, but got: "
1483                       "%s" % (device,))
1484    index = self._physical_device_to_index[device]
1485    details = pywrap_tfe.TF_GetDeviceDetails(index)
1486
1487    # Change compute_capability from a string to a tuple
1488    if "compute_capability" in details:
1489      try:
1490        major, minor = details["compute_capability"].split(".")
1491        details["compute_capability"] = (int(major), int(minor))
1492      except ValueError:
1493        raise RuntimeError("Device returned compute capability an in invalid "
1494                           "format: %s" % details["compute_capability"])
1495    return details
1496
1497  def _import_config(self):
1498    """Import config if passed in during construction.
1499
1500    If Context was created with a ConfigProto such as when calling
1501    tf.compat.v1.enable_eager_execution(), then we need to pull out the
1502    various pieces we might be replacing and import then into our internal
1503    class representation.
1504    """
1505    if self._config is None:
1506      return
1507
1508    num_cpus = self._config.device_count.get("CPU", 1)
1509    if num_cpus != 1:
1510      cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
1511      if num_cpus == 0:
1512        self.set_visible_devices([], "CPU")
1513      elif num_cpus > 1:
1514        self.set_logical_device_configuration(
1515            cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)])
1516
1517    # Parse GPU options
1518    gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
1519
1520    # If there are no GPUs detected, simply ignore all the GPU options passed in
1521    # rather than doing any validation checks.
1522    if not gpus:
1523      return
1524
1525    gpu_count = self._config.device_count.get("GPU", None)
1526
1527    visible_gpus = []
1528    # TODO(gjn): Handle importing existing virtual GPU configuration
1529    visible_indices = self._config.gpu_options.visible_device_list
1530    if visible_indices:
1531      for index in visible_indices.split(","):
1532        if int(index) >= len(gpus):
1533          raise ValueError("Invalid visible device index: %s" % index)
1534        visible_gpus.append(gpus[int(index)])
1535    else:
1536      visible_gpus = gpus
1537
1538    if gpu_count is not None:
1539      visible_gpus = visible_gpus[:gpu_count]
1540
1541    self.set_visible_devices(visible_gpus, "GPU")
1542
1543  def list_logical_devices(self, device_type=None):
1544    """Return logical devices."""
1545    self.ensure_initialized()
1546    if device_type is None:
1547      return list(self._logical_devices)
1548
1549    return [d for d in self._logical_devices if d.device_type == device_type]
1550
1551  def get_visible_devices(self, device_type=None):
1552    """Get the list of visible devices."""
1553    self._initialize_physical_devices()
1554
1555    if device_type is None:
1556      return list(self._visible_device_list)
1557
1558    return [
1559        d for d in self._visible_device_list if d.device_type == device_type
1560    ]
1561
1562  def set_visible_devices(self, devices, device_type=None):
1563    """Set the list of visible devices."""
1564    self._initialize_physical_devices()
1565
1566    if not isinstance(devices, list):
1567      devices = [devices]
1568
1569    for d in devices:
1570      if d not in self._physical_devices:
1571        raise ValueError("Unrecognized device: %s" % repr(d))
1572      if device_type is not None and d.device_type != device_type:
1573        raise ValueError("Unrecognized device: %s" % repr(d))
1574
1575    visible_device_list = []
1576    if device_type is not None:
1577      visible_device_list = [
1578          d for d in self._visible_device_list if d.device_type != device_type
1579      ]
1580
1581    visible_device_list += devices
1582
1583    if self._visible_device_list == visible_device_list:
1584      return
1585
1586    if self._context_handle is not None:
1587      raise RuntimeError(
1588          "Visible devices cannot be modified after being initialized")
1589
1590    self._visible_device_list = visible_device_list
1591
1592  def get_memory_info(self, dev):
1593    """Returns a dict of memory info for the device."""
1594    self._initialize_physical_devices()
1595    self.ensure_initialized()
1596    return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
1597
1598  def reset_memory_stats(self, dev):
1599    """Resets the tracked memory stats for the device."""
1600    self._initialize_physical_devices()
1601    self.ensure_initialized()
1602    pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev)
1603
1604  def get_memory_growth(self, dev):
1605    """Get if memory growth is enabled for a PhysicalDevice."""
1606    self._initialize_physical_devices()
1607
1608    if dev not in self._physical_devices:
1609      raise ValueError("Unrecognized device: %s" % repr(dev))
1610
1611    return self._memory_growth_map[dev]
1612
1613  def set_memory_growth(self, dev, enable):
1614    """Set if memory growth should be enabled for a PhysicalDevice."""
1615    self._initialize_physical_devices()
1616
1617    if dev not in self._physical_devices:
1618      raise ValueError("Unrecognized device: %s" % repr(dev))
1619
1620    if dev in self._virtual_device_map:
1621      raise ValueError(
1622          "Cannot set memory growth on device when virtual devices configured")
1623
1624    if dev.device_type != "GPU" and dev not in self._pluggable_devices:
1625      raise ValueError(
1626          "Cannot set memory growth on non-GPU and non-Pluggable devices")
1627
1628    if self._memory_growth_map.get(dev) == enable:
1629      return
1630
1631    if self._context_handle is not None:
1632      raise RuntimeError(
1633          "Physical devices cannot be modified after being initialized")
1634
1635    self._memory_growth_map[dev] = enable
1636
1637  def get_logical_device_configuration(self, dev):
1638    """Get the virtual device configuration for a PhysicalDevice."""
1639    self._initialize_physical_devices()
1640
1641    if dev not in self._physical_devices:
1642      raise ValueError("Unrecognized device: %s" % repr(dev))
1643
1644    return self._virtual_device_map.get(dev)
1645
1646  def set_logical_device_configuration(self, dev, virtual_devices):
1647    """Set the virtual device configuration for a PhysicalDevice."""
1648    self._initialize_physical_devices()
1649
1650    if dev not in self._physical_devices:
1651      raise ValueError("Unrecognized device: %s" % repr(dev))
1652
1653    if dev.device_type == "CPU":
1654      for vdev in virtual_devices:
1655        if vdev.memory_limit is not None:
1656          raise ValueError("Setting memory limit on CPU virtual devices is "
1657                           "currently not supported")
1658        if vdev.experimental_priority is not None:
1659          raise ValueError("Setting experimental_priority on CPU virtual "
1660                           " devices is currently not supported")
1661        if vdev.experimental_device_ordinal != 0:
1662          raise ValueError("Setting experimental_device_ordinal on CPU virtual "
1663                           " devices is currently not supported")
1664    elif dev.device_type == "GPU":
1665      for vdev in virtual_devices:
1666        if vdev.memory_limit is None:
1667          raise ValueError(
1668              "Setting memory limit is required for GPU virtual devices")
1669    else:
1670      raise ValueError("Virtual devices are not supported for %s" %
1671                       dev.device_type)
1672
1673    if self._virtual_device_map.get(dev) == virtual_devices:
1674      return
1675
1676    if self._context_handle is not None:
1677      raise RuntimeError(
1678          "Virtual devices cannot be modified after being initialized")
1679
1680    self._virtual_device_map[dev] = virtual_devices
1681
1682  def set_logical_cpu_devices(self, num_cpus, prefix=""):
1683    """Set virtual CPU devices in context.
1684
1685    If virtual CPU devices are already configured at context initialization
1686    by tf.config.set_logical_device_configuration(), this method should not be
1687    called.
1688
1689    Args:
1690      num_cpus: Number of virtual CPUs.
1691      prefix: Device name prefix.
1692
1693    Raises:
1694     RuntimeError: If virtual CPUs are already configured at context
1695     initialization.
1696    """
1697    server_def = self._server_def or self._collective_ops_server_def
1698    local_prefix = ["/device"]
1699    if server_def is not None:
1700      local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name,
1701                                                         server_def.task_index))
1702    logical_local_devices = [d for d in self.list_logical_devices("CPU") if
1703                             d.name.startswith(tuple(local_prefix))]
1704    self.ensure_initialized()
1705    # Error out if there are already multiple logical CPU in the context.
1706    if len(logical_local_devices) > 1:
1707      raise RuntimeError("Virtual CPUs already set, cannot modify again.")
1708
1709    pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix)
1710    self._initialize_logical_devices()
1711
1712  def get_compiler_ir(self, device_name, function_name, args, stage="hlo"):
1713    return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name,
1714                                       stage, device_name, args)
1715
1716  @deprecated(
1717      None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
1718  def enable_xla_devices(self):
1719    """Enables XLA:CPU and XLA:GPU devices registration."""
1720    pywrap_tfe.TF_EnableXlaDevices()
1721
1722  @property
1723  def enable_mlir_bridge(self):
1724    return pywrap_tfe.TF_IsMlirBridgeEnabled()
1725
1726  @property
1727  def enable_mlir_graph_optimization(self):
1728    return self._enable_mlir_graph_optimization
1729
1730  @enable_mlir_bridge.setter
1731  def enable_mlir_bridge(self, enabled):
1732    pywrap_tfe.TF_EnableMlirBridge(enabled)
1733    self._thread_local_data.function_call_options = None
1734
1735  @enable_mlir_graph_optimization.setter
1736  def enable_mlir_graph_optimization(self, enabled):
1737    self._enable_mlir_graph_optimization = enabled
1738    self._thread_local_data.function_call_options = None
1739
1740  @property
1741  def optimizer_jit(self):
1742    level = self.config.graph_options.optimizer_options.global_jit_level
1743    return (level == config_pb2.OptimizerOptions.ON_1 or
1744            level == config_pb2.OptimizerOptions.ON_2)
1745
1746  @optimizer_jit.setter
1747  def optimizer_jit(self, enabled):
1748    self._optimizer_jit = enabled
1749
1750    self._thread_local_data.function_call_options = None
1751
1752  def get_optimizer_experimental_options(self):
1753    """Get experimental options for the optimizer.
1754
1755    Returns:
1756      Dictionary of current option values
1757    """
1758    rewrite_options = self.config.graph_options.rewrite_options
1759    options = {}
1760
1761    def rewriter_toggle(option):
1762      attr = getattr(rewrite_options, option)
1763      if attr != 0:
1764        options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
1765
1766    def rewriter_bool(option):
1767      options[option] = getattr(rewrite_options, option)
1768
1769    rewriter_toggle("layout_optimizer")
1770    rewriter_toggle("constant_folding")
1771    rewriter_toggle("shape_optimization")
1772    rewriter_toggle("remapping")
1773    rewriter_toggle("arithmetic_optimization")
1774    rewriter_toggle("dependency_optimization")
1775    rewriter_toggle("loop_optimization")
1776    rewriter_toggle("function_optimization")
1777    rewriter_toggle("debug_stripper")
1778    rewriter_bool("disable_model_pruning")
1779    rewriter_toggle("scoped_allocator_optimization")
1780    rewriter_toggle("pin_to_host_optimization")
1781    rewriter_toggle("implementation_selector")
1782    rewriter_toggle("auto_mixed_precision")
1783    rewriter_toggle("use_plugin_optimizers")
1784    rewriter_bool("disable_meta_optimizer")
1785    rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1786    rewriter_toggle("auto_mixed_precision_mkl")
1787
1788    if rewrite_options.min_graph_nodes != 0:
1789      options["min_graph_nodes"] = rewrite_options.min_graph_nodes
1790
1791    return options
1792
1793  def set_optimizer_experimental_options(self, options):
1794    """Set experimental options for the optimizer.
1795
1796    Args:
1797      options: Dictionary of options to modify
1798    """
1799    self._optimizer_experimental_options.update(options)
1800
1801    self._thread_local_data.function_call_options = None
1802
1803  @property
1804  def intra_op_parallelism_threads(self):
1805    return self.config.intra_op_parallelism_threads
1806
1807  @intra_op_parallelism_threads.setter
1808  def intra_op_parallelism_threads(self, num_threads):
1809    if self._intra_op_parallelism_threads == num_threads:
1810      return
1811
1812    if self._context_handle is not None:
1813      raise RuntimeError(
1814          "Intra op parallelism cannot be modified after initialization.")
1815
1816    self._intra_op_parallelism_threads = num_threads
1817
1818  @property
1819  def inter_op_parallelism_threads(self):
1820    return self.config.inter_op_parallelism_threads
1821
1822  @inter_op_parallelism_threads.setter
1823  def inter_op_parallelism_threads(self, num_threads):
1824    if self._inter_op_parallelism_threads == num_threads:
1825      return
1826
1827    if self._context_handle is not None:
1828      raise RuntimeError(
1829          "Inter op parallelism cannot be modified after initialization.")
1830
1831    self._inter_op_parallelism_threads = num_threads
1832
1833  @property
1834  def soft_device_placement(self):
1835    return self.config.allow_soft_placement
1836
1837  @soft_device_placement.setter
1838  def soft_device_placement(self, enable):
1839    if self._context_handle is not None:
1840      pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
1841
1842    self._soft_device_placement = enable
1843    self._thread_local_data.function_call_options = None
1844
1845  @property
1846  def log_device_placement(self):
1847    return self.config.log_device_placement
1848
1849  @log_device_placement.setter
1850  def log_device_placement(self, enable):
1851    if self._context_handle is not None:
1852      pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
1853
1854    self._log_device_placement = enable
1855    self._thread_local_data.function_call_options = None
1856
1857  @property
1858  def jit_compile_rewrite(self):
1859    return self._jit_compile_rewrite
1860
1861  @jit_compile_rewrite.setter
1862  def jit_compile_rewrite(self, enable):
1863    if self._context_handle is not None:
1864      pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable)
1865    self._jit_compile_rewrite = enable
1866
1867  @property
1868  def device_policy(self):
1869    # Only get the policy from the context if it has already been initialized
1870    if self._context_handle is not None:
1871      return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
1872
1873    return self._device_policy
1874
1875  @device_policy.setter
1876  def device_policy(self, policy):
1877    if policy is None:
1878      policy = DEVICE_PLACEMENT_SILENT
1879
1880    if self._device_policy != policy:
1881      self._device_policy = policy
1882
1883      # Only set the policy if the context has already been initialized
1884      if self._context_handle is not None:
1885        pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
1886            self._handle, self._device_policy)
1887
1888  @property
1889  def use_tfrt(self):
1890    return self._use_tfrt
1891
1892  @use_tfrt.setter
1893  def use_tfrt(self, tfrt):
1894    """Sets whether to use TFRT."""
1895    if not isinstance(tfrt, bool):
1896      raise ValueError("Expecting a boolean but got %s" % type(tfrt))
1897
1898    if self._use_tfrt != tfrt:
1899      if self._initialized:
1900        raise ValueError("use_tfrt should be set before being initialized.")
1901      self._use_tfrt = tfrt
1902
1903  @property
1904  def use_tfrt_distributed_runtime(self):
1905    return self._use_tfrt_distributed_runtime
1906
1907  @use_tfrt_distributed_runtime.setter
1908  def use_tfrt_distributed_runtime(self, enable):
1909    """Sets whether to use TFRT distributed runtime.
1910
1911    This is only effective when use_tfrt is also true. Note that currently TFRT
1912    distributed runtime is not function complete and this config is for testing
1913    only.
1914    Args:
1915      enable: A boolean to set whether to use TFRT distributed runtime.
1916    """
1917    if not isinstance(enable, bool):
1918      raise ValueError("Expecting a boolean but got %s" % type(enable))
1919
1920    if self._use_tfrt_distributed_runtime != enable:
1921      if self._initialized:
1922        raise ValueError("use_tfrt should be set before being initialized.")
1923      self._use_tfrt_distributed_runtime = enable
1924
1925  @property
1926  def operation_timeout_in_ms(self):
1927    return self.config.operation_timeout_in_ms
1928
1929  @operation_timeout_in_ms.setter
1930  def operation_timeout_in_ms(self, timeout_in_ms):
1931    if self._operation_timeout_in_ms == timeout_in_ms:
1932      return
1933
1934    if self._context_handle is not None:
1935      raise RuntimeError(
1936          "Operation timeout cannot be modified after initialization.")
1937
1938    self._operation_timeout_in_ms = timeout_in_ms
1939
1940  def enable_run_metadata(self):
1941    """Enables tracing of op execution via RunMetadata.
1942
1943    To retrieve the accumulated metadata call context.export_run_metadata()
1944    and to stop tracing call context.disable_run_metadata().
1945    """
1946    self.ensure_initialized()
1947    pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
1948
1949  def disable_run_metadata(self):
1950    """Disables tracing of op execution via RunMetadata."""
1951    if not self._context_handle:
1952      return
1953    pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
1954
1955  def enable_graph_collection(self):
1956    """Enables graph collection of executed functions.
1957
1958    To retrieve the accumulated graphs call context.export_run_metadata()
1959    and to stop collecting graphs call context.disable_graph_collection().
1960    """
1961    self.ensure_initialized()
1962    pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
1963
1964  def disable_graph_collection(self):
1965    """Disables graph collection of executed functions."""
1966    if not self._context_handle:
1967      return
1968    pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
1969
1970  def export_run_metadata(self):
1971    """Returns a RunMetadata proto with accumulated information.
1972
1973    The returned protocol buffer contains information since the most recent call
1974    to either enable_run_metadata or export_run_metadata.
1975
1976    Returns:
1977      A RunMetadata protocol buffer. Or None if not enabled.
1978    """
1979    if not self._context_handle:
1980      return None
1981    with c_api_util.tf_buffer() as buffer_:
1982      pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
1983      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1984    run_metadata = config_pb2.RunMetadata()
1985    run_metadata.ParseFromString(compat.as_bytes(proto_data))
1986    return run_metadata
1987
1988  @property
1989  def context_switches(self):
1990    """Returns a stack of context switches."""
1991    return self._context_switches
1992
1993
1994class _EagerDeviceContext(object):
1995  """Context-manager forcing placement of ops and Tensors on a device."""
1996
1997  __slots__ = ["_device_name", "_ctx", "_stack"]
1998
1999  def __init__(self, ctx, device_name):
2000    self._device_name = device_name
2001    self._ctx = ctx
2002    self._stack = []
2003
2004  # TODO(b/189233748): Consolidate the device string parsing logic with
2005  # tensorflow/core/util/device_name_utils.cc.
2006  def __enter__(self):
2007    ctx = self._ctx
2008    old_device_name = ctx.device_name
2009    old_device_spec = ctx.device_spec
2010    new_device_name = self._device_name
2011    cache_key = (old_device_name, new_device_name)
2012    try:
2013      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
2014    except TypeError:
2015      # Error while trying to compute the cache key.
2016      raise ValueError("Expecting a string device name. Got %s(%s)" %
2017                       (type(new_device_name), new_device_name))
2018    except KeyError:
2019      # Handle a cache miss.
2020      if new_device_name is not None:
2021        if not isinstance(new_device_name, str):
2022          raise ValueError("Expecting a string device name. Got %s(%s)" %
2023                           (type(new_device_name), new_device_name))
2024        device_spec = pydev.DeviceSpec.from_string(new_device_name)
2025        if old_device_name:
2026          new_device_spec = copy.copy(old_device_spec)
2027        else:
2028          ctx.ensure_initialized()
2029          new_device_spec = pydev.DeviceSpec.from_string(
2030              ctx._context_devices[0])  # pylint: disable=protected-access
2031        new_device_spec = new_device_spec.make_merged_spec(device_spec)
2032      else:
2033        new_device_spec = pydev.DeviceSpec.from_string("")
2034      new_device_name = new_device_spec.to_string()
2035      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
2036
2037    ctx._set_device(new_device_name, new_device_spec)  # pylint: disable=protected-access
2038    self._stack.append((old_device_name, old_device_spec, new_device_spec))
2039
2040  def __exit__(self, *ex_info):
2041    ctx = self._ctx
2042    old_device_name, old_device_spec, new_device_spec = self._stack[-1]
2043    if ctx.device_spec is not new_device_spec:
2044      raise RuntimeError("Exiting device scope without proper scope nesting")
2045    del self._stack[-1]
2046    ctx._set_device(old_device_name, old_device_spec)  # pylint: disable=protected-access
2047
2048
2049# Do not change directly.
2050_context = None
2051_context_lock = threading.Lock()
2052
2053
2054def _set_context_locked(ctx):
2055  global _context
2056  pywrap_tfe.TFE_Py_SetEagerContext(ctx)
2057  ctx.mark_as_global_context()
2058  _context = ctx
2059
2060
2061def _set_context(ctx):
2062  with _context_lock:
2063    _set_context_locked(ctx)
2064
2065
2066def _create_context():
2067  with _context_lock:
2068    if _context is None:
2069      ctx = Context()
2070      _set_context_locked(ctx)
2071
2072
2073def _reset_context():
2074  """Clears and re-initializes the singleton context.
2075
2076  Should only be used for testing.
2077  """
2078  global _context
2079  global _device_parsing_cache
2080
2081  # Garbage collect and clear scalar cache to avoid Tensor from current context
2082  # polluting next context.
2083  gc.collect()
2084  pywrap_tfe.TFE_ClearScalarCache()
2085  with _context_lock:
2086    if _context is not None:
2087      _context._clear_caches()
2088      _context = None
2089  _create_context()
2090  _device_parsing_cache = {}
2091
2092
2093def _reset_jit_compiler_flags():
2094  """Clears and re-initializes the TF JIT compiler flags.
2095
2096  Should only be used for testing.
2097  """
2098  pywrap_tfe.TF_ResetJitCompilerFlags()
2099
2100
2101def context():
2102  """Returns a singleton context object."""
2103  if _context is None:
2104    _create_context()
2105  return _context
2106
2107
2108def context_safe():
2109  """Returns current context (or None if one hasn't been initialized)."""
2110  return _context
2111
2112
2113def ensure_initialized():
2114  """Initialize the context."""
2115  context().ensure_initialized()
2116
2117
2118def initialize_logical_devices():
2119  """Initialize the virtual devices."""
2120  context()._initialize_logical_devices()  # pylint: disable=protected-access
2121
2122
2123def set_global_seed(seed):
2124  """Sets the eager mode seed."""
2125  context()._set_global_seed(seed)  # pylint: disable=protected-access
2126
2127
2128def global_seed():
2129  """Returns the eager mode seed."""
2130  return context()._seed  # pylint: disable=protected-access
2131
2132
2133def internal_operation_seed():
2134  """Returns the operation seed generated based on global seed."""
2135  return context()._internal_operation_seed()  # pylint: disable=protected-access
2136
2137
2138@tf_export("executing_eagerly", v1=[])
2139def executing_eagerly():
2140  """Checks whether the current thread has eager execution enabled.
2141
2142  Eager execution is enabled by default and this API returns `True`
2143  in most of cases. However, this API might return `False` in the following use
2144  cases.
2145
2146  *  Executing inside `tf.function`, unless under `tf.init_scope` or
2147     `tf.config.run_functions_eagerly(True)` is previously called.
2148  *  Executing inside a transformation function for `tf.dataset`.
2149  *  `tf.compat.v1.disable_eager_execution()` is called.
2150
2151  General case:
2152
2153  >>> print(tf.executing_eagerly())
2154  True
2155
2156  Inside `tf.function`:
2157
2158  >>> @tf.function
2159  ... def fn():
2160  ...   with tf.init_scope():
2161  ...     print(tf.executing_eagerly())
2162  ...   print(tf.executing_eagerly())
2163  >>> fn()
2164  True
2165  False
2166
2167  Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called:
2168
2169  >>> tf.config.run_functions_eagerly(True)
2170  >>> @tf.function
2171  ... def fn():
2172  ...   with tf.init_scope():
2173  ...     print(tf.executing_eagerly())
2174  ...   print(tf.executing_eagerly())
2175  >>> fn()
2176  True
2177  True
2178  >>> tf.config.run_functions_eagerly(False)
2179
2180  Inside a transformation function for `tf.dataset`:
2181
2182  >>> def data_fn(x):
2183  ...   print(tf.executing_eagerly())
2184  ...   return x
2185  >>> dataset = tf.data.Dataset.range(100)
2186  >>> dataset = dataset.map(data_fn)
2187  False
2188
2189  Returns:
2190    `True` if the current thread has eager execution enabled.
2191  """
2192  ctx = context_safe()
2193  if ctx is None:
2194    return default_execution_mode == EAGER_MODE
2195
2196  return ctx.executing_eagerly()
2197
2198
2199@tf_export(v1=["executing_eagerly"])
2200def executing_eagerly_v1():
2201  """Checks whether the current thread has eager execution enabled.
2202
2203  Eager execution is typically enabled via
2204  `tf.compat.v1.enable_eager_execution`, but may also be enabled within the
2205  context of a Python function via tf.contrib.eager.py_func.
2206
2207  When eager execution is enabled, returns `True` in most cases. However,
2208  this API might return `False` in the following use cases.
2209
2210  *  Executing inside `tf.function`, unless under `tf.init_scope` or
2211     `tf.config.run_functions_eagerly(True)` is previously called.
2212  *  Executing inside a transformation function for `tf.dataset`.
2213  *  `tf.compat.v1.disable_eager_execution()` is called.
2214
2215  >>> tf.compat.v1.enable_eager_execution()
2216
2217  General case:
2218
2219  >>> print(tf.executing_eagerly())
2220  True
2221
2222  Inside `tf.function`:
2223
2224  >>> @tf.function
2225  ... def fn():
2226  ...   with tf.init_scope():
2227  ...     print(tf.executing_eagerly())
2228  ...   print(tf.executing_eagerly())
2229  >>> fn()
2230  True
2231  False
2232
2233  Inside `tf.function`
2234  after  `tf.config.run_functions_eagerly(True)` is called:
2235
2236  >>> tf.config.run_functions_eagerly(True)
2237  >>> @tf.function
2238  ... def fn():
2239  ...   with tf.init_scope():
2240  ...     print(tf.executing_eagerly())
2241  ...   print(tf.executing_eagerly())
2242  >>> fn()
2243  True
2244  True
2245  >>> tf.config.run_functions_eagerly(False)
2246
2247  Inside a transformation function for `tf.dataset`:
2248
2249  >>> def data_fn(x):
2250  ...   print(tf.executing_eagerly())
2251  ...   return x
2252  >>> dataset = tf.data.Dataset.range(100)
2253  >>> dataset = dataset.map(data_fn)
2254  False
2255
2256  Returns:
2257    `True` if the current thread has eager execution enabled.
2258  """
2259  return executing_eagerly()
2260
2261
2262def in_eager_mode():
2263  """Use executing_eagerly() instead. This function will be removed."""
2264  return executing_eagerly()
2265
2266
2267def anonymous_name():
2268  """Returns the anonymous shared name.
2269
2270  In eager mode we create anonymous resources to avoid spurious sharing issues.
2271  The runtime generates a unique name on our behalf when the reserved
2272  anonymous shared name is used as a shared name.
2273
2274  Returns:
2275    The anonymous shared name.
2276  """
2277
2278  # The magic value is defined as
2279  # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++.
2280  return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
2281
2282
2283def graph_mode():
2284  """Context-manager to disable eager execution for the current thread."""
2285  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
2286
2287
2288# Used by b/167638505 for keras backend API and Lambda layer.
2289@tf_export("__internal__.eager_context.eager_mode", v1=[])
2290def eager_mode():
2291  """Context-manager to enable eager execution for the current thread."""
2292  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
2293
2294
2295def scope_name():
2296  """Name of the current scope."""
2297  return context().scope_name
2298
2299
2300def device(name):
2301  """Context-manager to force placement of operations and Tensors on a device.
2302
2303  Example:
2304  ```python
2305  with tf.device('gpu:0'):
2306    with tf.device('cpu:0'):
2307      shape = tf.constant([], dtype=tf.int32)
2308    x = tf.random.truncated_normal(shape, tf.float32)
2309  ```
2310  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
2311  operation runs on GPU 0.
2312
2313  Args:
2314    name: Name of the device (see context().devices()), or None to perform
2315      automatic placement.
2316
2317  Returns:
2318    Context manager for setting the device.
2319  """
2320  ensure_initialized()
2321  return context().device(name)
2322
2323
2324# Expose some properties of Context as internally public APIs (b/160348781).
2325@tf_export("__internal__.eager_context.get_config", v1=[])
2326def get_config():
2327  """Get the ConfigProto of Context.
2328
2329  Returns:
2330    The ConfigProto of Context.
2331  """
2332  return context().config
2333
2334
2335@tf_export("__internal__.eager_context.get_device_name", v1=[])
2336def get_device_name():
2337  """Get the device name for the current thread.
2338
2339  Returns:
2340    The device name for the current thread.
2341  """
2342  return context().device_name
2343
2344
2345@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
2346def set_soft_device_placement(enabled):
2347  """Set if soft device placements should be allowed.
2348
2349  Args:
2350    enabled: Whether to enable soft device placement.
2351  """
2352  context().soft_device_placement = enabled
2353
2354
2355@tf_export("__internal__.eager_context.get_executor", v1=[])
2356def get_executor():
2357  """Get the Executor of the current thread.
2358
2359  Returns:
2360    The Executor of the current thread.
2361  """
2362  return context().executor
2363
2364
2365@tf_export("debugging.get_log_device_placement")
2366def get_log_device_placement():
2367  """Get if device placements are logged.
2368
2369  Returns:
2370    If device placements are logged.
2371  """
2372  return context().log_device_placement
2373
2374
2375@tf_export("debugging.set_log_device_placement")
2376def set_log_device_placement(enabled):
2377  """Turns logging for device placement decisions on or off.
2378
2379  Operations execute on a particular device, producing and consuming tensors on
2380  that device. This may change the performance of the operation or require
2381  TensorFlow to copy data to or from an accelerator, so knowing where operations
2382  execute is useful for debugging performance issues.
2383
2384  For more advanced profiling, use the [TensorFlow
2385  profiler](https://www.tensorflow.org/guide/profiler).
2386
2387  Device placement for operations is typically controlled by a `tf.device`
2388  scope, but there are exceptions, for example operations on a `tf.Variable`
2389  which follow the initial placement of the variable. Turning off soft device
2390  placement (with `tf.config.set_soft_device_placement`) provides more explicit
2391  control.
2392
2393  >>> tf.debugging.set_log_device_placement(True)
2394  >>> tf.ones([])
2395  >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
2396  >>> with tf.device("CPU"):
2397  ...  tf.ones([])
2398  >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0
2399  >>> tf.debugging.set_log_device_placement(False)
2400
2401  Turning on `tf.debugging.set_log_device_placement` also logs the placement of
2402  ops inside `tf.function` when the function is called.
2403
2404  Args:
2405    enabled: Whether to enabled device placement logging.
2406  """
2407  context().log_device_placement = enabled
2408
2409
2410@tf_contextlib.contextmanager
2411def device_policy(policy):
2412  """Context manager for setting device placement policy for current thread."""
2413  ctx = context()
2414  old_policy = ctx.device_policy
2415  try:
2416    ctx.device_policy = policy
2417    yield
2418  finally:
2419    ctx.device_policy = old_policy
2420
2421
2422def set_execution_mode(mode):
2423  """Sets execution mode for the current thread."""
2424  context().execution_mode = mode
2425
2426
2427# TODO(fishx): remove this method.
2428@tf_contextlib.contextmanager
2429def execution_mode(mode):
2430  """Context manager for setting execution mode for current thread."""
2431  if mode is None:
2432    yield
2433  else:
2434    ctx = context()
2435    executor_new = executor.new_executor(mode == ASYNC)
2436    executor_old = ctx.executor
2437    try:
2438      executor_old.wait()
2439      ctx.executor = executor_new
2440      yield
2441    finally:
2442      ctx.executor = executor_old
2443      executor_new.wait()
2444
2445
2446@tf_contextlib.contextmanager
2447def executor_scope(e):
2448  """Context manager for changing executor for current thread.
2449
2450  Args:
2451    e: A Executor to execute eager ops under this scope. Setting it to None will
2452      switch back to use the default executor for the context.
2453
2454  Yields:
2455    Context manager for setting the executor for current thread.
2456  """
2457  ctx = context()
2458  executor_old = ctx.executor
2459  try:
2460    ctx.executor = e
2461    yield
2462  finally:
2463    ctx.executor = executor_old
2464
2465
2466@tf_export("experimental.function_executor_type")
2467@tf_contextlib.contextmanager
2468def function_executor_type(executor_type):
2469  """Context manager for setting the executor of eager defined functions.
2470
2471  Eager defined functions are functions decorated by tf.contrib.eager.defun.
2472
2473  Args:
2474    executor_type: a string for the name of the executor to be used to execute
2475      functions defined by tf.contrib.eager.defun.
2476
2477  Yields:
2478    Context manager for setting the executor of eager defined functions.
2479  """
2480  current_options = context().function_call_options
2481  old_options = copy.copy(current_options)
2482  try:
2483    current_options.executor_type = executor_type
2484    yield
2485  finally:
2486    context().function_call_options = old_options
2487
2488
2489def is_async():
2490  """Returns true if current thread is in async mode."""
2491  return context().is_async()
2492
2493
2494def num_gpus():
2495  """Get the number of available GPU devices.
2496
2497  Returns:
2498    The number of available GPU devices.
2499  """
2500  return context().num_gpus()
2501
2502
2503def enable_run_metadata():
2504  """Enables tracing of op execution via RunMetadata.
2505
2506  To retrieve the accumulated metadata call context.export_run_metadata()
2507  and to stop tracing call context.disable_run_metadata().
2508  """
2509  context().enable_run_metadata()
2510
2511
2512def disable_run_metadata():
2513  """Disables tracing of op execution via RunMetadata."""
2514  context().disable_run_metadata()
2515
2516
2517def enable_graph_collection():
2518  """Enables graph collection of executed functions.
2519
2520  To retrieve the accumulated graphs call context.export_run_metadata()
2521  and to stop collecting graphs call context.disable_graph_collection().
2522  """
2523  context().enable_graph_collection()
2524
2525
2526def disable_graph_collection():
2527  """Disables graph collection of executed functions."""
2528  context().disable_graph_collection()
2529
2530
2531def export_run_metadata():
2532  """Returns a RunMetadata proto with accumulated information.
2533
2534  The returned protocol buffer contains information since the most recent call
2535  to either enable_run_metadata or export_run_metadata.
2536
2537  Returns:
2538    A RunMetadata protocol buffer.
2539  """
2540  return context().export_run_metadata()
2541
2542
2543@contextlib.contextmanager
2544def collect_graphs(optimized=True):
2545  """Collects a flat list of pre- or post-optimization graphs.
2546
2547  The collected graphs include device placements, which can be useful for
2548  testing.
2549
2550  Usage:
2551
2552  ```
2553  @def_function.function
2554  def f(x):
2555    return x + constant_op.constant(1.)
2556
2557  with context.collect_graphs() as graphs:
2558    with ops.device("CPU:0"):
2559      f(constant_op.constant(1.))
2560
2561  graph, = graphs  # `graph` contains a single GraphDef for inspection
2562  ```
2563
2564  Args:
2565    optimized: whether to collect optimized graphs or non-optimized graphs
2566
2567  Yields:
2568    A list of GraphDefs, populated when the context manager exits.
2569  """
2570  ctx = context()
2571  ctx.enable_graph_collection()
2572  try:
2573    graphs = []
2574    yield graphs
2575    metadata = ctx.export_run_metadata()
2576  finally:
2577    ctx.disable_graph_collection()
2578  for graph in metadata.function_graphs:
2579    if optimized:
2580      graphs.append(graph.post_optimization_graph)
2581    else:
2582      graphs.append(graph.pre_optimization_graph)
2583
2584
2585def get_server_def():
2586  return context().get_server_def()
2587
2588
2589def set_server_def(server_def):
2590  context().set_server_def(server_def)
2591
2592
2593def update_server_def(server_def):
2594  context().update_server_def(server_def)
2595
2596
2597def check_alive(worker_name):
2598  return context().check_alive(worker_name)
2599
2600
2601@tf_export("experimental.async_scope")
2602@tf_contextlib.contextmanager
2603def async_scope():
2604  """Context manager for grouping async operations.
2605
2606  Ops/function calls inside the scope can return before finishing the actual
2607  execution. When exiting the async scope, a synchronization barrier will be
2608  automatically added to ensure the completion of all async op and function
2609  execution, potentially raising exceptions if async execution results in
2610  an error state.
2611
2612  Users may write the following code to asynchronously invoke `train_step_fn`
2613  and log the `loss` metric for every `num_steps` steps in a training loop.
2614  `train_step_fn` internally consumes data using `iterator.get_next()`, and may
2615  throw OutOfRangeError when running out of data. In the case:
2616
2617  ```
2618  try:
2619    with tf.experimental.async_scope():
2620      for _ in range(num_steps):
2621        # Step function updates the metric `loss` internally
2622        train_step_fn()
2623  except tf.errors.OutOfRangeError:
2624    tf.experimental.async_clear_error()
2625  logging.info('loss = %s', loss.numpy())
2626  ```
2627
2628  Yields:
2629    Context manager for grouping async operations.
2630  """
2631  # TODO(haoyuzhang): replace env var once we have a config method to turn on
2632  # and off async streaming RPC
2633  remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
2634  old_policy = os.environ.get(remote_async_env_var)
2635  try:
2636    os.environ[remote_async_env_var] = str(True)
2637    yield
2638    # Note: sync local and remote executors iff the async block does not raise
2639    # an exception. Triggering sync after an exception may lead to derived
2640    # runtime errors and unexpected exception types.
2641    context().sync_executors()
2642  finally:
2643    if old_policy is None:
2644      del os.environ[remote_async_env_var]
2645    else:
2646      os.environ[remote_async_env_var] = old_policy
2647
2648
2649def async_wait():
2650  """Sync all async operations and raise any errors during execution.
2651
2652  In async execution mode, an op/function call can return before finishing the
2653  actual execution. Calling this method creates a synchronization barrier for
2654  all async op and function execution. It only returns when all pending nodes
2655  are finished, potentially raising exceptions if async execution results in
2656  an error state. It is a no-op if the context is not initialized.
2657  """
2658  disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"
2659  if os.environ.get(disable_async_executor_env_var) == str(True):
2660    return
2661  if context()._context_handle is not None:  # pylint: disable=protected-access
2662    context().sync_executors()
2663
2664
2665@tf_export("experimental.async_clear_error")
2666def async_clear_error():
2667  """Clear pending operations and error statuses in async execution.
2668
2669  In async execution mode, an error in op/function execution can lead to errors
2670  in subsequent ops/functions that are scheduled but not yet executed. Calling
2671  this method clears all pending operations and reset the async execution state.
2672
2673  Example:
2674
2675  ```
2676  while True:
2677    try:
2678      # Step function updates the metric `loss` internally
2679      train_step_fn()
2680    except tf.errors.OutOfRangeError:
2681      tf.experimental.async_clear_error()
2682      break
2683  logging.info('loss = %s', loss.numpy())
2684  ```
2685  """
2686  context().clear_executor_errors()
2687
2688
2689def add_function(fdef):
2690  """Add a function definition to the context."""
2691  context().add_function(fdef)
2692
2693
2694def remove_function(name):
2695  """Remove a function from the context."""
2696  context().remove_function(name)
2697
2698
2699def get_function_def(name):
2700  return context().get_function_def(name)
2701
2702
2703def register_custom_device(device_capsule, device_name, device_info_capsule):
2704  """Calls TFE_RegisterCustomDevice to register a custom device with Python.
2705
2706  Enables using C extensions specifying a custom device from Python. See the
2707  experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2708  details.
2709
2710  Note that custom devices are not currently supported inside `tf.function`s.
2711
2712  Args:
2713    device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
2714      containing a pointer to a TFE_CustomDevice struct. The capsule retains
2715      ownership of the memory.
2716    device_name: A string indicating the name to register the custom device
2717      under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
2718      subsequently be passed to `with tf.device(...):`.
2719    device_info_capsule: A PyCapsule with the name set to
2720      'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
2721      struct with the initial state of the custom device (the void* device_info
2722      argument to TFE_RegisterCustomDevice). This method takes ownership of the
2723      memory and clears the capsule destructor.
2724  """
2725  context().register_custom_device(device_capsule, device_name,
2726                                   device_info_capsule)
2727
2728
2729# Not every user creates a Context via context.context()
2730# (for example, enable_eager_execution in python/framework/ops.py),
2731# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
2732# in_graph_mode are both parameterless functions.
2733def _tmp_in_graph_mode():
2734  if context_safe() is None:
2735    # Context not yet initialized. Assume graph mode following the
2736    # default implementation in `is_in_graph_mode`.
2737    return True
2738  return not executing_eagerly()
2739
2740
2741is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode
2742