• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import copy
24import random
25import threading
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python import pywrap_tensorflow
29from tensorflow.python import tf2
30from tensorflow.python.framework import c_api_util
31from tensorflow.python.framework import device as pydev
32from tensorflow.python.util import compat
33from tensorflow.python.util import is_in_graph_mode
34from tensorflow.python.util import tf_contextlib
35from tensorflow.python.util.tf_export import tf_export
36
37GRAPH_MODE = 0
38EAGER_MODE = 1
39
40default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
41
42# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
43# new_device_spec).
44# Note that we do not protect this with a lock and instead rely on python's GIL
45# and the idempotent nature of writes to provide thread safety.
46_device_parsing_cache = {}
47_starting_device_spec = pydev.DeviceSpec.from_string("")
48
49_MAXINT32 = 2**31 - 1
50
51DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
52DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
53DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
54DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
55    pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
56SYNC = 0
57ASYNC = 1
58
59
60class _EagerTensorCache(object):
61  """Simple cache which evicts items based on length in a FIFO manner."""
62
63  def __init__(self, max_items=256, max_tensor_size=10000):
64    self._data = collections.OrderedDict()
65    self._max_items = max_items
66    self._max_tensor_size = max_tensor_size
67
68  def put(self, key, value):
69    if value._num_elements() > self._max_tensor_size:  # pylint: disable=protected-access
70      return
71
72    self._data[key] = value
73
74    if len(self._data) > self._max_items:
75      self._data.popitem(last=False)
76
77  def get(self, key):
78    return self._data.get(key, None)
79
80  def flush(self):
81    self._data = {}
82
83
84class FunctionCallOptions(object):
85  """Options applied at call sites of eager functions.
86
87  Eager functions are functions decorated with tf.contrib.eager.defun.
88  """
89
90  def __init__(self, executor_type=None, config_proto=None):
91    """Constructor.
92
93    Args:
94      executor_type: (optional) name of the executor to be used to execute the
95        eager function. If None or an empty string, the default Tensorflow
96        executor will be used.
97      config_proto: (optional) a `config_pb2.ConfigProto` proto or
98        a serialized string of that proto.
99        The config used by Grappler when optimizing the function graph.
100        Each concrete function is optimized the first time is called. Changing
101        config_proto after the first call has no effect.
102        If config_proto is None, an empty RewriterConfig will be used.
103    """
104    self.config_proto_serialized = config_proto
105    self.executor_type = executor_type
106
107  @property
108  def executor_type(self):
109    return self._executor_type
110
111  @executor_type.setter
112  def executor_type(self, executor_type):
113    self._executor_type = executor_type
114
115  @property
116  def config_proto_serialized(self):
117    return self._config_proto_serialized
118
119  @config_proto_serialized.setter
120  def config_proto_serialized(self, config):
121    if isinstance(config, config_pb2.ConfigProto):
122      self._config_proto_serialized = config.SerializeToString()
123    elif isinstance(config, str):
124      self._config_proto_serialized = config
125    elif config is None:
126      self._config_proto_serialized = (
127          config_pb2.ConfigProto().SerializeToString())
128    else:
129      raise ValueError("the rewriter config must be either a "
130                       "config_pb2.ConfigProto, or a serialized string of that "
131                       "proto or None. got: {}".format(type(config)))
132
133
134class _ThreadLocalData(threading.local):
135  """Thread local storage for the eager context."""
136
137  def __init__(self):
138    super(_ThreadLocalData, self).__init__()
139    self.device_spec = _starting_device_spec
140    self.device_name = ""
141    self.mode = default_execution_mode
142    self.is_eager = default_execution_mode == EAGER_MODE
143    self.scope_name = ""
144    self.summary_writer = None
145    self.summary_recording = None
146    self.summary_recording_distribution_strategy = True
147    self.summary_step = None
148    self.scalar_cache = {}
149    self._ones_rank_cache = None
150    self._zeros_cache = None
151    self.execution_mode = SYNC
152    self.function_call_options = None
153
154  @property
155  def ones_rank_cache(self):
156    if not self._ones_rank_cache:
157      self._ones_rank_cache = _EagerTensorCache()
158    return self._ones_rank_cache
159
160  @property
161  def zeros_cache(self):
162    if not self._zeros_cache:
163      self._zeros_cache = _EagerTensorCache()
164    return self._zeros_cache
165
166
167ContextSwitch = collections.namedtuple(
168    "ContextSwitch", ["is_building_function", "enter_context_fn",
169                      "device_stack"])
170
171
172# `_ContextSwitchStack` is a `threading.local` to match the semantics of
173# ``DefaultGraphStack`, which is also a `threading.local`.
174class _ContextSwitchStack(threading.local):
175  """A thread-local stack of context switches."""
176
177  def __init__(self, eager):
178    super(_ContextSwitchStack, self).__init__()
179    self.stack = []
180    if eager:
181      # Initialize the stack with a pointer to enter the eager context; this
182      # ensures that the fact that eager execution was enabled is propagated
183      # across threads, since (1) `enable_eager_execution` modifies a
184      # process-level flag (`default_execution_mode`) and (2) `__init__` is
185      # called each time a threading.local object is used in a separate thread.
186      self.push(is_building_function=False, enter_context_fn=eager_mode,
187                device_stack=None)
188
189  def push(self, is_building_function, enter_context_fn, device_stack):
190    """Push metadata about a context switch onto the stack.
191
192    A context switch can take any one of the two forms: installing a graph as
193    the default graph, or entering the eager context. For each context switch,
194    we record whether or not the entered context is building a function.
195
196    Args:
197      is_building_function: (bool.) Whether the context is building a function.
198      enter_context_fn: (function.) A callable that executes the context switch.
199        For example, `graph.as_default` or `eager_mode`.
200      device_stack: If applicable, the device function stack for this
201        graph. When breaking out of graphs in init_scope, the innermost nonempty
202        device stack is used. Eager contexts put `None` here and the value is
203        never used.
204    """
205
206    self.stack.append(
207        ContextSwitch(is_building_function, enter_context_fn, device_stack))
208
209  def pop(self):
210    """Pop the stack."""
211
212    self.stack.pop()
213
214
215# TODO(agarwal): rename to EagerContext / EagerRuntime ?
216# TODO(agarwal): consider keeping the corresponding Graph here.
217class Context(object):
218  """Environment in which eager operations execute."""
219
220  # TODO(agarwal): create and link in some documentation for `execution_mode`.
221  # pylint: disable=redefined-outer-name
222  def __init__(self,
223               config=None,
224               device_policy=None,
225               execution_mode=None,
226               server_def=None):
227    """Creates a new Context.
228
229    Args:
230      config: (Optional.) A `ConfigProto` protocol buffer with configuration
231        options for the Context. Note that a lot of these options may be
232        currently unimplemented or irrelevant when eager execution is enabled.
233      device_policy: (Optional.) What policy to use when trying to run an
234        operation on a device with inputs which are not on that device.
235        When set to None, an appropriate value will be picked automatically.
236        The value picked may change between TensorFlow releases.
237
238        Defaults to DEVICE_PLACEMENT_SILENT.
239        Valid values:
240        - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is
241          not correct.
242        - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
243          right device but raises a warning.
244        - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
245          hide performance problems.
246        - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
247          raising errors on the other ones.
248      execution_mode: (Optional.) Policy controlling how operations dispatched
249        are actually executed. When set to None, an appropriate value will be
250        picked automatically. The value picked may change between TensorFlow
251        releases.
252        Valid values:
253        - SYNC: executes each operation synchronously.
254        - ASYNC: executes each operation asynchronously. These
255          operations may return "non-ready" handles.
256      server_def: (Optional.) A tensorflow::ServerDef proto.
257        Enables execution on remote devices. GrpcServers need to be started by
258        creating an identical server_def to this, and setting the appropriate
259        task_indexes, so that the servers can communicate. It will then be
260        possible to execute operations on remote devices.
261
262    Raises:
263     ValueError: If execution_mode is not valid.
264    """
265    if config is None:
266      config = config_pb2.ConfigProto(
267          allow_soft_placement=True,
268          log_device_placement=False,
269      )
270    self._config = config
271    self._thread_local_data = _ThreadLocalData()
272    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
273    self._context_handle = None
274    self._context_devices = None
275    self._post_execution_callbacks = []
276    self._seed = None
277    self._initialize_lock = threading.Lock()
278    if device_policy is None:
279      device_policy = DEVICE_PLACEMENT_SILENT
280    self._device_policy = device_policy
281    if execution_mode not in (None, SYNC, ASYNC):
282      raise ValueError(
283          "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode)
284    if execution_mode is None:
285      execution_mode = SYNC
286    self._execution_mode = execution_mode
287    self._server_def = server_def
288    self._collective_ops_server_def = None
289
290  # pylint: enable=redefined-outer-name
291
292  def _set_global_seed(self, seed):
293    """Set a global eager mode seed for random ops."""
294    self._seed = seed
295    self._rng = random.Random(self._seed)
296    # Also clear the kernel cache, to reset any existing seeds
297    if self._context_handle is not None:
298      pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
299
300  def _internal_operation_seed(self):
301    """Returns a fake operation seed.
302
303      In eager mode, user shouldn't set or depend on operation seed.
304      Here, we generate a random seed based on global seed to make
305      operation's randomness different and depend on the global seed.
306
307    Returns:
308      A fake operation seed based on global seed.
309    """
310    return self._rng.randint(0, _MAXINT32)
311
312  def _initialize_devices(self):
313    """Helper to initialize devices."""
314    # Store list of devices
315    self._context_devices = []
316    device_list = pywrap_tensorflow.TFE_ContextListDevices(
317        self._context_handle)
318    try:
319      self._num_gpus = 0
320      for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
321        dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
322        self._context_devices.append(pydev.canonical_name(dev_name))
323        dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
324        if dev_type == "GPU":
325          self._num_gpus += 1
326
327    finally:
328      pywrap_tensorflow.TF_DeleteDeviceList(device_list)
329
330  def _initialize_handle_and_devices(self):
331    """Initialize handle and devices."""
332    with self._initialize_lock:
333      if self._context_handle is not None:
334        return
335      assert self._context_devices is None
336      opts = pywrap_tensorflow.TFE_NewContextOptions()
337      try:
338        if self._config is not None:
339          config_str = self._config.SerializeToString()
340          pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
341        if self._device_policy is not None:
342          pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
343              opts, self._device_policy)
344        if self._execution_mode == ASYNC:
345          pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
346        self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
347      finally:
348        pywrap_tensorflow.TFE_DeleteContextOptions(opts)
349      assert not (self._server_def and self._collective_ops_server_def), (
350          "Cannot enable remote execution as well as collective ops at the "
351          "moment. If this is important to you, please file an issue.")
352      if self._server_def is not None:
353        server_def_str = self._server_def.SerializeToString()
354        pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600,
355                                                  server_def_str)
356      elif self._collective_ops_server_def is not None:
357        server_def_str = self._collective_ops_server_def.SerializeToString()
358        pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle,
359                                                  server_def_str)
360
361      self._initialize_devices()
362
363  def _clear_caches(self):
364    self.scalar_cache().clear()
365    self.ones_rank_cache().flush()
366    self.zeros_cache().flush()
367
368  def set_server_def(self, server_def, keep_alive_secs=600):
369    """Allow setting a server_def on the context.
370
371    When a server def is replaced, it effectively clears a bunch of caches
372    within the context. If you attempt to use a tensor object that was pointing
373    to a tensor on the remote device, it will raise an error.
374
375    Args:
376      server_def: A tensorflow::ServerDef proto.
377        Enables execution on remote devices.
378      keep_alive_secs: Num. seconds after which the remote end will hang up.
379        As long as the client is still alive, the server state for the context
380        will be kept alive. If the client is killed (or there is some failure),
381        the server will clean up its context keep_alive_secs after the final RPC
382        it receives.
383
384    Raises:
385      ValueError: if server_def is None.
386    """
387    if not server_def:
388      raise ValueError("server_def is None.")
389    if not self._context_handle:
390      self._server_def = server_def
391    else:
392      server_def_str = server_def.SerializeToString()
393      pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
394                                                keep_alive_secs, server_def_str)
395
396      # Clear all the caches in case there are remote tensors in them.
397      self._clear_caches()
398
399      self._initialize_devices()
400
401  def enable_collective_ops(self, server_def):
402    """Enable collective ops with an appropriate server_def.
403
404    If previously enabled, this cannot be re-enabled.
405
406    Args:
407      server_def: A tensorflow::ServerDef proto. Enables execution on remote
408        devices.
409
410    Raises:
411      ValueError: if server_def is None.
412    """
413    if not server_def:
414      raise ValueError("server_def is None.")
415    if not self._context_handle:
416      self._collective_ops_server_def = server_def
417    else:
418      server_def_str = server_def.SerializeToString()
419      pywrap_tensorflow.TFE_EnableCollectiveOps(self._context_handle,
420                                                server_def_str)
421
422      self._clear_caches()
423      self._initialize_devices()
424
425  @property
426  def _handle(self):
427    ctx = self._context_handle
428    if ctx is None:
429      self._initialize_handle_and_devices()
430      return self._context_handle
431    else:
432      return ctx
433
434  @property
435  def _devices(self):
436    devices = self._context_devices
437    if devices is None:
438      self._initialize_handle_and_devices()
439      return self._context_devices
440    else:
441      return devices
442
443  def __str__(self):
444    if self._context_handle is None:
445      return "Eager TensorFlow Context. Devices currently uninitialized."
446    else:
447      devices = self._devices
448      lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
449      for i, d in enumerate(devices):
450        lines.append("   Device %d: %s" % (i, d))
451      return "\n".join(lines)
452
453  @tf_contextlib.contextmanager
454  def _mode(self, mode):
455    """A context manager to allow setting the mode to EAGER/GRAPH."""
456    ctx = self._thread_local_data
457    old_mode = ctx.mode
458    old_is_eager = ctx.is_eager
459    ctx.mode = mode
460    ctx.is_eager = mode == EAGER_MODE
461    if mode == EAGER_MODE:
462      # Entering graph mode does not provide us with sufficient information to
463      # record a context switch; graph-based context switches are only logged
464      # when a graph is registered as the default graph.
465      self.context_switches.push(False, eager_mode, None)
466    try:
467      yield
468    finally:
469      ctx.is_eager = old_is_eager
470      ctx.mode = old_mode
471      if mode == EAGER_MODE:
472        self.context_switches.pop()
473
474  def executing_eagerly(self):
475    """Returns True if current thread has eager executing enabled."""
476    return self._thread_local_data.is_eager
477
478  def scalar_cache(self):
479    """Per-device cache for scalars."""
480    return self._thread_local_data.scalar_cache
481
482  def ones_rank_cache(self):
483    """Per-device cache for scalars."""
484    return self._thread_local_data.ones_rank_cache
485
486  def zeros_cache(self):
487    """Per-device cache for scalars."""
488    return self._thread_local_data.zeros_cache
489
490  @property
491  def scope_name(self):
492    """Returns scope name for the current thread."""
493    return self._thread_local_data.scope_name
494
495  @scope_name.setter
496  def scope_name(self, s):
497    """Sets scope name for the current thread."""
498    self._thread_local_data.scope_name = s
499
500  @property
501  def summary_writer(self):
502    """Returns default summary writer for the current thread."""
503    return self._thread_local_data.summary_writer
504
505  @summary_writer.setter
506  def summary_writer(self, writer):
507    """Sets default summary writer for the current thread."""
508    self._thread_local_data.summary_writer = writer
509
510  @property
511  def summary_recording(self):
512    """Returns summary recording condition."""
513    return self._thread_local_data.summary_recording
514
515  @summary_recording.setter
516  def summary_recording(self, condition):
517    """Sets summary recording condition."""
518    self._thread_local_data.summary_recording = condition
519
520  @property
521  def summary_recording_distribution_strategy(self):
522    """Returns summary recording condition for distribution strategy."""
523    return self._thread_local_data.summary_recording_distribution_strategy
524
525  @summary_recording_distribution_strategy.setter
526  def summary_recording_distribution_strategy(self, condition):
527    """Sets summary recording condition for distribution strategy."""
528    self._thread_local_data.summary_recording_distribution_strategy = condition
529
530  @property
531  def summary_step(self):
532    """Returns summary step variable."""
533    return self._thread_local_data.summary_step
534
535  @summary_step.setter
536  def summary_step(self, step):
537    """Sets summary step variable."""
538    self._thread_local_data.summary_step = step
539
540  @property
541  def device_name(self):
542    """Returns the device name for the current thread."""
543    return self._thread_local_data.device_name
544
545  @property
546  def device_spec(self):
547    """Returns the device spec for the current thread."""
548    return self._thread_local_data.device_spec
549
550  @tf_contextlib.contextmanager
551  def device(self, name):
552    """Context-manager to force placement of operations and Tensors on a device.
553
554    Args:
555      name: Name of the device or None to get default placement.
556
557    Yields:
558      Nothing.
559
560    Raises:
561      ValueError: If name is not a string or is an invalid device name.
562    """
563    eager_context = self._thread_local_data
564    old_device_name = eager_context.device_name
565    old_device_spec = eager_context.device_spec
566    cache_key = (old_device_name, name)
567    try:
568      new_device_name, new_device_spec = _device_parsing_cache[cache_key]
569    except TypeError:
570      # Error while trying to compute the cache key.
571      raise ValueError("Expecting a string device name. Got %s(%s)" %
572                       (type(name), name))
573    except KeyError:
574      # Handle a cache miss.
575      if name is not None:
576        if not isinstance(name, str):
577          raise ValueError("Expecting a string device name. Got %s(%s)" %
578                           (type(name), name))
579        device_spec = pydev.DeviceSpec.from_string(name)
580        if old_device_name:
581          new_device_spec = copy.copy(old_device_spec)
582        else:
583          self._initialize_handle_and_devices()
584          new_device_spec = pydev.DeviceSpec.from_string(
585              self._context_devices[0])
586        new_device_spec.merge_from(device_spec)
587      else:
588        new_device_spec = pydev.DeviceSpec.from_string("")
589      new_device_name = new_device_spec.to_string()
590      _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
591
592    try:
593      eager_context.device_name = new_device_name
594      eager_context.device_spec = new_device_spec
595      yield
596    finally:
597      eager_context.device_name = old_device_name
598      eager_context.device_spec = old_device_spec
599
600  def devices(self):
601    """List of the names of devices available to execute operations."""
602    return self._devices
603
604  @property
605  def execution_mode(self):
606    """Gets execution mode for current thread."""
607    # Only get the execution mode from the context if it has already been
608    # initialized
609    if self._context_handle is None:
610      return self._execution_mode
611
612    mode = self._thread_local_data.execution_mode
613    if mode is None:
614      mode = self._execution_mode
615    return mode
616
617  @execution_mode.setter
618  def execution_mode(self, mode):
619    """Sets execution mode for current thread."""
620    if mode not in (None, SYNC, ASYNC):
621      raise ValueError(
622          "Execution mode should be None/SYNC/ASYNC. Got %s" % mode)
623    if mode is None:
624      mode = SYNC
625
626    if self._thread_local_data.execution_mode != mode:
627      self._thread_local_data.execution_mode = mode
628
629      # Only set the execution mode if the context has already been initialized
630      if self._context_handle is not None:
631        pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._context_handle,
632                                                       mode == ASYNC)
633      else:
634        self._execution_mode = mode
635
636  @property
637  def function_call_options(self):
638    """Returns function call options for current thread.
639
640    Note that the returned object is still referenced by the eager context.
641
642    Returns: the FunctionCallOptions for current thread.
643    """
644    if self._thread_local_data.function_call_options is None:
645      base_config = config_pb2.ConfigProto()
646      base_config.CopyFrom(self._config)
647      self._thread_local_data.function_call_options = FunctionCallOptions(
648          config_proto=base_config)
649
650    return self._thread_local_data.function_call_options
651
652  @function_call_options.setter
653  def function_call_options(self, options):
654    """Returns function call options for current thread."""
655    self._thread_local_data.function_call_options = options
656
657  def async_wait(self):
658    """Waits for ops dispatched in ASYNC mode to finish."""
659    pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)
660
661  def async_clear_error(self):
662    """Clears errors raised during ASYNC execution."""
663    pywrap_tensorflow.TFE_ContextAsyncClearError(self._handle)
664
665  def num_gpus(self):
666    """The number of GPUs available to execute operations."""
667    self._initialize_handle_and_devices()
668    return self._num_gpus
669
670  def add_function(self, fn):
671    """Add a function definition to the context.
672
673    Once added, the function (identified by its name) can be executed like any
674    other operation.
675
676    Args:
677      fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
678    """
679    pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
680
681  def add_function_def(self, fdef):
682    """Add a function definition to the context.
683
684    Once added, the function (identified by its name) can be executed like any
685    other operation.
686
687    Args:
688      fdef: A FunctionDef protocol buffer message.
689    """
690    fdef_string = fdef.SerializeToString()
691    pywrap_tensorflow.TFE_ContextAddFunctionDef(
692        self._handle, fdef_string, len(fdef_string))
693
694  def has_function(self, name):
695    """Check if a function `name` is registered."""
696    return bool(pywrap_tensorflow.TFE_ContextHasFunction(self._handle, name))
697
698  def add_post_execution_callback(self, callback):
699    """Add a post-execution callback to the context.
700
701    A post-execution callback is invoked immediately after an eager operation or
702    function has finished execution, providing access to the op's type, name
703    input and output tensors. Multiple execution callbacks can be added, in
704    which case the callbacks will be invoked in the order in which they are
705    added.
706
707    Args:
708      callback: a callable of the signature
709      `f(op_type, op_name, attrs, inputs, outputs)`.
710      `op_type` is the type of the operation that was just executed (e.g.,
711        `MatMul`).
712      `op_name` is the name of the operation that has was just executed. This
713        name is set by the client who created the operation and can be `None` if
714        it is unset.
715      `attrs` contains the attributes of the operation as a `tuple` of
716        alternating attribute names and attribute values.
717      `inputs` is the `list` of input `Tensor`(s) to the op.
718      `outputs` is the `list` of output `Tensor`(s) from the op.
719       Return value(s) from the callback are ignored.
720    """
721    # TODO(cais): (b/64674139) Allow access to function-internal operations.
722    self._post_execution_callbacks.append(callback)
723
724  def clear_post_execution_callbacks(self):
725    """Clear all post-execution callbacks added to the context."""
726    del self._post_execution_callbacks[:]
727
728  @property
729  def post_execution_callbacks(self):
730    """Get the list of post-execution callbacks added to the context."""
731    return self._post_execution_callbacks
732
733  @property
734  def gpu_per_process_memory_fraction(self):
735    return self._config.gpu_options.per_process_gpu_memory_fraction
736
737  @gpu_per_process_memory_fraction.setter
738  def gpu_per_process_memory_fraction(self, fraction):
739    if self._context_handle is not None:
740      raise RuntimeError(
741          "GPU options must be set at program startup")
742
743    self._config.gpu_options.per_process_gpu_memory_fraction = fraction
744
745  @property
746  def gpu_per_process_memory_growth(self):
747    return self._config.gpu_options.allow_growth
748
749  @gpu_per_process_memory_growth.setter
750  def gpu_per_process_memory_growth(self, enabled):
751    if self._context_handle is not None:
752      raise RuntimeError(
753          "GPU options must be set at program startup")
754
755    self._config.gpu_options.allow_growth = enabled
756
757  @property
758  def intra_op_parallelism_threads(self):
759    return self._config.intra_op_parallelism_threads
760
761  @intra_op_parallelism_threads.setter
762  def intra_op_parallelism_threads(self, num_threads):
763    if self._context_handle is not None:
764      raise RuntimeError(
765          "Intra op parallelism must be set at program startup")
766
767    self._config.intra_op_parallelism_threads = num_threads
768
769  @property
770  def inter_op_parallelism_threads(self):
771    return self._config.inter_op_parallelism_threads
772
773  @inter_op_parallelism_threads.setter
774  def inter_op_parallelism_threads(self, num_threads):
775    if self._context_handle is not None:
776      raise RuntimeError(
777          "Inter op parallelism must be set at program startup")
778
779    self._config.inter_op_parallelism_threads = num_threads
780
781  @property
782  def soft_device_placement(self):
783    return self._config.allow_soft_placement
784
785  @soft_device_placement.setter
786  def soft_device_placement(self, enabled):
787    self._config.allow_soft_placement = enabled
788
789    self._thread_local_data.function_call_options = None
790
791  @property
792  def log_device_placement(self):
793    return self._config.log_device_placement
794
795  @log_device_placement.setter
796  def log_device_placement(self, enabled):
797    if self._context_handle is not None:
798      raise RuntimeError(
799          "Device placement logging must be set at program startup")
800
801    self._config.log_device_placement = enabled
802
803  @property
804  def device_policy(self):
805    # Only get the policy from the context if it has already been initialized
806    if self._context_handle is not None:
807      return pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(self._handle)
808
809    return self._device_policy
810
811  @device_policy.setter
812  def device_policy(self, policy):
813    if policy is None:
814      policy = DEVICE_PLACEMENT_SILENT
815
816    if self._device_policy != policy:
817      self._device_policy = policy
818
819      # Only set the policy if the context has already been initialized
820      if self._context_handle is not None:
821        pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
822            self._handle, self._device_policy)
823
824  def enable_run_metadata(self):
825    """Enables tracing of op execution via RunMetadata.
826
827    To retrieve the accumulated metadata call context.export_run_metadata()
828    and to stop tracing call context.disable_run_metadata().
829    """
830    pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._handle)
831
832  def disable_run_metadata(self):
833    """Disables tracing of op execution via RunMetadata."""
834    if not self._context_handle:
835      return
836    pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
837
838  def enable_graph_collection(self):
839    """Enables graph collection of executed functions.
840
841    To retrieve the accumulated graphs call context.export_run_metadata()
842    and to stop collecting graphs call context.disable_graph_collection().
843    """
844    pywrap_tensorflow.TFE_ContextEnableGraphCollection(self._handle)
845
846  def disable_graph_collection(self):
847    """Disables graph collections of executed functions."""
848    if not self._context_handle:
849      return
850    pywrap_tensorflow.TFE_ContextDisableGraphCollection(self._context_handle)
851
852  def export_run_metadata(self):
853    """Returns a RunMetadata proto with accumulated information.
854
855    The returned protocol buffer contains information since the most recent call
856    to either enable_run_metadata or export_run_metadata.
857
858    Returns:
859      A RunMetadata protocol buffer. Or None if not enabled.
860    """
861    if not self._context_handle:
862      return None
863    with c_api_util.tf_buffer() as buffer_:
864      pywrap_tensorflow.TFE_ContextExportRunMetadata(
865          self._context_handle, buffer_)
866      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
867    run_metadata = config_pb2.RunMetadata()
868    run_metadata.ParseFromString(compat.as_bytes(proto_data))
869    return run_metadata
870
871  @property
872  def context_switches(self):
873    """Returns a stack of context switches."""
874    return self._context_switches
875
876  def start_step(self):
877    pywrap_tensorflow.TFE_ContextStartStep(self._handle)
878
879  def end_step(self):
880    pywrap_tensorflow.TFE_ContextEndStep(self._handle)
881
882_context = None
883_context_lock = threading.Lock()
884
885
886def _initialize_context():
887  global _context
888  with _context_lock:
889    if _context is None:
890      _context = Context()
891
892
893def context():
894  """Returns a singleton context object."""
895  if _context is None:
896    _initialize_context()
897  return _context
898
899
900def context_safe():
901  """Returns current context (or None if one hasn't been initialized)."""
902  return _context
903
904
905def set_global_seed(seed):
906  """Sets the eager mode seed."""
907  context()._set_global_seed(seed)  # pylint: disable=protected-access
908
909
910def global_seed():
911  """Returns the eager mode seed."""
912  return context()._seed  # pylint: disable=protected-access
913
914
915def internal_operation_seed():
916  """Returns the operation seed generated based on global seed."""
917  return context()._internal_operation_seed()  # pylint: disable=protected-access
918
919
920@tf_export("executing_eagerly")
921def executing_eagerly():
922  """Returns True if the current thread has eager execution enabled.
923
924  Eager execution is typically enabled via `tf.enable_eager_execution`,
925  but may also be enabled within the context of a Python function via
926  tf.contrib.eager.py_func.
927  """
928  return context().executing_eagerly()
929
930
931def in_eager_mode():
932  """Use executing_eagerly() instead. This function will be removed."""
933  return executing_eagerly()
934
935
936def shared_name(name=None):
937  """Returns the anonymous shared name GUID if no shared name is specified.
938
939  In eager mode we need to use a unique shared name to avoid spurious sharing
940  issues. The runtime generates a unique name on our behalf when the reserved
941  GUID is used as a shared name.
942
943  Args:
944    name: Optional shared name
945
946  Returns:
947    Eager compatible shared name.
948  """
949  if name or not executing_eagerly():
950    return name
951
952  # Ensure a unique name when eager execution is enabled to avoid spurious
953  # sharing issues.
954  return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
955
956
957def graph_mode():
958  """Context-manager to disable eager execution for the current thread."""
959  return context()._mode(GRAPH_MODE)  # pylint: disable=protected-access
960
961
962def eager_mode():
963  """Context-manager to enable eager execution for the current thread."""
964  return context()._mode(EAGER_MODE)  # pylint: disable=protected-access
965
966
967# TODO(agarwal): get rid of this and use ops.name_scope instead.
968@contextlib.contextmanager
969def namescope(name):
970  """ContextManager for creating hierarchical name scopes."""
971  ctx = context()
972  old_name = ctx.scope_name
973  ctx.scope_name = "%s/%s" % (old_name, name) if old_name else name
974  try:
975    yield
976  finally:
977    ctx.scope_name = old_name
978
979
980def scope_name():
981  """Name of the current scope."""
982  return context().scope_name
983
984
985def device(name):
986  """Context-manager to force placement of operations and Tensors on a device.
987
988  Example:
989  ```python
990  with tfe.device('gpu:0'):
991    with tfe.device('cpu:0'):
992      shape = tf.constant([], dtype=tf.int32)
993    x = tf.truncated_normal(shape, tf.float32)
994  ```
995  will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
996  operation runs on GPU 0.
997
998  Args:
999    name: Name of the device (see context().devices()), or None to
1000      perform automatic placement.
1001
1002  Returns:
1003    Context manager for setting the device.
1004  """
1005  return context().device(name)
1006
1007
1008@tf_export("config.experimental_list_devices")
1009def list_devices():
1010  """List the names of the available devices.
1011
1012  Returns:
1013    Names of the available devices, as a `list`.
1014  """
1015  return context().devices()
1016
1017
1018@tf_export("debugging.get_log_device_placement")
1019def get_log_device_placement():
1020  """Get if device placements are logged.
1021
1022  Returns:
1023    If device placements are logged.
1024  """
1025  return context().log_device_placement
1026
1027
1028@tf_export("debugging.set_log_device_placement")
1029def set_log_device_placement(enabled):
1030  """Set if device placements should be logged.
1031
1032  Args:
1033    enabled: Whether to enabled device placement logging.
1034  """
1035  context().log_device_placement = enabled
1036
1037
1038@tf_contextlib.contextmanager
1039def device_policy(policy):
1040  """Context manager for setting device placement policy for current thread."""
1041  ctx = context()
1042  old_policy = ctx.device_policy
1043  try:
1044    ctx.device_policy = policy
1045    yield
1046  finally:
1047    ctx.device_policy = old_policy
1048
1049
1050def set_execution_mode(mode):
1051  """Sets execution mode for the current thread."""
1052  context().execution_mode = mode
1053
1054
1055@tf_contextlib.contextmanager
1056def execution_mode(mode):
1057  """Context manager for setting execution mode for current thread."""
1058  ctx = context()
1059  old_mode = ctx.execution_mode
1060  try:
1061    ctx.execution_mode = mode
1062    yield
1063  finally:
1064    ctx.execution_mode = old_mode
1065
1066
1067@tf_export("experimental.function_executor_type")
1068@tf_contextlib.contextmanager
1069def function_executor_type(executor_type):
1070  """Context manager for setting the executor of eager defined functions.
1071
1072  Eager defined functions are functions decorated by tf.contrib.eager.defun.
1073
1074  Args:
1075    executor_type: a string for the name of the executor to be used to execute
1076      functions defined by tf.contrib.eager.defun.
1077
1078  Yields:
1079    Context manager for setting the executor of eager defined functions.
1080  """
1081  current_options = context().function_call_options
1082  old_options = copy.copy(current_options)
1083  try:
1084    current_options.executor_type = executor_type
1085    yield
1086  finally:
1087    context().function_call_options = old_options
1088
1089
1090def async_wait():
1091  """Waits for ops dispatched in ASYNC mode to finish."""
1092  return context().async_wait()
1093
1094
1095def async_clear_error():
1096  """Clears errors raised during ASYNC execution mode."""
1097  return context().async_clear_error()
1098
1099
1100def num_gpus():
1101  """Get the number of available GPU devices.
1102
1103  Returns:
1104    The number of available GPU devices.
1105  """
1106  return context().num_gpus()
1107
1108
1109def enable_run_metadata():
1110  """Enables tracing of op execution via RunMetadata.
1111
1112  To retrieve the accumulated metadata call context.export_run_metadata()
1113  and to stop tracing call context.disable_run_metadata().
1114  """
1115  context().enable_run_metadata()
1116
1117
1118def disable_run_metadata():
1119  """Disables tracing of op execution via RunMetadata."""
1120  context().disable_run_metadata()
1121
1122
1123def enable_graph_collection():
1124  """Enables tracing of op execution via RunMetadata.
1125
1126  To retrieve the accumulated metadata call context.export_run_metadata()
1127  and to stop tracing call context.disable_run_metadata().
1128  """
1129  context().enable_graph_collection()
1130
1131
1132def disable_graph_collection():
1133  """Disables tracing of op execution via RunMetadata."""
1134  context().disable_graph_collection()
1135
1136
1137def export_run_metadata():
1138  """Returns a RunMetadata proto with accumulated information.
1139
1140  The returned protocol buffer contains information since the most recent call
1141  to either enable_run_metadata or export_run_metadata.
1142
1143  Returns:
1144    A RunMetadata protocol buffer.
1145  """
1146  return context().export_run_metadata()
1147
1148
1149def set_server_def(server_def):
1150  context().set_server_def(server_def)
1151
1152
1153def add_function(fdef):
1154  """Add a function definition to the context."""
1155  context().add_function(fdef)
1156
1157
1158# Not every user creates a Context via context.context()
1159# (for example, enable_eager_execution in python/framework/ops.py),
1160# but they do all import this file.  Note that IS_IN_GRAPH_MODE and
1161# in_graph_mode are both parameterless functions.
1162def _tmp_in_graph_mode():
1163  if context_safe() is None:
1164    # Context not yet initialized. Assume graph mode following the
1165    # default implementation in `is_in_graph_mode`.
1166    return True
1167  return not executing_eagerly()
1168
1169
1170is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode
1171