• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17import collections
18import copy
19import re
20import sys
21import threading
22import types
23from absl import app
24
25import numpy as np
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.framework import full_type_pb2
29from tensorflow.core.framework import function_pb2
30from tensorflow.core.framework import graph_pb2
31from tensorflow.core.framework import node_def_pb2
32from tensorflow.core.framework import op_def_pb2
33from tensorflow.core.framework import versions_pb2
34from tensorflow.core.protobuf import config_pb2
35# pywrap_tensorflow must be imported first to avoid protobuf issues.
36# (b/143110113)
37# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
38from tensorflow.python import pywrap_tensorflow
39from tensorflow.python import pywrap_tfe
40# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
41from tensorflow.python import tf2
42from tensorflow.python.client import pywrap_tf_session
43from tensorflow.python.eager import context
44from tensorflow.python.eager import core
45from tensorflow.python.eager import monitoring
46from tensorflow.python.eager import tape
47from tensorflow.python.framework import c_api_util
48from tensorflow.python.framework import composite_tensor
49from tensorflow.python.framework import cpp_shape_inference_pb2
50from tensorflow.python.framework import device as pydev
51from tensorflow.python.framework import dtypes
52from tensorflow.python.framework import errors
53from tensorflow.python.framework import indexed_slices
54from tensorflow.python.framework import registry
55from tensorflow.python.framework import tensor_conversion_registry
56from tensorflow.python.framework import tensor_shape
57from tensorflow.python.framework import traceable_stack
58from tensorflow.python.framework import versions
59from tensorflow.python.ops import control_flow_util
60from tensorflow.python.platform import tf_logging as logging
61from tensorflow.python.profiler import trace as profiler_trace
62from tensorflow.python.types import core as core_tf_types
63from tensorflow.python.types import internal
64from tensorflow.python.util import compat
65from tensorflow.python.util import decorator_utils
66from tensorflow.python.util import deprecation
67from tensorflow.python.util import dispatch
68from tensorflow.python.util import function_utils
69from tensorflow.python.util import lock_util
70from tensorflow.python.util import memory
71from tensorflow.python.util import object_identity
72from tensorflow.python.util import tf_contextlib
73from tensorflow.python.util import tf_stack
74from tensorflow.python.util import traceback_utils
75from tensorflow.python.util.compat import collections_abc
76from tensorflow.python.util.deprecation import deprecated_args
77from tensorflow.python.util.lazy_loader import LazyLoader
78from tensorflow.python.util.tf_export import kwarg_only
79from tensorflow.python.util.tf_export import tf_export
80
81# TODO(b/218887885): Loaded lazily due to a circular dependency with this file.
82tensor_spec = LazyLoader(
83    "tensor_spec", globals(),
84    "tensorflow.python.framework.tensor_spec")
85ag_ctx = LazyLoader(
86    "ag_ctx", globals(),
87    "tensorflow.python.autograph.core.ag_ctx")
88
89
90# Temporary global switches determining if we should enable the work-in-progress
91# calls to the C API. These will be removed once all functionality is supported.
92_USE_C_API = True
93_USE_C_SHAPES = True
94
95_api_usage_gauge = monitoring.BoolGauge(
96    "/tensorflow/api/ops_eager_execution",
97    "Whether ops.enable_eager_execution() is called.")
98
99_tensor_equality_api_usage_gauge = monitoring.BoolGauge(
100    "/tensorflow/api/enable_tensor_equality",
101    "Whether ops.enable_tensor_equality() is called.")
102
103_control_flow_api_gauge = monitoring.BoolGauge(
104    "/tensorflow/api/enable_control_flow_v2",
105    "Whether enable_control_flow_v2() is called.")
106
107_tf_function_api_guage = monitoring.BoolGauge(
108    "/tensorflow/api/tf_function",
109    "Whether tf.function() is used.")
110
111# pylint: disable=protected-access
112_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
113# pylint: enable=protected-access
114
115
116def tensor_id(tensor):
117  """Returns a unique identifier for this Tensor."""
118  return tensor._id  # pylint: disable=protected-access
119
120
121class _UserDeviceSpec(object):
122  """Store user-specified device and provide computation of merged device."""
123
124  def __init__(self, device_name_or_function):
125    self._device_name_or_function = device_name_or_function
126    self.display_name = str(self._device_name_or_function)
127    self.function = device_name_or_function
128    self.raw_string = None
129
130    if isinstance(device_name_or_function, pydev.MergeDevice):
131      self.is_null_merge = device_name_or_function.is_null_merge
132
133    elif callable(device_name_or_function):
134      self.is_null_merge = False
135      dev_func = self._device_name_or_function
136      func_name = function_utils.get_func_name(dev_func)
137      func_code = function_utils.get_func_code(dev_func)
138      if func_code:
139        fname = func_code.co_filename
140        lineno = func_code.co_firstlineno
141      else:
142        fname = "unknown"
143        lineno = -1
144      self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
145
146    elif device_name_or_function is None:
147      # NOTE(taylorrobie): This MUST be False. None signals a break in the
148      #   device stack, so `is_null_merge` must be False for such a case to
149      #   allow callers to safely skip over null merges without missing a None.
150      self.is_null_merge = False
151
152    else:
153      self.raw_string = device_name_or_function
154      self.function = pydev.merge_device(device_name_or_function)
155      self.is_null_merge = self.function.is_null_merge
156
157    # We perform this check in __init__ because it is of non-trivial cost,
158    # and self.string_merge is typically called many times.
159    self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
160
161  def string_merge(self, node_def):
162    if self.fast_string_merge:
163      return self.function.shortcut_string_merge(node_def)
164
165    return compat.as_str(_device_string(self.function(node_def)))
166
167
168class NullContextmanager(object):
169
170  def __init__(self, *args, **kwargs):
171    pass
172
173  def __enter__(self):
174    pass
175
176  def __exit__(self, type_arg, value_arg, traceback_arg):
177    return False  # False values do not suppress exceptions
178
179
180def _override_helper(clazz_object, operator, func):
181  """Overrides (string) operator on Tensors to call func.
182
183  Args:
184    clazz_object: the class to override for; either Tensor or SparseTensor.
185    operator: the string name of the operator to override.
186    func: the function that replaces the overridden operator.
187
188  Raises:
189    ValueError: If operator is not allowed to be overwritten.
190  """
191  if operator not in Tensor.OVERLOADABLE_OPERATORS:
192    raise ValueError(f"Overriding {operator} is disallowed. "
193                     f"Allowed operators are {Tensor.OVERLOADABLE_OPERATORS}.")
194  setattr(clazz_object, operator, func)
195
196
197def _as_graph_element(obj):
198  """Convert `obj` to a graph element if possible, otherwise return `None`.
199
200  Args:
201    obj: Object to convert.
202
203  Returns:
204    The result of `obj._as_graph_element()` if that method is available;
205        otherwise `None`.
206  """
207  conv_fn = getattr(obj, "_as_graph_element", None)
208  if conv_fn and callable(conv_fn):
209    return conv_fn()
210  return None
211
212
213# Deprecated - do not use.
214# This API to avoid breaking estimator and tensorflow-mesh which depend on this
215# internal API. The stub should be safe to use after TF 2.3 is released.
216def is_dense_tensor_like(t):
217  return isinstance(t, core_tf_types.Tensor)
218
219
220def uid():
221  """A unique (within this program execution) integer."""
222  return pywrap_tfe.TFE_Py_UID()
223
224
225def numpy_text(tensor, is_repr=False):
226  """Human readable representation of a tensor's numpy value."""
227  if tensor.dtype.is_numpy_compatible:
228    # pylint: disable=protected-access
229    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
230    # pylint: enable=protected-access
231  else:
232    text = "<unprintable>"
233  if "\n" in text:
234    text = "\n" + text
235  return text
236
237
238def value_text(tensor, is_repr=False):
239  """Either the NumPy value or a custom TensorFlow formatting of `tensor`.
240
241  Custom formatting is used for custom device tensors, e.g. parallel tensors
242  with multiple components on different devices.
243
244  Args:
245    tensor: The tensor to format.
246    is_repr: Controls the style/verbosity of formatting.
247
248  Returns:
249    The formatted tensor.
250  """
251  # pylint: disable=protected-access  # friend access
252  if tensor._prefer_custom_summarizer():
253    text = tensor._summarize_value()
254    # pylint: enable=protected-access
255    if is_repr:
256      text = "value=" + text
257  else:
258    text = numpy_text(tensor, is_repr=is_repr)
259    if is_repr:
260      text = "numpy=" + text
261  return text
262
263
264@tf_export(v1=["enable_tensor_equality"])
265def enable_tensor_equality():
266  """Compare Tensors with element-wise comparison and thus be unhashable.
267
268  Comparing tensors with element-wise allows comparisons such as
269  tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are
270  unhashable. Thus tensors can no longer be directly used in sets or as a key in
271  a dictionary.
272  """
273  logging.vlog(1, "Enabling tensor equality")
274  _tensor_equality_api_usage_gauge.get_cell().set(True)
275  Tensor._USE_EQUALITY = True  # pylint: disable=protected-access
276
277
278@tf_export(v1=["disable_tensor_equality"])
279def disable_tensor_equality():
280  """Compare Tensors by their id and be hashable.
281
282  This is a legacy behaviour of TensorFlow and is highly discouraged.
283  """
284  logging.vlog(1, "Disabling tensor equality")
285  _tensor_equality_api_usage_gauge.get_cell().set(False)
286  Tensor._USE_EQUALITY = False  # pylint: disable=protected-access
287
288
289# TODO(mdan): This object should subclass Symbol, not just Tensor.
290@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
291class Tensor(internal.NativeObject, core_tf_types.Tensor):
292  """A `tf.Tensor` represents a multidimensional array of elements.
293
294  All elements are of a single known data type.
295
296  When writing a TensorFlow program, the main object that is
297  manipulated and passed around is the `tf.Tensor`.
298
299  A `tf.Tensor` has the following properties:
300
301  * a single data type (float32, int32, or string, for example)
302  * a shape
303
304  TensorFlow supports eager execution and graph execution.  In eager
305  execution, operations are evaluated immediately.  In graph
306  execution, a computational graph is constructed for later
307  evaluation.
308
309  TensorFlow defaults to eager execution.  In the example below, the
310  matrix multiplication results are calculated immediately.
311
312  >>> # Compute some values using a Tensor
313  >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
314  >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
315  >>> e = tf.matmul(c, d)
316  >>> print(e)
317  tf.Tensor(
318  [[1. 3.]
319   [3. 7.]], shape=(2, 2), dtype=float32)
320
321  Note that during eager execution, you may discover your `Tensors` are actually
322  of type `EagerTensor`.  This is an internal detail, but it does give you
323  access to a useful function, `numpy`:
324
325  >>> type(e)
326  <class '...ops.EagerTensor'>
327  >>> print(e.numpy())
328    [[1. 3.]
329     [3. 7.]]
330
331  In TensorFlow, `tf.function`s are a common way to define graph execution.
332
333  A Tensor's shape (that is, the rank of the Tensor and the size of
334  each dimension) may not always be fully known.  In `tf.function`
335  definitions, the shape may only be partially known.
336
337  Most operations produce tensors of fully-known shapes if the shapes of their
338  inputs are also fully known, but in some cases it's only possible to find the
339  shape of a tensor at execution time.
340
341  A number of specialized tensors are available: see `tf.Variable`,
342  `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and
343  `tf.RaggedTensor`.
344
345  Caution: when constructing a tensor from a numpy array or pandas dataframe
346  the underlying buffer may be re-used:
347
348  ```python
349  a = np.array([1, 2, 3])
350  b = tf.constant(a)
351  a[0] = 4
352  print(b)  # tf.Tensor([4 2 3], shape=(3,), dtype=int64)
353  ```
354
355  Note: this is an implementation detail that is subject to change and users
356  should not rely on this behaviour.
357
358  For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor).
359
360  """
361
362  # List of Python operators that we allow to override.
363  OVERLOADABLE_OPERATORS = {
364      # Binary.
365      "__add__",
366      "__radd__",
367      "__sub__",
368      "__rsub__",
369      "__mul__",
370      "__rmul__",
371      "__div__",
372      "__rdiv__",
373      "__truediv__",
374      "__rtruediv__",
375      "__floordiv__",
376      "__rfloordiv__",
377      "__mod__",
378      "__rmod__",
379      "__lt__",
380      "__le__",
381      "__gt__",
382      "__ge__",
383      "__ne__",
384      "__eq__",
385      "__and__",
386      "__rand__",
387      "__or__",
388      "__ror__",
389      "__xor__",
390      "__rxor__",
391      "__getitem__",
392      "__pow__",
393      "__rpow__",
394      # Unary.
395      "__invert__",
396      "__neg__",
397      "__abs__",
398      "__matmul__",
399      "__rmatmul__"
400  }
401
402  # Whether to allow hashing or numpy-style equality
403  _USE_EQUALITY = tf2.enabled()
404
405  def __init__(self, op, value_index, dtype):
406    """Creates a new `Tensor`.
407
408    Args:
409      op: An `Operation`. `Operation` that computes this tensor.
410      value_index: An `int`. Index of the operation's endpoint that produces
411        this tensor.
412      dtype: A `DType`. Type of elements stored in this tensor.
413
414    Raises:
415      TypeError: If the op is not an `Operation`.
416    """
417    if not isinstance(op, Operation):
418      raise TypeError(f"op needs to be an Operation. "
419                      f"An instance of type {type(op).__name__} is provided.")
420
421    self._op = op
422    self._value_index = value_index
423    self._dtype = dtypes.as_dtype(dtype)
424    # This will be set by self._as_tf_output().
425    self._tf_output = None
426    # This will be set by self.shape().
427    self._shape_val = None
428    # List of operations that use this Tensor as input.  We maintain this list
429    # to easily navigate a computation graph.
430    self._consumers = []
431    self._id = uid()
432    self._name = None
433
434  def __getattr__(self, name):
435    if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
436                "tolist", "data"}:
437      # TODO(wangpeng): Export the enable_numpy_behavior knob
438      raise AttributeError(
439          f"{type(self).__name__} object has no attribute '{name}'. " + """
440        If you are looking for numpy-related methods, please run the following:
441        from tensorflow.python.ops.numpy_ops import np_config
442        np_config.enable_numpy_behavior()
443      """)
444    self.__getattribute__(name)
445
446  @staticmethod
447  def _create_with_tf_output(op, value_index, dtype, tf_output):
448    ret = Tensor(op, value_index, dtype)
449    ret._tf_output = tf_output
450    return ret
451
452  @property
453  def op(self):
454    """The `Operation` that produces this tensor as an output."""
455    return self._op
456
457  @property
458  def dtype(self):
459    """The `DType` of elements in this tensor."""
460    return self._dtype
461
462  @property
463  def graph(self):
464    """The `Graph` that contains this tensor."""
465    return self._op.graph
466
467  @property
468  def name(self):
469    """The string name of this tensor."""
470    if self._name is None:
471      assert self._op.name
472      self._name = "%s:%d" % (self._op.name, self._value_index)
473    return self._name
474
475  @property
476  def device(self):
477    """The name of the device on which this tensor will be produced, or None."""
478    return self._op.device
479
480  @property
481  def shape(self):
482    """Returns a `tf.TensorShape` that represents the shape of this tensor.
483
484    >>> t = tf.constant([1,2,3,4,5])
485    >>> t.shape
486    TensorShape([5])
487
488    `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`.
489
490    In a `tf.function` or when building a model using
491    `tf.keras.Input`, they return the build-time shape of the
492    tensor, which may be partially unknown.
493
494    A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor
495    containing the shape, calculated at runtime.
496
497    See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples.
498    """
499    if self._shape_val is None:
500      self._shape_val = self._c_api_shape()
501    return self._shape_val
502
503  def _c_api_shape(self):
504    """Returns the TensorShape of this tensor according to the C API."""
505    with self._op._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
506      shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
507          c_graph, self._as_tf_output())
508    if unknown_shape:
509      return tensor_shape.unknown_shape()
510    else:
511      shape_vec = [None if d == -1 else d for d in shape_vec]
512      return tensor_shape.TensorShape(shape_vec)
513
514  @property
515  def _shape(self):
516    logging.warning("Tensor._shape is private, use Tensor.shape "
517                    "instead. Tensor._shape will eventually be removed.")
518    return self.shape
519
520  @_shape.setter
521  def _shape(self, value):
522    raise ValueError(
523        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
524
525  def _disallow_when_autograph_unavailable(self, task):
526    raise errors.OperatorNotAllowedInGraphError(
527        f"{task} is not allowed: AutoGraph is unavailable in this runtime. See"
528        " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code"
529        " for more information.")
530
531  def _disallow_when_autograph_disabled(self, task):
532    raise errors.OperatorNotAllowedInGraphError(
533        f"{task} is not allowed: AutoGraph is disabled in this function."
534        " Try decorating it directly with @tf.function.")
535
536  def _disallow_when_autograph_enabled(self, task):
537    raise errors.OperatorNotAllowedInGraphError(
538        f"{task} is not allowed: AutoGraph did convert this function. This"
539        " might indicate you are trying to use an unsupported feature.")
540
541  def _disallow_in_graph_mode(self, task):
542    raise errors.OperatorNotAllowedInGraphError(
543        f"{task} is not allowed in Graph execution. Use Eager execution or"
544        " decorate this function with @tf.function.")
545
546  def _disallow_bool_casting(self):
547    if not ag_ctx.INSPECT_SOURCE_SUPPORTED:
548      self._disallow_when_autograph_unavailable(
549          "Using a symbolic `tf.Tensor` as a Python `bool`")
550    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
551      self._disallow_when_autograph_disabled(
552          "Using a symbolic `tf.Tensor` as a Python `bool`")
553    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
554      self._disallow_when_autograph_enabled(
555          "Using a symbolic `tf.Tensor` as a Python `bool`")
556    else:
557      # Default: V1-style Graph execution.
558      self._disallow_in_graph_mode(
559          "Using a symbolic `tf.Tensor` as a Python `bool`")
560
561  def _disallow_iteration(self):
562    if not ag_ctx.INSPECT_SOURCE_SUPPORTED:
563      self._disallow_when_autograph_unavailable(
564          "Iterating over a symbolic `tf.Tensor`")
565    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
566      self._disallow_when_autograph_disabled(
567          "Iterating over a symbolic `tf.Tensor`")
568    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
569      self._disallow_when_autograph_enabled(
570          "Iterating over a symbolic `tf.Tensor`")
571    else:
572      # Default: V1-style Graph execution.
573      self._disallow_in_graph_mode("Iterating over a symbolic `tf.Tensor`")
574
575  def __iter__(self):
576    if not context.executing_eagerly():
577      self._disallow_iteration()
578
579    shape = self._shape_tuple()
580    if shape is None:
581      raise TypeError("Cannot iterate over a tensor with unknown shape.")
582    if not shape:
583      raise TypeError("Cannot iterate over a scalar tensor.")
584    if shape[0] is None:
585      raise TypeError(
586          "Cannot iterate over a tensor with unknown first dimension.")
587    return _TensorIterator(self, shape[0])
588
589  def _shape_as_list(self):
590    if self.shape.ndims is not None:
591      return [dim.value for dim in self.shape.dims]
592    else:
593      return None
594
595  def _shape_tuple(self):
596    shape = self._shape_as_list()
597    if shape is None:
598      return None
599    return tuple(shape)
600
601  def _rank(self):
602    """Integer rank of this Tensor, if known, else None.
603
604    Returns:
605      Integer rank or None
606    """
607    return self.shape.ndims
608
609  def get_shape(self):
610    """Returns a `tf.TensorShape` that represents the shape of this tensor.
611
612    In eager execution the shape is always fully-known.
613
614    >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
615    >>> print(a.shape)
616    (2, 3)
617
618    `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`.
619
620
621    When executing in a `tf.function` or building a model using
622    `tf.keras.Input`, `Tensor.shape` may return a partial shape (including
623    `None` for unknown dimensions). See `tf.TensorShape` for more details.
624
625    >>> inputs = tf.keras.Input(shape = [10])
626    >>> # Unknown batch size
627    >>> print(inputs.shape)
628    (None, 10)
629
630    The shape is computed using shape inference functions that are
631    registered for each `tf.Operation`.
632
633    The returned `tf.TensorShape` is determined at *build* time, without
634    executing the underlying kernel. It is not a `tf.Tensor`. If you need a
635    shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or
636    use the `tf.shape(tensor)` function, which returns the tensor's shape at
637    *execution* time.
638
639    This is useful for debugging and providing early errors. For
640    example, when tracing a `tf.function`, no ops are being executed, shapes
641    may be unknown (See the [Concrete Functions
642    Guide](https://www.tensorflow.org/guide/concrete_function) for details).
643
644    >>> @tf.function
645    ... def my_matmul(a, b):
646    ...   result = a@b
647    ...   # the `print` executes during tracing.
648    ...   print("Result shape: ", result.shape)
649    ...   return result
650
651    The shape inference functions propagate shapes to the extent possible:
652
653    >>> f = my_matmul.get_concrete_function(
654    ...   tf.TensorSpec([None,3]),
655    ...   tf.TensorSpec([3,5]))
656    Result shape: (None, 5)
657
658    Tracing may fail if a shape missmatch can be detected:
659
660    >>> cf = my_matmul.get_concrete_function(
661    ...   tf.TensorSpec([None,3]),
662    ...   tf.TensorSpec([4,5]))
663    Traceback (most recent call last):
664    ...
665    ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
666    'MatMul') with input shapes: [?,3], [4,5].
667
668    In some cases, the inferred shape may have unknown dimensions. If
669    the caller has additional information about the values of these
670    dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment
671    the inferred shape.
672
673    >>> @tf.function
674    ... def my_fun(a):
675    ...   a = tf.ensure_shape(a, [5, 5])
676    ...   # the `print` executes during tracing.
677    ...   print("Result shape: ", a.shape)
678    ...   return a
679
680    >>> cf = my_fun.get_concrete_function(
681    ...   tf.TensorSpec([None, None]))
682    Result shape: (5, 5)
683
684    Returns:
685      A `tf.TensorShape` representing the shape of this tensor.
686
687    """
688    return self.shape
689
690  def set_shape(self, shape):
691    """Updates the shape of this tensor.
692
693    Note: It is recommended to use `tf.ensure_shape` instead of
694    `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for
695    programming errors and can create guarantees for compiler
696    optimization.
697
698    With eager execution this operates as a shape assertion.
699    Here the shapes match:
700
701    >>> t = tf.constant([[1,2,3]])
702    >>> t.set_shape([1, 3])
703
704    Passing a `None` in the new shape allows any value for that axis:
705
706    >>> t.set_shape([1,None])
707
708    An error is raised if an incompatible shape is passed.
709
710    >>> t.set_shape([1,5])
711    Traceback (most recent call last):
712    ...
713    ValueError: Tensor's shape (1, 3) is not compatible with supplied
714    shape [1, 5]
715
716    When executing in a `tf.function`, or building a model using
717    `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with
718    the current shape of this tensor, and set the tensor's shape to the
719    merged value (see `tf.TensorShape.merge_with` for details):
720
721    >>> t = tf.keras.Input(shape=[None, None, 3])
722    >>> print(t.shape)
723    (None, None, None, 3)
724
725    Dimensions set to `None` are not updated:
726
727    >>> t.set_shape([None, 224, 224, None])
728    >>> print(t.shape)
729    (None, 224, 224, 3)
730
731    The main use case for this is to provide additional shape information
732    that cannot be inferred from the graph alone.
733
734    For example if you know all the images in a dataset have shape [28,28,3] you
735    can set it with `tf.set_shape`:
736
737    >>> @tf.function
738    ... def load_image(filename):
739    ...   raw = tf.io.read_file(filename)
740    ...   image = tf.image.decode_png(raw, channels=3)
741    ...   # the `print` executes during tracing.
742    ...   print("Initial shape: ", image.shape)
743    ...   image.set_shape([28, 28, 3])
744    ...   print("Final shape: ", image.shape)
745    ...   return image
746
747    Trace the function, see the [Concrete Functions
748    Guide](https://www.tensorflow.org/guide/concrete_function) for details.
749
750    >>> cf = load_image.get_concrete_function(
751    ...     tf.TensorSpec([], dtype=tf.string))
752    Initial shape:  (None, None, 3)
753    Final shape: (28, 28, 3)
754
755    Similarly the `tf.io.parse_tensor` function could return a tensor with
756    any shape, even the `tf.rank` is unknown. If you know that all your
757    serialized tensors will be 2d, set it with `set_shape`:
758
759    >>> @tf.function
760    ... def my_parse(string_tensor):
761    ...   result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
762    ...   # the `print` executes during tracing.
763    ...   print("Initial shape: ", result.shape)
764    ...   result.set_shape([None, None])
765    ...   print("Final shape: ", result.shape)
766    ...   return result
767
768    Trace the function
769
770    >>> concrete_parse = my_parse.get_concrete_function(
771    ...     tf.TensorSpec([], dtype=tf.string))
772    Initial shape:  <unknown>
773    Final shape:  (None, None)
774
775    Make sure it works:
776
777    >>> t = tf.ones([5,3], dtype=tf.float32)
778    >>> serialized = tf.io.serialize_tensor(t)
779    >>> print(serialized.dtype)
780    <dtype: 'string'>
781    >>> print(serialized.shape)
782    ()
783    >>> t2 = concrete_parse(serialized)
784    >>> print(t2.shape)
785    (5, 3)
786
787    Caution: `set_shape` ensures that the applied shape is compatible with
788    the existing shape, but it does not check at runtime. Setting
789    incorrect shapes can result in inconsistencies between the
790    statically-known graph and the runtime value of tensors. For runtime
791    validation of the shape, use `tf.ensure_shape` instead. It also modifies
792    the `shape` of the tensor.
793
794    >>> # Serialize a rank-3 tensor
795    >>> t = tf.ones([5,5,5], dtype=tf.float32)
796    >>> serialized = tf.io.serialize_tensor(t)
797    >>> # The function still runs, even though it `set_shape([None,None])`
798    >>> t2 = concrete_parse(serialized)
799    >>> print(t2.shape)
800    (5, 5, 5)
801
802    Args:
803      shape: A `TensorShape` representing the shape of this tensor, a
804        `TensorShapeProto`, a list, a tuple, or None.
805
806    Raises:
807      ValueError: If `shape` is not compatible with the current shape of
808        this tensor.
809    """
810    # Reset cached shape.
811    self._shape_val = None
812
813    # We want set_shape to be reflected in the C API graph for when we run it.
814    if not isinstance(shape, tensor_shape.TensorShape):
815      shape = tensor_shape.TensorShape(shape)
816    dim_list = []
817    if shape.dims is None:
818      unknown_shape = True
819    else:
820      unknown_shape = False
821      for dim in shape.dims:
822        if dim.value is None:
823          dim_list.append(-1)
824        else:
825          dim_list.append(dim.value)
826    try:
827      with self._op._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
828        pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
829            c_graph, self._as_tf_output(), dim_list, unknown_shape)
830    except errors.InvalidArgumentError as e:
831      # Convert to ValueError for backwards compatibility.
832      raise ValueError(e.message)
833
834  @property
835  def value_index(self):
836    """The index of this tensor in the outputs of its `Operation`."""
837    return self._value_index
838
839  def consumers(self):
840    """Returns a list of `Operation`s that consume this tensor.
841
842    Returns:
843      A list of `Operation`s.
844    """
845    consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
846        self._as_tf_output())
847    # pylint: disable=protected-access
848    return [
849        self.graph._get_operation_by_name_unsafe(name)
850        for name in consumer_names
851    ]
852    # pylint: enable=protected-access
853
854  def _as_node_def_input(self):
855    """Return a value to use for the NodeDef "input" attribute.
856
857    The returned string can be used in a NodeDef "input" attribute
858    to indicate that the NodeDef uses this Tensor as input.
859
860    Raises:
861      ValueError: if this Tensor's Operation does not have a name.
862
863    Returns:
864      a string.
865    """
866    assert self._op.name
867    if self._value_index == 0:
868      return self._op.name
869    else:
870      return "%s:%d" % (self._op.name, self._value_index)
871
872  def _as_tf_output(self):
873    # pylint: disable=protected-access
874    # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
875    # also guarantees that a dictionary of tf_output objects will retain a
876    # deterministic (yet unsorted) order which prevents memory blowup in the
877    # cache of executor(s) stored for every session.
878    if self._tf_output is None:
879      self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
880    return self._tf_output
881    # pylint: enable=protected-access
882
883  def __str__(self):
884    return "Tensor(\"%s\"%s%s%s)" % (
885        self.name,
886        (", shape=%s" %
887         self.get_shape()) if self.get_shape().ndims is not None else "",
888        (", dtype=%s" % self._dtype.name) if self._dtype else "",
889        (", device=%s" % self.device) if self.device else "")
890
891  def __repr__(self):
892    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
893                                                   self._dtype.name)
894
895  def __hash__(self):
896    g = getattr(self, "graph", None)
897    if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
898        (g is None or g.building_function)):
899      raise TypeError("Tensor is unhashable. "
900                      "Instead, use tensor.ref() as the key.")
901    else:
902      return id(self)
903
904  def __copy__(self):
905    # TODO(b/77597810): get rid of Tensor copies.
906    cls = self.__class__
907    result = cls.__new__(cls)
908    result.__dict__.update(self.__dict__)
909    return result
910
911  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
912  # operators to run when the left operand is an ndarray, because it
913  # accords the Tensor class higher priority than an ndarray, or a
914  # numpy matrix.
915  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
916  # mechanism, which allows more control over how Tensors interact
917  # with ndarrays.
918  __array_priority__ = 100
919
920  def __array__(self, dtype=None):
921    del dtype
922    raise NotImplementedError(
923        f"Cannot convert a symbolic tf.Tensor ({self.name}) to a numpy array."
924        f" This error may indicate that you're trying to pass a Tensor to"
925        f" a NumPy call, which is not supported.")
926
927  def __len__(self):
928    raise TypeError(f"len is not well defined for a symbolic Tensor "
929                    f"({self.name}). Please call `x.shape` rather than "
930                    f"`len(x)` for shape information.")
931
932  # TODO(mdan): This convoluted machinery is hard to maintain. Clean up.
933  @staticmethod
934  def _override_operator(operator, func):
935    _override_helper(Tensor, operator, func)
936
937  def __bool__(self):
938    """Dummy method to prevent a tensor from being used as a Python `bool`.
939
940    This overload raises a `TypeError` when the user inadvertently
941    treats a `Tensor` as a boolean (most commonly in an `if` or `while`
942    statement), in code that was not converted by AutoGraph. For example:
943
944    ```python
945    if tf.constant(True):  # Will raise.
946      # ...
947
948    if tf.constant(5) < tf.constant(7):  # Will raise.
949      # ...
950    ```
951
952    Raises:
953      `TypeError`.
954    """
955    self._disallow_bool_casting()
956
957  def __nonzero__(self):
958    """Dummy method to prevent a tensor from being used as a Python `bool`.
959
960    This is the Python 2.x counterpart to `__bool__()` above.
961
962    Raises:
963      `TypeError`.
964    """
965    self._disallow_bool_casting()
966
967  def eval(self, feed_dict=None, session=None):
968    """Evaluates this tensor in a `Session`.
969
970    Note: If you are not using `compat.v1` libraries, you should not need this,
971    (or `feed_dict` or `Session`).  In eager execution (or within `tf.function`)
972    you do not need to call `eval`.
973
974    Calling this method will execute all preceding operations that
975    produce the inputs needed for the operation that produces this
976    tensor.
977
978    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
979    launched in a session, and either a default session must be
980    available, or `session` must be specified explicitly.
981
982    Args:
983      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
984        `tf.Session.run` for a description of the valid feed values.
985      session: (Optional.) The `Session` to be used to evaluate this tensor. If
986        none, the default session will be used.
987
988    Returns:
989      A numpy array corresponding to the value of this tensor.
990    """
991    return _eval_using_default_session(self, feed_dict, self.graph, session)
992
993  @deprecation.deprecated(None, "Use ref() instead.")
994  def experimental_ref(self):
995    return self.ref()
996
997  def ref(self):
998    # tf.Variable also has the same ref() API.  If you update the
999    # documentation here, please update tf.Variable.ref() as well.
1000    """Returns a hashable reference object to this Tensor.
1001
1002    The primary use case for this API is to put tensors in a set/dictionary.
1003    We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
1004    available starting Tensorflow 2.0.
1005
1006    The following will raise an exception starting 2.0
1007
1008    >>> x = tf.constant(5)
1009    >>> y = tf.constant(10)
1010    >>> z = tf.constant(10)
1011    >>> tensor_set = {x, y, z}
1012    Traceback (most recent call last):
1013      ...
1014    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
1015    >>> tensor_dict = {x: 'five', y: 'ten'}
1016    Traceback (most recent call last):
1017      ...
1018    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
1019
1020    Instead, we can use `tensor.ref()`.
1021
1022    >>> tensor_set = {x.ref(), y.ref(), z.ref()}
1023    >>> x.ref() in tensor_set
1024    True
1025    >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
1026    >>> tensor_dict[y.ref()]
1027    'ten'
1028
1029    Also, the reference object provides `.deref()` function that returns the
1030    original Tensor.
1031
1032    >>> x = tf.constant(5)
1033    >>> x.ref().deref()
1034    <tf.Tensor: shape=(), dtype=int32, numpy=5>
1035    """
1036    return object_identity.Reference(self)
1037
1038  def __tf_tracing_type__(self, signature_context):
1039    return tensor_spec.TensorSpec(
1040        self.shape, self.dtype).__tf_tracing_type__(signature_context)
1041
1042
1043# TODO(agarwal): consider getting rid of this.
1044# TODO(mdan): This object should not subclass ops.Tensor.
1045class _EagerTensorBase(Tensor):
1046  """Base class for EagerTensor."""
1047
1048  # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and
1049  # only work for scalars; values are cast as per numpy.
1050  def __complex__(self):
1051    return complex(self._numpy())
1052
1053  def __int__(self):
1054    return int(self._numpy())
1055
1056  def __long__(self):
1057    return long(self._numpy())
1058
1059  def __float__(self):
1060    return float(self._numpy())
1061
1062  def __index__(self):
1063    return self._numpy().__index__()
1064
1065  def __bool__(self):
1066    return bool(self._numpy())
1067
1068  __nonzero__ = __bool__
1069
1070  def __format__(self, format_spec):
1071    if self._prefer_custom_summarizer():
1072      return self._summarize_value().__format__(format_spec)
1073    elif self.dtype.is_numpy_compatible:
1074      # Not numpy_text here, otherwise the __format__ behaves differently.
1075      return self._numpy().__format__(format_spec)
1076    else:
1077      return "<unprintable>".__format__(format_spec)
1078
1079  def __reduce__(self):
1080    return convert_to_tensor, (self._numpy(),)
1081
1082  def __copy__(self):
1083    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1084    return self
1085
1086  def __deepcopy__(self, memo):
1087    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1088    del memo
1089    return self
1090
1091  def __str__(self):
1092    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (
1093        value_text(self, is_repr=False), self.shape, self.dtype.name)
1094
1095  def __repr__(self):
1096    return "<tf.Tensor: shape=%s, dtype=%s, %s>" % (
1097        self.shape, self.dtype.name, value_text(self, is_repr=True))
1098
1099  def __len__(self):
1100    """Returns the length of the first dimension in the Tensor."""
1101    if not self.shape.ndims:
1102      raise TypeError("Scalar tensor has no `len()`")
1103    # pylint: disable=protected-access
1104    try:
1105      return self._shape_tuple()[0]
1106    except core._NotOkStatusException as e:
1107      raise core._status_to_exception(e) from None
1108
1109  def __array__(self, dtype=None):
1110    a = self._numpy()
1111    if not dtype:
1112      return a
1113
1114    return np.array(a, dtype=dtype)
1115
1116  def _numpy_internal(self):
1117    raise NotImplementedError()
1118
1119  def _numpy(self):
1120    try:
1121      return self._numpy_internal()
1122    except core._NotOkStatusException as e:  # pylint: disable=protected-access
1123      raise core._status_to_exception(e) from None  # pylint: disable=protected-access
1124
1125  @property
1126  def dtype(self):
1127    # Note: using the intern table directly here as this is
1128    # performance-sensitive in some models.
1129    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
1130
1131  def numpy(self):
1132    """Copy of the contents of this Tensor into a NumPy array or scalar.
1133
1134    Unlike NumPy arrays, Tensors are immutable, so this method has to copy
1135    the contents to ensure safety. Use `memoryview` to get a readonly
1136    view of the contents without doing a copy:
1137
1138    >>> t = tf.constant([42])
1139    >>> np.array(memoryview(t))
1140    array([42], dtype=int32)
1141
1142    Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor
1143    is on GPU, it will have to be transferred to CPU first in order for
1144    `memoryview` to work.
1145
1146    Returns:
1147      A NumPy array of the same shape and dtype or a NumPy scalar, if this
1148      Tensor has rank 0.
1149
1150    Raises:
1151      ValueError: If the dtype of this Tensor does not have a compatible
1152        NumPy dtype.
1153    """
1154    # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
1155    maybe_arr = self._numpy()  # pylint: disable=protected-access
1156    return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1157
1158  @property
1159  def backing_device(self):
1160    """Returns the name of the device holding this tensor's memory.
1161
1162    `.backing_device` is usually the same as `.device`, which returns
1163    the device on which the kernel of the operation that produced this tensor
1164    ran. However, some operations can produce tensors on a different device
1165    (e.g., an operation that executes on the GPU but produces output tensors
1166    in host memory).
1167    """
1168    raise NotImplementedError()
1169
1170  def _datatype_enum(self):
1171    raise NotImplementedError()
1172
1173  def _shape_tuple(self):
1174    """The shape of this Tensor, as a tuple.
1175
1176    This is more performant than tuple(shape().as_list()) as it avoids
1177    two list and one object creation. Marked private for now as from an API
1178    perspective, it would be better to have a single performant way of
1179    getting a shape rather than exposing shape() and shape_tuple()
1180    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
1181    but ideally one would work things out and remove the need for this method.
1182
1183    Returns:
1184      tuple with the shape.
1185    """
1186    raise NotImplementedError()
1187
1188  def _rank(self):
1189    """Integer rank of this Tensor.
1190
1191    Unlike regular Tensors, the rank is always known for EagerTensors.
1192
1193    This is more performant than len(self._shape_tuple())
1194
1195    Returns:
1196      Integer rank
1197    """
1198    raise NotImplementedError()
1199
1200  def _num_elements(self):
1201    """Number of elements of this Tensor.
1202
1203    Unlike regular Tensors, the number of elements is always known for
1204    EagerTensors.
1205
1206    This is more performant than tensor.shape.num_elements
1207
1208    Returns:
1209      Long - num elements in the tensor
1210    """
1211    raise NotImplementedError()
1212
1213  def _copy_to_device(self, device_name):  # pylint: disable=redefined-outer-name
1214    raise NotImplementedError()
1215
1216  @staticmethod
1217  def _override_operator(name, func):
1218    setattr(_EagerTensorBase, name, func)
1219
1220  def _copy_nograd(self, ctx=None, device_name=None):
1221    """Copies tensor to dest device, but doesn't record the operation."""
1222    # Creates a new tensor on the dest device.
1223    if ctx is None:
1224      ctx = context.context()
1225    if device_name is None:
1226      device_name = ctx.device_name
1227    # pylint: disable=protected-access
1228    try:
1229      ctx.ensure_initialized()
1230      new_tensor = self._copy_to_device(device_name)
1231    except core._NotOkStatusException as e:
1232      raise core._status_to_exception(e) from None
1233    return new_tensor
1234
1235  def _copy(self, ctx=None, device_name=None):
1236    """Copies tensor to dest device."""
1237    new_tensor = self._copy_nograd(ctx, device_name)
1238    # Record the copy on tape and define backprop copy as well.
1239    if context.executing_eagerly():
1240      self_device = self.device
1241
1242      def grad_fun(dresult):
1243        return [
1244            dresult._copy(device_name=self_device)
1245            if hasattr(dresult, "_copy") else dresult
1246        ]
1247
1248      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
1249    return new_tensor
1250    # pylint: enable=protected-access
1251
1252  @property
1253  def shape(self):
1254    if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
1255      # pylint: disable=protected-access
1256      try:
1257        # `_tensor_shape` is declared and defined in the definition of
1258        # `EagerTensor`, in C.
1259        self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
1260      except core._NotOkStatusException as e:
1261        raise core._status_to_exception(e) from None
1262
1263    return self._tensor_shape
1264
1265  def get_shape(self):
1266    """Alias of Tensor.shape."""
1267    return self.shape
1268
1269  def _shape_as_list(self):
1270    """The shape of the tensor as a list."""
1271    return list(self._shape_tuple())
1272
1273  @property
1274  def ndim(self):
1275    """Returns the number of Tensor dimensions."""
1276    return self.shape.ndims
1277
1278  @deprecation.deprecated(None, "Use tf.identity instead.")
1279  def cpu(self):
1280    """A copy of this Tensor with contents backed by host memory."""
1281    return self._copy(context.context(), "CPU:0")
1282
1283  @deprecation.deprecated(None, "Use tf.identity instead.")
1284  def gpu(self, gpu_index=0):
1285    """A copy of this Tensor with contents backed by memory on the GPU.
1286
1287    Args:
1288      gpu_index: Identifies which GPU to place the contents on the returned
1289        Tensor in.
1290
1291    Returns:
1292      A GPU-memory backed Tensor object initialized with the same contents
1293      as this Tensor.
1294    """
1295    return self._copy(context.context(), "GPU:" + str(gpu_index))
1296
1297  def set_shape(self, shape):
1298    if not self.shape.is_compatible_with(shape):
1299      raise ValueError(f"Tensor's shape {self.shape} is not compatible "
1300                       f"with supplied shape {shape}.")
1301
1302  # Methods not supported / implemented for Eager Tensors.
1303  @property
1304  def op(self):
1305    raise AttributeError(
1306        "Tensor.op is undefined when eager execution is enabled.")
1307
1308  @property
1309  def graph(self):
1310    raise AttributeError(
1311        "Tensor.graph is undefined when eager execution is enabled.")
1312
1313  @property
1314  def name(self):
1315    raise AttributeError(
1316        "Tensor.name is undefined when eager execution is enabled.")
1317
1318  @property
1319  def value_index(self):
1320    raise AttributeError(
1321        "Tensor.value_index is undefined when eager execution is enabled.")
1322
1323  def consumers(self):
1324    raise NotImplementedError(
1325        "Tensor.consumers is undefined when eager execution is enabled.")
1326
1327  def _add_consumer(self, consumer):
1328    raise NotImplementedError(
1329        "_add_consumer not supported when eager execution is enabled.")
1330
1331  def _as_node_def_input(self):
1332    raise NotImplementedError(
1333        "_as_node_def_input not supported when eager execution is enabled.")
1334
1335  def _as_tf_output(self):
1336    raise NotImplementedError(
1337        "_as_tf_output not supported when eager execution is enabled.")
1338
1339  def eval(self, feed_dict=None, session=None):
1340    raise NotImplementedError(
1341        "eval is not supported when eager execution is enabled, "
1342        "is .numpy() what you're looking for?")
1343
1344
1345# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
1346# registers it with the current module.
1347# It is exposed as an __internal__ api for now (b/171081052), though we
1348# expect it to be eventually covered by tf Tensor types and typing.
1349EagerTensor = tf_export("__internal__.EagerTensor", v1=[])(
1350    pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase))
1351
1352
1353@tf_export(v1=["convert_to_tensor"])
1354@dispatch.add_dispatch_support
1355def convert_to_tensor_v1_with_dispatch(
1356    value,
1357    dtype=None,
1358    name=None,
1359    preferred_dtype=None,
1360    dtype_hint=None):
1361  """Converts the given `value` to a `Tensor`.
1362
1363  This function converts Python objects of various types to `Tensor`
1364  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1365  and Python scalars. For example:
1366
1367  ```python
1368  import numpy as np
1369
1370  def my_func(arg):
1371    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1372    return tf.matmul(arg, arg) + arg
1373
1374  # The following calls are equivalent.
1375  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1376  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1377  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1378  ```
1379
1380  This function can be useful when composing a new operation in Python
1381  (such as `my_func` in the example above). All standard Python op
1382  constructors apply this function to each of their Tensor-valued
1383  inputs, which allows those ops to accept numpy arrays, Python lists,
1384  and scalars in addition to `Tensor` objects.
1385
1386  Note: This function diverges from default Numpy behavior for `float` and
1387    `string` types when `None` is present in a Python list or scalar. Rather
1388    than silently converting `None` values, an error will be thrown.
1389
1390  Args:
1391    value: An object whose type has a registered `Tensor` conversion function.
1392    dtype: Optional element type for the returned tensor. If missing, the type
1393      is inferred from the type of `value`.
1394    name: Optional name to use if a new `Tensor` is created.
1395    preferred_dtype: Optional element type for the returned tensor, used when
1396      dtype is None. In some cases, a caller may not have a dtype in mind when
1397      converting to a tensor, so preferred_dtype can be used as a soft
1398      preference.  If the conversion to `preferred_dtype` is not possible, this
1399      argument has no effect.
1400    dtype_hint: same meaning as preferred_dtype, and overrides it.
1401
1402  Returns:
1403    A `Tensor` based on `value`.
1404
1405  Raises:
1406    TypeError: If no conversion function is registered for `value` to `dtype`.
1407    RuntimeError: If a registered conversion function returns an invalid value.
1408    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1409  """
1410  return convert_to_tensor_v1(value, dtype=dtype, name=name,
1411                              preferred_dtype=preferred_dtype,
1412                              dtype_hint=dtype_hint)
1413
1414
1415def convert_to_tensor_v1(value,
1416                         dtype=None,
1417                         name=None,
1418                         preferred_dtype=None,
1419                         dtype_hint=None):
1420  """Converts the given `value` to a `Tensor` (with the TF1 API)."""
1421  preferred_dtype = deprecation.deprecated_argument_lookup(
1422      "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
1423  return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
1424
1425
1426@tf_export("convert_to_tensor", v1=[])
1427@dispatch.add_dispatch_support
1428def convert_to_tensor_v2_with_dispatch(
1429    value, dtype=None, dtype_hint=None, name=None):
1430  """Converts the given `value` to a `Tensor`.
1431
1432  This function converts Python objects of various types to `Tensor`
1433  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1434  and Python scalars.
1435
1436  For example:
1437
1438  >>> import numpy as np
1439  >>> def my_func(arg):
1440  ...   arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1441  ...   return arg
1442
1443  >>> # The following calls are equivalent.
1444  ...
1445  >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1446  >>> print(value_1)
1447  tf.Tensor(
1448    [[1. 2.]
1449     [3. 4.]], shape=(2, 2), dtype=float32)
1450  >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1451  >>> print(value_2)
1452  tf.Tensor(
1453    [[1. 2.]
1454     [3. 4.]], shape=(2, 2), dtype=float32)
1455  >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1456  >>> print(value_3)
1457  tf.Tensor(
1458    [[1. 2.]
1459     [3. 4.]], shape=(2, 2), dtype=float32)
1460
1461  This function can be useful when composing a new operation in Python
1462  (such as `my_func` in the example above). All standard Python op
1463  constructors apply this function to each of their Tensor-valued
1464  inputs, which allows those ops to accept numpy arrays, Python lists,
1465  and scalars in addition to `Tensor` objects.
1466
1467  Note: This function diverges from default Numpy behavior for `float` and
1468    `string` types when `None` is present in a Python list or scalar. Rather
1469    than silently converting `None` values, an error will be thrown.
1470
1471  Args:
1472    value: An object whose type has a registered `Tensor` conversion function.
1473    dtype: Optional element type for the returned tensor. If missing, the type
1474      is inferred from the type of `value`.
1475    dtype_hint: Optional element type for the returned tensor, used when dtype
1476      is None. In some cases, a caller may not have a dtype in mind when
1477      converting to a tensor, so dtype_hint can be used as a soft preference.
1478      If the conversion to `dtype_hint` is not possible, this argument has no
1479      effect.
1480    name: Optional name to use if a new `Tensor` is created.
1481
1482  Returns:
1483    A `Tensor` based on `value`.
1484
1485  Raises:
1486    TypeError: If no conversion function is registered for `value` to `dtype`.
1487    RuntimeError: If a registered conversion function returns an invalid value.
1488    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1489  """
1490  return convert_to_tensor_v2(
1491      value, dtype=dtype, dtype_hint=dtype_hint, name=name)
1492
1493
1494def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
1495  """Converts the given `value` to a `Tensor`."""
1496  return convert_to_tensor(
1497      value=value,
1498      dtype=dtype,
1499      name=name,
1500      preferred_dtype=dtype_hint,
1501      as_ref=False)
1502
1503
1504def _add_error_prefix(msg, *, name=None):
1505  return msg if name is None else f"{name}: {msg}"
1506
1507
1508def pack_eager_tensors(tensors, ctx=None):
1509  """Pack multiple `EagerTensor`s of the same dtype and shape.
1510
1511  Args:
1512    tensors: a list of EagerTensors to pack.
1513    ctx: context.context().
1514
1515  Returns:
1516    A packed EagerTensor.
1517  """
1518  if not isinstance(tensors, list):
1519    raise TypeError(f"tensors must be a list, but got a {type(tensors)}")
1520
1521  if not tensors:
1522    raise ValueError("Cannot pack an empty list of tensors.")
1523
1524  dtype = tensors[0].dtype
1525  shape = tensors[0].shape
1526  handle_data = tensors[0]._handle_data  # pylint: disable=protected-access
1527  is_resource = dtype == dtypes.resource
1528  for i in range(len(tensors)):
1529    t = tensors[i]
1530    if not isinstance(t, EagerTensor):
1531      raise TypeError(f"All tensors being packed must be EagerTensor. "
1532                      f"Found an item of type {type(t)}.")
1533
1534    if t.dtype != dtype:
1535      raise ValueError(
1536          f"All tensors being packed should have the same dtype {dtype}, "
1537          f"but the {i}-th tensor is of dtype {t.dtype}")
1538    if t.shape != shape:
1539      raise ValueError(
1540          f"All tensors being packed should have the same shape {shape}, "
1541          f"but the {i}-th tensor is of shape {t.shape}")
1542    # pylint: disable=protected-access
1543    if is_resource and t._handle_data != handle_data:
1544      raise ValueError(
1545          f"All tensors being packed should have the same handle data "
1546          f"{handle_data}, "
1547          f"but the {i}-th tensor is of handle data {t._handle_data}")
1548    # pylint: enable=protected-access
1549
1550  if ctx is None:
1551    ctx = context.context()
1552
1553  # Propogate handle data for resource variables
1554  packed_tensor = ctx.pack_eager_tensors(tensors)
1555  if handle_data is not None:
1556    packed_tensor._handle_data = handle_data  # pylint: disable=protected-access
1557
1558  def grad_fun(_):
1559    raise ValueError(
1560        "Computing gradients through pack_eager_tensors is not supported.")
1561
1562  tape.record_operation("pack_eager_tensors", [packed_tensor], tensors,
1563                        grad_fun)
1564
1565  return packed_tensor
1566
1567
1568@profiler_trace.trace_wrapper("convert_to_tensor")
1569def convert_to_tensor(value,
1570                      dtype=None,
1571                      name=None,
1572                      as_ref=False,
1573                      preferred_dtype=None,
1574                      dtype_hint=None,
1575                      ctx=None,
1576                      accepted_result_types=(Tensor,)):
1577  """Implementation of the public convert_to_tensor."""
1578  # TODO(b/142518781): Fix all call-sites and remove redundant arg
1579  preferred_dtype = preferred_dtype or dtype_hint
1580  if isinstance(value, EagerTensor):
1581    if ctx is None:
1582      ctx = context.context()
1583    if not ctx.executing_eagerly():
1584      graph = get_default_graph()
1585      if not graph.building_function:
1586        raise RuntimeError(
1587            _add_error_prefix(
1588                "Attempting to capture an EagerTensor without "
1589                "building a function.",
1590                name=name))
1591      return graph.capture(value, name=name)
1592
1593  if dtype is not None:
1594    dtype = dtypes.as_dtype(dtype)
1595  if isinstance(value, Tensor):
1596    if dtype is not None and not dtype.is_compatible_with(value.dtype):
1597      raise ValueError(
1598          _add_error_prefix(
1599              f"Tensor conversion requested dtype {dtype.name} "
1600              f"for Tensor with dtype {value.dtype.name}: {value!r}",
1601              name=name))
1602    return value
1603
1604  if preferred_dtype is not None:
1605    preferred_dtype = dtypes.as_dtype(preferred_dtype)
1606
1607  # See below for the reason why it's `type(value)` and not just `value`.
1608  # https://docs.python.org/3.8/reference/datamodel.html#special-lookup
1609  overload = getattr(type(value), "__tf_tensor__", None)
1610  if overload is not None:
1611    return overload(value, dtype, name)  #  pylint: disable=not-callable
1612
1613  for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
1614    # If dtype is None but preferred_dtype is not None, we try to
1615    # cast to preferred_dtype first.
1616    ret = None
1617    if dtype is None and preferred_dtype is not None:
1618      try:
1619        ret = conversion_func(
1620            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1621      except (TypeError, ValueError):
1622        # Could not coerce the conversion to use the preferred dtype.
1623        pass
1624      else:
1625        if (ret is not NotImplemented and
1626            ret.dtype.base_dtype != preferred_dtype.base_dtype):
1627          raise RuntimeError(
1628              _add_error_prefix(
1629                  f"Conversion function {conversion_func!r} for type "
1630                  f"{base_type} returned incompatible dtype: requested = "
1631                  f"{preferred_dtype.base_dtype.name}, "
1632                  f"actual = {ret.dtype.base_dtype.name}",
1633                  name=name))
1634
1635    if ret is None:
1636      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1637
1638    if ret is NotImplemented:
1639      continue
1640
1641    if not isinstance(ret, accepted_result_types):
1642      raise RuntimeError(
1643          _add_error_prefix(
1644              f"Conversion function {conversion_func!r} for type "
1645              f"{base_type} returned non-Tensor: {ret!r}",
1646              name=name))
1647    if dtype and not dtype.is_compatible_with(ret.dtype):
1648      raise RuntimeError(
1649          _add_error_prefix(
1650              f"Conversion function {conversion_func} for type {base_type} "
1651              f"returned incompatible dtype: requested = {dtype.name}, "
1652              f"actual = {ret.dtype.name}",
1653              name=name))
1654    return ret
1655  raise TypeError(
1656      _add_error_prefix(
1657          f"Cannot convert {value!r} with type {type(value)} to Tensor: "
1658          f"no conversion function registered.",
1659          name=name))
1660
1661
1662internal_convert_to_tensor = convert_to_tensor
1663
1664
1665def internal_convert_n_to_tensor(values,
1666                                 dtype=None,
1667                                 name=None,
1668                                 as_ref=False,
1669                                 preferred_dtype=None,
1670                                 ctx=None):
1671  """Converts `values` to a list of `Tensor` objects.
1672
1673  Args:
1674    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1675    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1676    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1677      which case element `i` will be given the name `name + '_' + i`.
1678    as_ref: True if the caller wants the results as ref tensors.
1679    preferred_dtype: Optional element type for the returned tensors, used when
1680      dtype is None. In some cases, a caller may not have a dtype in mind when
1681      converting to a tensor, so preferred_dtype can be used as a soft
1682      preference.  If the conversion to `preferred_dtype` is not possible, this
1683      argument has no effect.
1684    ctx: The value of context.context().
1685
1686  Returns:
1687    A list of `Tensor` and/or `IndexedSlices` objects.
1688
1689  Raises:
1690    TypeError: If no conversion function is registered for an element in
1691      `values`.
1692    RuntimeError: If a registered conversion function returns an invalid
1693      value.
1694  """
1695  if not isinstance(values, collections_abc.Sequence):
1696    raise TypeError("values must be a sequence.")
1697  ret = []
1698  if ctx is None:
1699    ctx = context.context()
1700  for i, value in enumerate(values):
1701    n = None if name is None else "%s_%d" % (name, i)
1702    ret.append(
1703        convert_to_tensor(
1704            value,
1705            dtype=dtype,
1706            name=n,
1707            as_ref=as_ref,
1708            preferred_dtype=preferred_dtype,
1709            ctx=ctx))
1710  return ret
1711
1712
1713def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1714  """Converts `values` to a list of `Tensor` objects.
1715
1716  Args:
1717    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1718    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1719    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1720      which case element `i` will be given the name `name + '_' + i`.
1721    preferred_dtype: Optional element type for the returned tensors, used when
1722      dtype is None. In some cases, a caller may not have a dtype in mind when
1723      converting to a tensor, so preferred_dtype can be used as a soft
1724      preference.  If the conversion to `preferred_dtype` is not possible, this
1725      argument has no effect.
1726
1727  Returns:
1728    A list of `Tensor` and/or `IndexedSlices` objects.
1729
1730  Raises:
1731    TypeError: If no conversion function is registered for an element in
1732      `values`.
1733    RuntimeError: If a registered conversion function returns an invalid
1734      value.
1735  """
1736  return internal_convert_n_to_tensor(
1737      values=values,
1738      dtype=dtype,
1739      name=name,
1740      preferred_dtype=preferred_dtype,
1741      as_ref=False)
1742
1743
1744def convert_to_tensor_or_composite(value, dtype=None, name=None):
1745  """Converts the given object to a `Tensor` or `CompositeTensor`.
1746
1747  If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1748  is converted to a `Tensor` using `convert_to_tensor()`.
1749
1750  Args:
1751    value: A `CompositeTensor` or an object that can be consumed by
1752      `convert_to_tensor()`.
1753    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1754      `CompositeTensor`.
1755    name: (Optional.) A name to use if a new `Tensor` is created.
1756
1757  Returns:
1758    A `Tensor` or `CompositeTensor`, based on `value`.
1759
1760  Raises:
1761    ValueError: If `dtype` does not match the element type of `value`.
1762  """
1763  return internal_convert_to_tensor_or_composite(
1764      value=value, dtype=dtype, name=name, as_ref=False)
1765
1766
1767def internal_convert_to_tensor_or_composite(value,
1768                                            dtype=None,
1769                                            name=None,
1770                                            as_ref=False):
1771  """Converts the given object to a `Tensor` or `CompositeTensor`.
1772
1773  If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
1774  is converted to a `Tensor` using `convert_to_tensor()`.
1775
1776  Args:
1777    value: A `CompositeTensor`, or an object that can be consumed by
1778      `convert_to_tensor()`.
1779    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1780      `CompositeTensor`.
1781    name: (Optional.) A name to use if a new `Tensor` is created.
1782    as_ref: True if the caller wants the results as ref tensors.
1783
1784  Returns:
1785    A `Tensor` or `CompositeTensor`, based on `value`.
1786
1787  Raises:
1788    ValueError: If `dtype` does not match the element type of `value`.
1789  """
1790  if isinstance(value, composite_tensor.CompositeTensor):
1791    value_dtype = getattr(value, "dtype", None)
1792    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1793      raise ValueError(f"Tensor conversion dtype mismatch. "
1794                       f"Requested dtype is {dtypes.as_dtype(dtype).name}, "
1795                       f"Tensor has dtype {value.dtype.name}: {value!r}")
1796    return value
1797  else:
1798    return convert_to_tensor(
1799        value,
1800        dtype=dtype,
1801        name=name,
1802        as_ref=as_ref,
1803        accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
1804
1805
1806def internal_convert_n_to_tensor_or_composite(values,
1807                                              dtype=None,
1808                                              name=None,
1809                                              as_ref=False):
1810  """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1811
1812  Any `CompositeTensor` objects in `values` are returned unmodified.
1813
1814  Args:
1815    values: A list of `None`, `CompositeTensor`, or objects that can be consumed
1816      by `convert_to_tensor()`.
1817    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1818      `CompositeTensor`s.
1819    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1820      which case element `i` will be given the name `name + '_' + i`.
1821    as_ref: True if the caller wants the results as ref tensors.
1822
1823  Returns:
1824    A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1825
1826  Raises:
1827    TypeError: If no conversion function is registered for an element in
1828      `values`.
1829    RuntimeError: If a registered conversion function returns an invalid
1830      value.
1831  """
1832  if not isinstance(values, collections_abc.Sequence):
1833    raise TypeError("values must be a sequence.")
1834  ret = []
1835  for i, value in enumerate(values):
1836    if value is None:
1837      ret.append(value)
1838    else:
1839      n = None if name is None else "%s_%d" % (name, i)
1840      ret.append(
1841          internal_convert_to_tensor_or_composite(
1842              value, dtype=dtype, name=n, as_ref=as_ref))
1843  return ret
1844
1845
1846def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1847  """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1848
1849  Any `CompositeTensor` objects in `values` are returned unmodified.
1850
1851  Args:
1852    values: A list of `None`, `CompositeTensor``, or objects that can be
1853      consumed by `convert_to_tensor()`.
1854    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1855      `CompositeTensor`s.
1856    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1857      which case element `i` will be given the name `name + '_' + i`.
1858
1859  Returns:
1860    A list of `Tensor` and/or `CompositeTensor` objects.
1861
1862  Raises:
1863    TypeError: If no conversion function is registered for an element in
1864      `values`.
1865    RuntimeError: If a registered conversion function returns an invalid
1866      value.
1867  """
1868  return internal_convert_n_to_tensor_or_composite(
1869      values=values, dtype=dtype, name=name, as_ref=False)
1870
1871
1872def _device_string(dev_spec):
1873  if pydev.is_device_spec(dev_spec):
1874    return dev_spec.to_string()
1875  else:
1876    return dev_spec
1877
1878
1879def _NodeDef(op_type, name, attrs=None):
1880  """Create a NodeDef proto.
1881
1882  Args:
1883    op_type: Value for the "op" attribute of the NodeDef proto.
1884    name: Value for the "name" attribute of the NodeDef proto.
1885    attrs: Dictionary where the key is the attribute name (a string)
1886      and the value is the respective "attr" attribute of the NodeDef proto (an
1887      AttrValue).
1888
1889  Returns:
1890    A node_def_pb2.NodeDef protocol buffer.
1891  """
1892  node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type),
1893                                  name=compat.as_bytes(name))
1894  if attrs:
1895    for k, v in attrs.items():
1896      node_def.attr[k].CopyFrom(v)
1897  return node_def
1898
1899
1900# Copied from core/framework/node_def_util.cc
1901# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1902_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
1903_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$")
1904
1905
1906@tf_export("__internal__.create_c_op", v1=[])
1907@traceback_utils.filter_traceback
1908def _create_c_op(graph,
1909                 node_def,
1910                 inputs,
1911                 control_inputs,
1912                 op_def=None,
1913                 extract_traceback=True):
1914  """Creates a TF_Operation.
1915
1916  Args:
1917    graph: a `Graph`.
1918    node_def: `node_def_pb2.NodeDef` for the operation to create.
1919    inputs: A flattened list of `Tensor`s. This function handles grouping
1920      tensors into lists as per attributes in the `node_def`.
1921    control_inputs: A list of `Operation`s to set as control dependencies.
1922    op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not
1923      specified, is looked up from the `graph` using `node_def.op`.
1924    extract_traceback: if True, extract the current Python traceback to the
1925      TF_Operation.
1926
1927  Returns:
1928    A wrapped TF_Operation*.
1929  """
1930  if op_def is None:
1931    op_def = graph._get_op_def(node_def.op)  # pylint: disable=protected-access
1932  # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1933  # Refactor so we don't have to do this here.
1934  inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
1935  # pylint: disable=protected-access
1936  with graph._c_graph.get() as c_graph:
1937    op_desc = pywrap_tf_session.TF_NewOperation(c_graph,
1938                                                compat.as_str(node_def.op),
1939                                                compat.as_str(node_def.name))
1940  if node_def.device:
1941    pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1942  # Add inputs
1943  for op_input in inputs:
1944    if isinstance(op_input, (list, tuple)):
1945      pywrap_tf_session.TF_AddInputList(op_desc,
1946                                        [t._as_tf_output() for t in op_input])
1947    else:
1948      pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
1949
1950  # Add control inputs
1951  for control_input in control_inputs:
1952    pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
1953  # pylint: enable=protected-access
1954
1955  # Add attrs
1956  for name, attr_value in node_def.attr.items():
1957    serialized = attr_value.SerializeToString()
1958    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1959    # It might be worth creating a convenient way to re-use the same status.
1960    pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
1961                                           serialized)
1962
1963  try:
1964    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1965  except errors.InvalidArgumentError as e:
1966    # Convert to ValueError for backwards compatibility.
1967    raise ValueError(e.message)
1968
1969  # Record the current Python stack trace as the creating stacktrace of this
1970  # TF_Operation.
1971  if extract_traceback:
1972    tf_stack.extract_stack_for_op(c_op, stacklevel=3)
1973
1974  return c_op
1975
1976
1977@tf_export("Operation")
1978class Operation(object):
1979  """Represents a graph node that performs computation on tensors.
1980
1981  An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
1982  objects as input, and produces zero or more `Tensor` objects as output.
1983  Objects of type `Operation` are created by calling a Python op constructor
1984  (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default`
1985  context manager.
1986
1987  For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an
1988  `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and
1989  produces `c` as output.
1990
1991  If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be
1992  executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for
1993  calling `tf.compat.v1.get_default_session().run(op)`.
1994  """
1995
1996  def __init__(self,
1997               node_def,
1998               g,
1999               inputs=None,
2000               output_types=None,
2001               control_inputs=None,
2002               input_types=None,
2003               original_op=None,
2004               op_def=None):
2005    r"""Creates an `Operation`.
2006
2007    NOTE: This constructor validates the name of the `Operation` (passed
2008    as `node_def.name`). Valid `Operation` names match the following
2009    regular expression:
2010
2011        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
2012
2013    Args:
2014      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`. Used for
2015        attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and
2016        `device`.  The `input` attribute is irrelevant here as it will be
2017        computed when generating the model.
2018      g: `Graph`. The parent graph.
2019      inputs: list of `Tensor` objects. The inputs to this `Operation`.
2020      output_types: list of `DType` objects.  List of the types of the `Tensors`
2021        computed by this operation.  The length of this list indicates the
2022        number of output endpoints of the `Operation`.
2023      control_inputs: list of operations or tensors from which to have a control
2024        dependency.
2025      input_types: List of `DType` objects representing the types of the tensors
2026        accepted by the `Operation`.  By default uses `[x.dtype.base_dtype for x
2027        in inputs]`.  Operations that expect reference-typed inputs must specify
2028        these explicitly.
2029      original_op: Optional. Used to associate the new `Operation` with an
2030        existing `Operation` (for example, a replica with the op that was
2031        replicated).
2032      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type
2033        that this `Operation` represents.
2034
2035    Raises:
2036      TypeError: if control inputs are not Operations or Tensors,
2037        or if `node_def` is not a `NodeDef`,
2038        or if `g` is not a `Graph`,
2039        or if `inputs` are not tensors,
2040        or if `inputs` and `input_types` are incompatible.
2041      ValueError: if the `node_def` name is not valid.
2042    """
2043    if not isinstance(g, Graph):
2044      raise TypeError(f"Argument g must be a Graph. "
2045                      f"Received an instance of type {type(g)}")
2046
2047    # TODO(feyu): This message is redundant with the check below. We raise it
2048    # to help users to migrate. Remove this after 07/01/2022.
2049    if isinstance(node_def, pywrap_tf_session.TF_Operation):
2050      raise ValueError(
2051          "Calling Operation() with node_def of a TF_Operation is deprecated. "
2052          "Please switch to Operation.from_c_op.")
2053
2054    if not isinstance(node_def, node_def_pb2.NodeDef):
2055      raise TypeError(f"Argument node_def must be a NodeDef. "
2056                      f"Received an instance of type: {type(node_def)}.")
2057    if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
2058      raise ValueError(
2059          f"Cannot create a tensor proto whose content is larger than 2GB. "
2060          f"Size of tensor is {node_def.ByteSize()} bytes.")
2061
2062    # TODO(mdan): This does not belong here. Graph::AddNode should handle it.
2063    if not _VALID_OP_NAME_REGEX.match(node_def.name):
2064      raise ValueError(
2065          f"`{node_def.name}` is not a valid node name. "
2066          f"Accepted names conform to Regex /{_VALID_OP_NAME_REGEX}/")
2067
2068    # FIXME(b/225400189): output_types is unused. Consider remove it from
2069    # the argument list.
2070    del output_types
2071
2072    if inputs is None:
2073      inputs = []
2074    elif not isinstance(inputs, list):
2075      raise TypeError(f"Argument inputs shall be a list of Tensors. "
2076                      f"Received an instance of type {type(inputs)}")
2077    for a in inputs:
2078      if not isinstance(a, Tensor):
2079        raise TypeError(f"Items of argument inputs shall be Tensor. "
2080                        f"Received an instance of type {type(a)}.")
2081    if input_types is None:
2082      input_types = [i.dtype.base_dtype for i in inputs]
2083    else:
2084      if not all(
2085          x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)):
2086        raise TypeError("In op '%s', input types (%s) are not compatible "
2087                        "with expected types (%s)" %
2088                        (node_def.name, [i.dtype for i in inputs], input_types))
2089
2090    # Build the list of control inputs.
2091    control_input_ops = []
2092    if control_inputs:
2093      for c in control_inputs:
2094        control_op = None
2095        if isinstance(c, Operation):
2096          control_op = c
2097        elif isinstance(c, (Tensor, IndexedSlices)):
2098          control_op = c.op
2099        else:
2100          raise TypeError(f"Control input must be an Operation, "
2101                          f"a Tensor, or IndexedSlices. "
2102                          f"Received an instance of type {type(c)}.")
2103        control_input_ops.append(control_op)
2104
2105    # Initialize c_op from node_def and other inputs
2106    c_op = _create_c_op(g, node_def, inputs, control_input_ops, op_def=op_def)
2107    self._init_from_c_op(c_op=c_op, g=g)
2108
2109    self._original_op = original_op
2110
2111    # Post process for control flows.
2112    self._control_flow_post_processing(input_tensors=inputs)
2113
2114    # Removes this frame from the Python traceback.
2115    # We adjust stacklevel directly to avoid triggering serialization.
2116    self.traceback._stacklevel += 1  # pylint: disable=protected-access
2117
2118  @classmethod
2119  def _from_c_op(cls, c_op, g):
2120    """Create an Operation from a TF_Operation.
2121
2122    For internal use only: This is useful for creating Operation for ops
2123    indirectly created by C API methods, e.g. the ops created by
2124    TF_ImportGraphDef.
2125
2126    Args:
2127      c_op: a TF_Operation.
2128      g: A Graph.
2129
2130    Returns:
2131      an Operation object.
2132    """
2133    self = object.__new__(cls)
2134
2135    self._init_from_c_op(c_op=c_op, g=g)  # pylint: disable=protected-access
2136    return self
2137
2138  def _init_from_c_op(self, c_op, g):
2139    """Initializes Operation from a TF_Operation."""
2140
2141    if not isinstance(g, Graph):
2142      raise TypeError(f"Operation initialization requires a Graph, "
2143                      f"got {type(g)} for argument g.")
2144
2145    if not isinstance(c_op, pywrap_tf_session.TF_Operation):
2146      raise TypeError(f"Operation initialization requires a TF_Operation, "
2147                      f"got {type(c_op)} for argument c_op.")
2148
2149    self._original_op = None
2150
2151    self._graph = g
2152    self._c_op = c_op
2153
2154    # This will be set by self.inputs.
2155    self._inputs_val = None
2156
2157    # List of _UserDevSpecs holding code location of device context manager
2158    # invocations and the users original argument to them.
2159    self._device_code_locations = None
2160    # Dict mapping op name to file and line information for op colocation
2161    # context managers.
2162    self._colocation_code_locations = None
2163    self._control_flow_context = g._get_control_flow_context()  # pylint: disable=protected-access
2164
2165    # Gradient function for this op. There are three ways to specify gradient
2166    # function, and first available gradient gets used, in the following order.
2167    # 1. self._gradient_function
2168    # 2. Gradient name registered by "_gradient_op_type" attribute.
2169    # 3. Gradient name registered by op.type.
2170    self._gradient_function = None
2171
2172    op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))  # pylint: disable=protected-access
2173
2174    self._is_stateful = op_def.is_stateful
2175
2176    # Initialize self._outputs.
2177    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2178    self._outputs = []
2179    for i in range(num_outputs):
2180      tf_output = c_api_util.tf_output(self._c_op, i)
2181      output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
2182      tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output)  # pylint: disable=protected-access
2183      self._outputs.append(tensor)
2184
2185    self._id_value = g._add_op(self, self.name)  # pylint: disable=protected-access
2186
2187  def _control_flow_post_processing(self, input_tensors=None):
2188    """Add this op to its control flow context.
2189
2190    This may add new ops and change this op's inputs. self.inputs must be
2191    available before calling this method.
2192
2193    Args:
2194      input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
2195        of this op, which should be equivalent to `self.inputs`. Pass this
2196        argument to avoid evaluating `self.inputs` unnecessarily.
2197    """
2198    if input_tensors is None:
2199      input_tensors = self.inputs
2200    for input_tensor in input_tensors:
2201      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
2202    if self._control_flow_context is not None:
2203      self._control_flow_context.AddOp(self)
2204
2205  def colocation_groups(self):
2206    """Returns the list of colocation groups of the op."""
2207    default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)]
2208    try:
2209      class_attr = self.get_attr("_class")
2210    except ValueError:
2211      # This op has no explicit colocation group, so it is itself its
2212      # own root of a colocation group.
2213      return default_colocation_group
2214
2215    attr_groups = [
2216        class_name for class_name in class_attr
2217        if class_name.startswith(b"loc:@")
2218    ]
2219
2220    # If there are no colocation groups in the explicit _class field,
2221    # return the default colocation group.
2222    return attr_groups if attr_groups else default_colocation_group
2223
2224  def values(self):
2225    """DEPRECATED: Use outputs."""
2226    return tuple(self.outputs)
2227
2228  def _get_control_flow_context(self):
2229    """Returns the control flow context of this op.
2230
2231    Returns:
2232      A context object.
2233    """
2234    return self._control_flow_context
2235
2236  def _set_control_flow_context(self, ctx):
2237    """Sets the current control flow context of this op.
2238
2239    Args:
2240      ctx: a context object.
2241    """
2242    self._control_flow_context = ctx
2243
2244  @property
2245  def name(self):
2246    """The full name of this operation."""
2247    return pywrap_tf_session.TF_OperationName(self._c_op)
2248
2249  @property
2250  def _id(self):
2251    """The unique integer id of this operation."""
2252    return self._id_value
2253
2254  @property
2255  def device(self):
2256    """The name of the device to which this op has been assigned, if any.
2257
2258    Returns:
2259      The string name of the device to which this op has been
2260      assigned, or an empty string if it has not been assigned to a
2261      device.
2262    """
2263    return pywrap_tf_session.TF_OperationDevice(self._c_op)
2264
2265  @property
2266  def _device_assignments(self):
2267    """Code locations for device context managers active at op creation.
2268
2269    This property will return a list of traceable_stack.TraceableObject
2270    instances where .obj is a string representing the assigned device
2271    (or information about the function that would be applied to this op
2272    to compute the desired device) and the filename and lineno members
2273    record the location of the relevant device context manager.
2274
2275    For example, suppose file_a contained these lines:
2276
2277      file_a.py:
2278        15: with tf.device('/gpu:0'):
2279        16:   node_b = tf.constant(4, name='NODE_B')
2280
2281    Then a TraceableObject t_obj representing the device context manager
2282    would have these member values:
2283
2284      t_obj.obj -> '/gpu:0'
2285      t_obj.filename = 'file_a.py'
2286      t_obj.lineno = 15
2287
2288    and node_b.op._device_assignments would return the list [t_obj].
2289
2290    Returns:
2291      [str: traceable_stack.TraceableObject, ...] as per this method's
2292      description, above.
2293    """
2294    return self._device_code_locations or []
2295
2296  @property
2297  def _colocation_dict(self):
2298    """Code locations for colocation context managers active at op creation.
2299
2300    This property will return a dictionary for which the keys are nodes with
2301    which this Operation is colocated, and for which the values are
2302    traceable_stack.TraceableObject instances.  The TraceableObject instances
2303    record the location of the relevant colocation context manager but have the
2304    "obj" field set to None to prevent leaking private data.
2305
2306    For example, suppose file_a contained these lines:
2307
2308      file_a.py:
2309        14: node_a = tf.constant(3, name='NODE_A')
2310        15: with tf.compat.v1.colocate_with(node_a):
2311        16:   node_b = tf.constant(4, name='NODE_B')
2312
2313    Then a TraceableObject t_obj representing the colocation context manager
2314    would have these member values:
2315
2316      t_obj.obj -> None
2317      t_obj.filename = 'file_a.py'
2318      t_obj.lineno = 15
2319
2320    and node_b.op._colocation_dict would return the dictionary
2321
2322      { 'NODE_A': t_obj }
2323
2324    Returns:
2325      {str: traceable_stack.TraceableObject} as per this method's description,
2326      above.
2327    """
2328    locations_dict = self._colocation_code_locations or {}
2329    return locations_dict.copy()
2330
2331  @property
2332  def _output_types(self):
2333    """List this operation's output types.
2334
2335    Returns:
2336      List of the types of the Tensors computed by this operation.
2337      Each element in the list is an integer whose value is one of
2338      the TF_DataType enums defined in pywrap_tf_session.h
2339      The length of this list indicates the number of output endpoints
2340      of the operation.
2341    """
2342    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2343    output_types = [
2344        int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
2345        for i in range(num_outputs)
2346    ]
2347
2348    return output_types
2349
2350  def _tf_output(self, output_idx):
2351    """Create and return a new TF_Output for output_idx'th output of this op."""
2352    tf_output = pywrap_tf_session.TF_Output()
2353    tf_output.oper = self._c_op
2354    tf_output.index = output_idx
2355    return tf_output
2356
2357  def _tf_input(self, input_idx):
2358    """Create and return a new TF_Input for input_idx'th input of this op."""
2359    tf_input = pywrap_tf_session.TF_Input()
2360    tf_input.oper = self._c_op
2361    tf_input.index = input_idx
2362    return tf_input
2363
2364  def _set_device(self, device):  # pylint: disable=redefined-outer-name
2365    """Set the device of this operation.
2366
2367    Args:
2368      device: string or device..  The device to set.
2369    """
2370    self._set_device_from_string(compat.as_str(_device_string(device)))
2371
2372  def _set_device_from_string(self, device_str):
2373    """Fast path to set device if the type is known to be a string.
2374
2375    This function is called frequently enough during graph construction that
2376    there are non-trivial performance gains if the caller can guarantee that
2377    the specified device is already a string.
2378
2379    Args:
2380      device_str: A string specifying where to place this op.
2381    """
2382    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2383      pywrap_tf_session.SetRequestedDevice(
2384          c_graph,
2385          self._c_op,  # pylint: disable=protected-access
2386          device_str)
2387
2388  def _update_input(self, index, tensor):
2389    """Update the input to this operation at the given index.
2390
2391    NOTE: This is for TF internal use only. Please don't use it.
2392
2393    Args:
2394      index: the index of the input to update.
2395      tensor: the Tensor to be used as the input at the given index.
2396
2397    Raises:
2398      TypeError: if tensor is not a Tensor,
2399        or if input tensor type is not convertible to dtype.
2400      ValueError: if the Tensor is from a different graph.
2401    """
2402    if not isinstance(tensor, Tensor):
2403      raise TypeError("tensor must be a Tensor: %s" % tensor)
2404    _assert_same_graph(self, tensor)
2405
2406    # Reset cached inputs.
2407    self._inputs_val = None
2408    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2409      pywrap_tf_session.UpdateEdge(
2410          c_graph,
2411          tensor._as_tf_output(),  # pylint: disable=protected-access
2412          self._tf_input(index))
2413
2414  def _add_while_inputs(self, tensors):
2415    """See AddWhileInputHack in python_api.h.
2416
2417    NOTE: This is for TF internal use only. Please don't use it.
2418
2419    Args:
2420      tensors: list of Tensors
2421
2422    Raises:
2423      TypeError: if tensor is not a Tensor,
2424        or if input tensor type is not convertible to dtype.
2425      ValueError: if the Tensor is from a different graph.
2426    """
2427    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2428      for tensor in tensors:
2429        if not isinstance(tensor, Tensor):
2430          raise TypeError("tensor must be a Tensor: %s" % tensor)
2431        _assert_same_graph(self, tensor)
2432
2433        # Reset cached inputs.
2434        self._inputs_val = None
2435        pywrap_tf_session.AddWhileInputHack(
2436            c_graph,  # pylint: disable=protected-access
2437            tensor._as_tf_output(),  # pylint: disable=protected-access
2438            self._c_op)
2439
2440  def _add_control_inputs(self, ops):
2441    """Add a list of new control inputs to this operation.
2442
2443    Args:
2444      ops: the list of Operations to add as control input.
2445
2446    Raises:
2447      TypeError: if ops is not a list of Operations.
2448      ValueError: if any op in ops is from a different graph.
2449    """
2450    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2451      for op in ops:
2452        if not isinstance(op, Operation):
2453          raise TypeError("op must be an Operation: %s" % op)
2454        pywrap_tf_session.AddControlInput(
2455            c_graph,
2456            self._c_op,  # pylint: disable=protected-access
2457            op._c_op)  # pylint: disable=protected-access
2458
2459  def _add_control_input(self, op):
2460    """Add a new control input to this operation.
2461
2462    Args:
2463      op: the Operation to add as control input.
2464
2465    Raises:
2466      TypeError: if op is not an Operation.
2467      ValueError: if op is from a different graph.
2468    """
2469    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2470      if not isinstance(op, Operation):
2471        raise TypeError("op must be an Operation: %s" % op)
2472      pywrap_tf_session.AddControlInput(
2473          c_graph,
2474          self._c_op,  # pylint: disable=protected-access
2475          op._c_op)  # pylint: disable=protected-access
2476
2477  def _remove_all_control_inputs(self):
2478    """Removes any control inputs to this operation."""
2479    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2480      pywrap_tf_session.RemoveAllControlInputs(c_graph, self._c_op)  # pylint: disable=protected-access
2481
2482  def _add_outputs(self, types, shapes):
2483    """Adds new Tensors to self.outputs.
2484
2485    Note: this is generally unsafe to use. This is used in certain situations in
2486    conjunction with _set_type_list_attr.
2487
2488    Args:
2489      types: list of DTypes
2490      shapes: list of TensorShapes
2491    """
2492    assert len(types) == len(shapes)
2493    orig_num_outputs = len(self.outputs)
2494    for i in range(len(types)):
2495      t = Tensor(self, orig_num_outputs + i, types[i])
2496      self._outputs.append(t)
2497      t.set_shape(shapes[i])
2498
2499  def __str__(self):
2500    return str(self.node_def)
2501
2502  def __repr__(self):
2503    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2504
2505  def __tf_tensor__(self, dtype=None, name=None):
2506    """Raises a helpful error."""
2507    raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
2508
2509  @property
2510  def outputs(self):
2511    """The list of `Tensor` objects representing the outputs of this op."""
2512    return self._outputs
2513
2514  @property
2515  def inputs(self):
2516    """The sequence of `Tensor` objects representing the data inputs of this op."""
2517    if self._inputs_val is None:
2518      # pylint: disable=protected-access
2519      self._inputs_val = tuple(
2520          self.graph._get_tensor_by_tf_output(i)
2521          for i in pywrap_tf_session.GetOperationInputs(self._c_op))
2522      # pylint: enable=protected-access
2523    return self._inputs_val
2524
2525  @property
2526  def _input_types(self):
2527    num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
2528    input_types = [
2529        dtypes.as_dtype(
2530            pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
2531        for i in range(num_inputs)
2532    ]
2533    return input_types
2534
2535  @property
2536  def control_inputs(self):
2537    """The `Operation` objects on which this op has a control dependency.
2538
2539    Before this op is executed, TensorFlow will ensure that the
2540    operations in `self.control_inputs` have finished executing. This
2541    mechanism can be used to run ops sequentially for performance
2542    reasons, or to ensure that the side effects of an op are observed
2543    in the correct order.
2544
2545    Returns:
2546      A list of `Operation` objects.
2547
2548    """
2549    control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
2550        self._c_op)
2551    # pylint: disable=protected-access
2552    return [
2553        self.graph._get_operation_by_name_unsafe(
2554            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2555    ]
2556    # pylint: enable=protected-access
2557
2558  @property
2559  def _control_outputs(self):
2560    """The `Operation` objects which have a control dependency on this op.
2561
2562    Before any of the ops in self._control_outputs can execute tensorflow will
2563    ensure self has finished executing.
2564
2565    Returns:
2566      A list of `Operation` objects.
2567
2568    """
2569    control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
2570        self._c_op)
2571    # pylint: disable=protected-access
2572    return [
2573        self.graph._get_operation_by_name_unsafe(
2574            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2575    ]
2576    # pylint: enable=protected-access
2577
2578  @property
2579  def type(self):
2580    """The type of the op (e.g. `"MatMul"`)."""
2581    return pywrap_tf_session.TF_OperationOpType(self._c_op)
2582
2583  @property
2584  def graph(self):
2585    """The `Graph` that contains this operation."""
2586    return self._graph
2587
2588  @property
2589  def node_def(self):
2590    # pylint: disable=line-too-long
2591    """Returns the `NodeDef` representation of this operation.
2592
2593    Returns:
2594      A
2595      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2596      protocol buffer.
2597    """
2598    # pylint: enable=line-too-long
2599    with c_api_util.tf_buffer() as buf:
2600      pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
2601      data = pywrap_tf_session.TF_GetBuffer(buf)
2602    node_def = node_def_pb2.NodeDef()
2603    node_def.ParseFromString(compat.as_bytes(data))
2604    return node_def
2605
2606  @property
2607  def op_def(self):
2608    # pylint: disable=line-too-long
2609    """Returns the `OpDef` proto that represents the type of this op.
2610
2611    Returns:
2612      An
2613      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2614      protocol buffer.
2615    """
2616    # pylint: enable=line-too-long
2617    return self._graph._get_op_def(self.type)
2618
2619  @property
2620  def traceback(self):
2621    """Returns the call stack from when this operation was constructed."""
2622    # FIXME(b/225423591): This object contains a dangling reference if _c_op
2623    # goes out of scope.
2624    return pywrap_tf_session.TF_OperationGetStackTrace(self._c_op)
2625
2626  def _set_attr(self, attr_name, attr_value):
2627    """Private method used to set an attribute in the node_def."""
2628    buf = pywrap_tf_session.TF_NewBufferFromString(
2629        compat.as_bytes(attr_value.SerializeToString()))
2630    try:
2631      self._set_attr_with_buf(attr_name, buf)
2632    finally:
2633      pywrap_tf_session.TF_DeleteBuffer(buf)
2634
2635  def _set_attr_with_buf(self, attr_name, attr_buf):
2636    """Set an attr in the node_def with a pre-allocated buffer."""
2637    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2638      # pylint: disable=protected-access
2639      pywrap_tf_session.SetAttr(c_graph, self._c_op, attr_name, attr_buf)
2640      # pylint: enable=protected-access
2641
2642  def _set_func_attr(self, attr_name, func_name):
2643    """Private method used to set a function attribute in the node_def."""
2644    func = attr_value_pb2.NameAttrList(name=func_name)
2645    self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2646
2647  def _set_func_list_attr(self, attr_name, func_names):
2648    """Private method used to set a list(function) attribute in the node_def."""
2649    funcs = [attr_value_pb2.NameAttrList(name=func_name)
2650             for func_name in func_names]
2651    funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
2652    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
2653
2654  def _set_type_list_attr(self, attr_name, types):
2655    """Private method used to set a list(type) attribute in the node_def."""
2656    if not types:
2657      return
2658    if isinstance(types[0], dtypes.DType):
2659      types = [dt.as_datatype_enum for dt in types]
2660    types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2661    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2662
2663  def _set_shape_list_attr(self, attr_name, shapes):
2664    """Private method used to set a list(shape) attribute in the node_def."""
2665    shapes = [s.as_proto() for s in shapes]
2666    shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2667    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2668
2669  def _clear_attr(self, attr_name):
2670    """Private method used to clear an attribute in the node_def."""
2671    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2672      # pylint: disable=protected-access
2673      pywrap_tf_session.ClearAttr(c_graph, self._c_op, attr_name)
2674      # pylint: enable=protected-access
2675
2676  def get_attr(self, name):
2677    """Returns the value of the attr of this op with the given `name`.
2678
2679    Args:
2680      name: The name of the attr to fetch.
2681
2682    Returns:
2683      The value of the attr, as a Python object.
2684
2685    Raises:
2686      ValueError: If this op does not have an attr with the given `name`.
2687    """
2688    fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2689    try:
2690      with c_api_util.tf_buffer() as buf:
2691        pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2692        data = pywrap_tf_session.TF_GetBuffer(buf)
2693    except errors.InvalidArgumentError as e:
2694      # Convert to ValueError for backwards compatibility.
2695      raise ValueError(e.message)
2696    x = attr_value_pb2.AttrValue()
2697    x.ParseFromString(data)
2698
2699    oneof_value = x.WhichOneof("value")
2700    if oneof_value is None:
2701      return []
2702    if oneof_value == "list":
2703      for f in fields:
2704        if getattr(x.list, f):
2705          if f == "type":
2706            return [dtypes.as_dtype(t) for t in x.list.type]
2707          else:
2708            return list(getattr(x.list, f))
2709      return []
2710    if oneof_value == "type":
2711      return dtypes.as_dtype(x.type)
2712    assert oneof_value in fields, "Unsupported field type in " + str(x)
2713    return getattr(x, oneof_value)
2714
2715  def _get_attr_type(self, name):
2716    """Returns the `DType` value of the attr of this op with the given `name`."""
2717    try:
2718      dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
2719      return _DTYPES_INTERN_TABLE[dtype_enum]
2720    except errors.InvalidArgumentError as e:
2721      # Convert to ValueError for backwards compatibility.
2722      raise ValueError(e.message)
2723
2724  def _get_attr_bool(self, name):
2725    """Returns the `bool` value of the attr of this op with the given `name`."""
2726    try:
2727      return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
2728    except errors.InvalidArgumentError as e:
2729      # Convert to ValueError for backwards compatibility.
2730      raise ValueError(e.message)
2731
2732  def _get_attr_int(self, name):
2733    """Returns the `int` value of the attr of this op with the given `name`."""
2734    try:
2735      return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
2736    except errors.InvalidArgumentError as e:
2737      # Convert to ValueError for backwards compatibility.
2738      raise ValueError(e.message)
2739
2740  def experimental_set_type(self, type_proto):
2741    """Sets the corresponding node's `experimental_type` field.
2742
2743    See the description of `NodeDef.experimental_type` for more info.
2744
2745    Args:
2746      type_proto: A FullTypeDef proto message. The root type_if of this object
2747        must be `TFT_PRODUCT`, even for ops which only have a singlre return
2748        value.
2749    """
2750    with self._graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
2751      if (type_proto.type_id
2752          not in (full_type_pb2.TFT_UNSET, full_type_pb2.TFT_PRODUCT)):
2753        raise ValueError("error setting the type of ", self.name,
2754                         ": expected TFT_UNSET or TFT_PRODUCT, got ",
2755                         type_proto.type_id)
2756      pywrap_tf_session.SetFullType(c_graph, self._c_op,
2757                                    type_proto.SerializeToString())  # pylint:disable=protected-access
2758
2759  def run(self, feed_dict=None, session=None):
2760    """Runs this operation in a `Session`.
2761
2762    Calling this method will execute all preceding operations that
2763    produce the inputs needed for this operation.
2764
2765    *N.B.* Before invoking `Operation.run()`, its graph must have been
2766    launched in a session, and either a default session must be
2767    available, or `session` must be specified explicitly.
2768
2769    Args:
2770      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
2771        `tf.Session.run` for a description of the valid feed values.
2772      session: (Optional.) The `Session` to be used to run to this operation. If
2773        none, the default session will be used.
2774    """
2775    _run_using_default_session(self, feed_dict, self.graph, session)
2776
2777# TODO(b/185395742): Clean up usages of _gradient_registry
2778gradient_registry = _gradient_registry = registry.Registry("gradient")
2779
2780
2781@tf_export("RegisterGradient")
2782class RegisterGradient(object):
2783  """A decorator for registering the gradient function for an op type.
2784
2785  This decorator is only used when defining a new op type. For an op
2786  with `m` inputs and `n` outputs, the gradient function is a function
2787  that takes the original `Operation` and `n` `Tensor` objects
2788  (representing the gradients with respect to each output of the op),
2789  and returns `m` `Tensor` objects (representing the partial gradients
2790  with respect to each input of the op).
2791
2792  For example, assuming that operations of type `"Sub"` take two
2793  inputs `x` and `y`, and return a single output `x - y`, the
2794  following gradient function would be registered:
2795
2796  ```python
2797  @tf.RegisterGradient("Sub")
2798  def _sub_grad(unused_op, grad):
2799    return grad, tf.negative(grad)
2800  ```
2801
2802  The decorator argument `op_type` is the string type of an
2803  operation. This corresponds to the `OpDef.name` field for the proto
2804  that defines the operation.
2805  """
2806
2807  __slots__ = ["_op_type"]
2808
2809  def __init__(self, op_type):
2810    """Creates a new decorator with `op_type` as the Operation type.
2811
2812    Args:
2813      op_type: The string type of an operation. This corresponds to the
2814        `OpDef.name` field for the proto that defines the operation.
2815
2816    Raises:
2817      TypeError: If `op_type` is not string.
2818    """
2819    if not isinstance(op_type, str):
2820      raise TypeError("op_type must be a string")
2821    self._op_type = op_type
2822
2823  def __call__(self, f):
2824    """Registers the function `f` as gradient function for `op_type`."""
2825    gradient_registry.register(f, self._op_type)
2826    return f
2827
2828
2829@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2830@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2831def no_gradient(op_type):
2832  """Specifies that ops of type `op_type` is not differentiable.
2833
2834  This function should *not* be used for operations that have a
2835  well-defined gradient that is not yet implemented.
2836
2837  This function is only used when defining a new op type. It may be
2838  used for ops such as `tf.size()` that are not differentiable.  For
2839  example:
2840
2841  ```python
2842  tf.no_gradient("Size")
2843  ```
2844
2845  The gradient computed for 'op_type' will then propagate zeros.
2846
2847  For ops that have a well-defined gradient but are not yet implemented,
2848  no declaration should be made, and an error *must* be thrown if
2849  an attempt to request its gradient is made.
2850
2851  Args:
2852    op_type: The string type of an operation. This corresponds to the
2853      `OpDef.name` field for the proto that defines the operation.
2854
2855  Raises:
2856    TypeError: If `op_type` is not a string.
2857
2858  """
2859  if not isinstance(op_type, str):
2860    raise TypeError("op_type must be a string")
2861  gradient_registry.register(None, op_type)
2862
2863
2864# Aliases for the old names, will be eventually removed.
2865NoGradient = no_gradient
2866NotDifferentiable = no_gradient
2867
2868
2869def get_gradient_function(op):
2870  """Returns the function that computes gradients for "op"."""
2871  if not op.inputs:
2872    return None
2873
2874  gradient_function = op._gradient_function  # pylint: disable=protected-access
2875  if gradient_function:
2876    return gradient_function
2877
2878  try:
2879    op_type = op.get_attr("_gradient_op_type")
2880  except ValueError:
2881    op_type = op.type
2882  return gradient_registry.lookup(op_type)
2883
2884
2885def set_shape_and_handle_data_for_outputs(_):
2886  """No op. TODO(b/74620627): Remove this."""
2887  pass
2888
2889
2890class OpStats(object):
2891  """A holder for statistics about an operator.
2892
2893  This class holds information about the resource requirements for an op,
2894  including the size of its weight parameters on-disk and how many FLOPS it
2895  requires to execute forward inference.
2896
2897  If you define a new operation, you can create a function that will return a
2898  set of information about its usage of the CPU and disk space when serialized.
2899  The function itself takes a Graph object that's been set up so you can call
2900  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2901  argument.
2902
2903  """
2904
2905  __slots__ = ["_statistic_type", "_value"]
2906
2907  def __init__(self, statistic_type, value=None):
2908    """Sets up the initial placeholders for the statistics."""
2909    self.statistic_type = statistic_type
2910    self.value = value
2911
2912  @property
2913  def statistic_type(self):
2914    return self._statistic_type
2915
2916  @statistic_type.setter
2917  def statistic_type(self, statistic_type):
2918    self._statistic_type = statistic_type
2919
2920  @property
2921  def value(self):
2922    return self._value
2923
2924  @value.setter
2925  def value(self, value):
2926    self._value = value
2927
2928  def __iadd__(self, other):
2929    if other.statistic_type != self.statistic_type:
2930      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2931                       (self.statistic_type, other.statistic_type))
2932    if self.value is None:
2933      self.value = other.value
2934    elif other.value is not None:
2935      self._value += other.value
2936    return self
2937
2938
2939_stats_registry = registry.Registry("statistical functions")
2940
2941
2942class RegisterStatistics(object):
2943  """A decorator for registering the statistics function for an op type.
2944
2945  This decorator can be defined for an op type so that it gives a
2946  report on the resources used by an instance of an operator, in the
2947  form of an OpStats object.
2948
2949  Well-known types of statistics include these so far:
2950
2951  - flops: When running a graph, the bulk of the computation happens doing
2952    numerical calculations like matrix multiplications. This type allows a node
2953    to return how many floating-point operations it takes to complete. The
2954    total number of FLOPs for a graph is a good guide to its expected latency.
2955
2956  You can add your own statistics just by picking a new type string, registering
2957  functions for the ops you care about, and then calling get_stats_for_node_def.
2958
2959  If a statistic for an op is registered multiple times, a KeyError will be
2960  raised.
2961
2962  Since the statistics is counted on a per-op basis. It is not suitable for
2963  model parameters (capacity), which is expected to be counted only once, even
2964  if it is shared by multiple ops. (e.g. RNN)
2965
2966  For example, you can define a new metric called doohickey for a Foo operation
2967  by placing this in your code:
2968
2969  ```python
2970  @ops.RegisterStatistics("Foo", "doohickey")
2971  def _calc_foo_bojangles(unused_graph, unused_node_def):
2972    return ops.OpStats("doohickey", 20)
2973  ```
2974
2975  Then in client code you can retrieve the value by making this call:
2976
2977  ```python
2978  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2979  ```
2980
2981  If the NodeDef is for an op with a registered doohickey function, you'll get
2982  back the calculated amount in doohickey.value, or None if it's not defined.
2983
2984  """
2985
2986  __slots__ = ["_op_type", "_statistic_type"]
2987
2988  def __init__(self, op_type, statistic_type):
2989    """Saves the `op_type` as the `Operation` type."""
2990    if not isinstance(op_type, str):
2991      raise TypeError("op_type must be a string.")
2992    if "," in op_type:
2993      raise TypeError("op_type must not contain a comma.")
2994    self._op_type = op_type
2995    if not isinstance(statistic_type, str):
2996      raise TypeError("statistic_type must be a string.")
2997    if "," in statistic_type:
2998      raise TypeError("statistic_type must not contain a comma.")
2999    self._statistic_type = statistic_type
3000
3001  def __call__(self, f):
3002    """Registers "f" as the statistics function for "op_type"."""
3003    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
3004    return f
3005
3006
3007def get_stats_for_node_def(graph, node, statistic_type):
3008  """Looks up the node's statistics function in the registry and calls it.
3009
3010  This function takes a Graph object and a NodeDef from a GraphDef, and if
3011  there's an associated statistics method, calls it and returns a result. If no
3012  function has been registered for the particular node type, it returns an empty
3013  statistics object.
3014
3015  Args:
3016    graph: A Graph object that's been set up with the node's graph.
3017    node: A NodeDef describing the operator.
3018    statistic_type: A string identifying the statistic we're interested in.
3019
3020  Returns:
3021    An OpStats object containing information about resource usage.
3022  """
3023
3024  try:
3025    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
3026    result = stats_func(graph, node)
3027  except LookupError:
3028    result = OpStats(statistic_type)
3029  return result
3030
3031
3032def name_from_scope_name(name):
3033  """Returns the name of an op given the name of its scope.
3034
3035  Args:
3036    name: the name of the scope.
3037
3038  Returns:
3039    the name of the op (equal to scope name minus any trailing slash).
3040  """
3041  return name[:-1] if (name and name[-1] == "/") else name
3042
3043
3044_MUTATION_LOCK_GROUP = 0
3045_SESSION_RUN_LOCK_GROUP = 1
3046
3047
3048@tf_contextlib.contextmanager
3049def resource_creator_scope(resource_type, resource_creator):
3050  with get_default_graph()._resource_creator_scope(resource_type,  # pylint: disable=protected-access
3051                                                   resource_creator):
3052    yield
3053
3054
3055@tf_export("Graph")
3056class Graph(object):
3057  """A TensorFlow computation, represented as a dataflow graph.
3058
3059  Graphs are used by `tf.function`s to represent the function's computations.
3060  Each graph contains a set of `tf.Operation` objects, which represent units of
3061  computation; and `tf.Tensor` objects, which represent the units of data that
3062  flow between operations.
3063
3064  ### Using graphs directly (deprecated)
3065
3066  A `tf.Graph` can be constructed and used directly without a `tf.function`, as
3067  was required in TensorFlow 1, but this is deprecated and it is recommended to
3068  use a `tf.function` instead. If a graph is directly used, other deprecated
3069  TensorFlow 1 classes are also required to execute the graph, such as a
3070  `tf.compat.v1.Session`.
3071
3072  A default graph can be registered with the `tf.Graph.as_default` context
3073  manager. Then, operations will be added to the graph instead of being executed
3074  eagerly. For example:
3075
3076  ```python
3077  g = tf.Graph()
3078  with g.as_default():
3079    # Define operations and tensors in `g`.
3080    c = tf.constant(30.0)
3081    assert c.graph is g
3082  ```
3083
3084  `tf.compat.v1.get_default_graph()` can be used to obtain the default graph.
3085
3086  Important note: This class *is not* thread-safe for graph construction. All
3087  operations should be created from a single thread, or external
3088  synchronization must be provided. Unless otherwise specified, all methods
3089  are not thread-safe.
3090
3091  A `Graph` instance supports an arbitrary number of "collections"
3092  that are identified by name. For convenience when building a large
3093  graph, collections can store groups of related objects: for
3094  example, the `tf.Variable` uses a collection (named
3095  `tf.GraphKeys.GLOBAL_VARIABLES`) for
3096  all variables that are created during the construction of a graph. The caller
3097  may define additional collections by specifying a new name.
3098  """
3099
3100  def __init__(self):
3101    """Creates a new, empty Graph."""
3102    # Protects core state that can be returned via public accessors.
3103    # Thread-safety is provided on a best-effort basis to support buggy
3104    # programs, and is not guaranteed by the public `tf.Graph` API.
3105    #
3106    # NOTE(mrry): This does not protect the various stacks. A warning will
3107    # be reported if these are used from multiple threads
3108    self._lock = threading.RLock()
3109    # The group lock synchronizes Session.run calls with methods that create
3110    # and mutate ops (e.g. Graph.create_op()). This synchronization is
3111    # necessary because it's illegal to modify an operation after it's been run.
3112    # The group lock allows any number of threads to mutate ops at the same time
3113    # but if any modification is going on, all Session.run calls have to wait.
3114    # Similarly, if one or more Session.run calls are going on, all mutate ops
3115    # have to wait until all Session.run calls have finished.
3116    self._group_lock = lock_util.GroupLock(num_groups=2)
3117    self._nodes_by_id = {}  # GUARDED_BY(self._lock)
3118    self._next_id_counter = 0  # GUARDED_BY(self._lock)
3119    self._nodes_by_name = {}  # GUARDED_BY(self._lock)
3120    self._version = 0  # GUARDED_BY(self._lock)
3121    # Maps a name used in the graph to the next id to use for that name.
3122    self._names_in_use = {}
3123    self._stack_state_is_thread_local = False
3124    self._thread_local = threading.local()
3125    # Functions that will be applied to choose a device if none is specified.
3126    # In TF2.x or after switch_to_thread_local(),
3127    # self._thread_local._device_function_stack is used instead.
3128    self._graph_device_function_stack = traceable_stack.TraceableStack()
3129    # Default original_op applied to new ops.
3130    self._default_original_op = None
3131    # Current control flow context. It could be either CondContext or
3132    # WhileContext defined in ops/control_flow_ops.py
3133    self._control_flow_context = None
3134    # A new node will depend of the union of all of the nodes in the stack.
3135    # In TF2.x or after switch_to_thread_local(),
3136    # self._thread_local._control_dependencies_stack is used instead.
3137    self._graph_control_dependencies_stack = []
3138    # Arbitrary collections of objects.
3139    self._collections = {}
3140    # The graph-level random seed
3141    self._seed = None
3142    # A dictionary of attributes that should be applied to all ops.
3143    self._attr_scope_map = {}
3144    # A map from op type to the kernel label that should be used.
3145    self._op_to_kernel_label_map = {}
3146    # A map from op type to an alternative op type that should be used when
3147    # computing gradients.
3148    self._gradient_override_map = {}
3149    # A map from op type to a gradient function that should be used instead.
3150    self._gradient_function_map = {}
3151    # True if the graph is considered "finalized".  In that case no
3152    # new operations can be added.
3153    self._finalized = False
3154    # Functions defined in the graph
3155    self._functions = collections.OrderedDict()
3156    # Default GraphDef versions
3157    self._graph_def_versions = versions_pb2.VersionDef(
3158        producer=versions.GRAPH_DEF_VERSION,
3159        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
3160    self._building_function = False
3161    # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
3162    # self._thread_local._colocation_stack is used instead.
3163    self._graph_colocation_stack = traceable_stack.TraceableStack()
3164    # Set of tensors that are dangerous to feed!
3165    self._unfeedable_tensors = object_identity.ObjectIdentitySet()
3166    # Set of operations that are dangerous to fetch!
3167    self._unfetchable_ops = set()
3168    # A map of tensor handle placeholder to tensor dtype.
3169    self._handle_feeders = {}
3170    # A map from tensor handle to its read op.
3171    self._handle_readers = {}
3172    # A map from tensor handle to its move op.
3173    self._handle_movers = {}
3174    # A map from tensor handle to its delete op.
3175    self._handle_deleters = {}
3176    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
3177    # will be shared when defining function graphs, for example, so optimizers
3178    # being called inside function definitions behave as if they were seeing the
3179    # actual outside graph).
3180    self._graph_key = "graph-key-%d/" % (uid(),)
3181    # A string with the last reduction method passed to
3182    # losses.compute_weighted_loss(), or None. This is required only for
3183    # backward compatibility with Estimator and optimizer V1 use cases.
3184    self._last_loss_reduction = None
3185    # Flag that is used to indicate whether loss has been scaled by optimizer.
3186    # If this flag has been set, then estimator uses it to scale losss back
3187    # before reporting. This is required only for backward compatibility with
3188    # Estimator and optimizer V1 use cases.
3189    self._is_loss_scaled_by_optimizer = False
3190    self._container = ""
3191
3192    # The current AutomaticControlDependencies context manager.
3193    self.experimental_acd_manager = None
3194    # Set to True if this graph is being built in an
3195    # AutomaticControlDependencies context.
3196    # Deprecated: use acd_manager instead.
3197    self._add_control_dependencies = False
3198
3199    # Cache for OpDef protobufs retrieved via the C API.
3200    self._op_def_cache = {}
3201    # Cache for constant results of `broadcast_gradient_args()`. The keys are
3202    # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
3203    # values are tuples of reduction indices: (rx, ry).
3204    self._bcast_grad_args_cache = {}
3205    # Cache for constant results of `reduced_shape()`. The keys are pairs of
3206    # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
3207    # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
3208    self._reduced_shape_cache = {}
3209
3210    # TODO(skyewm): fold as much of the above as possible into the C
3211    # implementation
3212    self._c_graph = c_api_util.ScopedTFGraph(self._graph_key)
3213    # The C API requires all ops to have shape functions. Disable this
3214    # requirement (many custom ops do not have shape functions, and we don't
3215    # want to break these existing cases).
3216    with self._c_graph.get() as c_graph:
3217      pywrap_tf_session.SetRequireShapeInferenceFns(c_graph, False)
3218    if tf2.enabled():
3219      self.switch_to_thread_local()
3220
3221  # Note: this method is private because the API of tf.Graph() is public and
3222  # frozen, and this functionality is still not ready for public visibility.
3223  @tf_contextlib.contextmanager
3224  def _variable_creator_scope(self, creator, priority=100):
3225    """Scope which defines a variable creation function.
3226
3227    Args:
3228      creator: A callable taking `next_creator` and `kwargs`. See the
3229        `tf.variable_creator_scope` docstring.
3230      priority: Creators with a higher `priority` are called first. Within the
3231        same priority, creators are called inner-to-outer.
3232
3233    Yields:
3234      `_variable_creator_scope` is a context manager with a side effect, but
3235      doesn't return a value.
3236
3237    Raises:
3238      RuntimeError: If variable creator scopes are not properly nested.
3239    """
3240    # This step keeps a reference to the existing stack, and it also initializes
3241    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3242    old = self._variable_creator_stack
3243    new = list(old)
3244    new.append((priority, creator))
3245    # Sorting is stable, so we'll put higher-priority creators later in the list
3246    # but otherwise maintain registration order.
3247    new.sort(key=lambda item: item[0])
3248    self._thread_local._variable_creator_stack = new  # pylint: disable=protected-access
3249    try:
3250      yield
3251    finally:
3252      if self._thread_local._variable_creator_stack is not new:  # pylint: disable=protected-access
3253        raise RuntimeError(
3254            "Exiting variable_creator_scope without proper nesting.")
3255      self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
3256
3257  # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope.
3258  # pylint: disable=protected-access
3259  @tf_contextlib.contextmanager
3260  def _resource_creator_scope(self, resource_type, creator):
3261    """Scope which defines a resource creation function used by some resource.
3262
3263    The resource should be a subclass of CapturableResource with a class method
3264    `cls._resource_type`, the output of which is what the `resource_type`
3265    argument should be. By default, `cls._resource_type` returns the class name,
3266    `cls.__name__`. Given a scope, creators being added with the same
3267    `resource_type` argument will be composed together to apply to all classes
3268    with this `_resource_type`.
3269
3270
3271    `creator` is expected to be a function with the following signature:
3272
3273    ```
3274      def resource_creator(next_creator, *a, **kwargs)
3275    ```
3276
3277    The creator is supposed to eventually call the next_creator to create an
3278    instance if it does want to create an instance and not call
3279    the class initialization method directly. This helps make creators
3280    composable. A creator may choose to create multiple instances, return
3281    already existing instances, or simply register that an instance was created
3282    and defer to the next creator in line. Creators can also modify keyword
3283    arguments seen by the next creators.
3284
3285    Valid keyword arguments in `kwargs` depends on the specific resource
3286    class. For StaticHashTable, this may be:
3287    * initializer: The table initializer to use.
3288    * default_value: The value to use if a key is missing in the table.
3289    * name: Optional name for the table, default to None.
3290
3291
3292    Args:
3293      resource_type: the output of the resource class's `_resource_type` method.
3294      creator: the passed creator for the resource.
3295
3296    Yields:
3297      A scope in which the creator is active
3298
3299    Raises:
3300      RuntimeError: If resource_creator_scope is existed without proper nesting.
3301    """
3302    # This step keeps a reference to the existing stack, and it also initializes
3303    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3304    old = self._resource_creator_stack
3305    new = copy.deepcopy(old)
3306    if isinstance(resource_type, (list, tuple)):
3307      for r in resource_type:
3308        new[r].append(creator)
3309    else:
3310      new[resource_type].append(creator)
3311    self._thread_local._resource_creator_stack = new
3312    try:
3313      yield
3314    finally:
3315      if self._thread_local._resource_creator_stack is not new:
3316        raise RuntimeError(
3317            "Exiting resource_creator_scope without proper nesting.")
3318      self._thread_local._resource_creator_stack = old
3319
3320  @property
3321  def _resource_creator_stack(self):
3322    if not hasattr(self._thread_local, "_resource_creator_stack"):
3323      self._thread_local._resource_creator_stack = collections.defaultdict(list)
3324    return self._thread_local._resource_creator_stack
3325
3326  @_resource_creator_stack.setter
3327  def _resource_creator_stack(self, resource_creator_stack):
3328    self._thread_local._resource_creator_stack = resource_creator_stack
3329  # pylint: enable=protected-access
3330
3331  # Note: this method is private because the API of tf.Graph() is public and
3332  # frozen, and this functionality is still not ready for public visibility.
3333  @property
3334  def _variable_creator_stack(self):
3335    if not hasattr(self._thread_local, "_variable_creator_stack"):
3336      self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
3337
3338    # This previously returned a copy of the stack instead of the stack itself,
3339    # to guard against accidental mutation. Consider, however, code that wants
3340    # to save and restore the variable creator stack:
3341    #     def f():
3342    #       original_stack = graph._variable_creator_stack
3343    #       graph._variable_creator_stack = new_stack
3344    #       ...  # Some code
3345    #       graph._variable_creator_stack = original_stack
3346    #
3347    # And lets say you have some code that calls this function with some
3348    # variable_creator:
3349    #     def g():
3350    #       with variable_scope.variable_creator_scope(creator):
3351    #         f()
3352    # When exiting the variable creator scope, it would see a different stack
3353    # object than it expected leading to a "Exiting variable_creator_scope
3354    # without proper nesting" error.
3355    return self._thread_local._variable_creator_stack  # pylint: disable=protected-access
3356
3357  @_variable_creator_stack.setter
3358  def _variable_creator_stack(self, variable_creator_stack):
3359    self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
3360
3361  def _check_not_finalized(self):
3362    """Check if the graph is finalized.
3363
3364    Raises:
3365      RuntimeError: If the graph finalized.
3366    """
3367    if self._finalized:
3368      raise RuntimeError("Graph is finalized and cannot be modified.")
3369
3370  def _add_op(self, op, op_name):
3371    """Adds 'op' to the graph and returns the unique ID for the added Operation.
3372
3373    Args:
3374      op: the Operation to add.
3375      op_name: the name of the Operation.
3376
3377    Returns:
3378      An integer that is a unique ID for the added Operation.
3379    """
3380    self._check_not_finalized()
3381    with self._lock:
3382      self._next_id_counter += 1
3383      op_id = self._next_id_counter
3384      self._nodes_by_id[op_id] = op
3385      self._nodes_by_name[op_name] = op
3386      self._version = max(self._version, op_id)
3387      return op_id
3388
3389  @property
3390  def version(self):
3391    """Returns a version number that increases as ops are added to the graph.
3392
3393    Note that this is unrelated to the
3394    `tf.Graph.graph_def_versions`.
3395
3396    Returns:
3397       An integer version that increases as ops are added to the graph.
3398    """
3399    if self._finalized:
3400      return self._version
3401
3402    with self._lock:
3403      return self._version
3404
3405  @property
3406  def graph_def_versions(self):
3407    # pylint: disable=line-too-long
3408    """The GraphDef version information of this graph.
3409
3410    For details on the meaning of each version, see
3411    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
3412
3413    Returns:
3414      A `VersionDef`.
3415    """
3416    # pylint: enable=line-too-long
3417    with c_api_util.tf_buffer() as buf:
3418      with self._c_graph.get() as c_graph:
3419        pywrap_tf_session.TF_GraphVersions(c_graph, buf)
3420      data = pywrap_tf_session.TF_GetBuffer(buf)
3421    version_def = versions_pb2.VersionDef()
3422    version_def.ParseFromString(compat.as_bytes(data))
3423    return version_def
3424
3425  @property
3426  def seed(self):
3427    """The graph-level random seed of this graph."""
3428    return self._seed
3429
3430  @seed.setter
3431  def seed(self, seed):
3432    self._seed = seed
3433
3434  @property
3435  def finalized(self):
3436    """True if this graph has been finalized."""
3437    return self._finalized
3438
3439  def finalize(self):
3440    """Finalizes this graph, making it read-only.
3441
3442    After calling `g.finalize()`, no new operations can be added to
3443    `g`.  This method is used to ensure that no operations are added
3444    to a graph when it is shared between multiple threads, for example
3445    when using a `tf.compat.v1.train.QueueRunner`.
3446    """
3447    self._finalized = True
3448
3449  def _unsafe_unfinalize(self):
3450    """Opposite of `finalize`.
3451
3452    Internal interface.
3453
3454    NOTE: Unfinalizing a graph could have negative impact on performance,
3455    especially in a multi-threaded environment.  Unfinalizing a graph
3456    when it is in use by a Session may lead to undefined behavior. Ensure
3457    that all sessions using a graph are closed before calling this method.
3458    """
3459    self._finalized = False
3460
3461  def _get_control_flow_context(self):
3462    """Returns the current control flow context.
3463
3464    Returns:
3465      A context object.
3466    """
3467    return self._control_flow_context
3468
3469  def _set_control_flow_context(self, ctx):
3470    """Sets the current control flow context.
3471
3472    Args:
3473      ctx: a context object.
3474    """
3475    self._control_flow_context = ctx
3476
3477  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3478    """If this graph contains functions, copy them to `graph_def`."""
3479    bytesize = starting_bytesize
3480    for f in self._functions.values():
3481      bytesize += f.definition.ByteSize()
3482      if bytesize >= (1 << 31) or bytesize < 0:
3483        raise ValueError("GraphDef cannot be larger than 2GB.")
3484      graph_def.library.function.extend([f.definition])
3485      if f.grad_func_name:
3486        grad_def = function_pb2.GradientDef()
3487        grad_def.function_name = f.name
3488        grad_def.gradient_func = f.grad_func_name
3489        graph_def.library.gradient.extend([grad_def])
3490
3491  def _as_graph_def(self, from_version=None, add_shapes=False):
3492    # pylint: disable=line-too-long
3493    """Returns a serialized `GraphDef` representation of this graph.
3494
3495    The serialized `GraphDef` can be imported into another `Graph`
3496    (using `tf.import_graph_def`) or used with the
3497    [C++ Session API](https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/r0.10/tensorflow/g3doc/api_docs/cc/index.md).
3498
3499    This method is thread-safe.
3500
3501    Args:
3502      from_version: Optional.  If this is set, returns a `GraphDef` containing
3503        only the nodes that were added to this graph since its `version`
3504        property had the given value.
3505      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3506        the inferred shapes of each of its outputs.
3507
3508    Returns:
3509      A tuple containing a
3510      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3511      protocol buffer, and the version of the graph to which that
3512      `GraphDef` corresponds.
3513
3514    Raises:
3515      ValueError: If the `graph_def` would be too large.
3516
3517    """
3518    # pylint: enable=line-too-long
3519    with self._lock:
3520      with c_api_util.tf_buffer() as buf:
3521        with self._c_graph.get() as c_graph:
3522          pywrap_tf_session.TF_GraphToGraphDef(c_graph, buf)
3523          data = pywrap_tf_session.TF_GetBuffer(buf)
3524      graph = graph_pb2.GraphDef()
3525      graph.ParseFromString(compat.as_bytes(data))
3526      # Strip the experimental library field iff it's empty.
3527      if not graph.library.function:
3528        graph.ClearField("library")
3529
3530      if add_shapes:
3531        for node in graph.node:
3532          op = self._nodes_by_name[node.name]
3533          if op.outputs:
3534            node.attr["_output_shapes"].list.shape.extend(
3535                [output.get_shape().as_proto() for output in op.outputs])
3536        for function_def in graph.library.function:
3537          defined_function = self._functions[function_def.signature.name]
3538          try:
3539            func_graph = defined_function.graph
3540          except AttributeError:
3541            # _DefinedFunction doesn't have a graph, _EagerDefinedFunction
3542            # does. Both rely on ops.py, so we can't really isinstance check
3543            # them.
3544            continue
3545          input_shapes = function_def.attr["_input_shapes"]
3546          try:
3547            func_graph_inputs = func_graph.inputs
3548          except AttributeError:
3549            continue
3550          # TODO(b/141471245): Fix the inconsistency when inputs of func graph
3551          # are appended during gradient computation of while/cond.
3552          assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)]
3553          # If the function_def has inputs already filled out, skip this step.
3554          if not input_shapes.list.shape:
3555            for input_tensor, arg_def in zip(func_graph_inputs,
3556                                             function_def.signature.input_arg):
3557              input_shapes.list.shape.add().CopyFrom(
3558                  input_tensor.get_shape().as_proto())
3559              if input_tensor.dtype == dtypes.resource:
3560                _copy_handle_data_to_arg_def(input_tensor, arg_def)
3561
3562          for output_tensor, arg_def in zip(func_graph.outputs,
3563                                            function_def.signature.output_arg):
3564            if output_tensor.dtype == dtypes.resource:
3565              _copy_handle_data_to_arg_def(output_tensor, arg_def)
3566
3567          for node in function_def.node_def:
3568            try:
3569              op = func_graph.get_operation_by_name(node.name)
3570            except KeyError:
3571              continue
3572            outputs = op.outputs
3573
3574            if op.type == "StatefulPartitionedCall":
3575              # Filter out any extra outputs (possibly added by function
3576              # backpropagation rewriting).
3577              num_outputs = len(node.attr["Tout"].list.type)
3578              outputs = outputs[:num_outputs]
3579
3580            node.attr["_output_shapes"].list.shape.extend(
3581                [output.get_shape().as_proto() for output in outputs])
3582
3583    return graph, self._version
3584
3585  def as_graph_def(self, from_version=None, add_shapes=False):
3586    # pylint: disable=line-too-long
3587    """Returns a serialized `GraphDef` representation of this graph.
3588
3589    The serialized `GraphDef` can be imported into another `Graph`
3590    (using `tf.import_graph_def`) or used with the
3591    [C++ Session API](../../api_docs/cc/index.md).
3592
3593    This method is thread-safe.
3594
3595    Args:
3596      from_version: Optional.  If this is set, returns a `GraphDef` containing
3597        only the nodes that were added to this graph since its `version`
3598        property had the given value.
3599      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3600        the inferred shapes of each of its outputs.
3601
3602    Returns:
3603      A
3604      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3605      protocol buffer.
3606
3607    Raises:
3608      ValueError: If the `graph_def` would be too large.
3609    """
3610    # pylint: enable=line-too-long
3611    result, _ = self._as_graph_def(from_version, add_shapes)
3612    return result
3613
3614  def _is_function(self, name):
3615    """Tests whether 'name' is registered in this graph's function library.
3616
3617    Args:
3618      name: string op name.
3619
3620    Returns:
3621      bool indicating whether or not 'name' is registered in function library.
3622    """
3623    return compat.as_str(name) in self._functions
3624
3625  def _get_function(self, name):
3626    """Returns the function definition for 'name'.
3627
3628    Args:
3629      name: string function name.
3630
3631    Returns:
3632      The function def proto.
3633    """
3634    return self._functions.get(compat.as_str(name), None)
3635
3636  def _add_function(self, function):
3637    """Adds a function to the graph.
3638
3639    After the function has been added, you can call to the function by
3640    passing the function name in place of an op name to
3641    `Graph.create_op()`.
3642
3643    Args:
3644      function: A `_DefinedFunction` object.
3645
3646    Raises:
3647      ValueError: if another function is defined with the same name.
3648    """
3649    self._check_not_finalized()
3650
3651    name = function.name
3652    # Sanity checks on gradient definition.
3653    if (function.grad_func_name is not None) and (function.python_grad_func is
3654                                                  not None):
3655      raise ValueError("Gradient defined twice for function %s" % name)
3656
3657    # Add function to graph
3658    # pylint: disable=protected-access
3659    with self._c_graph.get() as c_graph:
3660      with function._c_func.get() as func:
3661        if function._grad_func:
3662          with function._grad_func._c_func.get() as gradient:
3663            pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient)
3664        else:
3665          pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None)
3666    # pylint: enable=protected-access
3667
3668    self._functions[compat.as_str(name)] = function
3669
3670    # Need a new-enough consumer to support the functions we add to the graph.
3671    if self._graph_def_versions.min_consumer < 12:
3672      self._graph_def_versions.min_consumer = 12
3673
3674  @property
3675  def building_function(self):
3676    """Returns True iff this graph represents a function."""
3677    return self._building_function
3678
3679  # Helper functions to create operations.
3680  @deprecated_args(None,
3681                   "Shapes are always computed; don't use the compute_shapes "
3682                   "as it has no effect.", "compute_shapes")
3683  @traceback_utils.filter_traceback
3684  def create_op(
3685      self,
3686      op_type,
3687      inputs,
3688      dtypes=None,  # pylint: disable=redefined-outer-name
3689      input_types=None,
3690      name=None,
3691      attrs=None,
3692      op_def=None,
3693      compute_shapes=True,
3694      compute_device=True):
3695    """Creates an `Operation` in this graph.
3696
3697    This is a low-level interface for creating an `Operation`. Most
3698    programs will not call this method directly, and instead use the
3699    Python op constructors, such as `tf.constant()`, which add ops to
3700    the default graph.
3701
3702    Args:
3703      op_type: The `Operation` type to create. This corresponds to the
3704        `OpDef.name` field for the proto that defines the operation.
3705      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3706      dtypes: (Optional) A list of `DType` objects that will be the types of the
3707        tensors that the operation produces.
3708      input_types: (Optional.) A list of `DType`s that will be the types of the
3709        tensors that the operation consumes. By default, uses the base `DType`
3710        of each input in `inputs`. Operations that expect reference-typed inputs
3711        must specify `input_types` explicitly.
3712      name: (Optional.) A string name for the operation. If not specified, a
3713        name is generated based on `op_type`.
3714      attrs: (Optional.) A dictionary where the key is the attribute name (a
3715        string) and the value is the respective `attr` attribute of the
3716        `NodeDef` proto that will represent the operation (an `AttrValue`
3717        proto).
3718      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3719        the operation will have.
3720      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3721        computed).
3722      compute_device: (Optional.) If True, device functions will be executed to
3723        compute the device property of the Operation.
3724
3725    Raises:
3726      TypeError: if any of the inputs is not a `Tensor`.
3727      ValueError: if colocation conflicts with existing device assignment.
3728
3729    Returns:
3730      An `Operation` object.
3731    """
3732    del compute_shapes
3733    for idx, a in enumerate(inputs):
3734      if not isinstance(a, Tensor):
3735        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3736    return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
3737                                    attrs, op_def, compute_device)
3738
3739  def _create_op_internal(
3740      self,
3741      op_type,
3742      inputs,
3743      dtypes=None,  # pylint: disable=redefined-outer-name
3744      input_types=None,
3745      name=None,
3746      attrs=None,
3747      op_def=None,
3748      compute_device=True):
3749    """Creates an `Operation` in this graph.
3750
3751    Implements `Graph.create_op()` without the overhead of the deprecation
3752    wrapper.
3753
3754    Args:
3755      op_type: The `Operation` type to create. This corresponds to the
3756        `OpDef.name` field for the proto that defines the operation.
3757      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3758      dtypes: (Optional) A list of `DType` objects that will be the types of the
3759        tensors that the operation produces.
3760      input_types: (Optional.) A list of `DType`s that will be the types of the
3761        tensors that the operation consumes. By default, uses the base `DType`
3762        of each input in `inputs`. Operations that expect reference-typed inputs
3763        must specify `input_types` explicitly.
3764      name: (Optional.) A string name for the operation. If not specified, a
3765        name is generated based on `op_type`.
3766      attrs: (Optional.) A dictionary where the key is the attribute name (a
3767        string) and the value is the respective `attr` attribute of the
3768        `NodeDef` proto that will represent the operation (an `AttrValue`
3769        proto).
3770      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3771        the operation will have.
3772      compute_device: (Optional.) If True, device functions will be executed to
3773        compute the device property of the Operation.
3774
3775    Raises:
3776      ValueError: if colocation conflicts with existing device assignment.
3777
3778    Returns:
3779      An `Operation` object.
3780    """
3781    self._check_not_finalized()
3782    if name is None:
3783      name = op_type
3784    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3785    # after removing the trailing '/'.
3786    if name and name[-1] == "/":
3787      name = name_from_scope_name(name)
3788    else:
3789      name = self.unique_name(name)
3790
3791    node_def = _NodeDef(op_type, name, attrs)
3792
3793    input_ops = set(t.op for t in inputs)
3794    control_inputs = self._control_dependencies_for_inputs(input_ops)
3795    # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3796    # Session.run call cannot occur between creating and mutating the op.
3797    with self._mutation_lock():
3798      ret = Operation(
3799          node_def,
3800          self,
3801          inputs=inputs,
3802          output_types=dtypes,
3803          control_inputs=control_inputs,
3804          input_types=input_types,
3805          original_op=self._default_original_op,
3806          op_def=op_def)
3807      self._create_op_helper(ret, compute_device=compute_device)
3808    return ret
3809
3810  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3811    """Creates an `Operation` in this graph from the supplied TF_Operation.
3812
3813    This method is like create_op() except the new Operation is constructed
3814    using `c_op`. The returned Operation will have `c_op` as its _c_op
3815    field. This is used to create Operation objects around TF_Operations created
3816    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3817
3818    This function does not call Operation._control_flow_post_processing or
3819    Graph._control_dependencies_for_inputs (since the inputs may not be
3820    available yet). The caller is responsible for calling these methods.
3821
3822    Args:
3823      c_op: a wrapped TF_Operation
3824      compute_device: (Optional.) If True, device functions will be executed to
3825        compute the device property of the Operation.
3826
3827    Returns:
3828      An `Operation` object.
3829    """
3830    self._check_not_finalized()
3831    ret = Operation._from_c_op(c_op=c_op, g=self)  # pylint: disable=protected-access
3832    # If a name_scope was created with ret.name but no nodes were created in it,
3833    # the name will still appear in _names_in_use even though the name hasn't
3834    # been used. This is ok, just leave _names_in_use as-is in this case.
3835    # TODO(skyewm): make the C API guarantee no name conflicts.
3836    name_key = ret.name.lower()
3837    if name_key not in self._names_in_use:
3838      self._names_in_use[name_key] = 1
3839    self._create_op_helper(ret, compute_device=compute_device)
3840    return ret
3841
3842  def _create_op_helper(self, op, compute_device=True):
3843    """Common logic for creating an op in this graph."""
3844    # Apply any additional attributes requested. Do not overwrite any existing
3845    # attributes.
3846    for key, value in self._attr_scope_map.items():
3847      try:
3848        op.get_attr(key)
3849      except ValueError:
3850        if callable(value):
3851          value = value(op.node_def)
3852          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3853            raise TypeError(
3854                "Callable for scope map key '%s' must return either None or "
3855                "an AttrValue protocol buffer; but it returned: %s" %
3856                (key, value))
3857        if value:
3858          op._set_attr(key, value)  # pylint: disable=protected-access
3859
3860    # Apply a kernel label if one has been specified for this op type.
3861    try:
3862      kernel_label = self._op_to_kernel_label_map[op.type]
3863      op._set_attr("_kernel",  # pylint: disable=protected-access
3864                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3865    except KeyError:
3866      pass
3867
3868    op._gradient_function = self._gradient_function_map.get(op.type)  # pylint: disable=protected-access
3869
3870    # Apply the overriding op type for gradients if one has been specified for
3871    # this op type.
3872    try:
3873      mapped_op_type = self._gradient_override_map[op.type]
3874      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3875                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3876    except KeyError:
3877      pass
3878
3879    self._record_op_seen_by_control_dependencies(op)
3880
3881    if compute_device:
3882      self._apply_device_functions(op)
3883
3884    # Snapshot the colocation stack metadata before we might generate error
3885    # messages using it.  Note that this snapshot depends on the actual stack
3886    # and is independent of the op's _class attribute.
3887    # pylint: disable=protected-access
3888    op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3889    # pylint: enable=protected-access
3890
3891    if self._colocation_stack:
3892      all_colocation_groups = []
3893      is_device_set = False
3894      for colocation_op in self._colocation_stack.peek_objs():
3895        try:
3896          all_colocation_groups.extend(colocation_op.colocation_groups())
3897        except AttributeError:
3898          pass
3899        if colocation_op.device and not is_device_set:
3900          # pylint: disable=protected-access
3901          op._set_device(colocation_op.device)
3902          # pylint: enable=protected-access
3903          is_device_set = True
3904
3905      all_colocation_groups = sorted(set(all_colocation_groups))
3906      # pylint: disable=protected-access
3907      op._set_attr(
3908          "_class",
3909          attr_value_pb2.AttrValue(
3910              list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3911      # pylint: enable=protected-access
3912
3913    # Sets "container" attribute if
3914    # (1) self._container is not None
3915    # (2) "is_stateful" is set in OpDef
3916    # (3) "container" attribute is in OpDef
3917    # (4) "container" attribute is None
3918    if self._container and op._is_stateful:  # pylint: disable=protected-access
3919      try:
3920        container_attr = op.get_attr("container")
3921      except ValueError:
3922        # "container" attribute is not in OpDef
3923        pass
3924      else:
3925        if not container_attr:
3926          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3927              s=compat.as_bytes(self._container)))
3928
3929  def _add_new_tf_operations(self, compute_devices=True):
3930    """Creates `Operations` in this graph for any new TF_Operations.
3931
3932    This is useful for when TF_Operations are indirectly created by the C API
3933    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3934    TF_FinishWhile). This ensures there are corresponding Operations for all
3935    TF_Operations in the underlying TF_Graph.
3936
3937    Args:
3938      compute_devices: (Optional.) If True, device functions will be executed to
3939        compute the device properties of each new Operation.
3940
3941    Returns:
3942      A list of the new `Operation` objects.
3943    """
3944    self._check_not_finalized()
3945
3946    # Create all Operation objects before accessing their inputs since an op may
3947    # be created before its inputs.
3948    new_ops = [
3949        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3950        for c_op in c_api_util.new_tf_operations(self)
3951    ]
3952
3953    # pylint: disable=protected-access
3954    for op in new_ops:
3955      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3956      op._add_control_inputs(new_control_inputs)
3957      op._control_flow_post_processing()
3958    # pylint: enable=protected-access
3959
3960    return new_ops
3961
3962  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3963    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3964
3965    This function validates that `obj` represents an element of this
3966    graph, and gives an informative error message if it is not.
3967
3968    This function is the canonical way to get/validate an object of
3969    one of the allowed types from an external argument reference in the
3970    Session API.
3971
3972    This method may be called concurrently from multiple threads.
3973
3974    Args:
3975      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can
3976        also be any object with an `_as_graph_element()` method that returns a
3977        value of one of these types. Note: `_as_graph_element` will be called
3978        inside the graph's lock and so may not modify the graph.
3979      allow_tensor: If true, `obj` may refer to a `Tensor`.
3980      allow_operation: If true, `obj` may refer to an `Operation`.
3981
3982    Returns:
3983      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3984
3985    Raises:
3986      TypeError: If `obj` is not a type we support attempting to convert
3987        to types.
3988      ValueError: If `obj` is of an appropriate type but invalid. For
3989        example, an invalid string.
3990      KeyError: If `obj` is not an object in the graph.
3991    """
3992    if self._finalized:
3993      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3994
3995    with self._lock:
3996      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3997
3998  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3999    """See `Graph.as_graph_element()` for details."""
4000    # The vast majority of this function is figuring
4001    # out what an API user might be doing wrong, so
4002    # that we can give helpful error messages.
4003    #
4004    # Ideally, it would be nice to split it up, but we
4005    # need context to generate nice error messages.
4006
4007    if allow_tensor and allow_operation:
4008      types_str = "Tensor or Operation"
4009    elif allow_tensor:
4010      types_str = "Tensor"
4011    elif allow_operation:
4012      types_str = "Operation"
4013    else:
4014      raise ValueError("allow_tensor and allow_operation can't both be False.")
4015
4016    temp_obj = _as_graph_element(obj)
4017    if temp_obj is not None:
4018      obj = temp_obj
4019
4020    # If obj appears to be a name...
4021    if isinstance(obj, compat.bytes_or_text_types):
4022      name = compat.as_str(obj)
4023
4024      if ":" in name and allow_tensor:
4025        # Looks like a Tensor name and can be a Tensor.
4026        try:
4027          op_name, out_n = name.split(":")
4028          out_n = int(out_n)
4029        except:
4030          raise ValueError("The name %s looks a like a Tensor name, but is "
4031                           "not a valid one. Tensor names must be of the "
4032                           "form \"<op_name>:<output_index>\"." % repr(name))
4033        if op_name in self._nodes_by_name:
4034          op = self._nodes_by_name[op_name]
4035        else:
4036          raise KeyError("The name %s refers to a Tensor which does not "
4037                         "exist. The operation, %s, does not exist in the "
4038                         "graph." % (repr(name), repr(op_name)))
4039        try:
4040          return op.outputs[out_n]
4041        except:
4042          raise KeyError("The name %s refers to a Tensor which does not "
4043                         "exist. The operation, %s, exists but only has "
4044                         "%s outputs." %
4045                         (repr(name), repr(op_name), len(op.outputs)))
4046
4047      elif ":" in name and not allow_tensor:
4048        # Looks like a Tensor name but can't be a Tensor.
4049        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
4050                         (repr(name), types_str))
4051
4052      elif ":" not in name and allow_operation:
4053        # Looks like an Operation name and can be an Operation.
4054        if name not in self._nodes_by_name:
4055          raise KeyError("The name %s refers to an Operation not in the "
4056                         "graph." % repr(name))
4057        return self._nodes_by_name[name]
4058
4059      elif ":" not in name and not allow_operation:
4060        # Looks like an Operation name but can't be an Operation.
4061        if name in self._nodes_by_name:
4062          # Yep, it's an Operation name
4063          err_msg = ("The name %s refers to an Operation, not a %s." %
4064                     (repr(name), types_str))
4065        else:
4066          err_msg = ("The name %s looks like an (invalid) Operation name, "
4067                     "not a %s." % (repr(name), types_str))
4068        err_msg += (" Tensor names must be of the form "
4069                    "\"<op_name>:<output_index>\".")
4070        raise ValueError(err_msg)
4071
4072    elif isinstance(obj, Tensor) and allow_tensor:
4073      # Actually obj is just the object it's referring to.
4074      if obj.graph is not self:
4075        raise ValueError("Tensor %s is not an element of this graph." % obj)
4076      return obj
4077    elif isinstance(obj, Operation) and allow_operation:
4078      # Actually obj is just the object it's referring to.
4079      if obj.graph is not self:
4080        raise ValueError("Operation %s is not an element of this graph." % obj)
4081      return obj
4082    else:
4083      # We give up!
4084      raise TypeError("Can not convert a %s into a %s." %
4085                      (type(obj).__name__, types_str))
4086
4087  def get_operations(self):
4088    """Return the list of operations in the graph.
4089
4090    You can modify the operations in place, but modifications
4091    to the list such as inserts/delete have no effect on the
4092    list of operations known to the graph.
4093
4094    This method may be called concurrently from multiple threads.
4095
4096    Returns:
4097      A list of Operations.
4098    """
4099    if self._finalized:
4100      return list(self._nodes_by_id.values())
4101
4102    with self._lock:
4103      return list(self._nodes_by_id.values())
4104
4105  def get_operation_by_name(self, name):
4106    """Returns the `Operation` with the given `name`.
4107
4108    This method may be called concurrently from multiple threads.
4109
4110    Args:
4111      name: The name of the `Operation` to return.
4112
4113    Returns:
4114      The `Operation` with the given `name`.
4115
4116    Raises:
4117      TypeError: If `name` is not a string.
4118      KeyError: If `name` does not correspond to an operation in this graph.
4119    """
4120
4121    if not isinstance(name, str):
4122      raise TypeError("Operation names are strings (or similar), not %s." %
4123                      type(name).__name__)
4124    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
4125
4126  def _get_operation_by_name_unsafe(self, name):
4127    """Returns the `Operation` with the given `name`.
4128
4129    This is a internal unsafe version of get_operation_by_name. It skips many
4130    checks and does not have user friendly error messages but runs considerably
4131    faster. This method may be called concurrently from multiple threads.
4132
4133    Args:
4134      name: The name of the `Operation` to return.
4135
4136    Returns:
4137      The `Operation` with the given `name`.
4138
4139    Raises:
4140      KeyError: If `name` does not correspond to an operation in this graph.
4141    """
4142
4143    if self._finalized:
4144      return self._nodes_by_name[name]
4145
4146    with self._lock:
4147      return self._nodes_by_name[name]
4148
4149  def _get_operation_by_tf_operation(self, tf_oper):
4150    op_name = pywrap_tf_session.TF_OperationName(tf_oper)
4151    return self._get_operation_by_name_unsafe(op_name)
4152
4153  def get_tensor_by_name(self, name):
4154    """Returns the `Tensor` with the given `name`.
4155
4156    This method may be called concurrently from multiple threads.
4157
4158    Args:
4159      name: The name of the `Tensor` to return.
4160
4161    Returns:
4162      The `Tensor` with the given `name`.
4163
4164    Raises:
4165      TypeError: If `name` is not a string.
4166      KeyError: If `name` does not correspond to a tensor in this graph.
4167    """
4168    # Names should be strings.
4169    if not isinstance(name, str):
4170      raise TypeError("Tensor names are strings (or similar), not %s." %
4171                      type(name).__name__)
4172    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
4173
4174  def _get_tensor_by_tf_output(self, tf_output):
4175    """Returns the `Tensor` representing `tf_output`.
4176
4177    Note that there is only one such `Tensor`, i.e. multiple calls to this
4178    function with the same TF_Output value will always return the same `Tensor`
4179    object.
4180
4181    Args:
4182      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
4183
4184    Returns:
4185      The `Tensor` that represents `tf_output`.
4186    """
4187    op = self._get_operation_by_tf_operation(tf_output.oper)
4188    return op.outputs[tf_output.index]
4189
4190  @property
4191  def _last_id(self):
4192    return self._next_id_counter
4193
4194  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
4195    """Returns the `OpDef` proto for `type`. `type` is a string."""
4196    # NOTE: No locking is required because the lookup and insertion operations
4197    # on Python dictionaries are atomic.
4198    try:
4199      return self._op_def_cache[type]
4200    except KeyError:
4201      with c_api_util.tf_buffer() as buf:
4202        with self._c_graph.get() as c_graph:
4203          pywrap_tf_session.TF_GraphGetOpDef(c_graph, compat.as_bytes(type),
4204                                             buf)
4205        data = pywrap_tf_session.TF_GetBuffer(buf)
4206      op_def = op_def_pb2.OpDef()
4207      op_def.ParseFromString(compat.as_bytes(data))
4208      self._op_def_cache[type] = op_def
4209      return op_def
4210
4211  def as_default(self):
4212    """Returns a context manager that makes this `Graph` the default graph.
4213
4214    This method should be used if you want to create multiple graphs
4215    in the same process. For convenience, a global default graph is
4216    provided, and all ops will be added to this graph if you do not
4217    create a new graph explicitly.
4218
4219    Use this method with the `with` keyword to specify that ops created within
4220    the scope of a block should be added to this graph. In this case, once
4221    the scope of the `with` is exited, the previous default graph is set again
4222    as default. There is a stack, so it's ok to have multiple nested levels
4223    of `as_default` calls.
4224
4225    The default graph is a property of the current thread. If you
4226    create a new thread, and wish to use the default graph in that
4227    thread, you must explicitly add a `with g.as_default():` in that
4228    thread's function.
4229
4230    The following code examples are equivalent:
4231
4232    ```python
4233    # 1. Using Graph.as_default():
4234    g = tf.Graph()
4235    with g.as_default():
4236      c = tf.constant(5.0)
4237      assert c.graph is g
4238
4239    # 2. Constructing and making default:
4240    with tf.Graph().as_default() as g:
4241      c = tf.constant(5.0)
4242      assert c.graph is g
4243    ```
4244
4245    If eager execution is enabled ops created under this context manager will be
4246    added to the graph instead of executed eagerly.
4247
4248    Returns:
4249      A context manager for using this graph as the default graph.
4250    """
4251    return _default_graph_stack.get_controller(self)
4252
4253  @property
4254  def collections(self):
4255    """Returns the names of the collections known to this graph."""
4256    return list(self._collections)
4257
4258  def add_to_collection(self, name, value):
4259    """Stores `value` in the collection with the given `name`.
4260
4261    Note that collections are not sets, so it is possible to add a value to
4262    a collection several times.
4263
4264    Args:
4265      name: The key for the collection. The `GraphKeys` class contains many
4266        standard names for collections.
4267      value: The value to add to the collection.
4268    """  # pylint: disable=g-doc-exception
4269    self._check_not_finalized()
4270    with self._lock:
4271      if name not in self._collections:
4272        self._collections[name] = [value]
4273      else:
4274        self._collections[name].append(value)
4275
4276  def add_to_collections(self, names, value):
4277    """Stores `value` in the collections given by `names`.
4278
4279    Note that collections are not sets, so it is possible to add a value to
4280    a collection several times. This function makes sure that duplicates in
4281    `names` are ignored, but it will not check for pre-existing membership of
4282    `value` in any of the collections in `names`.
4283
4284    `names` can be any iterable, but if `names` is a string, it is treated as a
4285    single collection name.
4286
4287    Args:
4288      names: The keys for the collections to add to. The `GraphKeys` class
4289        contains many standard names for collections.
4290      value: The value to add to the collections.
4291    """
4292    # Make sure names are unique, but treat strings as a single collection name
4293    names = (names,) if isinstance(names, str) else set(names)
4294    for name in names:
4295      self.add_to_collection(name, value)
4296
4297  def get_collection_ref(self, name):
4298    """Returns a list of values in the collection with the given `name`.
4299
4300    If the collection exists, this returns the list itself, which can
4301    be modified in place to change the collection.  If the collection does
4302    not exist, it is created as an empty list and the list is returned.
4303
4304    This is different from `get_collection()` which always returns a copy of
4305    the collection list if it exists and never creates an empty collection.
4306
4307    Args:
4308      name: The key for the collection. For example, the `GraphKeys` class
4309        contains many standard names for collections.
4310
4311    Returns:
4312      The list of values in the collection with the given `name`, or an empty
4313      list if no value has been added to that collection.
4314    """  # pylint: disable=g-doc-exception
4315    with self._lock:
4316      coll_list = self._collections.get(name, None)
4317      if coll_list is None:
4318        coll_list = []
4319        self._collections[name] = coll_list
4320      return coll_list
4321
4322  def get_collection(self, name, scope=None):
4323    """Returns a list of values in the collection with the given `name`.
4324
4325    This is different from `get_collection_ref()` which always returns the
4326    actual collection list if it exists in that it returns a new list each time
4327    it is called.
4328
4329    Args:
4330      name: The key for the collection. For example, the `GraphKeys` class
4331        contains many standard names for collections.
4332      scope: (Optional.) A string. If supplied, the resulting list is filtered
4333        to include only items whose `name` attribute matches `scope` using
4334        `re.match`. Items without a `name` attribute are never returned if a
4335        scope is supplied. The choice of `re.match` means that a `scope` without
4336        special tokens filters by prefix.
4337
4338    Returns:
4339      The list of values in the collection with the given `name`, or
4340      an empty list if no value has been added to that collection. The
4341      list contains the values in the order under which they were
4342      collected.
4343    """  # pylint: disable=g-doc-exception
4344    with self._lock:
4345      collection = self._collections.get(name, None)
4346      if collection is None:
4347        return []
4348      if scope is None:
4349        return list(collection)
4350      else:
4351        c = []
4352        regex = re.compile(scope)
4353        for item in collection:
4354          try:
4355            if regex.match(item.name):
4356              c.append(item)
4357          except AttributeError:
4358            # Collection items with no name are ignored.
4359            pass
4360        return c
4361
4362  def get_all_collection_keys(self):
4363    """Returns a list of collections used in this graph."""
4364    with self._lock:
4365      return [x for x in self._collections if isinstance(x, str)]
4366
4367  def clear_collection(self, name):
4368    """Clears all values in a collection.
4369
4370    Args:
4371      name: The key for the collection. The `GraphKeys` class contains many
4372        standard names for collections.
4373    """
4374    self._check_not_finalized()
4375    with self._lock:
4376      if name in self._collections:
4377        del self._collections[name]
4378
4379  @tf_contextlib.contextmanager
4380  def _original_op(self, op):
4381    """Python 'with' handler to help annotate ops with their originator.
4382
4383    An op may have an 'original_op' property that indicates the op on which
4384    it was based. For example a replica op is based on the op that was
4385    replicated and a gradient op is based on the op that was differentiated.
4386
4387    All ops created in the scope of this 'with' handler will have
4388    the given 'op' as their original op.
4389
4390    Args:
4391      op: The Operation that all ops created in this scope will have as their
4392        original op.
4393
4394    Yields:
4395      Nothing.
4396    """
4397    old_original_op = self._default_original_op
4398    self._default_original_op = op
4399    try:
4400      yield
4401    finally:
4402      self._default_original_op = old_original_op
4403
4404  @property
4405  def _name_stack(self):
4406    # This may be called from a thread where name_stack doesn't yet exist.
4407    if not hasattr(self._thread_local, "_name_stack"):
4408      self._thread_local._name_stack = ""
4409    return self._thread_local._name_stack
4410
4411  @_name_stack.setter
4412  def _name_stack(self, name_stack):
4413    self._thread_local._name_stack = name_stack
4414
4415  # pylint: disable=g-doc-return-or-yield,line-too-long
4416  @tf_contextlib.contextmanager
4417  def name_scope(self, name):
4418    """Returns a context manager that creates hierarchical names for operations.
4419
4420    A graph maintains a stack of name scopes. A `with name_scope(...):`
4421    statement pushes a new name onto the stack for the lifetime of the context.
4422
4423    The `name` argument will be interpreted as follows:
4424
4425    * A string (not ending with '/') will create a new name scope, in which
4426      `name` is appended to the prefix of all operations created in the
4427      context. If `name` has been used before, it will be made unique by
4428      calling `self.unique_name(name)`.
4429    * A scope previously captured from a `with g.name_scope(...) as
4430      scope:` statement will be treated as an "absolute" name scope, which
4431      makes it possible to re-enter existing scopes.
4432    * A value of `None` or the empty string will reset the current name scope
4433      to the top-level (empty) name scope.
4434
4435    For example:
4436
4437    ```python
4438    with tf.Graph().as_default() as g:
4439      c = tf.constant(5.0, name="c")
4440      assert c.op.name == "c"
4441      c_1 = tf.constant(6.0, name="c")
4442      assert c_1.op.name == "c_1"
4443
4444      # Creates a scope called "nested"
4445      with g.name_scope("nested") as scope:
4446        nested_c = tf.constant(10.0, name="c")
4447        assert nested_c.op.name == "nested/c"
4448
4449        # Creates a nested scope called "inner".
4450        with g.name_scope("inner"):
4451          nested_inner_c = tf.constant(20.0, name="c")
4452          assert nested_inner_c.op.name == "nested/inner/c"
4453
4454        # Create a nested scope called "inner_1".
4455        with g.name_scope("inner"):
4456          nested_inner_1_c = tf.constant(30.0, name="c")
4457          assert nested_inner_1_c.op.name == "nested/inner_1/c"
4458
4459          # Treats `scope` as an absolute name scope, and
4460          # switches to the "nested/" scope.
4461          with g.name_scope(scope):
4462            nested_d = tf.constant(40.0, name="d")
4463            assert nested_d.op.name == "nested/d"
4464
4465            with g.name_scope(""):
4466              e = tf.constant(50.0, name="e")
4467              assert e.op.name == "e"
4468    ```
4469
4470    The name of the scope itself can be captured by `with
4471    g.name_scope(...) as scope:`, which stores the name of the scope
4472    in the variable `scope`. This value can be used to name an
4473    operation that represents the overall result of executing the ops
4474    in a scope. For example:
4475
4476    ```python
4477    inputs = tf.constant(...)
4478    with g.name_scope('my_layer') as scope:
4479      weights = tf.Variable(..., name="weights")
4480      biases = tf.Variable(..., name="biases")
4481      affine = tf.matmul(inputs, weights) + biases
4482      output = tf.nn.relu(affine, name=scope)
4483    ```
4484
4485    NOTE: This constructor validates the given `name`. Valid scope
4486    names match one of the following regular expressions:
4487
4488        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4489        [A-Za-z0-9_.\\-/]* (for other scopes)
4490
4491    Args:
4492      name: A name for the scope.
4493
4494    Returns:
4495      A context manager that installs `name` as a new name scope.
4496
4497    Raises:
4498      ValueError: If `name` is not a valid scope name, according to the rules
4499        above.
4500    """
4501    if name:
4502      if isinstance(name, compat.bytes_or_text_types):
4503        name = compat.as_str(name)
4504
4505      if self._name_stack:
4506        # Scopes created in a nested scope may have initial characters
4507        # that are illegal as the initial character of an op name
4508        # (viz. '-', '\', '/', and '_').
4509        if not _VALID_SCOPE_NAME_REGEX.match(name):
4510          raise ValueError(
4511              f"'{name}' is not a valid scope name. A scope name has to match "
4512              f"the following pattern: {_VALID_SCOPE_NAME_REGEX.pattern}")
4513      else:
4514        # Scopes created in the root must match the more restrictive
4515        # op name regex, which constrains the initial character.
4516        if not _VALID_OP_NAME_REGEX.match(name):
4517          raise ValueError(
4518              f"'{name}' is not a valid root scope name. A root scope name has "
4519              f"to match the following pattern: {_VALID_OP_NAME_REGEX.pattern}")
4520    old_stack = self._name_stack
4521    if not name:  # Both for name=None and name="" we re-set to empty scope.
4522      new_stack = ""
4523      returned_scope = ""
4524    elif name[-1] == "/":
4525      new_stack = name_from_scope_name(name)
4526      returned_scope = name
4527    else:
4528      new_stack = self.unique_name(name)
4529      returned_scope = new_stack + "/"
4530    self._name_stack = new_stack
4531    try:
4532      yield returned_scope
4533    finally:
4534      self._name_stack = old_stack
4535
4536  # pylint: enable=g-doc-return-or-yield,line-too-long
4537
4538  def unique_name(self, name, mark_as_used=True):
4539    """Return a unique operation name for `name`.
4540
4541    Note: You rarely need to call `unique_name()` directly.  Most of
4542    the time you just need to create `with g.name_scope()` blocks to
4543    generate structured names.
4544
4545    `unique_name` is used to generate structured names, separated by
4546    `"/"`, to help identify operations when debugging a graph.
4547    Operation names are displayed in error messages reported by the
4548    TensorFlow runtime, and in various visualization tools such as
4549    TensorBoard.
4550
4551    If `mark_as_used` is set to `True`, which is the default, a new
4552    unique name is created and marked as in use. If it's set to `False`,
4553    the unique name is returned without actually being marked as used.
4554    This is useful when the caller simply wants to know what the name
4555    to be created will be.
4556
4557    Args:
4558      name: The name for an operation.
4559      mark_as_used: Whether to mark this name as being used.
4560
4561    Returns:
4562      A string to be passed to `create_op()` that will be used
4563      to name the operation being created.
4564    """
4565    if self._name_stack:
4566      name = self._name_stack + "/" + name
4567
4568    # For the sake of checking for names in use, we treat names as case
4569    # insensitive (e.g. foo = Foo).
4570    name_key = name.lower()
4571    i = self._names_in_use.get(name_key, 0)
4572    # Increment the number for "name_key".
4573    if mark_as_used:
4574      self._names_in_use[name_key] = i + 1
4575    if i > 0:
4576      base_name_key = name_key
4577      # Make sure the composed name key is not already used.
4578      while name_key in self._names_in_use:
4579        name_key = "%s_%d" % (base_name_key, i)
4580        i += 1
4581      # Mark the composed name_key as used in case someone wants
4582      # to call unique_name("name_1").
4583      if mark_as_used:
4584        self._names_in_use[name_key] = 1
4585
4586      # Return the new name with the original capitalization of the given name.
4587      name = "%s_%d" % (name, i - 1)
4588    return name
4589
4590  def get_name_scope(self):
4591    """Returns the current name scope.
4592
4593    For example:
4594
4595    ```python
4596    with tf.name_scope('scope1'):
4597      with tf.name_scope('scope2'):
4598        print(tf.compat.v1.get_default_graph().get_name_scope())
4599    ```
4600    would print the string `scope1/scope2`.
4601
4602    Returns:
4603      A string representing the current name scope.
4604    """
4605    return self._name_stack
4606
4607  @tf_contextlib.contextmanager
4608  def _colocate_with_for_gradient(self, op, gradient_uid,
4609                                  ignore_existing=False):
4610    with self.colocate_with(op, ignore_existing):
4611      if gradient_uid is not None:
4612        ctx = _get_enclosing_context(self)
4613        if ctx is not None:
4614          ctx.EnterGradientColocation(op, gradient_uid)
4615          try:
4616            yield
4617          finally:
4618            ctx.ExitGradientColocation(op, gradient_uid)
4619        else:
4620          yield
4621      else:
4622        yield
4623
4624  @tf_contextlib.contextmanager
4625  def colocate_with(self, op, ignore_existing=False):
4626    """Returns a context manager that specifies an op to colocate with.
4627
4628    Note: this function is not for public use, only for internal libraries.
4629
4630    For example:
4631
4632    ```python
4633    a = tf.Variable([1.0])
4634    with g.colocate_with(a):
4635      b = tf.constant(1.0)
4636      c = tf.add(a, b)
4637    ```
4638
4639    `b` and `c` will always be colocated with `a`, no matter where `a`
4640    is eventually placed.
4641
4642    **NOTE** Using a colocation scope resets any existing device constraints.
4643
4644    If `op` is `None` then `ignore_existing` must be `True` and the new
4645    scope resets all colocation and device constraints.
4646
4647    Args:
4648      op: The op to colocate all created ops with, or `None`.
4649      ignore_existing: If true, only applies colocation of this op within the
4650        context, rather than applying all colocation properties on the stack.
4651        If `op` is `None`, this value must be `True`.
4652
4653    Raises:
4654      ValueError: if op is None but ignore_existing is False.
4655
4656    Yields:
4657      A context manager that specifies the op with which to colocate
4658      newly created ops.
4659    """
4660    if op is None and not ignore_existing:
4661      raise ValueError("Trying to reset colocation (op is None) but "
4662                       "ignore_existing is not True")
4663    op, device_only_candidate = _op_to_colocate_with(op, self)
4664
4665    # By default, colocate_with resets the device function stack,
4666    # since colocate_with is typically used in specific internal
4667    # library functions where colocation is intended to be "stronger"
4668    # than device functions.
4669    #
4670    # In the future, a caller may specify that device_functions win
4671    # over colocation, in which case we can add support.
4672    device_fn_tmp = self._device_function_stack
4673    self._device_function_stack = traceable_stack.TraceableStack()
4674
4675    if ignore_existing:
4676      current_stack = self._colocation_stack
4677      self._colocation_stack = traceable_stack.TraceableStack()
4678
4679    if op is not None:
4680      # offset refers to the stack frame used for storing code location.
4681      # We use 4, the sum of 1 to use our caller's stack frame and 3
4682      # to jump over layers of context managers above us.
4683      if device_only_candidate is not None:
4684        self._colocation_stack.push_obj(device_only_candidate, offset=4)
4685      self._colocation_stack.push_obj(op, offset=4)
4686    elif not ignore_existing:
4687      raise ValueError("Trying to reset colocation (op is None) but "
4688                       "ignore_existing is not True")
4689    try:
4690      yield
4691    finally:
4692      # Restore device function stack
4693      self._device_function_stack = device_fn_tmp
4694      if op is not None:
4695        self._colocation_stack.pop_obj()
4696        if device_only_candidate is not None:
4697          self._colocation_stack.pop_obj()
4698
4699      # Reset the colocation stack if requested.
4700      if ignore_existing:
4701        self._colocation_stack = current_stack
4702
4703  def _add_device_to_stack(self, device_name_or_function, offset=0):
4704    """Add device to stack manually, separate from a context manager."""
4705    total_offset = 1 + offset
4706    spec = _UserDeviceSpec(device_name_or_function)
4707    self._device_function_stack.push_obj(spec, offset=total_offset)
4708    return spec
4709
4710  @tf_contextlib.contextmanager
4711  def device(self, device_name_or_function):
4712    # pylint: disable=line-too-long
4713    """Returns a context manager that specifies the default device to use.
4714
4715    The `device_name_or_function` argument may either be a device name
4716    string, a device function, or None:
4717
4718    * If it is a device name string, all operations constructed in
4719      this context will be assigned to the device with that name, unless
4720      overridden by a nested `device()` context.
4721    * If it is a function, it will be treated as a function from
4722      Operation objects to device name strings, and invoked each time
4723      a new Operation is created. The Operation will be assigned to
4724      the device with the returned name.
4725    * If it is None, all `device()` invocations from the enclosing context
4726      will be ignored.
4727
4728    For information about the valid syntax of device name strings, see
4729    the documentation in
4730    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4731
4732    For example:
4733
4734    ```python
4735    with g.device('/device:GPU:0'):
4736      # All operations constructed in this context will be placed
4737      # on GPU 0.
4738      with g.device(None):
4739        # All operations constructed in this context will have no
4740        # assigned device.
4741
4742    # Defines a function from `Operation` to device string.
4743    def matmul_on_gpu(n):
4744      if n.type == "MatMul":
4745        return "/device:GPU:0"
4746      else:
4747        return "/cpu:0"
4748
4749    with g.device(matmul_on_gpu):
4750      # All operations of type "MatMul" constructed in this context
4751      # will be placed on GPU 0; all other operations will be placed
4752      # on CPU 0.
4753    ```
4754
4755    **N.B.** The device scope may be overridden by op wrappers or
4756    other library code. For example, a variable assignment op
4757    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4758    incompatible device scopes will be ignored.
4759
4760    Args:
4761      device_name_or_function: The device name or function to use in the
4762        context.
4763
4764    Yields:
4765      A context manager that specifies the default device to use for newly
4766      created ops.
4767
4768    Raises:
4769      RuntimeError: If device scopes are not properly nested.
4770    """
4771    self._add_device_to_stack(device_name_or_function, offset=2)
4772    old_top_of_stack = self._device_function_stack.peek_top_obj()
4773    try:
4774      yield
4775    finally:
4776      new_top_of_stack = self._device_function_stack.peek_top_obj()
4777      if old_top_of_stack is not new_top_of_stack:
4778        raise RuntimeError("Exiting device scope without proper scope nesting.")
4779      self._device_function_stack.pop_obj()
4780
4781  def _apply_device_functions(self, op):
4782    """Applies the current device function stack to the given operation."""
4783    # Apply any device functions in LIFO order, so that the most recently
4784    # pushed function has the first chance to apply a device to the op.
4785    # We apply here because the result can depend on the Operation's
4786    # signature, which is computed in the Operation constructor.
4787    # pylint: disable=protected-access
4788    prior_device_string = None
4789    for device_spec in self._device_function_stack.peek_objs():
4790      if device_spec.is_null_merge:
4791        continue
4792
4793      if device_spec.function is None:
4794        break
4795
4796      device_string = device_spec.string_merge(op)
4797
4798      # Take advantage of the fact that None is a singleton and Python interns
4799      # strings, since identity checks are faster than equality checks.
4800      if device_string is not prior_device_string:
4801        op._set_device_from_string(device_string)
4802        prior_device_string = device_string
4803    op._device_code_locations = self._snapshot_device_function_stack_metadata()
4804    # pylint: enable=protected-access
4805
4806  # pylint: disable=g-doc-return-or-yield
4807  @tf_contextlib.contextmanager
4808  def container(self, container_name):
4809    """Returns a context manager that specifies the resource container to use.
4810
4811    Stateful operations, such as variables and queues, can maintain their
4812    states on devices so that they can be shared by multiple processes.
4813    A resource container is a string name under which these stateful
4814    operations are tracked. These resources can be released or cleared
4815    with `tf.Session.reset()`.
4816
4817    For example:
4818
4819    ```python
4820    with g.container('experiment0'):
4821      # All stateful Operations constructed in this context will be placed
4822      # in resource container "experiment0".
4823      v1 = tf.Variable([1.0])
4824      v2 = tf.Variable([2.0])
4825      with g.container("experiment1"):
4826        # All stateful Operations constructed in this context will be
4827        # placed in resource container "experiment1".
4828        v3 = tf.Variable([3.0])
4829        q1 = tf.queue.FIFOQueue(10, tf.float32)
4830      # All stateful Operations constructed in this context will be
4831      # be created in the "experiment0".
4832      v4 = tf.Variable([4.0])
4833      q1 = tf.queue.FIFOQueue(20, tf.float32)
4834      with g.container(""):
4835        # All stateful Operations constructed in this context will be
4836        # be placed in the default resource container.
4837        v5 = tf.Variable([5.0])
4838        q3 = tf.queue.FIFOQueue(30, tf.float32)
4839
4840    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4841    # will become undefined (such as uninitialized).
4842    tf.Session.reset(target, ["experiment0"])
4843    ```
4844
4845    Args:
4846      container_name: container name string.
4847
4848    Returns:
4849      A context manager for defining resource containers for stateful ops,
4850        yields the container name.
4851    """
4852    original_container = self._container
4853    self._container = container_name
4854    try:
4855      yield self._container
4856    finally:
4857      self._container = original_container
4858
4859  # pylint: enable=g-doc-return-or-yield
4860
4861  class _ControlDependenciesController(object):
4862    """Context manager for `control_dependencies()`."""
4863
4864    def __init__(self, graph, control_inputs):
4865      """Create a new `_ControlDependenciesController`.
4866
4867      A `_ControlDependenciesController` is the context manager for
4868      `with tf.control_dependencies()` blocks.  These normally nest,
4869      as described in the documentation for `control_dependencies()`.
4870
4871      The `control_inputs` argument list control dependencies that must be
4872      added to the current set of control dependencies.  Because of
4873      uniquification the set can be empty even if the caller passed a list of
4874      ops.  The special value `None` indicates that we want to start a new
4875      empty set of control dependencies instead of extending the current set.
4876
4877      In that case we also clear the current control flow context, which is an
4878      additional mechanism to add control dependencies.
4879
4880      Args:
4881        graph: The graph that this controller is managing.
4882        control_inputs: List of ops to use as control inputs in addition to the
4883          current control dependencies.  None to indicate that the dependencies
4884          should be cleared.
4885      """
4886      self._graph = graph
4887      if control_inputs is None:
4888        self._control_inputs_val = []
4889        self._new_stack = True
4890      else:
4891        self._control_inputs_val = control_inputs
4892        self._new_stack = False
4893      self._seen_nodes = set()
4894      self._old_stack = None
4895      self._old_control_flow_context = None
4896
4897# pylint: disable=protected-access
4898
4899    def __enter__(self):
4900      if self._new_stack:
4901        # Clear the control_dependencies graph.
4902        self._old_stack = self._graph._control_dependencies_stack
4903        self._graph._control_dependencies_stack = []
4904        # Clear the control_flow_context too.
4905        self._old_control_flow_context = self._graph._get_control_flow_context()
4906        self._graph._set_control_flow_context(None)
4907      self._graph._push_control_dependencies_controller(self)
4908
4909    def __exit__(self, unused_type, unused_value, unused_traceback):
4910      self._graph._pop_control_dependencies_controller(self)
4911      if self._new_stack:
4912        self._graph._control_dependencies_stack = self._old_stack
4913        self._graph._set_control_flow_context(self._old_control_flow_context)
4914
4915# pylint: enable=protected-access
4916
4917    @property
4918    def control_inputs(self):
4919      return self._control_inputs_val
4920
4921    def add_op(self, op):
4922      if isinstance(op, Tensor):
4923        op = op.ref()
4924      self._seen_nodes.add(op)
4925
4926    def op_in_group(self, op):
4927      if isinstance(op, Tensor):
4928        op = op.ref()
4929      return op in self._seen_nodes
4930
4931  def _push_control_dependencies_controller(self, controller):
4932    self._control_dependencies_stack.append(controller)
4933
4934  def _pop_control_dependencies_controller(self, controller):
4935    assert self._control_dependencies_stack[-1] is controller
4936    self._control_dependencies_stack.pop()
4937
4938  def _current_control_dependencies(self):
4939    ret = set()
4940    for controller in self._control_dependencies_stack:
4941      for op in controller.control_inputs:
4942        ret.add(op)
4943    return ret
4944
4945  def _control_dependencies_for_inputs(self, input_ops):
4946    """For an op that takes `input_ops` as inputs, compute control inputs.
4947
4948    The returned control dependencies should yield an execution that
4949    is equivalent to adding all control inputs in
4950    self._control_dependencies_stack to a newly created op. However,
4951    this function attempts to prune the returned control dependencies
4952    by observing that nodes created within the same `with
4953    control_dependencies(...):` block may have data dependencies that make
4954    the explicit approach redundant.
4955
4956    Args:
4957      input_ops: The data input ops for an op to be created.
4958
4959    Returns:
4960      A list of control inputs for the op to be created.
4961    """
4962    ret = []
4963    for controller in self._control_dependencies_stack:
4964      # If any of the input_ops already depends on the inputs from controller,
4965      # we say that the new op is dominated (by that input), and we therefore
4966      # do not need to add control dependencies for this controller's inputs.
4967      dominated = False
4968      for op in input_ops:
4969        if controller.op_in_group(op):
4970          dominated = True
4971          break
4972      if not dominated:
4973        # Don't add a control input if we already have a data dependency on i.
4974        # NOTE(mrry): We do not currently track transitive data dependencies,
4975        #   so we may add redundant control inputs.
4976        ret.extend(c for c in controller.control_inputs if c not in input_ops)
4977    return ret
4978
4979  def _record_op_seen_by_control_dependencies(self, op):
4980    """Record that the given op depends on all registered control dependencies.
4981
4982    Args:
4983      op: An Operation.
4984    """
4985    for controller in self._control_dependencies_stack:
4986      controller.add_op(op)
4987
4988  def control_dependencies(self, control_inputs):
4989    """Returns a context manager that specifies control dependencies.
4990
4991    Use with the `with` keyword to specify that all operations constructed
4992    within the context should have control dependencies on
4993    `control_inputs`. For example:
4994
4995    ```python
4996    with g.control_dependencies([a, b, c]):
4997      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4998      d = ...
4999      e = ...
5000    ```
5001
5002    Multiple calls to `control_dependencies()` can be nested, and in
5003    that case a new `Operation` will have control dependencies on the union
5004    of `control_inputs` from all active contexts.
5005
5006    ```python
5007    with g.control_dependencies([a, b]):
5008      # Ops constructed here run after `a` and `b`.
5009      with g.control_dependencies([c, d]):
5010        # Ops constructed here run after `a`, `b`, `c`, and `d`.
5011    ```
5012
5013    You can pass None to clear the control dependencies:
5014
5015    ```python
5016    with g.control_dependencies([a, b]):
5017      # Ops constructed here run after `a` and `b`.
5018      with g.control_dependencies(None):
5019        # Ops constructed here run normally, not waiting for either `a` or `b`.
5020        with g.control_dependencies([c, d]):
5021          # Ops constructed here run after `c` and `d`, also not waiting
5022          # for either `a` or `b`.
5023    ```
5024
5025    *N.B.* The control dependencies context applies *only* to ops that
5026    are constructed within the context. Merely using an op or tensor
5027    in the context does not add a control dependency. The following
5028    example illustrates this point:
5029
5030    ```python
5031    # WRONG
5032    def my_func(pred, tensor):
5033      t = tf.matmul(tensor, tensor)
5034      with tf.control_dependencies([pred]):
5035        # The matmul op is created outside the context, so no control
5036        # dependency will be added.
5037        return t
5038
5039    # RIGHT
5040    def my_func(pred, tensor):
5041      with tf.control_dependencies([pred]):
5042        # The matmul op is created in the context, so a control dependency
5043        # will be added.
5044        return tf.matmul(tensor, tensor)
5045    ```
5046
5047    Also note that though execution of ops created under this scope will trigger
5048    execution of the dependencies, the ops created under this scope might still
5049    be pruned from a normal tensorflow graph. For example, in the following
5050    snippet of code the dependencies are never executed:
5051
5052    ```python
5053      loss = model.loss()
5054      with tf.control_dependencies(dependencies):
5055        loss = loss + tf.constant(1)  # note: dependencies ignored in the
5056                                      # backward pass
5057      return tf.gradients(loss, model.variables)
5058    ```
5059
5060    This is because evaluating the gradient graph does not require evaluating
5061    the constant(1) op created in the forward pass.
5062
5063    Args:
5064      control_inputs: A list of `Operation` or `Tensor` objects which must be
5065        executed or computed before running the operations defined in the
5066        context.  Can also be `None` to clear the control dependencies.
5067
5068    Returns:
5069     A context manager that specifies control dependencies for all
5070     operations constructed within the context.
5071
5072    Raises:
5073      TypeError: If `control_inputs` is not a list of `Operation` or
5074        `Tensor` objects.
5075    """
5076    if control_inputs is None:
5077      return self._ControlDependenciesController(self, None)
5078    # First convert the inputs to ops, and deduplicate them.
5079    # NOTE(mrry): Other than deduplication, we do not currently track direct
5080    #   or indirect dependencies between control_inputs, which may result in
5081    #   redundant control inputs.
5082    control_ops = []
5083    current = self._current_control_dependencies()
5084    for c in control_inputs:
5085      # The hasattr(handle) is designed to match ResourceVariables. This is so
5086      # control dependencies on a variable or on an unread variable don't
5087      # trigger reads.
5088      if (isinstance(c, IndexedSlices) or
5089          (hasattr(c, "_handle") and hasattr(c, "op"))):
5090        c = c.op
5091      c = self.as_graph_element(c)
5092      if isinstance(c, Tensor):
5093        c = c.op
5094      elif not isinstance(c, Operation):
5095        raise TypeError("Control input must be Operation or Tensor: %s" % c)
5096      if c not in current:
5097        control_ops.append(c)
5098        current.add(c)
5099        # Mark this op with an attribute indicating that it is used as a manual
5100        # control dep in order to allow tracking how common utilization of
5101        # manual control deps in graphs run through the MLIR Bridge are. See
5102        # go/manual-control-dependencies-bridge for details.
5103        # pylint: disable=protected-access
5104        c._set_attr("_has_manual_control_dependencies",
5105                    attr_value_pb2.AttrValue(b=True))
5106        # pylint: enable=protected-access
5107    return self._ControlDependenciesController(self, control_ops)
5108
5109  # pylint: disable=g-doc-return-or-yield
5110  @tf_contextlib.contextmanager
5111  def _attr_scope(self, attr_map):
5112    """EXPERIMENTAL: A context manager for setting attributes on operators.
5113
5114    This context manager can be used to add additional
5115    attributes to operators within the scope of the context.
5116
5117    For example:
5118
5119       with ops.Graph().as_default() as g:
5120         f_1 = Foo()  # No extra attributes
5121         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
5122           f_2 = Foo()  # Additional attribute _a=False
5123           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
5124             f_3 = Foo()  # Additional attribute _a=False
5125             with g._attr_scope({"_a": None}):
5126               f_4 = Foo()  # No additional attributes.
5127
5128    Args:
5129      attr_map: A dictionary mapping attr name strings to AttrValue protocol
5130        buffers or None.
5131
5132    Returns:
5133      A context manager that sets the kernel label to be used for one or more
5134      ops created in that context.
5135
5136    Raises:
5137      TypeError: If attr_map is not a dictionary mapping
5138        strings to AttrValue protobufs.
5139    """
5140    if not isinstance(attr_map, dict):
5141      raise TypeError("attr_map must be a dictionary mapping "
5142                      "strings to AttrValue protocol buffers")
5143    # The saved_attrs dictionary stores any currently-set labels that
5144    # will be overridden by this context manager.
5145    saved_attrs = {}
5146    # Install the given attribute
5147    for name, attr in attr_map.items():
5148      if not (isinstance(name, str) and
5149              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
5150               callable(attr))):
5151        raise TypeError("attr_map must be a dictionary mapping "
5152                        "strings to AttrValue protocol buffers or "
5153                        "callables that emit AttrValue protocol buffers")
5154      try:
5155        saved_attrs[name] = self._attr_scope_map[name]
5156      except KeyError:
5157        pass
5158      if attr is None:
5159        del self._attr_scope_map[name]
5160      else:
5161        self._attr_scope_map[name] = attr
5162    try:
5163      yield  # The code within the context runs here.
5164    finally:
5165      # Remove the attributes set for this context, and restore any saved
5166      # attributes.
5167      for name, attr in attr_map.items():
5168        try:
5169          self._attr_scope_map[name] = saved_attrs[name]
5170        except KeyError:
5171          del self._attr_scope_map[name]
5172
5173  # pylint: enable=g-doc-return-or-yield
5174
5175  # pylint: disable=g-doc-return-or-yield
5176  @tf_contextlib.contextmanager
5177  def _kernel_label_map(self, op_to_kernel_label_map):
5178    """EXPERIMENTAL: A context manager for setting kernel labels.
5179
5180    This context manager can be used to select particular
5181    implementations of kernels within the scope of the context.
5182
5183    For example:
5184
5185        with ops.Graph().as_default() as g:
5186          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
5187          with g.kernel_label_map({"Foo": "v_2"}):
5188            f_2 = Foo()  # Uses the registered kernel with label "v_2"
5189                         # for the Foo op.
5190            with g.kernel_label_map({"Foo": "v_3"}):
5191              f_3 = Foo()  # Uses the registered kernel with label "v_3"
5192                           # for the Foo op.
5193              with g.kernel_label_map({"Foo": ""}):
5194                f_4 = Foo()  # Uses the default registered kernel
5195                             # for the Foo op.
5196
5197    Args:
5198      op_to_kernel_label_map: A dictionary mapping op type strings to kernel
5199        label strings.
5200
5201    Returns:
5202      A context manager that sets the kernel label to be used for one or more
5203      ops created in that context.
5204
5205    Raises:
5206      TypeError: If op_to_kernel_label_map is not a dictionary mapping
5207        strings to strings.
5208    """
5209    if not isinstance(op_to_kernel_label_map, dict):
5210      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
5211                      "strings to strings")
5212    # The saved_labels dictionary stores any currently-set labels that
5213    # will be overridden by this context manager.
5214    saved_labels = {}
5215    # Install the given label
5216    for op_type, label in op_to_kernel_label_map.items():
5217      if not (isinstance(op_type, str) and
5218              isinstance(label, str)):
5219        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
5220                        "strings to strings")
5221      try:
5222        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
5223      except KeyError:
5224        pass
5225      self._op_to_kernel_label_map[op_type] = label
5226    try:
5227      yield  # The code within the context runs here.
5228    finally:
5229      # Remove the labels set for this context, and restore any saved labels.
5230      for op_type, label in op_to_kernel_label_map.items():
5231        try:
5232          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
5233        except KeyError:
5234          del self._op_to_kernel_label_map[op_type]
5235
5236  # pylint: enable=g-doc-return-or-yield
5237
5238  @tf_contextlib.contextmanager
5239  def _override_gradient_function(self, gradient_function_map):
5240    """Specify gradient function for the given op type."""
5241
5242    # This is an internal API and we don't need nested context for this.
5243    # TODO(mdan): make it a proper context manager.
5244    assert not self._gradient_function_map
5245    self._gradient_function_map = gradient_function_map
5246    try:
5247      yield
5248    finally:
5249      self._gradient_function_map = {}
5250
5251  # pylint: disable=g-doc-return-or-yield
5252  @tf_contextlib.contextmanager
5253  def gradient_override_map(self, op_type_map):
5254    """EXPERIMENTAL: A context manager for overriding gradient functions.
5255
5256    This context manager can be used to override the gradient function
5257    that will be used for ops within the scope of the context.
5258
5259    For example:
5260
5261    ```python
5262    @tf.RegisterGradient("CustomSquare")
5263    def _custom_square_grad(op, grad):
5264      # ...
5265
5266    with tf.Graph().as_default() as g:
5267      c = tf.constant(5.0)
5268      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
5269      with g.gradient_override_map({"Square": "CustomSquare"}):
5270        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
5271                              # gradient of s_2.
5272    ```
5273
5274    Args:
5275      op_type_map: A dictionary mapping op type strings to alternative op type
5276        strings.
5277
5278    Returns:
5279      A context manager that sets the alternative op type to be used for one
5280      or more ops created in that context.
5281
5282    Raises:
5283      TypeError: If `op_type_map` is not a dictionary mapping strings to
5284        strings.
5285    """
5286    if not isinstance(op_type_map, dict):
5287      raise TypeError("op_type_map must be a dictionary mapping "
5288                      "strings to strings")
5289    # The saved_mappings dictionary stores any currently-set mappings that
5290    # will be overridden by this context manager.
5291    saved_mappings = {}
5292    # Install the given label
5293    for op_type, mapped_op_type in op_type_map.items():
5294      if not (isinstance(op_type, str) and
5295              isinstance(mapped_op_type, str)):
5296        raise TypeError("op_type_map must be a dictionary mapping "
5297                        "strings to strings")
5298      try:
5299        saved_mappings[op_type] = self._gradient_override_map[op_type]
5300      except KeyError:
5301        pass
5302      self._gradient_override_map[op_type] = mapped_op_type
5303    try:
5304      yield  # The code within the context runs here.
5305    finally:
5306      # Remove the labels set for this context, and restore any saved labels.
5307      for op_type, mapped_op_type in op_type_map.items():
5308        try:
5309          self._gradient_override_map[op_type] = saved_mappings[op_type]
5310        except KeyError:
5311          del self._gradient_override_map[op_type]
5312
5313  # pylint: enable=g-doc-return-or-yield
5314
5315  def prevent_feeding(self, tensor):
5316    """Marks the given `tensor` as unfeedable in this graph."""
5317    self._unfeedable_tensors.add(tensor)
5318
5319  def is_feedable(self, tensor):
5320    """Returns `True` if and only if `tensor` is feedable."""
5321    return tensor not in self._unfeedable_tensors
5322
5323  def prevent_fetching(self, op):
5324    """Marks the given `op` as unfetchable in this graph."""
5325    self._unfetchable_ops.add(op)
5326
5327  def is_fetchable(self, tensor_or_op):
5328    """Returns `True` if and only if `tensor_or_op` is fetchable."""
5329    if isinstance(tensor_or_op, Tensor):
5330      return tensor_or_op.op not in self._unfetchable_ops
5331    else:
5332      return tensor_or_op not in self._unfetchable_ops
5333
5334  def switch_to_thread_local(self):
5335    """Make device, colocation and dependencies stacks thread-local.
5336
5337    Device, colocation and dependencies stacks are not thread-local be default.
5338    If multiple threads access them, then the state is shared.  This means that
5339    one thread may affect the behavior of another thread.
5340
5341    After this method is called, the stacks become thread-local.  If multiple
5342    threads access them, then the state is not shared.  Each thread uses its own
5343    value; a thread doesn't affect other threads by mutating such a stack.
5344
5345    The initial value for every thread's stack is set to the current value
5346    of the stack when `switch_to_thread_local()` was first called.
5347    """
5348    if not self._stack_state_is_thread_local:
5349      self._stack_state_is_thread_local = True
5350
5351  @property
5352  def _device_function_stack(self):
5353    if self._stack_state_is_thread_local:
5354      # This may be called from a thread where device_function_stack doesn't yet
5355      # exist.
5356      # pylint: disable=protected-access
5357      if not hasattr(self._thread_local, "_device_function_stack"):
5358        stack_copy_for_this_thread = self._graph_device_function_stack.copy()
5359        self._thread_local._device_function_stack = stack_copy_for_this_thread
5360      return self._thread_local._device_function_stack
5361      # pylint: enable=protected-access
5362    else:
5363      return self._graph_device_function_stack
5364
5365  @property
5366  def _device_functions_outer_to_inner(self):
5367    user_device_specs = self._device_function_stack.peek_objs()
5368    device_functions = [spec.function for spec in user_device_specs]
5369    device_functions_outer_to_inner = list(reversed(device_functions))
5370    return device_functions_outer_to_inner
5371
5372  def _snapshot_device_function_stack_metadata(self):
5373    """Return device function stack as a list of TraceableObjects.
5374
5375    Returns:
5376      [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
5377      member is a displayable name for the user's argument to Graph.device, and
5378      the filename and lineno members point to the code location where
5379      Graph.device was called directly or indirectly by the user.
5380    """
5381    snapshot = []
5382    for obj in self._device_function_stack.peek_traceable_objs():
5383      obj_copy = obj.copy_metadata()
5384      obj_copy.obj = obj.obj.display_name
5385      snapshot.append(obj_copy)
5386    return snapshot
5387
5388  @_device_function_stack.setter
5389  def _device_function_stack(self, device_function_stack):
5390    if self._stack_state_is_thread_local:
5391      # pylint: disable=protected-access
5392      self._thread_local._device_function_stack = device_function_stack
5393      # pylint: enable=protected-access
5394    else:
5395      self._graph_device_function_stack = device_function_stack
5396
5397  @property
5398  def _colocation_stack(self):
5399    """Return thread-local copy of colocation stack."""
5400    if self._stack_state_is_thread_local:
5401      # This may be called from a thread where colocation_stack doesn't yet
5402      # exist.
5403      # pylint: disable=protected-access
5404      if not hasattr(self._thread_local, "_colocation_stack"):
5405        stack_copy_for_this_thread = self._graph_colocation_stack.copy()
5406        self._thread_local._colocation_stack = stack_copy_for_this_thread
5407      return self._thread_local._colocation_stack
5408      # pylint: enable=protected-access
5409    else:
5410      return self._graph_colocation_stack
5411
5412  def _snapshot_colocation_stack_metadata(self):
5413    """Return colocation stack metadata as a dictionary."""
5414    return {
5415        traceable_obj.obj.name: traceable_obj.copy_metadata()
5416        for traceable_obj in self._colocation_stack.peek_traceable_objs()
5417    }
5418
5419  @_colocation_stack.setter
5420  def _colocation_stack(self, colocation_stack):
5421    if self._stack_state_is_thread_local:
5422      # pylint: disable=protected-access
5423      self._thread_local._colocation_stack = colocation_stack
5424      # pylint: enable=protected-access
5425    else:
5426      self._graph_colocation_stack = colocation_stack
5427
5428  @property
5429  def _control_dependencies_stack(self):
5430    if self._stack_state_is_thread_local:
5431      # This may be called from a thread where control_dependencies_stack
5432      # doesn't yet exist.
5433      if not hasattr(self._thread_local, "_control_dependencies_stack"):
5434        self._thread_local._control_dependencies_stack = (
5435            self._graph_control_dependencies_stack[:])
5436      return self._thread_local._control_dependencies_stack
5437    else:
5438      return self._graph_control_dependencies_stack
5439
5440  @_control_dependencies_stack.setter
5441  def _control_dependencies_stack(self, control_dependencies):
5442    if self._stack_state_is_thread_local:
5443      self._thread_local._control_dependencies_stack = control_dependencies
5444    else:
5445      self._graph_control_dependencies_stack = control_dependencies
5446
5447  @property
5448  def _distribution_strategy_stack(self):
5449    """A stack to maintain distribution strategy context for each thread."""
5450    if not hasattr(self._thread_local, "_distribution_strategy_stack"):
5451      self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
5452    return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
5453
5454  @_distribution_strategy_stack.setter
5455  def _distribution_strategy_stack(self, _distribution_strategy_stack):
5456    self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
5457        _distribution_strategy_stack)
5458
5459  @property
5460  def _global_distribute_strategy_scope(self):
5461    """For implementing `tf.distribute.set_strategy()`."""
5462    if not hasattr(self._thread_local, "distribute_strategy_scope"):
5463      self._thread_local.distribute_strategy_scope = None
5464    return self._thread_local.distribute_strategy_scope
5465
5466  @_global_distribute_strategy_scope.setter
5467  def _global_distribute_strategy_scope(self, distribute_strategy_scope):
5468    self._thread_local.distribute_strategy_scope = (distribute_strategy_scope)
5469
5470  def _mutation_lock(self):
5471    """Returns a lock to guard code that creates & mutates ops.
5472
5473    See the comment for self._group_lock for more info.
5474    """
5475    return self._group_lock.group(_MUTATION_LOCK_GROUP)
5476
5477  def _session_run_lock(self):
5478    """Returns a lock to guard code for Session.run.
5479
5480    See the comment for self._group_lock for more info.
5481    """
5482    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5483
5484
5485# TODO(agarwal): currently device directives in an outer eager scope will not
5486# apply to inner graph mode code. Fix that.
5487
5488
5489@tf_export(v1=["device"])
5490def device(device_name_or_function):
5491  """Wrapper for `Graph.device()` using the default graph.
5492
5493  See `tf.Graph.device` for more details.
5494
5495  Args:
5496    device_name_or_function: The device name or function to use in the context.
5497
5498  Returns:
5499    A context manager that specifies the default device to use for newly
5500    created ops.
5501
5502  Raises:
5503    RuntimeError: If eager execution is enabled and a function is passed in.
5504  """
5505  if context.executing_eagerly():
5506    if callable(device_name_or_function):
5507      raise RuntimeError(
5508          "tf.device does not support functions when eager execution "
5509          "is enabled.")
5510    return context.device(device_name_or_function)
5511  elif executing_eagerly_outside_functions():
5512    @tf_contextlib.contextmanager
5513    def combined(device_name_or_function):
5514      with get_default_graph().device(device_name_or_function):
5515        if not callable(device_name_or_function):
5516          with context.device(device_name_or_function):
5517            yield
5518        else:
5519          yield
5520    return combined(device_name_or_function)
5521  else:
5522    return get_default_graph().device(device_name_or_function)
5523
5524
5525@tf_export("device", v1=[])
5526def device_v2(device_name):
5527  """Specifies the device for ops created/executed in this context.
5528
5529  This function specifies the device to be used for ops created/executed in a
5530  particular context. Nested contexts will inherit and also create/execute
5531  their ops on the specified device. If a specific device is not required,
5532  consider not using this function so that a device can be automatically
5533  assigned.  In general the use of this function is optional. `device_name` can
5534  be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially
5535  specified, containing only a subset of the "/"-separated fields. Any fields
5536  which are specified will override device annotations from outer scopes.
5537
5538  For example:
5539
5540  ```python
5541  with tf.device('/job:foo'):
5542    # ops created here have devices with /job:foo
5543    with tf.device('/job:bar/task:0/device:gpu:2'):
5544      # ops created here have the fully specified device above
5545    with tf.device('/device:gpu:1'):
5546      # ops created here have the device '/job:foo/device:gpu:1'
5547  ```
5548
5549  Args:
5550    device_name: The device name to use in the context.
5551
5552  Returns:
5553    A context manager that specifies the default device to use for newly
5554    created ops.
5555
5556  Raises:
5557    RuntimeError: If a function is passed in.
5558  """
5559  if callable(device_name):
5560    raise RuntimeError("tf.device does not support functions.")
5561  return device(device_name)
5562
5563
5564@tf_export(v1=["container"])
5565def container(container_name):
5566  """Wrapper for `Graph.container()` using the default graph.
5567
5568  Args:
5569    container_name: The container string to use in the context.
5570
5571  Returns:
5572    A context manager that specifies the default container to use for newly
5573    created stateful ops.
5574  """
5575  return get_default_graph().container(container_name)
5576
5577
5578def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5579  if context.executing_eagerly():
5580    if op is not None:
5581      if not hasattr(op, "device"):
5582        op = internal_convert_to_tensor_or_indexed_slices(op)
5583      return device(op.device)
5584    else:
5585      return NullContextmanager()
5586  else:
5587    default_graph = get_default_graph()
5588    if isinstance(op, EagerTensor):
5589      if default_graph.building_function:
5590        return default_graph.device(op.device)
5591      else:
5592        raise ValueError("Encountered an Eager-defined Tensor during graph "
5593                         "construction, but a function was not being built.")
5594    return default_graph._colocate_with_for_gradient(
5595        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5596
5597
5598# Internal interface to colocate_with. colocate_with has been deprecated from
5599# public API. There are still a few internal uses of colocate_with. Add internal
5600# only API for those uses to avoid deprecation warning.
5601def colocate_with(op, ignore_existing=False):
5602  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5603
5604
5605@deprecation.deprecated(
5606    date=None, instructions="Colocations handled automatically by placer.")
5607@tf_export(v1=["colocate_with"])
5608def _colocate_with(op, ignore_existing=False):
5609  return colocate_with(op, ignore_existing)
5610
5611
5612@tf_export("control_dependencies")
5613def control_dependencies(control_inputs):
5614  """Wrapper for `Graph.control_dependencies()` using the default graph.
5615
5616  See `tf.Graph.control_dependencies` for more details.
5617
5618  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
5619  this method, as ops execute in the expected order thanks to automatic control
5620  dependencies.* Only use `tf.control_dependencies` when working with v1
5621  `tf.Graph` code.
5622
5623  When eager execution is enabled, any callable object in the `control_inputs`
5624  list will be called.
5625
5626  Args:
5627    control_inputs: A list of `Operation` or `Tensor` objects which must be
5628      executed or computed before running the operations defined in the context.
5629      Can also be `None` to clear the control dependencies. If eager execution
5630      is enabled, any callable object in the `control_inputs` list will be
5631      called.
5632
5633  Returns:
5634   A context manager that specifies control dependencies for all
5635   operations constructed within the context.
5636  """
5637  if context.executing_eagerly():
5638    if control_inputs:
5639      # Execute any pending callables.
5640      for control in control_inputs:
5641        if callable(control):
5642          control()
5643    return NullContextmanager()
5644  else:
5645    return get_default_graph().control_dependencies(control_inputs)
5646
5647
5648class _DefaultStack(threading.local):
5649  """A thread-local stack of objects for providing implicit defaults."""
5650
5651  def __init__(self):
5652    super(_DefaultStack, self).__init__()
5653    self._enforce_nesting = True
5654    self.stack = []
5655
5656  def get_default(self):
5657    return self.stack[-1] if self.stack else None
5658
5659  def reset(self):
5660    self.stack = []
5661
5662  def is_cleared(self):
5663    return not self.stack
5664
5665  @property
5666  def enforce_nesting(self):
5667    return self._enforce_nesting
5668
5669  @enforce_nesting.setter
5670  def enforce_nesting(self, value):
5671    self._enforce_nesting = value
5672
5673  @tf_contextlib.contextmanager
5674  def get_controller(self, default):
5675    """A context manager for manipulating a default stack."""
5676    self.stack.append(default)
5677    try:
5678      yield default
5679    finally:
5680      # stack may be empty if reset() was called
5681      if self.stack:
5682        if self._enforce_nesting:
5683          if self.stack[-1] is not default:
5684            raise AssertionError(
5685                "Nesting violated for default stack of %s objects" %
5686                type(default))
5687          self.stack.pop()
5688        else:
5689          self.stack.remove(default)
5690
5691
5692_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
5693
5694
5695def default_session(session):
5696  """Python "with" handler for defining a default session.
5697
5698  This function provides a means of registering a session for handling
5699  Tensor.eval() and Operation.run() calls. It is primarily intended for use
5700  by session.Session, but can be used with any object that implements
5701  the Session.run() interface.
5702
5703  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
5704  invocations within the scope of a block should be executed by a particular
5705  session.
5706
5707  The default session applies to the current thread only, so it is always
5708  possible to inspect the call stack and determine the scope of a default
5709  session. If you create a new thread, and wish to use the default session
5710  in that thread, you must explicitly add a "with ops.default_session(sess):"
5711  block in that thread's function.
5712
5713  Example:
5714    The following code examples are equivalent:
5715
5716    # 1. Using the Session object directly:
5717    sess = ...
5718    c = tf.constant(5.0)
5719    sess.run(c)
5720
5721    # 2. Using default_session():
5722    sess = ...
5723    with ops.default_session(sess):
5724      c = tf.constant(5.0)
5725      result = c.eval()
5726
5727    # 3. Overriding default_session():
5728    sess = ...
5729    with ops.default_session(sess):
5730      c = tf.constant(5.0)
5731      with ops.default_session(...):
5732        c.eval(session=sess)
5733
5734  Args:
5735    session: The session to be installed as the default session.
5736
5737  Returns:
5738    A context manager for the default session.
5739  """
5740  return _default_session_stack.get_controller(session)
5741
5742
5743@tf_export(v1=["get_default_session"])
5744def get_default_session():
5745  """Returns the default session for the current thread.
5746
5747  The returned `Session` will be the innermost session on which a
5748  `Session` or `Session.as_default()` context has been entered.
5749
5750  NOTE: The default session is a property of the current thread. If you
5751  create a new thread, and wish to use the default session in that
5752  thread, you must explicitly add a `with sess.as_default():` in that
5753  thread's function.
5754
5755  Returns:
5756    The default `Session` being used in the current thread.
5757  """
5758  return _default_session_stack.get_default()
5759
5760
5761def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5762  """Uses the default session to evaluate one or more tensors.
5763
5764  Args:
5765    tensors: A single Tensor, or a list of Tensor objects.
5766    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5767      numpy ndarrays, TensorProtos, or strings.
5768    graph: The graph in which the tensors are defined.
5769    session: (Optional) A different session to use to evaluate "tensors".
5770
5771  Returns:
5772    Either a single numpy ndarray if "tensors" is a single tensor; or a list
5773    of numpy ndarrays that each correspond to the respective element in
5774    "tensors".
5775
5776  Raises:
5777    ValueError: If no default session is available; the default session
5778      does not have "graph" as its graph; or if "session" is specified,
5779      and it does not have "graph" as its graph.
5780  """
5781  if session is None:
5782    session = get_default_session()
5783    if session is None:
5784      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5785                       "session is registered. Use `with "
5786                       "sess.as_default()` or pass an explicit session to "
5787                       "`eval(session=sess)`")
5788    if session.graph is not graph:
5789      raise ValueError("Cannot use the default session to evaluate tensor: "
5790                       "the tensor's graph is different from the session's "
5791                       "graph. Pass an explicit session to "
5792                       "`eval(session=sess)`.")
5793  else:
5794    if session.graph is not graph:
5795      raise ValueError("Cannot use the given session to evaluate tensor: "
5796                       "the tensor's graph is different from the session's "
5797                       "graph.")
5798  return session.run(tensors, feed_dict)
5799
5800
5801def _run_using_default_session(operation, feed_dict, graph, session=None):
5802  """Uses the default session to run "operation".
5803
5804  Args:
5805    operation: The Operation to be run.
5806    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5807      numpy ndarrays, TensorProtos, or strings.
5808    graph: The graph in which "operation" is defined.
5809    session: (Optional) A different session to use to run "operation".
5810
5811  Raises:
5812    ValueError: If no default session is available; the default session
5813      does not have "graph" as its graph; or if "session" is specified,
5814      and it does not have "graph" as its graph.
5815  """
5816  if session is None:
5817    session = get_default_session()
5818    if session is None:
5819      raise ValueError("Cannot execute operation using `run()`: No default "
5820                       "session is registered. Use `with "
5821                       "sess.as_default():` or pass an explicit session to "
5822                       "`run(session=sess)`")
5823    if session.graph is not graph:
5824      raise ValueError("Cannot use the default session to execute operation: "
5825                       "the operation's graph is different from the "
5826                       "session's graph. Pass an explicit session to "
5827                       "run(session=sess).")
5828  else:
5829    if session.graph is not graph:
5830      raise ValueError("Cannot use the given session to execute operation: "
5831                       "the operation's graph is different from the session's "
5832                       "graph.")
5833  session.run(operation, feed_dict)
5834
5835
5836class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
5837  """A thread-local stack of objects for providing an implicit default graph."""
5838
5839  def __init__(self):
5840    super(_DefaultGraphStack, self).__init__()
5841    self._global_default_graph = None
5842
5843  def get_default(self):
5844    """Override that returns a global default if the stack is empty."""
5845    if self.stack:
5846      return self.stack[-1]
5847    elif self._global_default_graph:
5848      return self._global_default_graph
5849    else:
5850      self._global_default_graph = Graph()
5851      return self._global_default_graph
5852
5853  def _GetGlobalDefaultGraph(self):
5854    if self._global_default_graph is None:
5855      # TODO(mrry): Perhaps log that the default graph is being used, or set
5856      #   provide some other feedback to prevent confusion when a mixture of
5857      #   the global default graph and an explicit graph are combined in the
5858      #   same process.
5859      self._global_default_graph = Graph()
5860    return self._global_default_graph
5861
5862  def reset(self):
5863    super(_DefaultGraphStack, self).reset()
5864    self._global_default_graph = None
5865
5866  @tf_contextlib.contextmanager
5867  def get_controller(self, default):
5868    context.context().context_switches.push(default.building_function,
5869                                            default.as_default,
5870                                            default._device_function_stack)
5871    try:
5872      with super(_DefaultGraphStack,
5873                 self).get_controller(default) as g, context.graph_mode():
5874        yield g
5875    finally:
5876      # If an exception is raised here it may be hiding a related exception in
5877      # the try-block (just above).
5878      context.context().context_switches.pop()
5879
5880
5881_default_graph_stack = _DefaultGraphStack()
5882
5883
5884# Shared helper used in init_scope and executing_eagerly_outside_functions
5885# to obtain the outermost context that is not building a function, and the
5886# innermost non empty device stack.
5887def _get_outer_context_and_inner_device_stack():
5888  """Get the outermost context not building a function."""
5889  default_graph = get_default_graph()
5890  outer_context = None
5891  innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
5892
5893  if not _default_graph_stack.stack:
5894    # If the default graph stack is empty, then we cannot be building a
5895    # function. Install the global graph (which, in this case, is also the
5896    # default graph) as the outer context.
5897    if default_graph.building_function:
5898      raise RuntimeError("The global graph is building a function.")
5899    outer_context = default_graph.as_default
5900  else:
5901    # Find a context that is not building a function.
5902    for stack_entry in reversed(context.context().context_switches.stack):
5903      if not innermost_nonempty_device_stack:
5904        innermost_nonempty_device_stack = stack_entry.device_stack
5905      if not stack_entry.is_building_function:
5906        outer_context = stack_entry.enter_context_fn
5907        break
5908
5909    if outer_context is None:
5910      # As a last resort, obtain the global default graph; this graph doesn't
5911      # necessarily live on the graph stack (and hence it doesn't necessarily
5912      # live on the context stack), but it is stored in the graph stack's
5913      # encapsulating object.
5914      outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
5915
5916  if outer_context is None:
5917    # Sanity check; this shouldn't be triggered.
5918    raise RuntimeError("All graphs are building functions, and no "
5919                       "eager context was previously active.")
5920
5921  return outer_context, innermost_nonempty_device_stack
5922
5923
5924# pylint: disable=g-doc-return-or-yield,line-too-long
5925@tf_export("init_scope")
5926@tf_contextlib.contextmanager
5927def init_scope():
5928  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5929
5930  There is often a need to lift variable initialization ops out of control-flow
5931  scopes, function-building graphs, and gradient tapes. Entering an
5932  `init_scope` is a mechanism for satisfying these desiderata. In particular,
5933  entering an `init_scope` has three effects:
5934
5935    (1) All control dependencies are cleared the moment the scope is entered;
5936        this is equivalent to entering the context manager returned from
5937        `control_dependencies(None)`, which has the side-effect of exiting
5938        control-flow scopes like `tf.cond` and `tf.while_loop`.
5939
5940    (2) All operations that are created while the scope is active are lifted
5941        into the lowest context on the `context_stack` that is not building a
5942        graph function. Here, a context is defined as either a graph or an eager
5943        context. Every context switch, i.e., every installation of a graph as
5944        the default graph and every switch into eager mode, is logged in a
5945        thread-local stack called `context_switches`; the log entry for a
5946        context switch is popped from the stack when the context is exited.
5947        Entering an `init_scope` is equivalent to crawling up
5948        `context_switches`, finding the first context that is not building a
5949        graph function, and entering it. A caveat is that if graph mode is
5950        enabled but the default graph stack is empty, then entering an
5951        `init_scope` will simply install a fresh graph as the default one.
5952
5953    (3) The gradient tape is paused while the scope is active.
5954
5955  When eager execution is enabled, code inside an init_scope block runs with
5956  eager execution enabled even when tracing a `tf.function`. For example:
5957
5958  ```python
5959  tf.compat.v1.enable_eager_execution()
5960
5961  @tf.function
5962  def func():
5963    # A function constructs TensorFlow graphs,
5964    # it does not execute eagerly.
5965    assert not tf.executing_eagerly()
5966    with tf.init_scope():
5967      # Initialization runs with eager execution enabled
5968      assert tf.executing_eagerly()
5969  ```
5970
5971  Raises:
5972    RuntimeError: if graph state is incompatible with this initialization.
5973  """
5974  # pylint: enable=g-doc-return-or-yield,line-too-long
5975
5976  if context.executing_eagerly():
5977    # Fastpath.
5978    with tape.stop_recording():
5979      yield
5980  else:
5981    # Retrieve the active name scope: entering an `init_scope` preserves
5982    # the name scope of the current context.
5983    scope = get_default_graph().get_name_scope()
5984    if scope and scope[-1] != "/":
5985      # Names that end with trailing slashes are treated by `name_scope` as
5986      # absolute.
5987      scope = scope + "/"
5988
5989    outer_context, innermost_nonempty_device_stack = (
5990        _get_outer_context_and_inner_device_stack())
5991
5992    outer_graph = None
5993    outer_device_stack = None
5994    try:
5995      with outer_context(), name_scope(
5996          scope, skip_on_eager=False), control_dependencies(
5997              None), tape.stop_recording():
5998        context_manager = NullContextmanager
5999        context_manager_input = None
6000        if not context.executing_eagerly():
6001          # The device stack is preserved when lifting into a graph. Eager
6002          # execution doesn't implement device stacks and in particular it
6003          # doesn't support device functions, so in general it's not possible
6004          # to do the same when lifting into the eager context.
6005          outer_graph = get_default_graph()
6006          outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
6007          outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
6008        elif innermost_nonempty_device_stack is not None:
6009          for device_spec in innermost_nonempty_device_stack.peek_objs():
6010            if device_spec.function is None:
6011              break
6012            if device_spec.raw_string:
6013              context_manager = context.device
6014              context_manager_input = device_spec.raw_string
6015              break
6016            # It is currently not possible to have a device function in V2,
6017            # but in V1 we are unable to apply device functions in eager mode.
6018            # This means that we will silently skip some of the entries on the
6019            # device stack in V1 + eager mode.
6020
6021        with context_manager(context_manager_input):
6022          yield
6023    finally:
6024      # If an exception is raised here it may be hiding a related exception in
6025      # try-block (just above).
6026      if outer_graph is not None:
6027        outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
6028
6029
6030@tf_export(v1=["executing_eagerly_outside_functions"])
6031def executing_eagerly_outside_functions():
6032  """Returns True if executing eagerly, even if inside a graph function.
6033
6034  This function will check the outermost context for the program and see if
6035  it is in eager mode. It is useful comparing to `tf.executing_eagerly()`,
6036  which checks the current context and will return `False` within a
6037  `tf.function` body. It can be used to build library that behave differently
6038  in eager runtime and v1 session runtime (deprecated).
6039
6040  Example:
6041
6042  >>> tf.compat.v1.enable_eager_execution()
6043  >>> @tf.function
6044  ... def func():
6045  ...   # A function constructs TensorFlow graphs, it does not execute eagerly,
6046  ...   # but the outer most context is still eager.
6047  ...   assert not tf.executing_eagerly()
6048  ...   return tf.compat.v1.executing_eagerly_outside_functions()
6049  >>> func()
6050  <tf.Tensor: shape=(), dtype=bool, numpy=True>
6051
6052  Returns:
6053    boolean, whether the outermost context is in eager mode.
6054  """
6055  if context.executing_eagerly():
6056    return True
6057  else:
6058    outer_context, _ = _get_outer_context_and_inner_device_stack()
6059    with outer_context():
6060      return context.executing_eagerly()
6061
6062
6063@tf_export("inside_function", v1=[])
6064def inside_function():
6065  """Indicates whether the caller code is executing inside a `tf.function`.
6066
6067  Returns:
6068    Boolean, True if the caller code is executing inside a `tf.function`
6069    rather than eagerly.
6070
6071  Example:
6072
6073  >>> tf.inside_function()
6074  False
6075  >>> @tf.function
6076  ... def f():
6077  ...   print(tf.inside_function())
6078  >>> f()
6079  True
6080  """
6081  return get_default_graph().building_function
6082
6083
6084@tf_export(v1=["enable_eager_execution"])
6085def enable_eager_execution(config=None, device_policy=None,
6086                           execution_mode=None):
6087  """Enables eager execution for the lifetime of this program.
6088
6089  Eager execution provides an imperative interface to TensorFlow. With eager
6090  execution enabled, TensorFlow functions execute operations immediately (as
6091  opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`)
6092  and
6093  return concrete values (as opposed to symbolic references to a node in a
6094  computational graph).
6095
6096  For example:
6097
6098  ```python
6099  tf.compat.v1.enable_eager_execution()
6100
6101  # After eager execution is enabled, operations are executed as they are
6102  # defined and Tensor objects hold concrete values, which can be accessed as
6103  # numpy.ndarray`s through the numpy() method.
6104  assert tf.multiply(6, 7).numpy() == 42
6105  ```
6106
6107  Eager execution cannot be enabled after TensorFlow APIs have been used to
6108  create or execute graphs. It is typically recommended to invoke this function
6109  at program startup and not in a library (as most libraries should be usable
6110  both with and without eager execution).
6111
6112  @compatibility(TF2)
6113  This function is not necessary if you are using TF2. Eager execution is
6114  enabled by default.
6115  @end_compatibility
6116
6117  Args:
6118    config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the
6119      environment in which operations are executed. Note that
6120      `tf.compat.v1.ConfigProto` is also used to configure graph execution (via
6121      `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto`
6122      are not implemented (or are irrelevant) when eager execution is enabled.
6123    device_policy: (Optional.) Policy controlling how operations requiring
6124      inputs on a specific device (e.g., a GPU 0) handle inputs on a different
6125      device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will
6126      be picked automatically. The value picked may change between TensorFlow
6127      releases.
6128      Valid values:
6129      - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
6130        placement is not correct.
6131      - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
6132        on the right device but logs a warning.
6133      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
6134        Note that this may hide performance problems as there is no notification
6135        provided when operations are blocked on the tensor being copied between
6136        devices.
6137      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
6138        int32 tensors, raising errors on the other ones.
6139    execution_mode: (Optional.) Policy controlling how operations dispatched are
6140      actually executed. When set to None, an appropriate value will be picked
6141      automatically. The value picked may change between TensorFlow releases.
6142      Valid values:
6143      - tf.contrib.eager.SYNC: executes each operation synchronously.
6144      - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
6145        operations may return "non-ready" handles.
6146
6147  Raises:
6148    ValueError: If eager execution is enabled after creating/executing a
6149     TensorFlow graph, or if options provided conflict with a previous call
6150     to this function.
6151  """
6152  _api_usage_gauge.get_cell().set(True)
6153  logging.vlog(1, "Enabling eager execution")
6154  if context.default_execution_mode != context.EAGER_MODE:
6155    return enable_eager_execution_internal(
6156        config=config,
6157        device_policy=device_policy,
6158        execution_mode=execution_mode,
6159        server_def=None)
6160
6161
6162@tf_export(v1=["disable_eager_execution"])
6163def disable_eager_execution():
6164  """Disables eager execution.
6165
6166  This function can only be called before any Graphs, Ops, or Tensors have been
6167  created.
6168
6169  @compatibility(TF2)
6170  This function is not necessary if you are using TF2. Eager execution is
6171  enabled by default. If you want to use Graph mode please consider
6172  [tf.function](https://www.tensorflow.org/api_docs/python/tf/function).
6173  @end_compatibility
6174  """
6175  _api_usage_gauge.get_cell().set(False)
6176  logging.vlog(1, "Disabling eager execution")
6177  context.default_execution_mode = context.GRAPH_MODE
6178  c = context.context_safe()
6179  if c is not None:
6180    c._thread_local_data.is_eager = False  # pylint: disable=protected-access
6181
6182
6183def enable_eager_execution_internal(config=None,
6184                                    device_policy=None,
6185                                    execution_mode=None,
6186                                    server_def=None):
6187  """Enables eager execution for the lifetime of this program.
6188
6189  Most of the doc string for enable_eager_execution is relevant here as well.
6190
6191  Args:
6192    config: See enable_eager_execution doc string
6193    device_policy: See enable_eager_execution doc string
6194    execution_mode: See enable_eager_execution doc string
6195    server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on
6196      remote devices. GrpcServers need to be started by creating an identical
6197      server_def to this, and setting the appropriate task_indexes, so that the
6198      servers can communicate. It will then be possible to execute operations on
6199      remote devices.
6200
6201  Raises:
6202    ValueError
6203
6204  """
6205  if config is not None and not isinstance(config, config_pb2.ConfigProto):
6206    raise TypeError("config must be a tf.ConfigProto, but got %s" %
6207                    type(config))
6208  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
6209                           context.DEVICE_PLACEMENT_WARN,
6210                           context.DEVICE_PLACEMENT_SILENT,
6211                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
6212    raise ValueError(
6213        "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
6214    )
6215  if execution_mode not in (None, context.SYNC, context.ASYNC):
6216    raise ValueError(
6217        "execution_mode must be one of None, tf.contrib.eager.SYNC, "
6218        "tf.contrib.eager.ASYNC")
6219  if context.default_execution_mode == context.GRAPH_MODE:
6220    graph_mode_has_been_used = (
6221        _default_graph_stack._global_default_graph is not None)  # pylint: disable=protected-access
6222    if graph_mode_has_been_used:
6223      raise ValueError(
6224          "tf.enable_eager_execution must be called at program startup.")
6225  context.default_execution_mode = context.EAGER_MODE
6226  # pylint: disable=protected-access
6227  with context._context_lock:
6228    if context._context is None:
6229      context._set_context_locked(context.Context(
6230          config=config,
6231          device_policy=device_policy,
6232          execution_mode=execution_mode,
6233          server_def=server_def))
6234    elif ((config is not None and config is not context._context._config) or
6235          (device_policy is not None and
6236           device_policy is not context._context._device_policy) or
6237          (execution_mode is not None and
6238           execution_mode is not context._context._execution_mode)):
6239      raise ValueError(
6240          "Trying to change the options of an active eager"
6241          " execution. Context config: %s, specified config:"
6242          " %s. Context device policy: %s, specified device"
6243          " policy: %s. Context execution mode: %s, "
6244          " specified execution mode %s." %
6245          (context._context._config, config, context._context._device_policy,
6246           device_policy, context._context._execution_mode, execution_mode))
6247    else:
6248      # We already created everything, so update the thread local data.
6249      context._context._thread_local_data.is_eager = True
6250
6251  # Monkey patch to get rid of an unnecessary conditional since the context is
6252  # now initialized.
6253  context.context = context.context_safe
6254
6255
6256def eager_run(main=None, argv=None):
6257  """Runs the program with an optional main function and argv list.
6258
6259  The program will run with eager execution enabled.
6260
6261  Example:
6262  ```python
6263  import tensorflow as tf
6264  # Import subject to future changes:
6265  from tensorflow.contrib.eager.python import tfe
6266
6267  def main(_):
6268    u = tf.constant(6.0)
6269    v = tf.constant(7.0)
6270    print(u * v)
6271
6272  if __name__ == "__main__":
6273    tfe.run()
6274  ```
6275
6276  Args:
6277    main: the main function to run.
6278    argv: the arguments to pass to it.
6279  """
6280  enable_eager_execution()
6281  app.run(main, argv)
6282
6283
6284@tf_export(v1=["reset_default_graph"])
6285def reset_default_graph():
6286  """Clears the default graph stack and resets the global default graph.
6287
6288  NOTE: The default graph is a property of the current thread. This
6289  function applies only to the current thread.  Calling this function while
6290  a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will
6291  result in undefined
6292  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
6293  after calling this function will result in undefined behavior.
6294
6295  @compatibility(TF2)
6296  `reset_default_graph` does not work with either eager execution or
6297  `tf.function`, and you should not invoke it directly. To migrate code that
6298  uses Graph-related functions to TF2, rewrite the code without them. See the
6299  [migration guide](https://www.tensorflow.org/guide/migrate) for more
6300  description about the behavior and semantic changes between Tensorflow 1 and
6301  Tensorflow 2.
6302  @end_compatibility
6303
6304  Raises:
6305    AssertionError: If this function is called within a nested graph.
6306  """
6307  if not _default_graph_stack.is_cleared():
6308    raise AssertionError("Do not use tf.reset_default_graph() to clear "
6309                         "nested graphs. If you need a cleared graph, "
6310                         "exit the nesting and create a new graph.")
6311  _default_graph_stack.reset()
6312
6313
6314@tf_export(v1=["get_default_graph"])
6315def get_default_graph():
6316  """Returns the default graph for the current thread.
6317
6318  The returned graph will be the innermost graph on which a
6319  `Graph.as_default()` context has been entered, or a global default
6320  graph if none has been explicitly created.
6321
6322  NOTE: The default graph is a property of the current thread. If you
6323  create a new thread, and wish to use the default graph in that
6324  thread, you must explicitly add a `with g.as_default():` in that
6325  thread's function.
6326
6327  @compatibility(TF2)
6328  `get_default_graph` does not work with either eager execution or
6329  `tf.function`, and you should not invoke it directly. To migrate code that
6330  uses Graph-related functions to TF2, rewrite the code without them. See the
6331  [migration guide](https://www.tensorflow.org/guide/migrate) for more
6332  description about the behavior and semantic changes between Tensorflow 1 and
6333  Tensorflow 2.
6334  @end_compatibility
6335
6336  Returns:
6337    The default `Graph` being used in the current thread.
6338  """
6339  return _default_graph_stack.get_default()
6340
6341
6342def has_default_graph():
6343  """Returns True if there is a default graph."""
6344  return len(_default_graph_stack.stack) >= 1
6345
6346
6347# Exported due to b/171079555
6348@tf_export("__internal__.get_name_scope", v1=[])
6349def get_name_scope():
6350  """Returns the current name scope in the default_graph.
6351
6352  For example:
6353
6354  ```python
6355  with tf.name_scope('scope1'):
6356    with tf.name_scope('scope2'):
6357      print(tf.get_name_scope())
6358  ```
6359  would print the string `scope1/scope2`.
6360
6361  Returns:
6362    A string representing the current name scope.
6363  """
6364  if context.executing_eagerly():
6365    return context.context().scope_name.rstrip("/")
6366  return get_default_graph().get_name_scope()
6367
6368
6369def _assert_same_graph(original_item, item):
6370  """Fail if the 2 items are from different graphs.
6371
6372  Args:
6373    original_item: Original item to check against.
6374    item: Item to check.
6375
6376  Raises:
6377    ValueError: if graphs do not match.
6378  """
6379  original_graph = getattr(original_item, "graph", None)
6380  graph = getattr(item, "graph", None)
6381  if original_graph and graph and original_graph is not graph:
6382    raise ValueError(
6383        "%s must be from the same graph as %s (graphs are %s and %s)." %
6384        (item, original_item, graph, original_graph))
6385
6386
6387def _get_graph_from_inputs(op_input_list, graph=None):
6388  """Returns the appropriate graph to use for the given inputs.
6389
6390  This library method provides a consistent algorithm for choosing the graph
6391  in which an Operation should be constructed:
6392
6393  1. If the default graph is being used to construct a function, we
6394     use the default graph.
6395  2. If the "graph" is specified explicitly, we validate that all of the inputs
6396     in "op_input_list" are compatible with that graph.
6397  3. Otherwise, we attempt to select a graph from the first Operation-
6398     or Tensor-valued input in "op_input_list", and validate that all other
6399     such inputs are in the same graph.
6400  4. If the graph was not specified and it could not be inferred from
6401     "op_input_list", we attempt to use the default graph.
6402
6403  Args:
6404    op_input_list: A list of inputs to an operation, which may include `Tensor`,
6405      `Operation`, and other objects that may be converted to a graph element.
6406    graph: (Optional) The explicit graph to use.
6407
6408  Raises:
6409    TypeError: If op_input_list is not a list or tuple, or if graph is not a
6410      Graph.
6411    ValueError: If a graph is explicitly passed and not all inputs are from it,
6412      or if the inputs are from multiple graphs, or we could not find a graph
6413      and there was no default graph.
6414
6415  Returns:
6416    The appropriate graph to use for the given inputs.
6417
6418  """
6419  current_default_graph = get_default_graph()
6420  if current_default_graph.building_function:
6421    return current_default_graph
6422
6423  op_input_list = tuple(op_input_list)  # Handle generators correctly
6424  if graph and not isinstance(graph, Graph):
6425    raise TypeError("Input graph needs to be a Graph: %s" % (graph,))
6426
6427  # 1. We validate that all of the inputs are from the same graph. This is
6428  #    either the supplied graph parameter, or the first one selected from one
6429  #    the graph-element-valued inputs. In the latter case, we hold onto
6430  #    that input in original_graph_element so we can provide a more
6431  #    informative error if a mismatch is found.
6432  original_graph_element = None
6433  for op_input in op_input_list:
6434    # Determine if this is a valid graph_element.
6435    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
6436    # up.
6437    graph_element = None
6438    if (isinstance(op_input, (Operation, internal.NativeObject)) and
6439        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
6440      graph_element = op_input
6441    else:
6442      graph_element = _as_graph_element(op_input)
6443
6444    if graph_element is not None:
6445      if not graph:
6446        original_graph_element = graph_element
6447        graph = getattr(graph_element, "graph", None)
6448      elif original_graph_element is not None:
6449        _assert_same_graph(original_graph_element, graph_element)
6450      elif graph_element.graph is not graph:
6451        raise ValueError("%s is not from the passed-in graph." % graph_element)
6452
6453  # 2. If all else fails, we use the default graph, which is always there.
6454  return graph or current_default_graph
6455
6456
6457@tf_export(v1=["GraphKeys"])
6458class GraphKeys(object):
6459  """Standard names to use for graph collections.
6460
6461  The standard library uses various well-known names to collect and
6462  retrieve values associated with a graph. For example, the
6463  `tf.Optimizer` subclasses default to optimizing the variables
6464  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
6465  specified, but it is also possible to pass an explicit list of
6466  variables.
6467
6468  The following standard keys are defined:
6469
6470  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
6471    across distributed environment (model variables are subset of these). See
6472    `tf.compat.v1.global_variables`
6473    for more details.
6474    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
6475    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
6476  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
6477    machine. Usually used for temporarily variables, like counters.
6478    Note: use `tf.contrib.framework.local_variable` to add to this collection.
6479  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
6480    model for inference (feed forward). Note: use
6481    `tf.contrib.framework.model_variable` to add to this collection.
6482  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
6483    be trained by an optimizer. See
6484    `tf.compat.v1.trainable_variables`
6485    for more details.
6486  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
6487    graph. See
6488    `tf.compat.v1.summary.merge_all`
6489    for more details.
6490  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
6491    produce input for a computation. See
6492    `tf.compat.v1.train.start_queue_runners`
6493    for more details.
6494  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
6495    keep moving averages.  See
6496    `tf.compat.v1.moving_average_variables`
6497    for more details.
6498  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
6499    construction.
6500
6501  The following standard keys are _defined_, but their collections are **not**
6502  automatically populated as many of the others are:
6503
6504  * `WEIGHTS`
6505  * `BIASES`
6506  * `ACTIVATIONS`
6507  """
6508
6509  # Key to collect Variable objects that are global (shared across machines).
6510  # Default collection for all variables, except local ones.
6511  GLOBAL_VARIABLES = "variables"
6512  # Key to collect local variables that are local to the machine and are not
6513  # saved/restored.
6514  LOCAL_VARIABLES = "local_variables"
6515  # Key to collect local variables which are used to accumulate interal state
6516  # to be used in tf.metrics.*.
6517  METRIC_VARIABLES = "metric_variables"
6518  # Key to collect model variables defined by layers.
6519  MODEL_VARIABLES = "model_variables"
6520  # Key to collect Variable objects that will be trained by the
6521  # optimizers.
6522  TRAINABLE_VARIABLES = "trainable_variables"
6523  # Key to collect summaries.
6524  SUMMARIES = "summaries"
6525  # Key to collect QueueRunners.
6526  QUEUE_RUNNERS = "queue_runners"
6527  # Key to collect table initializers.
6528  TABLE_INITIALIZERS = "table_initializer"
6529  # Key to collect asset filepaths. An asset represents an external resource
6530  # like a vocabulary file.
6531  ASSET_FILEPATHS = "asset_filepaths"
6532  # Key to collect Variable objects that keep moving averages.
6533  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
6534  # Key to collect regularization losses at graph construction.
6535  REGULARIZATION_LOSSES = "regularization_losses"
6536  # Key to collect concatenated sharded variables.
6537  CONCATENATED_VARIABLES = "concatenated_variables"
6538  # Key to collect savers.
6539  SAVERS = "savers"
6540  # Key to collect weights
6541  WEIGHTS = "weights"
6542  # Key to collect biases
6543  BIASES = "biases"
6544  # Key to collect activations
6545  ACTIVATIONS = "activations"
6546  # Key to collect update_ops
6547  UPDATE_OPS = "update_ops"
6548  # Key to collect losses
6549  LOSSES = "losses"
6550  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6551  SAVEABLE_OBJECTS = "saveable_objects"
6552  # Key to collect all shared resources used by the graph which need to be
6553  # initialized once per cluster.
6554  RESOURCES = "resources"
6555  # Key to collect all shared resources used in this graph which need to be
6556  # initialized once per session.
6557  LOCAL_RESOURCES = "local_resources"
6558  # Trainable resource-style variables.
6559  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6560
6561  # Key to indicate various ops.
6562  INIT_OP = "init_op"
6563  LOCAL_INIT_OP = "local_init_op"
6564  READY_OP = "ready_op"
6565  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6566  SUMMARY_OP = "summary_op"
6567  GLOBAL_STEP = "global_step"
6568
6569  # Used to count the number of evaluations performed during a single evaluation
6570  # run.
6571  EVAL_STEP = "eval_step"
6572  TRAIN_OP = "train_op"
6573
6574  # Key for control flow context.
6575  COND_CONTEXT = "cond_context"
6576  WHILE_CONTEXT = "while_context"
6577
6578  # Used to store v2 summary names.
6579  _SUMMARY_COLLECTION = "_SUMMARY_V2"
6580
6581  # List of all collections that keep track of variables.
6582  _VARIABLE_COLLECTIONS = [
6583      GLOBAL_VARIABLES,
6584      LOCAL_VARIABLES,
6585      METRIC_VARIABLES,
6586      MODEL_VARIABLES,
6587      TRAINABLE_VARIABLES,
6588      MOVING_AVERAGE_VARIABLES,
6589      CONCATENATED_VARIABLES,
6590      TRAINABLE_RESOURCE_VARIABLES,
6591  ]
6592
6593  # Key for streaming model ports.
6594  # NOTE(yuanbyu): internal and experimental.
6595  _STREAMING_MODEL_PORTS = "streaming_model_ports"
6596
6597  @decorator_utils.classproperty
6598  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6599  def VARIABLES(cls):  # pylint: disable=no-self-argument
6600    return cls.GLOBAL_VARIABLES
6601
6602
6603def dismantle_graph(graph):
6604  """Cleans up reference cycles from a `Graph`.
6605
6606  Helpful for making sure the garbage collector doesn't need to run after a
6607  temporary `Graph` is no longer needed.
6608
6609  Args:
6610    graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6611      after this function runs.
6612  """
6613  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
6614
6615  # Now clean up Operation<->Graph reference cycles by clearing all of the
6616  # attributes for the Graph and its ops.
6617  graph_operations = graph.get_operations()
6618  for op in graph_operations:
6619    op.__dict__ = {}
6620  graph.__dict__ = {}
6621
6622
6623@tf_export(v1=["add_to_collection"])
6624def add_to_collection(name, value):
6625  """Wrapper for `Graph.add_to_collection()` using the default graph.
6626
6627  See `tf.Graph.add_to_collection`
6628  for more details.
6629
6630  Args:
6631    name: The key for the collection. For example, the `GraphKeys` class
6632      contains many standard names for collections.
6633    value: The value to add to the collection.
6634
6635  @compatibility(eager)
6636  Collections are only supported in eager when variables are created inside
6637  an EagerVariableStore (e.g. as part of a layer or template).
6638  @end_compatibility
6639  """
6640  get_default_graph().add_to_collection(name, value)
6641
6642
6643@tf_export(v1=["add_to_collections"])
6644def add_to_collections(names, value):
6645  """Wrapper for `Graph.add_to_collections()` using the default graph.
6646
6647  See `tf.Graph.add_to_collections`
6648  for more details.
6649
6650  Args:
6651    names: The key for the collections. The `GraphKeys` class contains many
6652      standard names for collections.
6653    value: The value to add to the collections.
6654
6655  @compatibility(eager)
6656  Collections are only supported in eager when variables are created inside
6657  an EagerVariableStore (e.g. as part of a layer or template).
6658  @end_compatibility
6659  """
6660  get_default_graph().add_to_collections(names, value)
6661
6662
6663@tf_export(v1=["get_collection_ref"])
6664def get_collection_ref(key):
6665  """Wrapper for `Graph.get_collection_ref()` using the default graph.
6666
6667  See `tf.Graph.get_collection_ref`
6668  for more details.
6669
6670  Args:
6671    key: The key for the collection. For example, the `GraphKeys` class contains
6672      many standard names for collections.
6673
6674  Returns:
6675    The list of values in the collection with the given `name`, or an empty
6676    list if no value has been added to that collection.  Note that this returns
6677    the collection list itself, which can be modified in place to change the
6678    collection.
6679
6680  @compatibility(eager)
6681  Collections are not supported when eager execution is enabled.
6682  @end_compatibility
6683  """
6684  return get_default_graph().get_collection_ref(key)
6685
6686
6687@tf_export(v1=["get_collection"])
6688def get_collection(key, scope=None):
6689  """Wrapper for `Graph.get_collection()` using the default graph.
6690
6691  See `tf.Graph.get_collection`
6692  for more details.
6693
6694  Args:
6695    key: The key for the collection. For example, the `GraphKeys` class contains
6696      many standard names for collections.
6697    scope: (Optional.) If supplied, the resulting list is filtered to include
6698      only items whose `name` attribute matches using `re.match`. Items without
6699      a `name` attribute are never returned if a scope is supplied and the
6700      choice or `re.match` means that a `scope` without special tokens filters
6701      by prefix.
6702
6703  Returns:
6704    The list of values in the collection with the given `name`, or
6705    an empty list if no value has been added to that collection. The
6706    list contains the values in the order under which they were
6707    collected.
6708
6709  @compatibility(eager)
6710  Collections are not supported when eager execution is enabled.
6711  @end_compatibility
6712  """
6713  return get_default_graph().get_collection(key, scope)
6714
6715
6716def get_all_collection_keys():
6717  """Returns a list of collections used in the default graph."""
6718  return get_default_graph().get_all_collection_keys()
6719
6720
6721def name_scope(name, default_name=None, values=None, skip_on_eager=True):
6722  """Internal-only entry point for `name_scope*`.
6723
6724  Internal ops do not use the public API and instead rely on
6725  `ops.name_scope` regardless of the execution mode. This function
6726  dispatches to the correct `name_scope*` implementation based on
6727  the arguments provided and the current mode. Specifically,
6728
6729  * if `values` contains a graph tensor `Graph.name_scope` is used;
6730  * `name_scope_v1` is used in graph mode;
6731  * `name_scope_v2` -- in eager mode.
6732
6733  Args:
6734    name: The name argument that is passed to the op function.
6735    default_name: The default name to use if the `name` argument is `None`.
6736    values: The list of `Tensor` arguments that are passed to the op function.
6737    skip_on_eager: Indicates to return NullContextmanager if executing eagerly.
6738      By default this is True since naming tensors and operations in eager mode
6739      have little use and cause unnecessary performance overhead. However, it is
6740      important to preserve variable names since they are often useful for
6741      debugging and saved models.
6742
6743  Returns:
6744    `name_scope*` context manager.
6745  """
6746  if not context.executing_eagerly():
6747    return internal_name_scope_v1(name, default_name, values)
6748
6749  if skip_on_eager:
6750    return NullContextmanager()
6751
6752  name = default_name if name is None else name
6753  if values:
6754    # The presence of a graph tensor in `values` overrides the context.
6755    # TODO(slebedev): this is Keras-specific and should be removed.
6756    # pylint: disable=unidiomatic-typecheck
6757    graph_value = next((value for value in values if type(value) == Tensor),
6758                       None)
6759    # pylint: enable=unidiomatic-typecheck
6760    if graph_value is not None:
6761      return graph_value.graph.name_scope(name)
6762
6763  return name_scope_v2(name or "")
6764
6765
6766class internal_name_scope_v1(object):  # pylint: disable=invalid-name
6767  """Graph-only version of `name_scope_v1`."""
6768
6769  @property
6770  def name(self):
6771    return self._name
6772
6773  def __init__(self, name, default_name=None, values=None):
6774    """Initialize the context manager.
6775
6776    Args:
6777      name: The name argument that is passed to the op function.
6778      default_name: The default name to use if the `name` argument is `None`.
6779      values: The list of `Tensor` arguments that are passed to the op function.
6780
6781    Raises:
6782      TypeError: if `default_name` is passed in but not a string.
6783    """
6784    if not (default_name is None or isinstance(default_name, str)):
6785      raise TypeError(
6786          "`default_name` type (%s) is not a string type. You likely meant to "
6787          "pass this into the `values` kwarg." % type(default_name))
6788    self._name = default_name if name is None else name
6789    self._default_name = default_name
6790    self._values = values
6791
6792  def __enter__(self):
6793    """Start the scope block.
6794
6795    Returns:
6796      The scope name.
6797
6798    Raises:
6799      ValueError: if neither `name` nor `default_name` is provided
6800        but `values` are.
6801    """
6802    if self._name is None and self._values is not None:
6803      # We only raise an error if values is not None (provided) because
6804      # currently tf.name_scope(None) (values=None then) is sometimes used as
6805      # an idiom to reset to top scope.
6806      raise ValueError(
6807          "At least one of name (%s) and default_name (%s) must be provided."
6808          % (self._name, self._default_name))
6809
6810    g = get_default_graph()
6811    if self._values and not g.building_function:
6812      # Specialize based on the knowledge that `_get_graph_from_inputs()`
6813      # ignores `inputs` when building a function.
6814      g_from_inputs = _get_graph_from_inputs(self._values)
6815      if g_from_inputs is not g:
6816        g = g_from_inputs
6817        self._g_manager = g.as_default()
6818        self._g_manager.__enter__()
6819      else:
6820        self._g_manager = None
6821    else:
6822      self._g_manager = None
6823
6824    try:
6825      self._name_scope = g.name_scope(self._name)
6826      return self._name_scope.__enter__()
6827    except:
6828      if self._g_manager is not None:
6829        self._g_manager.__exit__(*sys.exc_info())
6830      raise
6831
6832  def __exit__(self, *exc_info):
6833    self._name_scope.__exit__(*exc_info)
6834    if self._g_manager is not None:
6835      self._g_manager.__exit__(*exc_info)
6836
6837
6838# Named like a function for backwards compatibility with the
6839# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6840# some object creation overhead.
6841@tf_export(v1=["name_scope"])
6842class name_scope_v1(object):  # pylint: disable=invalid-name
6843  """A context manager for use when defining a Python op.
6844
6845  This context manager validates that the given `values` are from the
6846  same graph, makes that graph the default graph, and pushes a
6847  name scope in that graph (see
6848  `tf.Graph.name_scope`
6849  for more details on that).
6850
6851  For example, to define a new Python op called `my_op`:
6852
6853  ```python
6854  def my_op(a, b, c, name=None):
6855    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6856      a = tf.convert_to_tensor(a, name="a")
6857      b = tf.convert_to_tensor(b, name="b")
6858      c = tf.convert_to_tensor(c, name="c")
6859      # Define some computation that uses `a`, `b`, and `c`.
6860      return foo_op(..., name=scope)
6861  ```
6862  """
6863
6864  __slots__ = ["_name", "_name_scope"]
6865
6866  @property
6867  def name(self):
6868    return self._name
6869
6870  def __init__(self, name, default_name=None, values=None):
6871    """Initialize the context manager.
6872
6873    Args:
6874      name: The name argument that is passed to the op function.
6875      default_name: The default name to use if the `name` argument is `None`.
6876      values: The list of `Tensor` arguments that are passed to the op function.
6877
6878    Raises:
6879      TypeError: if `default_name` is passed in but not a string.
6880    """
6881    self._name_scope = name_scope(
6882        name, default_name, values, skip_on_eager=False)
6883    self._name = default_name if name is None else name
6884
6885  def __enter__(self):
6886    return self._name_scope.__enter__()
6887
6888  def __exit__(self, *exc_info):
6889    return self._name_scope.__exit__(*exc_info)
6890
6891
6892@tf_export("get_current_name_scope", v1=[])
6893def get_current_name_scope():
6894  """Returns current full name scope specified by `tf.name_scope(...)`s.
6895
6896  For example,
6897  ```python
6898  with tf.name_scope("outer"):
6899    tf.get_current_name_scope()  # "outer"
6900
6901    with tf.name_scope("inner"):
6902      tf.get_current_name_scope()  # "outer/inner"
6903  ```
6904
6905  In other words, `tf.get_current_name_scope()` returns the op name prefix that
6906  will be prepended to, if an op is created at that place.
6907
6908  Note that `@tf.function` resets the name scope stack as shown below.
6909
6910  ```
6911  with tf.name_scope("outer"):
6912
6913    @tf.function
6914    def foo(x):
6915      with tf.name_scope("inner"):
6916        return tf.add(x * x)  # Op name is "inner/Add", not "outer/inner/Add"
6917  ```
6918  """
6919
6920  ctx = context.context()
6921  if ctx.executing_eagerly():
6922    return ctx.scope_name.rstrip("/")
6923  else:
6924    return get_default_graph().get_name_scope()
6925
6926
6927@tf_export("name_scope", v1=[])
6928class name_scope_v2(object):
6929  """A context manager for use when defining a Python op.
6930
6931  This context manager pushes a name scope, which will make the name of all
6932  operations added within it have a prefix.
6933
6934  For example, to define a new Python op called `my_op`:
6935
6936  ```python
6937  def my_op(a, b, c, name=None):
6938    with tf.name_scope("MyOp") as scope:
6939      a = tf.convert_to_tensor(a, name="a")
6940      b = tf.convert_to_tensor(b, name="b")
6941      c = tf.convert_to_tensor(c, name="c")
6942      # Define some computation that uses `a`, `b`, and `c`.
6943      return foo_op(..., name=scope)
6944  ```
6945
6946  When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6947  and `MyOp/c`.
6948
6949  Inside a `tf.function`, if the scope name already exists, the name will be
6950  made unique by appending `_n`. For example, calling `my_op` the second time
6951  will generate `MyOp_1/a`, etc.
6952  """
6953
6954  __slots__ = ["_name", "_exit_fns"]
6955
6956  def __init__(self, name):
6957    """Initialize the context manager.
6958
6959    Args:
6960      name: The prefix to use on all names created within the name scope.
6961
6962    Raises:
6963      ValueError: If name is not a string.
6964    """
6965    if not isinstance(name, str):
6966      raise ValueError("name for name_scope must be a string.")
6967    self._name = name
6968    self._exit_fns = []
6969
6970  @property
6971  def name(self):
6972    return self._name
6973
6974  def __enter__(self):
6975    """Start the scope block.
6976
6977    Returns:
6978      The scope name.
6979    """
6980    ctx = context.context()
6981    if ctx.executing_eagerly():
6982      # Names are not auto-incremented in eager mode.
6983      # A trailing slash breaks out of nested name scopes, indicating a
6984      # fully specified scope name, for compatibility with Graph.name_scope.
6985      # This also prevents auto-incrementing.
6986      old_name = ctx.scope_name
6987      name = self._name
6988      if not name:
6989        scope_name = ""
6990      elif name[-1] == "/":
6991        scope_name = name
6992      elif old_name:
6993        scope_name = old_name + name + "/"
6994      else:
6995        scope_name = name + "/"
6996      ctx.scope_name = scope_name
6997
6998      def _restore_name_scope(*_):
6999        ctx.scope_name = old_name
7000
7001      self._exit_fns.append(_restore_name_scope)
7002    else:
7003      scope = get_default_graph().name_scope(self._name)
7004      scope_name = scope.__enter__()
7005      self._exit_fns.append(scope.__exit__)
7006    return scope_name
7007
7008  def __exit__(self, type_arg, value_arg, traceback_arg):
7009    self._exit_fns.pop()(type_arg, value_arg, traceback_arg)
7010    return False  # False values do not suppress exceptions
7011
7012  def __getstate__(self):
7013    return self._name, self._exit_fns
7014
7015  def __setstate__(self, state):
7016    self._name = state[0]
7017    self._exit_fns = state[1]
7018
7019
7020def strip_name_scope(name, export_scope):
7021  """Removes name scope from a name.
7022
7023  Args:
7024    name: A `string` name.
7025    export_scope: Optional `string`. Name scope to remove.
7026
7027  Returns:
7028    Name with name scope removed, or the original name if export_scope
7029    is None.
7030  """
7031  if export_scope:
7032    if export_scope[-1] == "/":
7033      export_scope = export_scope[:-1]
7034
7035    try:
7036      # Strips export_scope/, export_scope///,
7037      # ^export_scope/, loc:@export_scope/.
7038      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
7039      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
7040    except TypeError as e:
7041      # If the name is not of a type we can process, simply return it.
7042      logging.warning(e)
7043      return name
7044  else:
7045    return name
7046
7047
7048def prepend_name_scope(name, import_scope):
7049  """Prepends name scope to a name.
7050
7051  Args:
7052    name: A `string` name.
7053    import_scope: Optional `string`. Name scope to add.
7054
7055  Returns:
7056    Name with name scope added, or the original name if import_scope
7057    is None.
7058  """
7059  if import_scope:
7060    if import_scope[-1] == "/":
7061      import_scope = import_scope[:-1]
7062
7063    try:
7064      str_to_replace = r"([\^]|loc:@|^)(.*)"
7065      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
7066                    compat.as_str(name))
7067    except TypeError as e:
7068      # If the name is not of a type we can process, simply return it.
7069      logging.warning(e)
7070      return name
7071  else:
7072    return name
7073
7074
7075# pylint: disable=g-doc-return-or-yield
7076# pylint: disable=not-context-manager
7077@tf_export(v1=["op_scope"])
7078@tf_contextlib.contextmanager
7079def op_scope(values, name, default_name=None):
7080  """DEPRECATED. Same as name_scope above, just different argument order."""
7081  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
7082               " use tf.name_scope(name, default_name, values)")
7083  with name_scope(name, default_name=default_name, values=values) as scope:
7084    yield scope
7085
7086
7087_proto_function_registry = registry.Registry("proto functions")
7088
7089
7090def register_proto_function(collection_name,
7091                            proto_type=None,
7092                            to_proto=None,
7093                            from_proto=None):
7094  """Registers `to_proto` and `from_proto` functions for collection_name.
7095
7096  `to_proto` function converts a Python object to the corresponding protocol
7097  buffer, and returns the protocol buffer.
7098
7099  `from_proto` function converts protocol buffer into a Python object, and
7100  returns the object..
7101
7102  Args:
7103    collection_name: Name of the collection.
7104    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
7105      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
7106    to_proto: Function that implements Python object to protobuf conversion.
7107    from_proto: Function that implements protobuf to Python object conversion.
7108  """
7109  if to_proto and not callable(to_proto):
7110    raise TypeError("to_proto must be callable.")
7111  if from_proto and not callable(from_proto):
7112    raise TypeError("from_proto must be callable.")
7113
7114  _proto_function_registry.register((proto_type, to_proto, from_proto),
7115                                    collection_name)
7116
7117
7118def get_collection_proto_type(collection_name):
7119  """Returns the proto_type for collection_name."""
7120  try:
7121    return _proto_function_registry.lookup(collection_name)[0]
7122  except LookupError:
7123    return None
7124
7125
7126def get_to_proto_function(collection_name):
7127  """Returns the to_proto function for collection_name."""
7128  try:
7129    return _proto_function_registry.lookup(collection_name)[1]
7130  except LookupError:
7131    return None
7132
7133
7134def get_from_proto_function(collection_name):
7135  """Returns the from_proto function for collection_name."""
7136  try:
7137    return _proto_function_registry.lookup(collection_name)[2]
7138  except LookupError:
7139    return None
7140
7141
7142def _op_to_colocate_with(v, graph):
7143  """Operation object corresponding to v to use for colocation constraints."""
7144  if v is None:
7145    return None, None
7146  if isinstance(v, Operation):
7147    return v, None
7148
7149  # We always want to colocate with the reference op.
7150  # When 'v' is a ResourceVariable, the reference op is the handle creating op.
7151  #
7152  # What this should be is:
7153  # if isinstance(v, ResourceVariable):
7154  #   return v.handle.op, v
7155  # However, that would require a circular import dependency.
7156  # As of October 2018, there were attempts underway to remove
7157  # colocation constraints altogether. Assuming that will
7158  # happen soon, perhaps this hack to work around the circular
7159  # import dependency is acceptable.
7160  if hasattr(v, "handle") and isinstance(v.handle, Tensor):
7161    device_only_candidate = lambda: None
7162    device_only_candidate.device = v.device
7163    device_only_candidate.name = v.name
7164    if graph.building_function:
7165      return graph.capture(v.handle).op, device_only_candidate
7166    else:
7167      return v.handle.op, device_only_candidate
7168  return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None
7169
7170
7171def _is_keras_symbolic_tensor(x):
7172  return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
7173
7174
7175# These symbols were originally defined in this module; import them for
7176# backwards compatibility until all references have been updated to access
7177# them from the indexed_slices.py module.
7178IndexedSlices = indexed_slices.IndexedSlices
7179IndexedSlicesValue = indexed_slices.IndexedSlicesValue
7180convert_to_tensor_or_indexed_slices = \
7181    indexed_slices.convert_to_tensor_or_indexed_slices
7182convert_n_to_tensor_or_indexed_slices = \
7183    indexed_slices.convert_n_to_tensor_or_indexed_slices
7184internal_convert_to_tensor_or_indexed_slices = \
7185    indexed_slices.internal_convert_to_tensor_or_indexed_slices
7186internal_convert_n_to_tensor_or_indexed_slices = \
7187    indexed_slices.internal_convert_n_to_tensor_or_indexed_slices
7188register_tensor_conversion_function = \
7189    tensor_conversion_registry.register_tensor_conversion_function
7190
7191
7192# Helper functions for op wrapper modules generated by `python_op_gen`.
7193
7194
7195def to_raw_op(f):
7196  """Make a given op wrapper function `f` raw.
7197
7198  Raw op wrappers can only be called with keyword arguments.
7199
7200  Args:
7201    f: An op wrapper function to make raw.
7202
7203  Returns:
7204    Raw `f`.
7205  """
7206  # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail
7207  # due to double-registration.
7208  f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__,
7209                         f.__closure__)
7210  return kwarg_only(f)
7211
7212
7213def raise_from_not_ok_status(e, name):
7214  e.message += (" name: " + name if name is not None else "")
7215  raise core._status_to_exception(e) from None  # pylint: disable=protected-access
7216
7217
7218def add_exit_callback_to_default_func_graph(fn):
7219  """Add a callback to run when the default function graph goes out of scope.
7220
7221  Usage:
7222
7223  ```python
7224  @tf.function
7225  def fn(x, v):
7226    expensive = expensive_object(v)
7227    add_exit_callback_to_default_func_graph(lambda: expensive.release())
7228    return g(x, expensive)
7229
7230  fn(x=tf.constant(...), v=...)
7231  # `expensive` has been released.
7232  ```
7233
7234  Args:
7235    fn: A callable that takes no arguments and whose output is ignored.
7236      To be executed when exiting func graph scope.
7237
7238  Raises:
7239    RuntimeError: If executed when the current default graph is not a FuncGraph,
7240      or not currently executing in function creation mode (e.g., if inside
7241      an init_scope).
7242  """
7243  default_graph = get_default_graph()
7244  if not default_graph._building_function:  # pylint: disable=protected-access
7245    raise RuntimeError(
7246        "Cannot add scope exit callbacks when not building a function.  "
7247        "Default graph: {}".format(default_graph))
7248  default_graph._add_scope_exit_callback(fn)  # pylint: disable=protected-access
7249
7250
7251def _reconstruct_sequence_inputs(op_def, inputs, attrs):
7252  """Regroups a flat list of input tensors into scalar and sequence inputs.
7253
7254  Args:
7255    op_def: The `op_def_pb2.OpDef` (for knowing the input types)
7256    inputs: a list of input `Tensor`s to the op.
7257    attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
7258      how long each sequence is)
7259
7260  Returns:
7261    A list of `Tensor`s (corresponding to scalar inputs) and lists of
7262    `Tensor`s (corresponding to sequence inputs).
7263  """
7264  grouped_inputs = []
7265  i = 0
7266  for input_arg in op_def.input_arg:
7267    if input_arg.number_attr:
7268      input_len = attrs[input_arg.number_attr].i
7269      is_sequence = True
7270    elif input_arg.type_list_attr:
7271      input_len = len(attrs[input_arg.type_list_attr].list.type)
7272      is_sequence = True
7273    else:
7274      input_len = 1
7275      is_sequence = False
7276
7277    if is_sequence:
7278      grouped_inputs.append(inputs[i:i + input_len])
7279    else:
7280      grouped_inputs.append(inputs[i])
7281    i += input_len
7282
7283  assert i == len(inputs)
7284  return grouped_inputs
7285
7286
7287_numpy_style_type_promotion = False
7288
7289
7290def enable_numpy_style_type_promotion():
7291  """If called, follows NumPy's rules for type promotion.
7292
7293  Used for enabling NumPy behavior on methods for TF NumPy.
7294  """
7295  global _numpy_style_type_promotion
7296  _numpy_style_type_promotion = True
7297
7298
7299_numpy_style_slicing = False
7300
7301
7302def enable_numpy_style_slicing():
7303  """If called, follows NumPy's rules for slicing Tensors.
7304
7305  Used for enabling NumPy behavior on slicing for TF NumPy.
7306  """
7307  global _numpy_style_slicing
7308  _numpy_style_slicing = True
7309
7310
7311class _TensorIterator(object):
7312  """Iterates over the leading dim of a Tensor. Performs no error checks."""
7313
7314  __slots__ = ["_tensor", "_index", "_limit"]
7315
7316  def __init__(self, tensor, dim0):
7317    self._tensor = tensor
7318    self._index = 0
7319    self._limit = dim0
7320
7321  def __iter__(self):
7322    return self
7323
7324  def __next__(self):
7325    if self._index == self._limit:
7326      raise StopIteration
7327    result = self._tensor[self._index]
7328    self._index += 1
7329    return result
7330
7331  next = __next__  # python2.x compatibility.
7332
7333
7334def set_int_list_attr(op, attr_name, ints):
7335  """TF internal method used to set a list(int) attribute in the node_def."""
7336  ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
7337  op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list))  # pylint:disable=protected-access
7338
7339
7340def _get_enclosing_context(graph):
7341  # pylint: disable=protected-access
7342  if graph is None:
7343    return None
7344
7345  if graph._control_flow_context is not None:
7346    return graph._control_flow_context
7347
7348  if graph.building_function and hasattr(graph, "outer_graph"):
7349    return _get_enclosing_context(graph.outer_graph)
7350
7351
7352def get_resource_handle_data(graph_op):
7353  assert type(graph_op) == Tensor  # pylint: disable=unidiomatic-typecheck
7354
7355  with graph_op.graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
7356    handle_data = pywrap_tf_session.GetHandleShapeAndType(
7357        c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
7358
7359  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
7360      compat.as_bytes(handle_data))
7361
7362
7363def _copy_handle_data_to_arg_def(tensor, arg_def):
7364  handle_data = get_resource_handle_data(tensor)
7365  if handle_data.shape_and_type:
7366    shape_and_type = handle_data.shape_and_type[0]
7367    proto = arg_def.handle_data.add()
7368    proto.dtype = shape_and_type.dtype
7369    proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
7370