• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23import sys
24import threading
25import types
26
27import numpy as np
28import six
29from six.moves import map  # pylint: disable=redefined-builtin
30from six.moves import xrange  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.framework import function_pb2
34from tensorflow.core.framework import graph_pb2
35from tensorflow.core.framework import node_def_pb2
36from tensorflow.core.framework import op_def_pb2
37from tensorflow.core.framework import versions_pb2
38from tensorflow.core.protobuf import config_pb2
39# pywrap_tensorflow must be imported first to avoid profobuf issues.
40# (b/143110113)
41# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
42from tensorflow.python import pywrap_tensorflow
43from tensorflow.python import pywrap_tfe
44# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
45from tensorflow.python import tf2
46from tensorflow.python.client import pywrap_tf_session
47from tensorflow.python.eager import context
48from tensorflow.python.eager import core
49from tensorflow.python.eager import monitoring
50from tensorflow.python.eager import tape
51from tensorflow.python.framework import c_api_util
52from tensorflow.python.framework import composite_tensor
53from tensorflow.python.framework import device as pydev
54from tensorflow.python.framework import dtypes
55from tensorflow.python.framework import errors
56from tensorflow.python.framework import indexed_slices
57from tensorflow.python.framework import registry
58from tensorflow.python.framework import tensor_conversion_registry
59from tensorflow.python.framework import tensor_like
60from tensorflow.python.framework import tensor_shape
61from tensorflow.python.framework import traceable_stack
62from tensorflow.python.framework import versions
63from tensorflow.python.ops import control_flow_util
64from tensorflow.python.platform import app
65from tensorflow.python.platform import tf_logging as logging
66from tensorflow.python.util import compat
67from tensorflow.python.util import decorator_utils
68from tensorflow.python.util import deprecation
69from tensorflow.python.util import function_utils
70from tensorflow.python.util import lock_util
71from tensorflow.python.util import memory
72from tensorflow.python.util import object_identity
73from tensorflow.python.util import tf_contextlib
74from tensorflow.python.util import tf_stack
75from tensorflow.python.util.compat import collections_abc
76from tensorflow.python.util.deprecation import deprecated_args
77from tensorflow.python.util.lazy_loader import LazyLoader
78from tensorflow.python.util.tf_export import kwarg_only
79from tensorflow.python.util.tf_export import tf_export
80from tensorflow.tools.docs.doc_controls import do_not_generate_docs
81
82ag_ctx = LazyLoader(
83    "ag_ctx", globals(),
84    "tensorflow.python.autograph.core.ag_ctx")
85
86
87# Temporary global switches determining if we should enable the work-in-progress
88# calls to the C API. These will be removed once all functionality is supported.
89_USE_C_API = True
90_USE_C_SHAPES = True
91
92_api_usage_gauge = monitoring.BoolGauge(
93    "/tensorflow/api/ops_eager_execution",
94    "Whether ops.enable_eager_execution() is called.")
95
96
97# pylint: disable=protected-access
98_TensorLike = tensor_like._TensorLike
99_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
100# pylint: enable=protected-access
101
102
103def tensor_id(tensor):
104  """Returns a unique identifier for this Tensor."""
105  return tensor._id  # pylint: disable=protected-access
106
107
108class _UserDeviceSpec(object):
109  """Store user-specified device and provide computation of merged device."""
110
111  def __init__(self, device_name_or_function):
112    self._device_name_or_function = device_name_or_function
113    self.display_name = str(self._device_name_or_function)
114    self.function = device_name_or_function
115    self.raw_string = None
116
117    if isinstance(device_name_or_function, pydev.MergeDevice):
118      self.is_null_merge = device_name_or_function.is_null_merge
119
120    elif callable(device_name_or_function):
121      self.is_null_merge = False
122      dev_func = self._device_name_or_function
123      func_name = function_utils.get_func_name(dev_func)
124      func_code = function_utils.get_func_code(dev_func)
125      if func_code:
126        fname = func_code.co_filename
127        lineno = func_code.co_firstlineno
128      else:
129        fname = "unknown"
130        lineno = -1
131      self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
132
133    elif device_name_or_function is None:
134      # NOTE(taylorrobie): This MUST be False. None signals a break in the
135      #   device stack, so `is_null_merge` must be False for such a case to
136      #   allow callers to safely skip over null merges without missing a None.
137      self.is_null_merge = False
138
139    else:
140      self.raw_string = device_name_or_function
141      self.function = pydev.merge_device(device_name_or_function)
142      self.is_null_merge = self.function.is_null_merge
143
144    # We perform this check in __init__ because it is of non-trivial cost,
145    # and self.string_merge is typically called many times.
146    self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
147
148  def string_merge(self, node_def):
149    if self.fast_string_merge:
150      return self.function.shortcut_string_merge(node_def)
151
152    return compat.as_str(_device_string(self.function(node_def)))
153
154
155class NullContextmanager(object):
156
157  def __init__(self, *args, **kwargs):
158    pass
159
160  def __enter__(self):
161    pass
162
163  def __exit__(self, type_arg, value_arg, traceback_arg):
164    return False  # False values do not suppress exceptions
165
166
167def _override_helper(clazz_object, operator, func):
168  """Overrides (string) operator on Tensors to call func.
169
170  Args:
171    clazz_object: the class to override for; either Tensor or SparseTensor.
172    operator: the string name of the operator to override.
173    func: the function that replaces the overridden operator.
174
175  Raises:
176    ValueError: If operator has already been overwritten,
177      or if operator is not allowed to be overwritten.
178  """
179  existing = getattr(clazz_object, operator, None)
180  if existing is not None:
181    # Check to see if this is a default method-wrapper or slot wrapper which
182    # will be true for the comparison operators.
183    if not isinstance(existing, type(object.__lt__)):
184      raise ValueError("operator %s cannot be overwritten again on class %s." %
185                       (operator, clazz_object))
186  if operator not in Tensor.OVERLOADABLE_OPERATORS:
187    raise ValueError("Overriding %s is disallowed" % operator)
188  setattr(clazz_object, operator, func)
189
190
191def _as_graph_element(obj):
192  """Convert `obj` to a graph element if possible, otherwise return `None`.
193
194  Args:
195    obj: Object to convert.
196
197  Returns:
198    The result of `obj._as_graph_element()` if that method is available;
199        otherwise `None`.
200  """
201  conv_fn = getattr(obj, "_as_graph_element", None)
202  if conv_fn and callable(conv_fn):
203    return conv_fn()
204  return None
205
206
207_TENSOR_LIKE_TYPES = tuple()
208
209
210def is_dense_tensor_like(t):
211  """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
212
213  See `register_dense_tensor_like_type()` for the current definition of a
214  "tensor-like type".
215
216  Args:
217    t: An object.
218
219  Returns:
220    True iff `t` is an instance of one of the registered "tensor-like" types.
221  """
222  return isinstance(t, _TENSOR_LIKE_TYPES)
223
224
225def register_dense_tensor_like_type(tensor_type):
226  """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
227
228  A "tensor-like type" can represent a single dense tensor, and implements
229  the `name`, `dtype` and `shape` properties.
230
231  Args:
232    tensor_type: A type implementing the tensor interface.
233
234  Raises:
235    TypeError: If `tensor_type` does not implement the tensor interface.
236  """
237  if not (hasattr(tensor_type, "name") and
238          isinstance(tensor_type.name, property)):
239    raise TypeError("Type %s does not define a `name` property" %
240                    tensor_type.__name__)
241  if not (hasattr(tensor_type, "dtype") and
242          isinstance(tensor_type.dtype, property)):
243    raise TypeError("Type %s does not define a `dtype` property" %
244                    tensor_type.__name__)
245  if not (hasattr(tensor_type, "shape") and
246          isinstance(tensor_type.shape, property)):
247    raise TypeError("Type %s does not define a `shape` property" %
248                    tensor_type.__name__)
249  # We expect this list to be small, so choose quadratic complexity
250  # for registration, so that we have a tuple that can be used for
251  # more efficient `isinstance` checks later.
252  global _TENSOR_LIKE_TYPES
253  _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
254
255
256def uid():
257  """A unique (within this program execution) integer."""
258  return pywrap_tfe.TFE_Py_UID()
259
260
261def numpy_text(tensor, is_repr=False):
262  """Human readable representation of a tensor's numpy value."""
263  if tensor.dtype.is_numpy_compatible:
264    # pylint: disable=protected-access
265    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
266    # pylint: enable=protected-access
267  else:
268    text = "<unprintable>"
269  if "\n" in text:
270    text = "\n" + text
271  return text
272
273@tf_export(v1=["enable_tensor_equality"])
274def enable_tensor_equality():
275  """Compare Tensors with element-wise comparison and thus be unhashable.
276
277  Comparing tensors with element-wise allows comparisons such as
278  tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are
279  unhashable. Thus tensors can no longer be directly used in sets or as a key in
280  a dictionary.
281  """
282  Tensor._USE_EQUALITY = True  # pylint: disable=protected-access
283
284@tf_export(v1=["disable_tensor_equality"])
285def disable_tensor_equality():
286  """Compare Tensors by their id and be hashable.
287
288  This is a legacy behaviour of TensorFlow and is highly discouraged.
289  """
290  Tensor._USE_EQUALITY = False  # pylint: disable=protected-access
291
292
293@tf_export("Tensor")
294class Tensor(_TensorLike):
295  """A tensor represents a rectangular array of data.
296
297  When writing a TensorFlow program, the main object you manipulate and pass
298  around is the `tf.Tensor`. A `tf.Tensor` object represents a rectangular array
299  of arbitrary dimension, filled with data of a specific data type.
300
301  A `tf.Tensor` has the following properties:
302
303  * a data type (float32, int32, or string, for example)
304  * a shape
305
306  Each element in the Tensor has the same data type, and the data type is always
307  known.
308
309  In eager execution, which is the default mode in TensorFlow, results are
310  calculated immediately.
311
312  >>> # Compute some values using a Tensor
313  >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
314  >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
315  >>> e = tf.matmul(c, d)
316  >>> print(e)
317  tf.Tensor(
318  [[1. 3.]
319   [3. 7.]], shape=(2, 2), dtype=float32)
320
321
322  Note that during eager execution, you may discover your `Tensors` are actually
323  of type `EagerTensor`.  This is an internal detail, but it does give you
324  access to a useful function, `numpy`:
325
326  >>> type(e)
327  <class '...ops.EagerTensor'>
328  >>> print(e.numpy())
329    [[1. 3.]
330     [3. 7.]]
331
332  TensorFlow can define computations without immediately executing them, most
333  commonly inside `tf.function`s, as well as in (legacy) Graph mode. In those
334  cases, the shape (that is, the rank of the Tensor and the size of
335  each dimension) might be only partially known.
336
337  Most operations produce tensors of fully-known shapes if the shapes of their
338  inputs are also fully known, but in some cases it's only possible to find the
339  shape of a tensor at execution time.
340
341  There are specialized tensors; for these, see `tf.Variable`, `tf.constant`,
342  `tf.placeholder`, `tf.SparseTensor`, and `tf.RaggedTensor`.
343
344  For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor`).
345  """
346
347  # List of Python operators that we allow to override.
348  OVERLOADABLE_OPERATORS = {
349      # Binary.
350      "__add__",
351      "__radd__",
352      "__sub__",
353      "__rsub__",
354      "__mul__",
355      "__rmul__",
356      "__div__",
357      "__rdiv__",
358      "__truediv__",
359      "__rtruediv__",
360      "__floordiv__",
361      "__rfloordiv__",
362      "__mod__",
363      "__rmod__",
364      "__lt__",
365      "__le__",
366      "__gt__",
367      "__ge__",
368      "__ne__",
369      "__eq__",
370      "__and__",
371      "__rand__",
372      "__or__",
373      "__ror__",
374      "__xor__",
375      "__rxor__",
376      "__getitem__",
377      "__pow__",
378      "__rpow__",
379      # Unary.
380      "__invert__",
381      "__neg__",
382      "__abs__",
383      "__matmul__",
384      "__rmatmul__"
385  }
386
387  # Whether to allow hashing or numpy-style equality
388  _USE_EQUALITY = tf2.enabled()
389
390  def __init__(self, op, value_index, dtype):
391    """Creates a new `Tensor`.
392
393    Args:
394      op: An `Operation`. `Operation` that computes this tensor.
395      value_index: An `int`. Index of the operation's endpoint that produces
396        this tensor.
397      dtype: A `DType`. Type of elements stored in this tensor.
398
399    Raises:
400      TypeError: If the op is not an `Operation`.
401    """
402    if not isinstance(op, Operation):
403      raise TypeError("op needs to be an Operation: %s" % op)
404    self._op = op
405    self._value_index = value_index
406    self._dtype = dtypes.as_dtype(dtype)
407    # This will be set by self._as_tf_output().
408    self._tf_output = None
409    # This will be set by self.shape().
410    self._shape_val = None
411    # List of operations that use this Tensor as input.  We maintain this list
412    # to easily navigate a computation graph.
413    self._consumers = []
414    self._id = uid()
415    self._name = None
416
417  @staticmethod
418  def _create_with_tf_output(op, value_index, dtype, tf_output):
419    ret = Tensor(op, value_index, dtype)
420    ret._tf_output = tf_output
421    return ret
422
423  @property
424  def op(self):
425    """The `Operation` that produces this tensor as an output."""
426    return self._op
427
428  @property
429  def dtype(self):
430    """The `DType` of elements in this tensor."""
431    return self._dtype
432
433  @property
434  def graph(self):
435    """The `Graph` that contains this tensor."""
436    return self._op.graph
437
438  @property
439  def name(self):
440    """The string name of this tensor."""
441    if self._name is None:
442      if not self._op.name:
443        raise ValueError("Operation was not named: %s" % self._op)
444      self._name = "%s:%d" % (self._op.name, self._value_index)
445    return self._name
446
447  @property
448  def device(self):
449    """The name of the device on which this tensor will be produced, or None."""
450    return self._op.device
451
452  @property
453  def shape(self):
454    """Returns the `TensorShape` that represents the shape of this tensor.
455
456    The shape is computed using shape inference functions that are
457    registered in the Op for each `Operation`.  See
458    `tf.TensorShape`
459    for more details of what a shape represents.
460
461    The inferred shape of a tensor is used to provide shape
462    information without having to execute the underlying kernel. This
463    can be used for debugging and providing early error messages. For
464    example:
465
466    ```python
467    >>> c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
468    >>> print(c.shape) # will be TensorShape([2, 3])
469    (2, 3)
470
471    >>> d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
472    >>> print(d.shape)
473    (4, 2)
474
475    # Raises a ValueError, because `c` and `d` do not have compatible
476    # inner dimensions.
477    >>> e = tf.matmul(c, d)
478    Traceback (most recent call last):
479        ...
480    tensorflow.python.framework.errors_impl.InvalidArgumentError: Matrix
481    size-incompatible: In[0]: [2,3], In[1]: [4,2] [Op:MatMul] name: MatMul/
482
483    # This works because we have compatible shapes.
484    >>> f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
485    >>> print(f.shape)
486    (3, 4)
487
488    ```
489
490    In some cases, the inferred shape may have unknown dimensions. If
491    the caller has additional information about the values of these
492    dimensions, `Tensor.set_shape()` can be used to augment the
493    inferred shape.
494
495    Returns:
496      A `tf.TensorShape` representing the shape of this tensor.
497
498    """
499    if self._shape_val is None:
500      self._shape_val = self._c_api_shape()
501    return self._shape_val
502
503  def _c_api_shape(self):
504    """Returns the TensorShape of this tensor according to the C API."""
505    c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
506    shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
507        c_graph, self._as_tf_output())
508    if unknown_shape:
509      return tensor_shape.unknown_shape()
510    else:
511      shape_vec = [None if d == -1 else d for d in shape_vec]
512      return tensor_shape.TensorShape(shape_vec)
513
514  @property
515  def _shape(self):
516    logging.warning("Tensor._shape is private, use Tensor.shape "
517                    "instead. Tensor._shape will eventually be removed.")
518    return self.shape
519
520  @_shape.setter
521  def _shape(self, value):
522    raise ValueError(
523        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
524
525  def _disallow_when_autograph_disabled(self, task):
526    raise errors.OperatorNotAllowedInGraphError(
527        "{} is not allowed: AutoGraph is disabled in this function."
528        " Try decorating it directly with @tf.function.".format(task))
529
530  def _disallow_when_autograph_enabled(self, task):
531    raise errors.OperatorNotAllowedInGraphError(
532        "{} is not allowed: AutoGraph did not convert this function. Try"
533        " decorating it directly with @tf.function.".format(task))
534
535  def _disallow_in_graph_mode(self, task):
536    raise errors.OperatorNotAllowedInGraphError(
537        "{} is not allowed in Graph execution. Use Eager execution or decorate"
538        " this function with @tf.function.".format(task))
539
540  def _disallow_bool_casting(self):
541    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
542      self._disallow_when_autograph_disabled(
543          "using a `tf.Tensor` as a Python `bool`")
544    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
545      self._disallow_when_autograph_enabled(
546          "using a `tf.Tensor` as a Python `bool`")
547    else:
548      # Default: V1-style Graph execution.
549      self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
550
551  def _disallow_iteration(self):
552    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
553      self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
554    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
555      self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
556    else:
557      # Default: V1-style Graph execution.
558      self._disallow_in_graph_mode("iterating over `tf.Tensor`")
559
560  def __iter__(self):
561    if not context.executing_eagerly():
562      self._disallow_iteration()
563
564    shape = self._shape_tuple()
565    if shape is None:
566      raise TypeError("Cannot iterate over a tensor with unknown shape.")
567    if not shape:
568      raise TypeError("Cannot iterate over a scalar tensor.")
569    if shape[0] is None:
570      raise TypeError(
571          "Cannot iterate over a tensor with unknown first dimension.")
572    for i in xrange(shape[0]):
573      yield self[i]
574
575  def _shape_as_list(self):
576    if self.shape.ndims is not None:
577      return [dim.value for dim in self.shape.dims]
578    else:
579      return None
580
581  def _shape_tuple(self):
582    shape = self._shape_as_list()
583    if shape is None:
584      return None
585    return tuple(shape)
586
587  def _rank(self):
588    """Integer rank of this Tensor, if known, else None.
589
590    Returns:
591      Integer rank or None
592    """
593    return self.shape.ndims
594
595  def get_shape(self):
596    """Alias of `tf.Tensor.shape`."""
597    return self.shape
598
599  def set_shape(self, shape):
600    """Updates the shape of this tensor.
601
602    This method can be called multiple times, and will merge the given
603    `shape` with the current shape of this tensor. It can be used to
604    provide additional information about the shape of this tensor that
605    cannot be inferred from the graph alone. For example, this can be used
606    to provide additional information about the shapes of images:
607
608    ```python
609    _, image_data = tf.compat.v1.TFRecordReader(...).read(...)
610    image = tf.image.decode_png(image_data, channels=3)
611
612    # The height and width dimensions of `image` are data dependent, and
613    # cannot be computed without executing the op.
614    print(image.shape)
615    ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
616
617    # We know that each image in this dataset is 28 x 28 pixels.
618    image.set_shape([28, 28, 3])
619    print(image.shape)
620    ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
621    ```
622
623    NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
624    result in inconsistencies between the statically-known graph and the runtime
625    value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
626    instead.
627
628    Args:
629      shape: A `TensorShape` representing the shape of this tensor, a
630        `TensorShapeProto`, a list, a tuple, or None.
631
632    Raises:
633      ValueError: If `shape` is not compatible with the current shape of
634        this tensor.
635    """
636    # Reset cached shape.
637    self._shape_val = None
638
639    # We want set_shape to be reflected in the C API graph for when we run it.
640    if not isinstance(shape, tensor_shape.TensorShape):
641      shape = tensor_shape.TensorShape(shape)
642    dim_list = []
643    if shape.dims is None:
644      unknown_shape = True
645    else:
646      unknown_shape = False
647      for dim in shape.dims:
648        if dim.value is None:
649          dim_list.append(-1)
650        else:
651          dim_list.append(dim.value)
652    try:
653      pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
654          self._op._graph._c_graph,  # pylint: disable=protected-access
655          self._as_tf_output(),
656          dim_list,
657          unknown_shape)
658    except errors.InvalidArgumentError as e:
659      # Convert to ValueError for backwards compatibility.
660      raise ValueError(str(e))
661
662  @property
663  def value_index(self):
664    """The index of this tensor in the outputs of its `Operation`."""
665    return self._value_index
666
667  def consumers(self):
668    """Returns a list of `Operation`s that consume this tensor.
669
670    Returns:
671      A list of `Operation`s.
672    """
673    consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
674        self._as_tf_output())
675    # pylint: disable=protected-access
676    return [
677        self.graph._get_operation_by_name_unsafe(name)
678        for name in consumer_names
679    ]
680    # pylint: enable=protected-access
681
682  def _as_node_def_input(self):
683    """Return a value to use for the NodeDef "input" attribute.
684
685    The returned string can be used in a NodeDef "input" attribute
686    to indicate that the NodeDef uses this Tensor as input.
687
688    Raises:
689      ValueError: if this Tensor's Operation does not have a name.
690
691    Returns:
692      a string.
693    """
694    if not self._op.name:
695      raise ValueError("Operation was not named: %s" % self._op)
696    if self._value_index == 0:
697      return self._op.name
698    else:
699      return "%s:%d" % (self._op.name, self._value_index)
700
701  def _as_tf_output(self):
702    # pylint: disable=protected-access
703    # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
704    # also guarantees that a dictionary of tf_output objects will retain a
705    # deterministic (yet unsorted) order which prevents memory blowup in the
706    # cache of executor(s) stored for every session.
707    if self._tf_output is None:
708      self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
709    return self._tf_output
710    # pylint: enable=protected-access
711
712  def __str__(self):
713    return "Tensor(\"%s\"%s%s%s)" % (
714        self.name,
715        (", shape=%s" %
716         self.get_shape()) if self.get_shape().ndims is not None else "",
717        (", dtype=%s" % self._dtype.name) if self._dtype else "",
718        (", device=%s" % self.device) if self.device else "")
719
720  def __repr__(self):
721    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
722                                                   self._dtype.name)
723
724  def __hash__(self):
725    g = getattr(self, "graph", None)
726    if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
727        (g is None or g._building_function)):  # pylint: disable=protected-access
728      raise TypeError("Tensor is unhashable if Tensor equality is enabled. "
729                      "Instead, use tensor.experimental_ref() as the key.")
730    else:
731      return id(self)
732
733  def __copy__(self):
734    # TODO(b/77597810): get rid of Tensor copies.
735    cls = self.__class__
736    result = cls.__new__(cls)
737    result.__dict__.update(self.__dict__)
738    return result
739
740  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
741  # operators to run when the left operand is an ndarray, because it
742  # accords the Tensor class higher priority than an ndarray, or a
743  # numpy matrix.
744  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
745  # mechanism, which allows more control over how Tensors interact
746  # with ndarrays.
747  __array_priority__ = 100
748
749  def __array__(self):
750    raise NotImplementedError("Cannot convert a symbolic Tensor ({}) to a numpy"
751                              " array.".format(self.name))
752
753  def __len__(self):
754    raise TypeError("len is not well defined for symbolic Tensors. ({}) "
755                    "Please call `x.shape` rather than `len(x)` for "
756                    "shape information.".format(self.name))
757
758  @staticmethod
759  def _override_operator(operator, func):
760    _override_helper(Tensor, operator, func)
761
762  def __bool__(self):
763    """Dummy method to prevent a tensor from being used as a Python `bool`.
764
765    This overload raises a `TypeError` when the user inadvertently
766    treats a `Tensor` as a boolean (most commonly in an `if` or `while`
767    statement), in code that was not converted by AutoGraph. For example:
768
769    ```python
770    if tf.constant(True):  # Will raise.
771      # ...
772
773    if tf.constant(5) < tf.constant(7):  # Will raise.
774      # ...
775    ```
776
777    Raises:
778      `TypeError`.
779    """
780    self._disallow_bool_casting()
781
782  def __nonzero__(self):
783    """Dummy method to prevent a tensor from being used as a Python `bool`.
784
785    This is the Python 2.x counterpart to `__bool__()` above.
786
787    Raises:
788      `TypeError`.
789    """
790    self._disallow_bool_casting()
791
792  def eval(self, feed_dict=None, session=None):
793    """Evaluates this tensor in a `Session`.
794
795    Note: If you are not using `compat.v1` libraries, you should not need this,
796    (or `feed_dict` or `Session`).  In eager execution (or within `tf.function`)
797    you do not need to call `eval`.
798
799    Calling this method will execute all preceding operations that
800    produce the inputs needed for the operation that produces this
801    tensor.
802
803    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
804    launched in a session, and either a default session must be
805    available, or `session` must be specified explicitly.
806
807    Args:
808      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
809        `tf.Session.run` for a description of the valid feed values.
810      session: (Optional.) The `Session` to be used to evaluate this tensor. If
811        none, the default session will be used.
812
813    Returns:
814      A numpy array corresponding to the value of this tensor.
815    """
816    return _eval_using_default_session(self, feed_dict, self.graph, session)
817
818  def experimental_ref(self):
819    # tf.Variable also has the same experimental_ref() API.  If you update the
820    # documenation here, please update tf.Variable.experimental_ref() as well.
821    """Returns a hashable reference object to this Tensor.
822
823    Warning: Experimental API that could be changed or removed.
824
825    The primary usecase for this API is to put tensors in a set/dictionary.
826    We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
827    available starting Tensorflow 2.0.
828
829    ```python
830    import tensorflow as tf
831
832    x = tf.constant(5)
833    y = tf.constant(10)
834    z = tf.constant(10)
835
836    # The followings will raise an exception starting 2.0
837    # TypeError: Tensor is unhashable if Tensor equality is enabled.
838    tensor_set = {x, y, z}
839    tensor_dict = {x: 'five', y: 'ten', z: 'ten'}
840    ```
841
842    Instead, we can use `tensor.experimental_ref()`.
843
844    ```python
845    tensor_set = {x.experimental_ref(),
846                  y.experimental_ref(),
847                  z.experimental_ref()}
848
849    print(x.experimental_ref() in tensor_set)
850    ==> True
851
852    tensor_dict = {x.experimental_ref(): 'five',
853                   y.experimental_ref(): 'ten',
854                   z.experimental_ref(): 'ten'}
855
856    print(tensor_dict[y.experimental_ref()])
857    ==> ten
858    ```
859
860    Also, the reference object provides `.deref()` function that returns the
861    original Tensor.
862
863    ```python
864    x = tf.constant(5)
865    print(x.experimental_ref().deref())
866    ==> tf.Tensor(5, shape=(), dtype=int32)
867    ```
868    """
869    return object_identity.Reference(self)
870
871
872# TODO(agarwal): consider getting rid of this.
873class _EagerTensorBase(Tensor):
874  """Base class for EagerTensor."""
875
876  # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and
877  # only work for scalars; values are cast as per numpy.
878  def __complex__(self):
879    return complex(self._numpy())
880
881  def __int__(self):
882    return int(self._numpy())
883
884  def __long__(self):
885    return long(self._numpy())
886
887  def __float__(self):
888    return float(self._numpy())
889
890  def __index__(self):
891    return self._numpy().__index__()
892
893  def __bool__(self):
894    return bool(self._numpy())
895
896  __nonzero__ = __bool__
897
898  def __format__(self, format_spec):
899    return self._numpy().__format__(format_spec)
900
901  def __reduce__(self):
902    return convert_to_tensor, (self._numpy(),)
903
904  def __copy__(self):
905    # Eager Tensors are immutable so it's safe to return themselves as a copy.
906    return self
907
908  def __deepcopy__(self, memo):
909    # Eager Tensors are immutable so it's safe to return themselves as a copy.
910    del memo
911    return self
912
913  def __str__(self):
914    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape,
915                                                  self.dtype.name)
916
917  def __repr__(self):
918    return "<tf.Tensor: shape=%s, dtype=%s, numpy=%s>" % (
919        self.shape, self.dtype.name, numpy_text(self, is_repr=True))
920
921  def __len__(self):
922    """Returns the length of the first dimension in the Tensor."""
923    if not self.shape.ndims:
924      raise TypeError("Scalar tensor has no `len()`")
925    # pylint: disable=protected-access
926    try:
927      return self._shape_tuple()[0]
928    except core._NotOkStatusException as e:
929      six.raise_from(core._status_to_exception(e.code, e.message), None)
930
931  def _numpy_internal(self):
932    raise NotImplementedError()
933
934  def _numpy(self):
935    # pylint: disable=protected-access
936    try:
937      return self._numpy_internal()
938    except core._NotOkStatusException as e:
939      six.raise_from(core._status_to_exception(e.code, e.message), None)
940
941  @property
942  def dtype(self):
943    # Note: using the intern table directly here as this is
944    # performance-sensitive in some models.
945    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
946
947  def numpy(self):
948    """Copy of the contents of this Tensor into a NumPy array or scalar.
949
950    Unlike NumPy arrays, Tensors are immutable, so this method has to copy
951    the contents to ensure safety. Use `memoryview` to get a readonly
952    view of the contents without doing a copy:
953
954    >>> t = tf.constant([42])
955    >>> np.array(memoryview(t))
956    array([42], dtype=int32)
957
958    Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor
959    is on GPU, it will have to be transferred to CPU first in order for
960    `memoryview` to work.
961
962    Returns:
963      A NumPy array of the same shape and dtype or a NumPy scalar, if this
964      Tensor has rank 0.
965
966    Raises:
967      ValueError: If the dtype of this Tensor does not have a compatible
968        NumPy dtype.
969    """
970    # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
971    maybe_arr = self._numpy()  # pylint: disable=protected-access
972    return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
973
974  @property
975  def backing_device(self):
976    """Returns the name of the device holding this tensor's memory.
977
978    `.backing_device` is usually the same as `.device`, which returns
979    the device on which the kernel of the operation that produced this tensor
980    ran. However, some operations can produce tensors on a different device
981    (e.g., an operation that executes on the GPU but produces output tensors
982    in host memory).
983    """
984    raise NotImplementedError()
985
986  def _datatype_enum(self):
987    raise NotImplementedError()
988
989  def _shape_tuple(self):
990    """The shape of this Tensor, as a tuple.
991
992    This is more performant than tuple(shape().as_list()) as it avoids
993    two list and one object creation. Marked private for now as from an API
994    perspective, it would be better to have a single performant way of
995    getting a shape rather than exposing shape() and shape_tuple()
996    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
997    but ideally one would work things out and remove the need for this method.
998
999    Returns:
1000      tuple with the shape.
1001    """
1002    raise NotImplementedError()
1003
1004  def _rank(self):
1005    """Integer rank of this Tensor.
1006
1007    Unlike regular Tensors, the rank is always known for EagerTensors.
1008
1009    This is more performant than len(self._shape_tuple())
1010
1011    Returns:
1012      Integer rank
1013    """
1014    raise NotImplementedError()
1015
1016  def _num_elements(self):
1017    """Number of elements of this Tensor.
1018
1019    Unlike regular Tensors, the number of elements is always known for
1020    EagerTensors.
1021
1022    This is more performant than tensor.shape.num_elements
1023
1024    Returns:
1025      Long - num elements in the tensor
1026    """
1027    raise NotImplementedError()
1028
1029  def _copy_to_device(self, device_name):  # pylint: disable=redefined-outer-name
1030    raise NotImplementedError()
1031
1032  @staticmethod
1033  def _override_operator(name, func):
1034    setattr(_EagerTensorBase, name, func)
1035
1036  def _copy_nograd(self, ctx=None, device_name=None):
1037    """Copies tensor to dest device, but doesn't record the operation."""
1038    # Creates a new tensor on the dest device.
1039    if ctx is None:
1040      ctx = context.context()
1041    if device_name is None:
1042      device_name = ctx.device_name
1043    # pylint: disable=protected-access
1044    try:
1045      ctx.ensure_initialized()
1046      new_tensor = self._copy_to_device(device_name)
1047    except core._NotOkStatusException as e:
1048      six.raise_from(core._status_to_exception(e.code, e.message), None)
1049    return new_tensor
1050
1051  def _copy(self, ctx=None, device_name=None):
1052    """Copies tensor to dest device."""
1053    new_tensor = self._copy_nograd(ctx, device_name)
1054    # Record the copy on tape and define backprop copy as well.
1055    if context.executing_eagerly():
1056      self_device = self.device
1057
1058      def grad_fun(dresult):
1059        return [
1060            dresult._copy(device_name=self_device)
1061            if hasattr(dresult, "_copy") else dresult
1062        ]
1063
1064      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
1065    return new_tensor
1066    # pylint: enable=protected-access
1067
1068  @property
1069  def shape(self):
1070    if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
1071      # pylint: disable=protected-access
1072      try:
1073        # `_tensor_shape` is declared and defined in the definition of
1074        # `EagerTensor`, in C.
1075        self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
1076      except core._NotOkStatusException as e:
1077        six.raise_from(core._status_to_exception(e.code, e.message), None)
1078
1079    return self._tensor_shape
1080
1081  def get_shape(self):
1082    """Alias of Tensor.shape."""
1083    return self.shape
1084
1085  def _shape_as_list(self):
1086    """The shape of the tensor as a list."""
1087    return list(self._shape_tuple())
1088
1089  @property
1090  def ndim(self):
1091    """Returns the number of Tensor dimensions."""
1092    return self.shape.ndims
1093
1094  @deprecation.deprecated(None, "Use tf.identity instead.")
1095  def cpu(self):
1096    """A copy of this Tensor with contents backed by host memory."""
1097    return self._copy(context.context(), "CPU:0")
1098
1099  @deprecation.deprecated(None, "Use tf.identity instead.")
1100  def gpu(self, gpu_index=0):
1101    """A copy of this Tensor with contents backed by memory on the GPU.
1102
1103    Arguments:
1104      gpu_index: Identifies which GPU to place the contents on the returned
1105        Tensor in.
1106
1107    Returns:
1108      A GPU-memory backed Tensor object initialized with the same contents
1109      as this Tensor.
1110    """
1111    return self._copy(context.context(), "GPU:" + str(gpu_index))
1112
1113  def set_shape(self, shape):
1114    if not self.shape.is_compatible_with(shape):
1115      raise ValueError(
1116          "Tensor's shape %s is not compatible with supplied shape %s" %
1117          (self.shape, shape))
1118
1119  # Methods not supported / implemented for Eager Tensors.
1120  @property
1121  def op(self):
1122    raise AttributeError(
1123        "Tensor.op is meaningless when eager execution is enabled.")
1124
1125  @property
1126  def graph(self):
1127    raise AttributeError(
1128        "Tensor.graph is meaningless when eager execution is enabled.")
1129
1130  @property
1131  def name(self):
1132    raise AttributeError(
1133        "Tensor.name is meaningless when eager execution is enabled.")
1134
1135  @property
1136  def value_index(self):
1137    raise AttributeError(
1138        "Tensor.value_index is meaningless when eager execution is enabled.")
1139
1140  def consumers(self):
1141    raise NotImplementedError(
1142        "Tensor.consumers is meaningless when eager execution is enabled.")
1143
1144  def _add_consumer(self, consumer):
1145    raise NotImplementedError(
1146        "_add_consumer not supported when eager execution is enabled.")
1147
1148  def _as_node_def_input(self):
1149    raise NotImplementedError(
1150        "_as_node_def_input not supported when eager execution is enabled.")
1151
1152  def _as_tf_output(self):
1153    raise NotImplementedError(
1154        "_as_tf_output not supported when eager execution is enabled.")
1155
1156  def eval(self, feed_dict=None, session=None):
1157    raise NotImplementedError(
1158        "eval is not supported when eager execution is enabled, "
1159        "is .numpy() what you're looking for?")
1160
1161
1162# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
1163# registers it with the current module.
1164EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
1165
1166
1167register_dense_tensor_like_type(Tensor)
1168
1169
1170@tf_export(v1=["convert_to_tensor"])
1171def convert_to_tensor_v1(value,
1172                         dtype=None,
1173                         name=None,
1174                         preferred_dtype=None,
1175                         dtype_hint=None):
1176  """Converts the given `value` to a `Tensor`.
1177
1178  This function converts Python objects of various types to `Tensor`
1179  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1180  and Python scalars. For example:
1181
1182  ```python
1183  import numpy as np
1184
1185  def my_func(arg):
1186    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1187    return tf.matmul(arg, arg) + arg
1188
1189  # The following calls are equivalent.
1190  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1191  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1192  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1193  ```
1194
1195  This function can be useful when composing a new operation in Python
1196  (such as `my_func` in the example above). All standard Python op
1197  constructors apply this function to each of their Tensor-valued
1198  inputs, which allows those ops to accept numpy arrays, Python lists,
1199  and scalars in addition to `Tensor` objects.
1200
1201  Note: This function diverges from default Numpy behavior for `float` and
1202    `string` types when `None` is present in a Python list or scalar. Rather
1203    than silently converting `None` values, an error will be thrown.
1204
1205  Args:
1206    value: An object whose type has a registered `Tensor` conversion function.
1207    dtype: Optional element type for the returned tensor. If missing, the type
1208      is inferred from the type of `value`.
1209    name: Optional name to use if a new `Tensor` is created.
1210    preferred_dtype: Optional element type for the returned tensor, used when
1211      dtype is None. In some cases, a caller may not have a dtype in mind when
1212      converting to a tensor, so preferred_dtype can be used as a soft
1213      preference.  If the conversion to `preferred_dtype` is not possible, this
1214      argument has no effect.
1215    dtype_hint: same meaning as preferred_dtype, and overrides it.
1216
1217  Returns:
1218    A `Tensor` based on `value`.
1219
1220  Raises:
1221    TypeError: If no conversion function is registered for `value` to `dtype`.
1222    RuntimeError: If a registered conversion function returns an invalid value.
1223    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1224  """
1225  preferred_dtype = deprecation.deprecated_argument_lookup(
1226      "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
1227  return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
1228
1229
1230@tf_export("convert_to_tensor", v1=[])
1231def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
1232  """Converts the given `value` to a `Tensor`.
1233
1234  This function converts Python objects of various types to `Tensor`
1235  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1236  and Python scalars. For example:
1237
1238  >>> def my_func(arg):
1239  ...   arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1240  ...   return arg
1241
1242  >>> # The following calls are equivalent.
1243  >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1244  >>> print(value_1)
1245  tf.Tensor(
1246    [[1. 2.]
1247     [3. 4.]], shape=(2, 2), dtype=float32)
1248  >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1249  >>> print(value_2)
1250  tf.Tensor(
1251    [[1. 2.]
1252     [3. 4.]], shape=(2, 2), dtype=float32)
1253  >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1254  >>> print(value_3)
1255  tf.Tensor(
1256    [[1. 2.]
1257     [3. 4.]], shape=(2, 2), dtype=float32)
1258
1259  This function can be useful when composing a new operation in Python
1260  (such as `my_func` in the example above). All standard Python op
1261  constructors apply this function to each of their Tensor-valued
1262  inputs, which allows those ops to accept numpy arrays, Python lists,
1263  and scalars in addition to `Tensor` objects.
1264
1265  Note: This function diverges from default Numpy behavior for `float` and
1266    `string` types when `None` is present in a Python list or scalar. Rather
1267    than silently converting `None` values, an error will be thrown.
1268
1269  Args:
1270    value: An object whose type has a registered `Tensor` conversion function.
1271    dtype: Optional element type for the returned tensor. If missing, the type
1272      is inferred from the type of `value`.
1273    dtype_hint: Optional element type for the returned tensor, used when dtype
1274      is None. In some cases, a caller may not have a dtype in mind when
1275      converting to a tensor, so dtype_hint can be used as a soft preference.
1276      If the conversion to `dtype_hint` is not possible, this argument has no
1277      effect.
1278    name: Optional name to use if a new `Tensor` is created.
1279
1280  Returns:
1281    A `Tensor` based on `value`.
1282
1283  Raises:
1284    TypeError: If no conversion function is registered for `value` to `dtype`.
1285    RuntimeError: If a registered conversion function returns an invalid value.
1286    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1287  """
1288  return convert_to_tensor(
1289      value=value,
1290      dtype=dtype,
1291      name=name,
1292      preferred_dtype=dtype_hint,
1293      as_ref=False)
1294
1295
1296def _error_prefix(name):
1297  return "" if name is None else "%s: " % name
1298
1299
1300def convert_to_tensor(value,
1301                      dtype=None,
1302                      name=None,
1303                      as_ref=False,
1304                      preferred_dtype=None,
1305                      dtype_hint=None,
1306                      ctx=None,
1307                      accepted_result_types=(Tensor,)):
1308  """Implementation of the public convert_to_tensor."""
1309  # TODO(b/142518781): Fix all call-sites and remove redundant arg
1310  preferred_dtype = preferred_dtype or dtype_hint
1311  if isinstance(value, EagerTensor):
1312    if ctx is None:
1313      ctx = context.context()
1314    if not ctx.executing_eagerly():
1315      graph = get_default_graph()
1316      if not graph.building_function:
1317        raise RuntimeError("Attempting to capture an EagerTensor without "
1318                           "building a function.")
1319      return graph.capture(value, name=name)
1320
1321  if dtype is not None:
1322    dtype = dtypes.as_dtype(dtype)
1323  if isinstance(value, Tensor):
1324    if dtype is not None and not dtype.is_compatible_with(value.dtype):
1325      raise ValueError(
1326          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1327          (dtype.name, value.dtype.name, value))
1328    return value
1329
1330  if preferred_dtype is not None:
1331    preferred_dtype = dtypes.as_dtype(preferred_dtype)
1332  for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
1333    # If dtype is None but preferred_dtype is not None, we try to
1334    # cast to preferred_dtype first.
1335    ret = None
1336    if dtype is None and preferred_dtype is not None:
1337      try:
1338        ret = conversion_func(
1339            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1340      except (TypeError, ValueError):
1341        # Could not coerce the conversion to use the preferred dtype.
1342        pass
1343      else:
1344        if (ret is not NotImplemented and
1345            ret.dtype.base_dtype != preferred_dtype.base_dtype):
1346          raise TypeError("convert_to_tensor did not convert to "
1347                          "the preferred dtype: %s vs %s " %
1348                          (ret.dtype.base_dtype, preferred_dtype.base_dtype))
1349
1350    if ret is None:
1351      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1352
1353    if ret is NotImplemented:
1354      continue
1355
1356    if not isinstance(ret, accepted_result_types):
1357      raise RuntimeError(
1358          "%sConversion function %r for type %s returned non-Tensor: %r" %
1359          (_error_prefix(name), conversion_func, base_type, ret))
1360    if dtype and not dtype.is_compatible_with(ret.dtype):
1361      raise RuntimeError(
1362          "%sConversion function %r for type %s returned incompatible "
1363          "dtype: requested = %s, actual = %s" %
1364          (_error_prefix(name), conversion_func, base_type, dtype.name,
1365           ret.dtype.name))
1366    return ret
1367  raise TypeError("%sCannot convert %r with type %s to Tensor: "
1368                  "no conversion function registered." %
1369                  (_error_prefix(name), value, type(value)))
1370
1371
1372internal_convert_to_tensor = convert_to_tensor
1373
1374
1375def internal_convert_n_to_tensor(values,
1376                                 dtype=None,
1377                                 name=None,
1378                                 as_ref=False,
1379                                 preferred_dtype=None,
1380                                 ctx=None):
1381  """Converts `values` to a list of `Tensor` objects.
1382
1383  Args:
1384    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1385    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1386    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1387      which case element `i` will be given the name `name + '_' + i`.
1388    as_ref: True if the caller wants the results as ref tensors.
1389    preferred_dtype: Optional element type for the returned tensors, used when
1390      dtype is None. In some cases, a caller may not have a dtype in mind when
1391      converting to a tensor, so preferred_dtype can be used as a soft
1392      preference.  If the conversion to `preferred_dtype` is not possible, this
1393      argument has no effect.
1394    ctx: The value of context.context().
1395
1396  Returns:
1397    A list of `Tensor` and/or `IndexedSlices` objects.
1398
1399  Raises:
1400    TypeError: If no conversion function is registered for an element in
1401      `values`.
1402    RuntimeError: If a registered conversion function returns an invalid
1403      value.
1404  """
1405  if not isinstance(values, collections_abc.Sequence):
1406    raise TypeError("values must be a sequence.")
1407  ret = []
1408  if ctx is None:
1409    ctx = context.context()
1410  for i, value in enumerate(values):
1411    n = None if name is None else "%s_%d" % (name, i)
1412    ret.append(
1413        convert_to_tensor(
1414            value,
1415            dtype=dtype,
1416            name=n,
1417            as_ref=as_ref,
1418            preferred_dtype=preferred_dtype,
1419            ctx=ctx))
1420  return ret
1421
1422
1423def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1424  """Converts `values` to a list of `Tensor` objects.
1425
1426  Args:
1427    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1428    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1429    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1430      which case element `i` will be given the name `name + '_' + i`.
1431    preferred_dtype: Optional element type for the returned tensors, used when
1432      dtype is None. In some cases, a caller may not have a dtype in mind when
1433      converting to a tensor, so preferred_dtype can be used as a soft
1434      preference.  If the conversion to `preferred_dtype` is not possible, this
1435      argument has no effect.
1436
1437  Returns:
1438    A list of `Tensor` and/or `IndexedSlices` objects.
1439
1440  Raises:
1441    TypeError: If no conversion function is registered for an element in
1442      `values`.
1443    RuntimeError: If a registered conversion function returns an invalid
1444      value.
1445  """
1446  return internal_convert_n_to_tensor(
1447      values=values,
1448      dtype=dtype,
1449      name=name,
1450      preferred_dtype=preferred_dtype,
1451      as_ref=False)
1452
1453
1454def convert_to_tensor_or_composite(value, dtype=None, name=None):
1455  """Converts the given object to a `Tensor` or `CompositeTensor`.
1456
1457  If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1458  is converted to a `Tensor` using `convert_to_tensor()`.
1459
1460  Args:
1461    value: A `CompositeTensor` or an object that can be consumed by
1462      `convert_to_tensor()`.
1463    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1464      `CompositeTensor`.
1465    name: (Optional.) A name to use if a new `Tensor` is created.
1466
1467  Returns:
1468    A `Tensor` or `CompositeTensor`, based on `value`.
1469
1470  Raises:
1471    ValueError: If `dtype` does not match the element type of `value`.
1472  """
1473  return internal_convert_to_tensor_or_composite(
1474      value=value, dtype=dtype, name=name, as_ref=False)
1475
1476
1477def internal_convert_to_tensor_or_composite(value,
1478                                            dtype=None,
1479                                            name=None,
1480                                            as_ref=False):
1481  """Converts the given object to a `Tensor` or `CompositeTensor`.
1482
1483  If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
1484  is converted to a `Tensor` using `convert_to_tensor()`.
1485
1486  Args:
1487    value: A `CompositeTensor`, or an object that can be consumed by
1488      `convert_to_tensor()`.
1489    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1490      `CompositeTensor`.
1491    name: (Optional.) A name to use if a new `Tensor` is created.
1492    as_ref: True if the caller wants the results as ref tensors.
1493
1494  Returns:
1495    A `Tensor` or `CompositeTensor`, based on `value`.
1496
1497  Raises:
1498    ValueError: If `dtype` does not match the element type of `value`.
1499  """
1500  if isinstance(value, composite_tensor.CompositeTensor):
1501    value_dtype = getattr(value, "dtype", None)
1502    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1503      raise ValueError(
1504          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1505          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1506    return value
1507  else:
1508    return convert_to_tensor(
1509        value,
1510        dtype=dtype,
1511        name=name,
1512        as_ref=as_ref,
1513        accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
1514
1515
1516def internal_convert_n_to_tensor_or_composite(values,
1517                                              dtype=None,
1518                                              name=None,
1519                                              as_ref=False):
1520  """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1521
1522  Any `CompositeTensor` objects in `values` are returned unmodified.
1523
1524  Args:
1525    values: A list of `None`, `CompositeTensor`, or objects that can be consumed
1526      by `convert_to_tensor()`.
1527    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1528      `CompositeTensor`s.
1529    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1530      which case element `i` will be given the name `name + '_' + i`.
1531    as_ref: True if the caller wants the results as ref tensors.
1532
1533  Returns:
1534    A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1535
1536  Raises:
1537    TypeError: If no conversion function is registered for an element in
1538      `values`.
1539    RuntimeError: If a registered conversion function returns an invalid
1540      value.
1541  """
1542  if not isinstance(values, collections_abc.Sequence):
1543    raise TypeError("values must be a sequence.")
1544  ret = []
1545  for i, value in enumerate(values):
1546    if value is None:
1547      ret.append(value)
1548    else:
1549      n = None if name is None else "%s_%d" % (name, i)
1550      ret.append(
1551          internal_convert_to_tensor_or_composite(
1552              value, dtype=dtype, name=n, as_ref=as_ref))
1553  return ret
1554
1555
1556def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1557  """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1558
1559  Any `CompositeTensor` objects in `values` are returned unmodified.
1560
1561  Args:
1562    values: A list of `None`, `CompositeTensor``, or objects that can be
1563      consumed by `convert_to_tensor()`.
1564    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1565      `CompositeTensor`s.
1566    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1567      which case element `i` will be given the name `name + '_' + i`.
1568
1569  Returns:
1570    A list of `Tensor` and/or `CompositeTensor` objects.
1571
1572  Raises:
1573    TypeError: If no conversion function is registered for an element in
1574      `values`.
1575    RuntimeError: If a registered conversion function returns an invalid
1576      value.
1577  """
1578  return internal_convert_n_to_tensor_or_composite(
1579      values=values, dtype=dtype, name=name, as_ref=False)
1580
1581
1582def _device_string(dev_spec):
1583  if pydev.is_device_spec(dev_spec):
1584    return dev_spec.to_string()
1585  else:
1586    return dev_spec
1587
1588
1589def _NodeDef(op_type, name, attrs=None):
1590  """Create a NodeDef proto.
1591
1592  Args:
1593    op_type: Value for the "op" attribute of the NodeDef proto.
1594    name: Value for the "name" attribute of the NodeDef proto.
1595    attrs: Dictionary where the key is the attribute name (a string)
1596      and the value is the respective "attr" attribute of the NodeDef proto (an
1597      AttrValue).
1598
1599  Returns:
1600    A node_def_pb2.NodeDef protocol buffer.
1601  """
1602  node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type),
1603                                  name=compat.as_bytes(name))
1604  if attrs:
1605    for k, v in six.iteritems(attrs):
1606      node_def.attr[k].CopyFrom(v)
1607  return node_def
1608
1609
1610# Copied from core/framework/node_def_util.cc
1611# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1612_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/>]*$")
1613_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/>]*$")
1614
1615
1616def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
1617  """Creates a TF_Operation.
1618
1619  Args:
1620    graph: a `Graph`.
1621    node_def: `node_def_pb2.NodeDef` for the operation to create.
1622    inputs: A flattened list of `Tensor`s. This function handles grouping
1623      tensors into lists as per attributes in the `node_def`.
1624    control_inputs: A list of `Operation`s to set as control dependencies.
1625    op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not
1626      specified, is looked up from the `graph` using `node_def.op`.
1627
1628  Returns:
1629    A wrapped TF_Operation*.
1630  """
1631  if op_def is None:
1632    op_def = graph._get_op_def(node_def.op)  # pylint: disable=protected-access
1633  # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1634  # Refactor so we don't have to do this here.
1635  inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
1636  # pylint: disable=protected-access
1637  op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph,
1638                                              compat.as_str(node_def.op),
1639                                              compat.as_str(node_def.name))
1640  if node_def.device:
1641    pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1642  # Add inputs
1643  for op_input in inputs:
1644    if isinstance(op_input, (list, tuple)):
1645      pywrap_tf_session.TF_AddInputList(op_desc,
1646                                        [t._as_tf_output() for t in op_input])
1647    else:
1648      pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
1649
1650  # Add control inputs
1651  for control_input in control_inputs:
1652    pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
1653  # pylint: enable=protected-access
1654
1655  # Add attrs
1656  for name, attr_value in node_def.attr.items():
1657    serialized = attr_value.SerializeToString()
1658    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1659    # It might be worth creating a convenient way to re-use the same status.
1660    pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
1661                                           serialized)
1662
1663  try:
1664    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1665  except errors.InvalidArgumentError as e:
1666    # Convert to ValueError for backwards compatibility.
1667    raise ValueError(str(e))
1668
1669  return c_op
1670
1671
1672@tf_export("Operation")
1673class Operation(object):
1674  """Represents a graph node that performs computation on tensors.
1675
1676  An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
1677  objects as input, and produces zero or more `Tensor` objects as output.
1678  Objects of type `Operation` are created by calling a Python op constructor
1679  (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default`
1680  context manager.
1681
1682  For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an
1683  `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and
1684  produces `c` as output.
1685
1686  If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be
1687  executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for
1688  calling `tf.compat.v1.get_default_session().run(op)`.
1689  """
1690
1691  def __init__(self,
1692               node_def,
1693               g,
1694               inputs=None,
1695               output_types=None,
1696               control_inputs=None,
1697               input_types=None,
1698               original_op=None,
1699               op_def=None):
1700    r"""Creates an `Operation`.
1701
1702    NOTE: This constructor validates the name of the `Operation` (passed
1703    as `node_def.name`). Valid `Operation` names match the following
1704    regular expression:
1705
1706        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1707
1708    Args:
1709      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`. Used for
1710        attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and
1711        `device`.  The `input` attribute is irrelevant here as it will be
1712        computed when generating the model.
1713      g: `Graph`. The parent graph.
1714      inputs: list of `Tensor` objects. The inputs to this `Operation`.
1715      output_types: list of `DType` objects.  List of the types of the `Tensors`
1716        computed by this operation.  The length of this list indicates the
1717        number of output endpoints of the `Operation`.
1718      control_inputs: list of operations or tensors from which to have a control
1719        dependency.
1720      input_types: List of `DType` objects representing the types of the tensors
1721        accepted by the `Operation`.  By default uses `[x.dtype.base_dtype for x
1722        in inputs]`.  Operations that expect reference-typed inputs must specify
1723        these explicitly.
1724      original_op: Optional. Used to associate the new `Operation` with an
1725        existing `Operation` (for example, a replica with the op that was
1726        replicated).
1727      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type
1728        that this `Operation` represents.
1729
1730    Raises:
1731      TypeError: if control inputs are not Operations or Tensors,
1732        or if `node_def` is not a `NodeDef`,
1733        or if `g` is not a `Graph`,
1734        or if `inputs` are not tensors,
1735        or if `inputs` and `input_types` are incompatible.
1736      ValueError: if the `node_def` name is not valid.
1737    """
1738    # For internal use only: `node_def` can be set to a TF_Operation to create
1739    # an Operation for that op. This is useful for creating Operations for ops
1740    # indirectly created by C API methods, e.g. the ops created by
1741    # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
1742    # should be None.
1743
1744    if isinstance(node_def, node_def_pb2.NodeDef):
1745      if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1746        raise ValueError(
1747            "Cannot create a tensor proto whose content is larger than 2GB.")
1748      if not _VALID_OP_NAME_REGEX.match(node_def.name):
1749        raise ValueError("'%s' is not a valid node name" % node_def.name)
1750      c_op = None
1751    elif type(node_def).__name__ == "TF_Operation":
1752      assert inputs is None
1753      assert output_types is None
1754      assert control_inputs is None
1755      assert input_types is None
1756      assert original_op is None
1757      assert op_def is None
1758      c_op = node_def
1759    else:
1760      raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
1761
1762    if not isinstance(g, Graph):
1763      raise TypeError("g needs to be a Graph: %s" % g)
1764    self._graph = g
1765
1766    if inputs is None:
1767      inputs = []
1768    elif not isinstance(inputs, list):
1769      raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
1770    for a in inputs:
1771      if not isinstance(a, Tensor):
1772        raise TypeError("input needs to be a Tensor: %s" % a)
1773    if input_types is None:
1774      input_types = [i.dtype.base_dtype for i in inputs]
1775    else:
1776      if not all(
1777          x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)):
1778        raise TypeError("In op '%s', input types (%s) are not compatible "
1779                        "with expected types (%s)" %
1780                        (node_def.name, [i.dtype for i in inputs], input_types))
1781
1782    # Build the list of control inputs.
1783    control_input_ops = []
1784    if control_inputs:
1785      for c in control_inputs:
1786        control_op = None
1787        if isinstance(c, Operation):
1788          control_op = c
1789        elif isinstance(c, (Tensor, IndexedSlices)):
1790          control_op = c.op
1791        else:
1792          raise TypeError("Control input must be an Operation, "
1793                          "a Tensor, or IndexedSlices: %s" % c)
1794        control_input_ops.append(control_op)
1795
1796    # This will be set by self.inputs.
1797    self._inputs_val = None
1798
1799    # pylint: disable=protected-access
1800    self._original_op = original_op
1801    self._traceback = tf_stack.extract_stack()
1802
1803    # List of _UserDevSpecs holding code location of device context manager
1804    # invocations and the users original argument to them.
1805    self._device_code_locations = None
1806    # Dict mapping op name to file and line information for op colocation
1807    # context managers.
1808    self._colocation_code_locations = None
1809    self._control_flow_context = self.graph._get_control_flow_context()
1810
1811    # Gradient function for this op. There are three ways to specify gradient
1812    # function, and first available gradient gets used, in the following order.
1813    # 1. self._gradient_function
1814    # 2. Gradient name registered by "_gradient_op_type" attribute.
1815    # 3. Gradient name registered by op.type.
1816    self._gradient_function = None
1817
1818    # Initialize self._c_op.
1819    if c_op:
1820      self._c_op = c_op
1821      op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))
1822      name = self.name
1823    else:
1824      if op_def is None:
1825        op_def = self._graph._get_op_def(node_def.op)
1826      self._c_op = _create_c_op(self._graph, node_def, inputs,
1827                                control_input_ops, op_def)
1828      name = compat.as_str(node_def.name)
1829    # pylint: enable=protected-access
1830
1831    self._is_stateful = op_def.is_stateful
1832
1833    # Initialize self._outputs.
1834    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
1835    self._outputs = []
1836    for i in range(num_outputs):
1837      tf_output = c_api_util.tf_output(self._c_op, i)
1838      output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
1839      tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output)  # pylint: disable=protected-access
1840      self._outputs.append(tensor)
1841
1842    self._id_value = self._graph._add_op(self, name)  # pylint: disable=protected-access
1843
1844    if not c_op:
1845      self._control_flow_post_processing(input_tensors=inputs)
1846
1847  def _control_flow_post_processing(self, input_tensors=None):
1848    """Add this op to its control flow context.
1849
1850    This may add new ops and change this op's inputs. self.inputs must be
1851    available before calling this method.
1852
1853    Args:
1854      input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
1855        of this op, which should be equivalent to `self.inputs`. Pass this
1856        argument to avoid evaluating `self.inputs` unnecessarily.
1857    """
1858    if input_tensors is None:
1859      input_tensors = self.inputs
1860    for input_tensor in input_tensors:
1861      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
1862    if self._control_flow_context is not None:
1863      self._control_flow_context.AddOp(self)
1864
1865  def colocation_groups(self):
1866    """Returns the list of colocation groups of the op."""
1867    default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)]
1868    try:
1869      class_attr = self.get_attr("_class")
1870    except ValueError:
1871      # This op has no explicit colocation group, so it is itself its
1872      # own root of a colocation group.
1873      return default_colocation_group
1874
1875    attr_groups = [
1876        class_name for class_name in class_attr
1877        if class_name.startswith(b"loc:@")
1878    ]
1879
1880    # If there are no colocation groups in the explicit _class field,
1881    # return the default colocation group.
1882    return attr_groups if attr_groups else default_colocation_group
1883
1884  def values(self):
1885    """DEPRECATED: Use outputs."""
1886    return tuple(self.outputs)
1887
1888  def _get_control_flow_context(self):
1889    """Returns the control flow context of this op.
1890
1891    Returns:
1892      A context object.
1893    """
1894    return self._control_flow_context
1895
1896  def _set_control_flow_context(self, ctx):
1897    """Sets the current control flow context of this op.
1898
1899    Args:
1900      ctx: a context object.
1901    """
1902    self._control_flow_context = ctx
1903
1904  @property
1905  def name(self):
1906    """The full name of this operation."""
1907    return pywrap_tf_session.TF_OperationName(self._c_op)
1908
1909  @property
1910  def _id(self):
1911    """The unique integer id of this operation."""
1912    return self._id_value
1913
1914  @property
1915  def device(self):
1916    """The name of the device to which this op has been assigned, if any.
1917
1918    Returns:
1919      The string name of the device to which this op has been
1920      assigned, or an empty string if it has not been assigned to a
1921      device.
1922    """
1923    return pywrap_tf_session.TF_OperationDevice(self._c_op)
1924
1925  @property
1926  def _device_assignments(self):
1927    """Code locations for device context managers active at op creation.
1928
1929    This property will return a list of traceable_stack.TraceableObject
1930    instances where .obj is a string representing the assigned device
1931    (or information about the function that would be applied to this op
1932    to compute the desired device) and the filename and lineno members
1933    record the location of the relevant device context manager.
1934
1935    For example, suppose file_a contained these lines:
1936
1937      file_a.py:
1938        15: with tf.device('/gpu:0'):
1939        16:   node_b = tf.constant(4, name='NODE_B')
1940
1941    Then a TraceableObject t_obj representing the device context manager
1942    would have these member values:
1943
1944      t_obj.obj -> '/gpu:0'
1945      t_obj.filename = 'file_a.py'
1946      t_obj.lineno = 15
1947
1948    and node_b.op._device_assignments would return the list [t_obj].
1949
1950    Returns:
1951      [str: traceable_stack.TraceableObject, ...] as per this method's
1952      description, above.
1953    """
1954    return self._device_code_locations or []
1955
1956  @property
1957  def _colocation_dict(self):
1958    """Code locations for colocation context managers active at op creation.
1959
1960    This property will return a dictionary for which the keys are nodes with
1961    which this Operation is colocated, and for which the values are
1962    traceable_stack.TraceableObject instances.  The TraceableObject instances
1963    record the location of the relevant colocation context manager but have the
1964    "obj" field set to None to prevent leaking private data.
1965
1966    For example, suppose file_a contained these lines:
1967
1968      file_a.py:
1969        14: node_a = tf.constant(3, name='NODE_A')
1970        15: with tf.compat.v1.colocate_with(node_a):
1971        16:   node_b = tf.constant(4, name='NODE_B')
1972
1973    Then a TraceableObject t_obj representing the colocation context manager
1974    would have these member values:
1975
1976      t_obj.obj -> None
1977      t_obj.filename = 'file_a.py'
1978      t_obj.lineno = 15
1979
1980    and node_b.op._colocation_dict would return the dictionary
1981
1982      { 'NODE_A': t_obj }
1983
1984    Returns:
1985      {str: traceable_stack.TraceableObject} as per this method's description,
1986      above.
1987    """
1988    locations_dict = self._colocation_code_locations or {}
1989    return locations_dict.copy()
1990
1991  @property
1992  def _output_types(self):
1993    """List this operation's output types.
1994
1995    Returns:
1996      List of the types of the Tensors computed by this operation.
1997      Each element in the list is an integer whose value is one of
1998      the TF_DataType enums defined in pywrap_tf_session.h
1999      The length of this list indicates the number of output endpoints
2000      of the operation.
2001    """
2002    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2003    output_types = [
2004        int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
2005        for i in xrange(num_outputs)
2006    ]
2007
2008    return output_types
2009
2010  def _tf_output(self, output_idx):
2011    """Create and return a new TF_Output for output_idx'th output of this op."""
2012    tf_output = pywrap_tf_session.TF_Output()
2013    tf_output.oper = self._c_op
2014    tf_output.index = output_idx
2015    return tf_output
2016
2017  def _tf_input(self, input_idx):
2018    """Create and return a new TF_Input for input_idx'th input of this op."""
2019    tf_input = pywrap_tf_session.TF_Input()
2020    tf_input.oper = self._c_op
2021    tf_input.index = input_idx
2022    return tf_input
2023
2024  def _set_device(self, device):  # pylint: disable=redefined-outer-name
2025    """Set the device of this operation.
2026
2027    Args:
2028      device: string or device..  The device to set.
2029    """
2030    self._set_device_from_string(compat.as_str(_device_string(device)))
2031
2032  def _set_device_from_string(self, device_str):
2033    """Fast path to set device if the type is known to be a string.
2034
2035    This function is called frequently enough during graph construction that
2036    there are non-trivial performance gains if the caller can guarantee that
2037    the specified device is already a string.
2038
2039    Args:
2040      device_str: A string specifying where to place this op.
2041    """
2042    pywrap_tf_session.SetRequestedDevice(
2043        self._graph._c_graph,  # pylint: disable=protected-access
2044        self._c_op,  # pylint: disable=protected-access
2045        device_str)
2046
2047  def _update_input(self, index, tensor):
2048    """Update the input to this operation at the given index.
2049
2050    NOTE: This is for TF internal use only. Please don't use it.
2051
2052    Args:
2053      index: the index of the input to update.
2054      tensor: the Tensor to be used as the input at the given index.
2055
2056    Raises:
2057      TypeError: if tensor is not a Tensor,
2058        or if input tensor type is not convertible to dtype.
2059      ValueError: if the Tensor is from a different graph.
2060    """
2061    if not isinstance(tensor, Tensor):
2062      raise TypeError("tensor must be a Tensor: %s" % tensor)
2063    _assert_same_graph(self, tensor)
2064
2065    # Reset cached inputs.
2066    self._inputs_val = None
2067    pywrap_tf_session.UpdateEdge(
2068        self._graph._c_graph,  # pylint: disable=protected-access
2069        tensor._as_tf_output(),  # pylint: disable=protected-access
2070        self._tf_input(index))
2071
2072  def _add_while_inputs(self, tensors):
2073    """See AddWhileInputHack in python_api.h.
2074
2075    NOTE: This is for TF internal use only. Please don't use it.
2076
2077    Args:
2078      tensors: list of Tensors
2079
2080    Raises:
2081      TypeError: if tensor is not a Tensor,
2082        or if input tensor type is not convertible to dtype.
2083      ValueError: if the Tensor is from a different graph.
2084    """
2085    for tensor in tensors:
2086      if not isinstance(tensor, Tensor):
2087        raise TypeError("tensor must be a Tensor: %s" % tensor)
2088      _assert_same_graph(self, tensor)
2089
2090      # Reset cached inputs.
2091      self._inputs_val = None
2092      pywrap_tf_session.AddWhileInputHack(
2093          self._graph._c_graph,  # pylint: disable=protected-access
2094          tensor._as_tf_output(),  # pylint: disable=protected-access
2095          self._c_op)
2096
2097  def _add_control_inputs(self, ops):
2098    """Add a list of new control inputs to this operation.
2099
2100    Args:
2101      ops: the list of Operations to add as control input.
2102
2103    Raises:
2104      TypeError: if ops is not a list of Operations.
2105      ValueError: if any op in ops is from a different graph.
2106    """
2107    for op in ops:
2108      if not isinstance(op, Operation):
2109        raise TypeError("op must be an Operation: %s" % op)
2110      pywrap_tf_session.AddControlInput(
2111          self._graph._c_graph,  # pylint: disable=protected-access
2112          self._c_op,  # pylint: disable=protected-access
2113          op._c_op)  # pylint: disable=protected-access
2114
2115  def _add_control_input(self, op):
2116    """Add a new control input to this operation.
2117
2118    Args:
2119      op: the Operation to add as control input.
2120
2121    Raises:
2122      TypeError: if op is not an Operation.
2123      ValueError: if op is from a different graph.
2124    """
2125    if not isinstance(op, Operation):
2126      raise TypeError("op must be an Operation: %s" % op)
2127    pywrap_tf_session.AddControlInput(
2128        self._graph._c_graph,  # pylint: disable=protected-access
2129        self._c_op,  # pylint: disable=protected-access
2130        op._c_op)  # pylint: disable=protected-access
2131
2132  def _remove_all_control_inputs(self):
2133    """Removes any control inputs to this operation."""
2134    pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
2135
2136  def _add_outputs(self, types, shapes):
2137    """Adds new Tensors to self.outputs.
2138
2139    Note: this is generally unsafe to use. This is used in certain situations in
2140    conjunction with _set_type_list_attr.
2141
2142    Arguments:
2143      types: list of DTypes
2144      shapes: list of TensorShapes
2145    """
2146    assert len(types) == len(shapes)
2147    orig_num_outputs = len(self.outputs)
2148    for i in range(len(types)):
2149      t = Tensor(self, orig_num_outputs + i, types[i])
2150      self._outputs.append(t)
2151      t.set_shape(shapes[i])
2152
2153  def __str__(self):
2154    return str(self.node_def)
2155
2156  def __repr__(self):
2157    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2158
2159  @property
2160  def outputs(self):
2161    """The list of `Tensor` objects representing the outputs of this op."""
2162    return self._outputs
2163
2164  @property
2165  def inputs(self):
2166    """The sequence of `Tensor` objects representing the data inputs of this op."""
2167    if self._inputs_val is None:
2168      # pylint: disable=protected-access
2169      self._inputs_val = tuple(
2170          map(self.graph._get_tensor_by_tf_output,
2171              pywrap_tf_session.GetOperationInputs(self._c_op)))
2172      # pylint: enable=protected-access
2173    return self._inputs_val
2174
2175  @property
2176  def _input_types(self):
2177    num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
2178    input_types = [
2179        dtypes.as_dtype(
2180            pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
2181        for i in xrange(num_inputs)
2182    ]
2183    return input_types
2184
2185  @property
2186  def control_inputs(self):
2187    """The `Operation` objects on which this op has a control dependency.
2188
2189    Before this op is executed, TensorFlow will ensure that the
2190    operations in `self.control_inputs` have finished executing. This
2191    mechanism can be used to run ops sequentially for performance
2192    reasons, or to ensure that the side effects of an op are observed
2193    in the correct order.
2194
2195    Returns:
2196      A list of `Operation` objects.
2197
2198    """
2199    control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
2200        self._c_op)
2201    # pylint: disable=protected-access
2202    return [
2203        self.graph._get_operation_by_name_unsafe(
2204            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2205    ]
2206    # pylint: enable=protected-access
2207
2208  @property
2209  def _control_outputs(self):
2210    """The `Operation` objects which have a control dependency on this op.
2211
2212    Before any of the ops in self._control_outputs can execute tensorflow will
2213    ensure self has finished executing.
2214
2215    Returns:
2216      A list of `Operation` objects.
2217
2218    """
2219    control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
2220        self._c_op)
2221    # pylint: disable=protected-access
2222    return [
2223        self.graph._get_operation_by_name_unsafe(
2224            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2225    ]
2226    # pylint: enable=protected-access
2227
2228  @property
2229  def type(self):
2230    """The type of the op (e.g. `"MatMul"`)."""
2231    return pywrap_tf_session.TF_OperationOpType(self._c_op)
2232
2233  @property
2234  def graph(self):
2235    """The `Graph` that contains this operation."""
2236    return self._graph
2237
2238  @property
2239  def node_def(self):
2240    # pylint: disable=line-too-long
2241    """Returns the `NodeDef` representation of this operation.
2242
2243    Returns:
2244      A
2245      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2246      protocol buffer.
2247    """
2248    # pylint: enable=line-too-long
2249    with c_api_util.tf_buffer() as buf:
2250      pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
2251      data = pywrap_tf_session.TF_GetBuffer(buf)
2252    node_def = node_def_pb2.NodeDef()
2253    node_def.ParseFromString(compat.as_bytes(data))
2254    return node_def
2255
2256  @property
2257  def op_def(self):
2258    # pylint: disable=line-too-long
2259    """Returns the `OpDef` proto that represents the type of this op.
2260
2261    Returns:
2262      An
2263      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2264      protocol buffer.
2265    """
2266    # pylint: enable=line-too-long
2267    return self._graph._get_op_def(self.type)
2268
2269  @property
2270  def traceback(self):
2271    """Returns the call stack from when this operation was constructed."""
2272    return self._traceback
2273
2274  def _set_attr(self, attr_name, attr_value):
2275    """Private method used to set an attribute in the node_def."""
2276    buf = pywrap_tf_session.TF_NewBufferFromString(
2277        compat.as_bytes(attr_value.SerializeToString()))
2278    try:
2279      self._set_attr_with_buf(attr_name, buf)
2280    finally:
2281      pywrap_tf_session.TF_DeleteBuffer(buf)
2282
2283  def _set_attr_with_buf(self, attr_name, attr_buf):
2284    """Set an attr in the node_def with a pre-allocated buffer."""
2285    # pylint: disable=protected-access
2286    pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name,
2287                              attr_buf)
2288    # pylint: enable=protected-access
2289
2290  def _set_func_attr(self, attr_name, func_name):
2291    """Private method used to set a function attribute in the node_def."""
2292    func = attr_value_pb2.NameAttrList(name=func_name)
2293    self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2294
2295  def _set_func_list_attr(self, attr_name, func_names):
2296    """Private method used to set a list(function) attribute in the node_def."""
2297    funcs = [attr_value_pb2.NameAttrList(name=func_name)
2298             for func_name in func_names]
2299    funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
2300    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
2301
2302  def _set_type_list_attr(self, attr_name, types):
2303    """Private method used to set a list(type) attribute in the node_def."""
2304    if not types:
2305      return
2306    if isinstance(types[0], dtypes.DType):
2307      types = [dt.as_datatype_enum for dt in types]
2308    types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2309    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2310
2311  def _set_shape_list_attr(self, attr_name, shapes):
2312    """Private method used to set a list(shape) attribute in the node_def."""
2313    shapes = [s.as_proto() for s in shapes]
2314    shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2315    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2316
2317  def _clear_attr(self, attr_name):
2318    """Private method used to clear an attribute in the node_def."""
2319    # pylint: disable=protected-access
2320    pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
2321    # pylint: enable=protected-access
2322
2323  def get_attr(self, name):
2324    """Returns the value of the attr of this op with the given `name`.
2325
2326    Args:
2327      name: The name of the attr to fetch.
2328
2329    Returns:
2330      The value of the attr, as a Python object.
2331
2332    Raises:
2333      ValueError: If this op does not have an attr with the given `name`.
2334    """
2335    fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2336    try:
2337      with c_api_util.tf_buffer() as buf:
2338        pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2339        data = pywrap_tf_session.TF_GetBuffer(buf)
2340    except errors.InvalidArgumentError as e:
2341      # Convert to ValueError for backwards compatibility.
2342      raise ValueError(str(e))
2343    x = attr_value_pb2.AttrValue()
2344    x.ParseFromString(data)
2345
2346    oneof_value = x.WhichOneof("value")
2347    if oneof_value is None:
2348      return []
2349    if oneof_value == "list":
2350      for f in fields:
2351        if getattr(x.list, f):
2352          if f == "type":
2353            return [dtypes.as_dtype(t) for t in x.list.type]
2354          else:
2355            return list(getattr(x.list, f))
2356      return []
2357    if oneof_value == "type":
2358      return dtypes.as_dtype(x.type)
2359    assert oneof_value in fields, "Unsupported field type in " + str(x)
2360    return getattr(x, oneof_value)
2361
2362  def _get_attr_type(self, name):
2363    """Returns the `DType` value of the attr of this op with the given `name`."""
2364    try:
2365      dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
2366      return _DTYPES_INTERN_TABLE[dtype_enum]
2367    except errors.InvalidArgumentError as e:
2368      # Convert to ValueError for backwards compatibility.
2369      raise ValueError(str(e))
2370
2371  def _get_attr_bool(self, name):
2372    """Returns the `bool` value of the attr of this op with the given `name`."""
2373    try:
2374      return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
2375    except errors.InvalidArgumentError as e:
2376      # Convert to ValueError for backwards compatibility.
2377      raise ValueError(str(e))
2378
2379  def _get_attr_int(self, name):
2380    """Returns the `int` value of the attr of this op with the given `name`."""
2381    try:
2382      return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
2383    except errors.InvalidArgumentError as e:
2384      # Convert to ValueError for backwards compatibility.
2385      raise ValueError(str(e))
2386
2387  def run(self, feed_dict=None, session=None):
2388    """Runs this operation in a `Session`.
2389
2390    Calling this method will execute all preceding operations that
2391    produce the inputs needed for this operation.
2392
2393    *N.B.* Before invoking `Operation.run()`, its graph must have been
2394    launched in a session, and either a default session must be
2395    available, or `session` must be specified explicitly.
2396
2397    Args:
2398      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
2399        `tf.Session.run` for a description of the valid feed values.
2400      session: (Optional.) The `Session` to be used to run to this operation. If
2401        none, the default session will be used.
2402    """
2403    _run_using_default_session(self, feed_dict, self.graph, session)
2404
2405_gradient_registry = registry.Registry("gradient")
2406
2407
2408@tf_export("RegisterGradient")
2409class RegisterGradient(object):
2410  """A decorator for registering the gradient function for an op type.
2411
2412  This decorator is only used when defining a new op type. For an op
2413  with `m` inputs and `n` outputs, the gradient function is a function
2414  that takes the original `Operation` and `n` `Tensor` objects
2415  (representing the gradients with respect to each output of the op),
2416  and returns `m` `Tensor` objects (representing the partial gradients
2417  with respect to each input of the op).
2418
2419  For example, assuming that operations of type `"Sub"` take two
2420  inputs `x` and `y`, and return a single output `x - y`, the
2421  following gradient function would be registered:
2422
2423  ```python
2424  @tf.RegisterGradient("Sub")
2425  def _sub_grad(unused_op, grad):
2426    return grad, tf.negative(grad)
2427  ```
2428
2429  The decorator argument `op_type` is the string type of an
2430  operation. This corresponds to the `OpDef.name` field for the proto
2431  that defines the operation.
2432  """
2433
2434  def __init__(self, op_type):
2435    """Creates a new decorator with `op_type` as the Operation type.
2436
2437    Args:
2438      op_type: The string type of an operation. This corresponds to the
2439        `OpDef.name` field for the proto that defines the operation.
2440
2441    Raises:
2442      TypeError: If `op_type` is not string.
2443    """
2444    if not isinstance(op_type, six.string_types):
2445      raise TypeError("op_type must be a string")
2446    self._op_type = op_type
2447
2448  def __call__(self, f):
2449    """Registers the function `f` as gradient function for `op_type`."""
2450    _gradient_registry.register(f, self._op_type)
2451    return f
2452
2453
2454@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2455@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2456def no_gradient(op_type):
2457  """Specifies that ops of type `op_type` is not differentiable.
2458
2459  This function should *not* be used for operations that have a
2460  well-defined gradient that is not yet implemented.
2461
2462  This function is only used when defining a new op type. It may be
2463  used for ops such as `tf.size()` that are not differentiable.  For
2464  example:
2465
2466  ```python
2467  tf.no_gradient("Size")
2468  ```
2469
2470  The gradient computed for 'op_type' will then propagate zeros.
2471
2472  For ops that have a well-defined gradient but are not yet implemented,
2473  no declaration should be made, and an error *must* be thrown if
2474  an attempt to request its gradient is made.
2475
2476  Args:
2477    op_type: The string type of an operation. This corresponds to the
2478      `OpDef.name` field for the proto that defines the operation.
2479
2480  Raises:
2481    TypeError: If `op_type` is not a string.
2482
2483  """
2484  if not isinstance(op_type, six.string_types):
2485    raise TypeError("op_type must be a string")
2486  _gradient_registry.register(None, op_type)
2487
2488
2489# Aliases for the old names, will be eventually removed.
2490NoGradient = no_gradient
2491NotDifferentiable = no_gradient
2492
2493
2494def get_gradient_function(op):
2495  """Returns the function that computes gradients for "op"."""
2496  if not op.inputs:
2497    return None
2498
2499  gradient_function = op._gradient_function  # pylint: disable=protected-access
2500  if gradient_function:
2501    return gradient_function
2502
2503  try:
2504    op_type = op.get_attr("_gradient_op_type")
2505  except ValueError:
2506    op_type = op.type
2507  return _gradient_registry.lookup(op_type)
2508
2509
2510def set_shape_and_handle_data_for_outputs(_):
2511  """No op. TODO(b/74620627): Remove this."""
2512  pass
2513
2514
2515class OpStats(object):
2516  """A holder for statistics about an operator.
2517
2518  This class holds information about the resource requirements for an op,
2519  including the size of its weight parameters on-disk and how many FLOPS it
2520  requires to execute forward inference.
2521
2522  If you define a new operation, you can create a function that will return a
2523  set of information about its usage of the CPU and disk space when serialized.
2524  The function itself takes a Graph object that's been set up so you can call
2525  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2526  argument.
2527
2528  """
2529
2530  def __init__(self, statistic_type, value=None):
2531    """Sets up the initial placeholders for the statistics."""
2532    self.statistic_type = statistic_type
2533    self.value = value
2534
2535  @property
2536  def statistic_type(self):
2537    return self._statistic_type
2538
2539  @statistic_type.setter
2540  def statistic_type(self, statistic_type):
2541    self._statistic_type = statistic_type
2542
2543  @property
2544  def value(self):
2545    return self._value
2546
2547  @value.setter
2548  def value(self, value):
2549    self._value = value
2550
2551  def __iadd__(self, other):
2552    if other.statistic_type != self.statistic_type:
2553      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2554                       (self.statistic_type, other.statistic_type))
2555    if self.value is None:
2556      self.value = other.value
2557    elif other.value is not None:
2558      self._value += other.value
2559    return self
2560
2561
2562_stats_registry = registry.Registry("statistical functions")
2563
2564
2565class RegisterStatistics(object):
2566  """A decorator for registering the statistics function for an op type.
2567
2568  This decorator can be defined for an op type so that it gives a
2569  report on the resources used by an instance of an operator, in the
2570  form of an OpStats object.
2571
2572  Well-known types of statistics include these so far:
2573
2574  - flops: When running a graph, the bulk of the computation happens doing
2575    numerical calculations like matrix multiplications. This type allows a node
2576    to return how many floating-point operations it takes to complete. The
2577    total number of FLOPs for a graph is a good guide to its expected latency.
2578
2579  You can add your own statistics just by picking a new type string, registering
2580  functions for the ops you care about, and then calling get_stats_for_node_def.
2581
2582  If a statistic for an op is registered multiple times, a KeyError will be
2583  raised.
2584
2585  Since the statistics is counted on a per-op basis. It is not suitable for
2586  model parameters (capacity), which is expected to be counted only once, even
2587  if it is shared by multiple ops. (e.g. RNN)
2588
2589  For example, you can define a new metric called doohickey for a Foo operation
2590  by placing this in your code:
2591
2592  ```python
2593  @ops.RegisterStatistics("Foo", "doohickey")
2594  def _calc_foo_bojangles(unused_graph, unused_node_def):
2595    return ops.OpStats("doohickey", 20)
2596  ```
2597
2598  Then in client code you can retrieve the value by making this call:
2599
2600  ```python
2601  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2602  ```
2603
2604  If the NodeDef is for an op with a registered doohickey function, you'll get
2605  back the calculated amount in doohickey.value, or None if it's not defined.
2606
2607  """
2608
2609  def __init__(self, op_type, statistic_type):
2610    """Saves the `op_type` as the `Operation` type."""
2611    if not isinstance(op_type, six.string_types):
2612      raise TypeError("op_type must be a string.")
2613    if "," in op_type:
2614      raise TypeError("op_type must not contain a comma.")
2615    self._op_type = op_type
2616    if not isinstance(statistic_type, six.string_types):
2617      raise TypeError("statistic_type must be a string.")
2618    if "," in statistic_type:
2619      raise TypeError("statistic_type must not contain a comma.")
2620    self._statistic_type = statistic_type
2621
2622  def __call__(self, f):
2623    """Registers "f" as the statistics function for "op_type"."""
2624    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2625    return f
2626
2627
2628def get_stats_for_node_def(graph, node, statistic_type):
2629  """Looks up the node's statistics function in the registry and calls it.
2630
2631  This function takes a Graph object and a NodeDef from a GraphDef, and if
2632  there's an associated statistics method, calls it and returns a result. If no
2633  function has been registered for the particular node type, it returns an empty
2634  statistics object.
2635
2636  Args:
2637    graph: A Graph object that's been set up with the node's graph.
2638    node: A NodeDef describing the operator.
2639    statistic_type: A string identifying the statistic we're interested in.
2640
2641  Returns:
2642    An OpStats object containing information about resource usage.
2643  """
2644
2645  try:
2646    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2647    result = stats_func(graph, node)
2648  except LookupError:
2649    result = OpStats(statistic_type)
2650  return result
2651
2652
2653def name_from_scope_name(name):
2654  """Returns the name of an op given the name of its scope.
2655
2656  Args:
2657    name: the name of the scope.
2658
2659  Returns:
2660    the name of the op (equal to scope name minus any trailing slash).
2661  """
2662  return name[:-1] if (name and name[-1] == "/") else name
2663
2664
2665_MUTATION_LOCK_GROUP = 0
2666_SESSION_RUN_LOCK_GROUP = 1
2667
2668
2669@tf_export("Graph")
2670class Graph(object):
2671  """A TensorFlow computation, represented as a dataflow graph.
2672
2673  Graphs are used by `tf.function`s to represent the function's computations.
2674  Each graph contains a set of `tf.Operation` objects, which represent units of
2675  computation; and `tf.Tensor` objects, which represent the units of data that
2676  flow between operations.
2677
2678  ### Using graphs directly (deprecated)
2679
2680  A `tf.Graph` can be constructed and used directly without a `tf.function`, as
2681  was required in TensorFlow 1, but this is deprecated and it is recommended to
2682  use a `tf.function` instead. If a graph is directly used, other deprecated
2683  TensorFlow 1 classes are also required to execute the graph, such as a
2684  `tf.compat.v1.Session`.
2685
2686  A default graph can be registered with the `tf.Graph.as_default` context
2687  manager. Then, operations will be added to the graph instead of being executed
2688  eagerly. For example:
2689
2690  ```python
2691  g = tf.Graph()
2692  with g.as_default():
2693    # Define operations and tensors in `g`.
2694    c = tf.constant(30.0)
2695    assert c.graph is g
2696  ```
2697
2698  `tf.compat.v1.get_default_graph()` can be used to obtain the default graph.
2699
2700  Important note: This class *is not* thread-safe for graph construction. All
2701  operations should be created from a single thread, or external
2702  synchronization must be provided. Unless otherwise specified, all methods
2703  are not thread-safe.
2704
2705  A `Graph` instance supports an arbitrary number of "collections"
2706  that are identified by name. For convenience when building a large
2707  graph, collections can store groups of related objects: for
2708  example, the `tf.Variable` uses a collection (named
2709  `tf.GraphKeys.GLOBAL_VARIABLES`) for
2710  all variables that are created during the construction of a graph. The caller
2711  may define additional collections by specifying a new name.
2712  """
2713
2714  def __init__(self):
2715    """Creates a new, empty Graph."""
2716    # Protects core state that can be returned via public accessors.
2717    # Thread-safety is provided on a best-effort basis to support buggy
2718    # programs, and is not guaranteed by the public `tf.Graph` API.
2719    #
2720    # NOTE(mrry): This does not protect the various stacks. A warning will
2721    # be reported if these are used from multiple threads
2722    self._lock = threading.RLock()
2723    # The group lock synchronizes Session.run calls with methods that create
2724    # and mutate ops (e.g. Graph.create_op()). This synchronization is
2725    # necessary because it's illegal to modify an operation after it's been run.
2726    # The group lock allows any number of threads to mutate ops at the same time
2727    # but if any modification is going on, all Session.run calls have to wait.
2728    # Similarly, if one or more Session.run calls are going on, all mutate ops
2729    # have to wait until all Session.run calls have finished.
2730    self._group_lock = lock_util.GroupLock(num_groups=2)
2731    self._nodes_by_id = {}  # GUARDED_BY(self._lock)
2732    self._next_id_counter = 0  # GUARDED_BY(self._lock)
2733    self._nodes_by_name = {}  # GUARDED_BY(self._lock)
2734    self._version = 0  # GUARDED_BY(self._lock)
2735    # Maps a name used in the graph to the next id to use for that name.
2736    self._names_in_use = {}
2737    self._stack_state_is_thread_local = False
2738    self._thread_local = threading.local()
2739    # Functions that will be applied to choose a device if none is specified.
2740    # In TF2.x or after switch_to_thread_local(),
2741    # self._thread_local._device_function_stack is used instead.
2742    self._graph_device_function_stack = traceable_stack.TraceableStack()
2743    # Default original_op applied to new ops.
2744    self._default_original_op = None
2745    # Current control flow context. It could be either CondContext or
2746    # WhileContext defined in ops/control_flow_ops.py
2747    self._control_flow_context = None
2748    # A new node will depend of the union of all of the nodes in the stack.
2749    # In TF2.x or after switch_to_thread_local(),
2750    # self._thread_local._control_dependencies_stack is used instead.
2751    self._graph_control_dependencies_stack = []
2752    # Arbitrary collections of objects.
2753    self._collections = {}
2754    # The graph-level random seed
2755    self._seed = None
2756    # A dictionary of attributes that should be applied to all ops.
2757    self._attr_scope_map = {}
2758    # A map from op type to the kernel label that should be used.
2759    self._op_to_kernel_label_map = {}
2760    # A map from op type to an alternative op type that should be used when
2761    # computing gradients.
2762    self._gradient_override_map = {}
2763    # A map from op type to a gradient function that should be used instead.
2764    self._gradient_function_map = {}
2765    # True if the graph is considered "finalized".  In that case no
2766    # new operations can be added.
2767    self._finalized = False
2768    # Functions defined in the graph
2769    self._functions = collections.OrderedDict()
2770    # Default GraphDef versions
2771    self._graph_def_versions = versions_pb2.VersionDef(
2772        producer=versions.GRAPH_DEF_VERSION,
2773        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
2774    self._building_function = False
2775    # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
2776    # self._thread_local._colocation_stack is used instead.
2777    self._graph_colocation_stack = traceable_stack.TraceableStack()
2778    # Set of tensors that are dangerous to feed!
2779    self._unfeedable_tensors = object_identity.ObjectIdentitySet()
2780    # Set of operations that are dangerous to fetch!
2781    self._unfetchable_ops = set()
2782    # A map of tensor handle placeholder to tensor dtype.
2783    self._handle_feeders = {}
2784    # A map from tensor handle to its read op.
2785    self._handle_readers = {}
2786    # A map from tensor handle to its move op.
2787    self._handle_movers = {}
2788    # A map from tensor handle to its delete op.
2789    self._handle_deleters = {}
2790    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
2791    # will be shared when defining function graphs, for example, so optimizers
2792    # being called inside function definitions behave as if they were seeing the
2793    # actual outside graph).
2794    self._graph_key = "grap-key-%d/" % (uid(),)
2795    # A string with the last reduction method passed to
2796    # losses.compute_weighted_loss(), or None. This is required only for
2797    # backward compatibility with Estimator and optimizer V1 use cases.
2798    self._last_loss_reduction = None
2799    # Flag that is used to indicate whether loss has been scaled by optimizer.
2800    # If this flag has been set, then estimator uses it to scale losss back
2801    # before reporting. This is required only for backward compatibility with
2802    # Estimator and optimizer V1 use cases.
2803    self._is_loss_scaled_by_optimizer = False
2804    self._container = ""
2805    # Set to True if this graph is being built in an
2806    # AutomaticControlDependencies context.
2807    self._add_control_dependencies = False
2808    # Cache for OpDef protobufs retrieved via the C API.
2809    self._op_def_cache = {}
2810    # Cache for constant results of `broadcast_gradient_args()`. The keys are
2811    # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
2812    # values are tuples of reduction indices: (rx, ry).
2813    self._bcast_grad_args_cache = {}
2814    # Cache for constant results of `reduced_shape()`. The keys are pairs of
2815    # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
2816    # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
2817    self._reduced_shape_cache = {}
2818
2819    # TODO(skyewm): fold as much of the above as possible into the C
2820    # implementation
2821    self._scoped_c_graph = c_api_util.ScopedTFGraph()
2822    # The C API requires all ops to have shape functions. Disable this
2823    # requirement (many custom ops do not have shape functions, and we don't
2824    # want to break these existing cases).
2825    pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False)
2826    if tf2.enabled():
2827      self.switch_to_thread_local()
2828
2829  # Note: this method is private because the API of tf.Graph() is public and
2830  # frozen, and this functionality is still not ready for public visibility.
2831  @tf_contextlib.contextmanager
2832  def _variable_creator_scope(self, creator, priority=100):
2833    """Scope which defines a variable creation function.
2834
2835    Args:
2836      creator: A callable taking `next_creator` and `kwargs`. See the
2837        `tf.variable_creator_scope` docstring.
2838      priority: Creators with a higher `priority` are called first. Within the
2839        same priority, creators are called inner-to-outer.
2840
2841    Yields:
2842      `_variable_creator_scope` is a context manager with a side effect, but
2843      doesn't return a value.
2844
2845    Raises:
2846      RuntimeError: If variable creator scopes are not properly nested.
2847    """
2848    # This step keeps a reference to the existing stack, and it also initializes
2849    # self._thread_local._variable_creator_stack if it doesn't exist yet.
2850    old = self._variable_creator_stack
2851    new = list(old)
2852    new.append((priority, creator))
2853    # Sorting is stable, so we'll put higher-priority creators later in the list
2854    # but otherwise maintain registration order.
2855    new.sort(key=lambda item: item[0])
2856    self._thread_local._variable_creator_stack = new  # pylint: disable=protected-access
2857    try:
2858      yield
2859    finally:
2860      if self._thread_local._variable_creator_stack is not new:  # pylint: disable=protected-access
2861        raise RuntimeError(
2862            "Exiting variable_creator_scope without proper nesting.")
2863      self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
2864
2865  # Note: this method is private because the API of tf.Graph() is public and
2866  # frozen, and this functionality is still not ready for public visibility.
2867  @property
2868  def _variable_creator_stack(self):
2869    if not hasattr(self._thread_local, "_variable_creator_stack"):
2870      self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
2871
2872    # This previously returned a copy of the stack instead of the stack itself,
2873    # to guard against accidental mutation. Consider, however, code that wants
2874    # to save and restore the variable creator stack:
2875    #     def f():
2876    #       original_stack = graph._variable_creator_stack
2877    #       graph._variable_creator_stack = new_stack
2878    #       ...  # Some code
2879    #       graph._variable_creator_stack = original_stack
2880    #
2881    # And lets say you have some code that calls this function with some
2882    # variable_creator:
2883    #     def g():
2884    #       with variable_scope.variable_creator_scope(creator):
2885    #         f()
2886    # When exiting the variable creator scope, it would see a different stack
2887    # object than it expected leading to a "Exiting variable_creator_scope
2888    # without proper nesting" error.
2889    return self._thread_local._variable_creator_stack  # pylint: disable=protected-access
2890
2891  @_variable_creator_stack.setter
2892  def _variable_creator_stack(self, variable_creator_stack):
2893    self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
2894
2895  def _check_not_finalized(self):
2896    """Check if the graph is finalized.
2897
2898    Raises:
2899      RuntimeError: If the graph finalized.
2900    """
2901    if self._finalized:
2902      raise RuntimeError("Graph is finalized and cannot be modified.")
2903
2904  def _add_op(self, op, op_name):
2905    """Adds 'op' to the graph and returns the unique ID for the added Operation.
2906
2907    Args:
2908      op: the Operation to add.
2909      op_name: the name of the Operation.
2910
2911    Returns:
2912      An integer that is a unique ID for the added Operation.
2913    """
2914    self._check_not_finalized()
2915    with self._lock:
2916      self._next_id_counter += 1
2917      op_id = self._next_id_counter
2918      self._nodes_by_id[op_id] = op
2919      self._nodes_by_name[op_name] = op
2920      self._version = max(self._version, op_id)
2921      return op_id
2922
2923  @property
2924  def _c_graph(self):
2925    if self._scoped_c_graph:
2926      return self._scoped_c_graph.graph
2927    return None
2928
2929  @property
2930  def version(self):
2931    """Returns a version number that increases as ops are added to the graph.
2932
2933    Note that this is unrelated to the
2934    `tf.Graph.graph_def_versions`.
2935
2936    Returns:
2937       An integer version that increases as ops are added to the graph.
2938    """
2939    if self._finalized:
2940      return self._version
2941
2942    with self._lock:
2943      return self._version
2944
2945  @property
2946  def graph_def_versions(self):
2947    # pylint: disable=line-too-long
2948    """The GraphDef version information of this graph.
2949
2950    For details on the meaning of each version, see
2951    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
2952
2953    Returns:
2954      A `VersionDef`.
2955    """
2956    # pylint: enable=line-too-long
2957    with c_api_util.tf_buffer() as buf:
2958      pywrap_tf_session.TF_GraphVersions(self._c_graph, buf)
2959      data = pywrap_tf_session.TF_GetBuffer(buf)
2960    version_def = versions_pb2.VersionDef()
2961    version_def.ParseFromString(compat.as_bytes(data))
2962    return version_def
2963
2964  @property
2965  def seed(self):
2966    """The graph-level random seed of this graph."""
2967    return self._seed
2968
2969  @seed.setter
2970  def seed(self, seed):
2971    self._seed = seed
2972
2973  @property
2974  def finalized(self):
2975    """True if this graph has been finalized."""
2976    return self._finalized
2977
2978  def finalize(self):
2979    """Finalizes this graph, making it read-only.
2980
2981    After calling `g.finalize()`, no new operations can be added to
2982    `g`.  This method is used to ensure that no operations are added
2983    to a graph when it is shared between multiple threads, for example
2984    when using a `tf.compat.v1.train.QueueRunner`.
2985    """
2986    self._finalized = True
2987
2988  def _unsafe_unfinalize(self):
2989    """Opposite of `finalize`.
2990
2991    Internal interface.
2992
2993    NOTE: Unfinalizing a graph could have negative impact on performance,
2994    especially in a multi-threaded environment.  Unfinalizing a graph
2995    when it is in use by a Session may lead to undefined behavior. Ensure
2996    that all sessions using a graph are closed before calling this method.
2997    """
2998    self._finalized = False
2999
3000  def _get_control_flow_context(self):
3001    """Returns the current control flow context.
3002
3003    Returns:
3004      A context object.
3005    """
3006    return self._control_flow_context
3007
3008  def _set_control_flow_context(self, ctx):
3009    """Sets the current control flow context.
3010
3011    Args:
3012      ctx: a context object.
3013    """
3014    self._control_flow_context = ctx
3015
3016  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3017    """If this graph contains functions, copy them to `graph_def`."""
3018    bytesize = starting_bytesize
3019    for f in self._functions.values():
3020      bytesize += f.definition.ByteSize()
3021      if bytesize >= (1 << 31) or bytesize < 0:
3022        raise ValueError("GraphDef cannot be larger than 2GB.")
3023      graph_def.library.function.extend([f.definition])
3024      if f.grad_func_name:
3025        grad_def = function_pb2.GradientDef()
3026        grad_def.function_name = f.name
3027        grad_def.gradient_func = f.grad_func_name
3028        graph_def.library.gradient.extend([grad_def])
3029
3030  def _as_graph_def(self, from_version=None, add_shapes=False):
3031    # pylint: disable=line-too-long
3032    """Returns a serialized `GraphDef` representation of this graph.
3033
3034    The serialized `GraphDef` can be imported into another `Graph`
3035    (using `tf.import_graph_def`) or used with the
3036    [C++ Session API](../../../../api_docs/cc/index.md).
3037
3038    This method is thread-safe.
3039
3040    Args:
3041      from_version: Optional.  If this is set, returns a `GraphDef` containing
3042        only the nodes that were added to this graph since its `version`
3043        property had the given value.
3044      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3045        the inferred shapes of each of its outputs.
3046
3047    Returns:
3048      A tuple containing a
3049      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3050      protocol buffer, and the version of the graph to which that
3051      `GraphDef` corresponds.
3052
3053    Raises:
3054      ValueError: If the `graph_def` would be too large.
3055
3056    """
3057    # pylint: enable=line-too-long
3058    with self._lock:
3059      with c_api_util.tf_buffer() as buf:
3060        pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf)
3061        data = pywrap_tf_session.TF_GetBuffer(buf)
3062      graph = graph_pb2.GraphDef()
3063      graph.ParseFromString(compat.as_bytes(data))
3064      # Strip the experimental library field iff it's empty.
3065      if not graph.library.function:
3066        graph.ClearField("library")
3067
3068      if add_shapes:
3069        for node in graph.node:
3070          op = self._nodes_by_name[node.name]
3071          if op.outputs:
3072            node.attr["_output_shapes"].list.shape.extend(
3073                [output.get_shape().as_proto() for output in op.outputs])
3074        for function_def in graph.library.function:
3075          defined_function = self._functions[function_def.signature.name]
3076          try:
3077            func_graph = defined_function.graph
3078          except AttributeError:
3079            # _DefinedFunction doesn't have a graph, _EagerDefinedFunction
3080            # does. Both rely on ops.py, so we can't really isinstance check
3081            # them.
3082            continue
3083          input_shapes = function_def.attr["_input_shapes"]
3084          try:
3085            func_graph_inputs = func_graph.inputs
3086          except AttributeError:
3087            continue
3088          # TODO(b/141471245): Fix the inconsistency when inputs of func graph
3089          # are appended during gradient computation of while/cond.
3090          for input_tensor, _ in zip(func_graph_inputs,
3091                                     function_def.signature.input_arg):
3092            if input_tensor.dtype == dtypes.resource:
3093              # TODO(allenl): Save and restore handle data, then save the
3094              # resource placeholder's shape. Right now some shape functions get
3095              # confused if we set the shape of the resource placeholder (to a
3096              # scalar of course) and there isn't any handle data.
3097              input_shapes.list.shape.add().CopyFrom(
3098                  tensor_shape.TensorShape(None).as_proto())
3099            else:
3100              input_shapes.list.shape.add().CopyFrom(
3101                  input_tensor.get_shape().as_proto())
3102          for node in function_def.node_def:
3103            try:
3104              op = func_graph.get_operation_by_name(node.name)
3105            except KeyError:
3106              continue
3107            outputs = op.outputs
3108
3109            if op.type == "StatefulPartitionedCall":
3110              # Filter out any extra outputs (possibly added by function
3111              # backpropagation rewriting).
3112              num_outputs = len(node.attr["Tout"].list.type)
3113              outputs = outputs[:num_outputs]
3114
3115            node.attr["_output_shapes"].list.shape.extend(
3116                [output.get_shape().as_proto() for output in outputs])
3117
3118    return graph, self._version
3119
3120  def as_graph_def(self, from_version=None, add_shapes=False):
3121    # pylint: disable=line-too-long
3122    """Returns a serialized `GraphDef` representation of this graph.
3123
3124    The serialized `GraphDef` can be imported into another `Graph`
3125    (using `tf.import_graph_def`) or used with the
3126    [C++ Session API](../../api_docs/cc/index.md).
3127
3128    This method is thread-safe.
3129
3130    Args:
3131      from_version: Optional.  If this is set, returns a `GraphDef` containing
3132        only the nodes that were added to this graph since its `version`
3133        property had the given value.
3134      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3135        the inferred shapes of each of its outputs.
3136
3137    Returns:
3138      A
3139      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3140      protocol buffer.
3141
3142    Raises:
3143      ValueError: If the `graph_def` would be too large.
3144    """
3145    # pylint: enable=line-too-long
3146    result, _ = self._as_graph_def(from_version, add_shapes)
3147    return result
3148
3149  def _is_function(self, name):
3150    """Tests whether 'name' is registered in this graph's function library.
3151
3152    Args:
3153      name: string op name.
3154
3155    Returns:
3156      bool indicating whether or not 'name' is registered in function library.
3157    """
3158    return compat.as_str(name) in self._functions
3159
3160  def _get_function(self, name):
3161    """Returns the function definition for 'name'.
3162
3163    Args:
3164      name: string function name.
3165
3166    Returns:
3167      The function def proto.
3168    """
3169    return self._functions.get(compat.as_str(name), None)
3170
3171  def _add_function(self, function):
3172    """Adds a function to the graph.
3173
3174    After the function has been added, you can call to the function by
3175    passing the function name in place of an op name to
3176    `Graph.create_op()`.
3177
3178    Args:
3179      function: A `_DefinedFunction` object.
3180
3181    Raises:
3182      ValueError: if another function is defined with the same name.
3183    """
3184    name = function.name
3185    # Sanity checks on gradient definition.
3186    if (function.grad_func_name is not None) and (function.python_grad_func is
3187                                                  not None):
3188      raise ValueError("Gradient defined twice for function %s" % name)
3189
3190    # Add function to graph
3191    # pylint: disable=protected-access
3192    gradient = (
3193        function._grad_func._c_func.func if function._grad_func else None)
3194    pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
3195                                           gradient)
3196    # pylint: enable=protected-access
3197
3198    self._functions[compat.as_str(name)] = function
3199
3200    # Need a new-enough consumer to support the functions we add to the graph.
3201    if self._graph_def_versions.min_consumer < 12:
3202      self._graph_def_versions.min_consumer = 12
3203
3204  @property
3205  def building_function(self):
3206    """Returns True iff this graph represents a function."""
3207    return self._building_function
3208
3209  # Helper functions to create operations.
3210  @deprecated_args(None,
3211                   "Shapes are always computed; don't use the compute_shapes "
3212                   "as it has no effect.", "compute_shapes")
3213  def create_op(
3214      self,
3215      op_type,
3216      inputs,
3217      dtypes=None,  # pylint: disable=redefined-outer-name
3218      input_types=None,
3219      name=None,
3220      attrs=None,
3221      op_def=None,
3222      compute_shapes=True,
3223      compute_device=True):
3224    """Creates an `Operation` in this graph.
3225
3226    This is a low-level interface for creating an `Operation`. Most
3227    programs will not call this method directly, and instead use the
3228    Python op constructors, such as `tf.constant()`, which add ops to
3229    the default graph.
3230
3231    Args:
3232      op_type: The `Operation` type to create. This corresponds to the
3233        `OpDef.name` field for the proto that defines the operation.
3234      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3235      dtypes: (Optional) A list of `DType` objects that will be the types of the
3236        tensors that the operation produces.
3237      input_types: (Optional.) A list of `DType`s that will be the types of the
3238        tensors that the operation consumes. By default, uses the base `DType`
3239        of each input in `inputs`. Operations that expect reference-typed inputs
3240        must specify `input_types` explicitly.
3241      name: (Optional.) A string name for the operation. If not specified, a
3242        name is generated based on `op_type`.
3243      attrs: (Optional.) A dictionary where the key is the attribute name (a
3244        string) and the value is the respective `attr` attribute of the
3245        `NodeDef` proto that will represent the operation (an `AttrValue`
3246        proto).
3247      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3248        the operation will have.
3249      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3250        computed).
3251      compute_device: (Optional.) If True, device functions will be executed to
3252        compute the device property of the Operation.
3253
3254    Raises:
3255      TypeError: if any of the inputs is not a `Tensor`.
3256      ValueError: if colocation conflicts with existing device assignment.
3257
3258    Returns:
3259      An `Operation` object.
3260    """
3261    del compute_shapes
3262    for idx, a in enumerate(inputs):
3263      if not isinstance(a, Tensor):
3264        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3265    return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
3266                                    attrs, op_def, compute_device)
3267
3268  def _create_op_internal(
3269      self,
3270      op_type,
3271      inputs,
3272      dtypes=None,  # pylint: disable=redefined-outer-name
3273      input_types=None,
3274      name=None,
3275      attrs=None,
3276      op_def=None,
3277      compute_device=True):
3278    """Creates an `Operation` in this graph.
3279
3280    Implements `Graph.create_op()` without the overhead of the deprecation
3281    wrapper.
3282
3283    Args:
3284      op_type: The `Operation` type to create. This corresponds to the
3285        `OpDef.name` field for the proto that defines the operation.
3286      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3287      dtypes: (Optional) A list of `DType` objects that will be the types of the
3288        tensors that the operation produces.
3289      input_types: (Optional.) A list of `DType`s that will be the types of the
3290        tensors that the operation consumes. By default, uses the base `DType`
3291        of each input in `inputs`. Operations that expect reference-typed inputs
3292        must specify `input_types` explicitly.
3293      name: (Optional.) A string name for the operation. If not specified, a
3294        name is generated based on `op_type`.
3295      attrs: (Optional.) A dictionary where the key is the attribute name (a
3296        string) and the value is the respective `attr` attribute of the
3297        `NodeDef` proto that will represent the operation (an `AttrValue`
3298        proto).
3299      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3300        the operation will have.
3301      compute_device: (Optional.) If True, device functions will be executed to
3302        compute the device property of the Operation.
3303
3304    Raises:
3305      ValueError: if colocation conflicts with existing device assignment.
3306
3307    Returns:
3308      An `Operation` object.
3309    """
3310    self._check_not_finalized()
3311    if name is None:
3312      name = op_type
3313    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3314    # after removing the trailing '/'.
3315    if name and name[-1] == "/":
3316      name = name_from_scope_name(name)
3317    else:
3318      name = self.unique_name(name)
3319
3320    node_def = _NodeDef(op_type, name, attrs)
3321
3322    input_ops = set(t.op for t in inputs)
3323    control_inputs = self._control_dependencies_for_inputs(input_ops)
3324    # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3325    # Session.run call cannot occur between creating and mutating the op.
3326    with self._mutation_lock():
3327      ret = Operation(
3328          node_def,
3329          self,
3330          inputs=inputs,
3331          output_types=dtypes,
3332          control_inputs=control_inputs,
3333          input_types=input_types,
3334          original_op=self._default_original_op,
3335          op_def=op_def)
3336      self._create_op_helper(ret, compute_device=compute_device)
3337    return ret
3338
3339  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3340    """Creates an `Operation` in this graph from the supplied TF_Operation.
3341
3342    This method is like create_op() except the new Operation is constructed
3343    using `c_op`. The returned Operation will have `c_op` as its _c_op
3344    field. This is used to create Operation objects around TF_Operations created
3345    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3346
3347    This function does not call Operation._control_flow_post_processing or
3348    Graph._control_dependencies_for_inputs (since the inputs may not be
3349    available yet). The caller is responsible for calling these methods.
3350
3351    Args:
3352      c_op: a wrapped TF_Operation
3353      compute_device: (Optional.) If True, device functions will be executed to
3354        compute the device property of the Operation.
3355
3356    Returns:
3357      An `Operation` object.
3358    """
3359    self._check_not_finalized()
3360    ret = Operation(c_op, self)
3361    # If a name_scope was created with ret.name but no nodes were created in it,
3362    # the name will still appear in _names_in_use even though the name hasn't
3363    # been used. This is ok, just leave _names_in_use as-is in this case.
3364    # TODO(skyewm): make the C API guarantee no name conflicts.
3365    name_key = ret.name.lower()
3366    if name_key not in self._names_in_use:
3367      self._names_in_use[name_key] = 1
3368    self._create_op_helper(ret, compute_device=compute_device)
3369    return ret
3370
3371  def _create_op_helper(self, op, compute_device=True):
3372    """Common logic for creating an op in this graph."""
3373    # Apply any additional attributes requested. Do not overwrite any existing
3374    # attributes.
3375    for key, value in self._attr_scope_map.items():
3376      try:
3377        op.get_attr(key)
3378      except ValueError:
3379        if callable(value):
3380          value = value(op.node_def)
3381          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3382            raise TypeError(
3383                "Callable for scope map key '%s' must return either None or "
3384                "an AttrValue protocol buffer; but it returned: %s" %
3385                (key, value))
3386        if value:
3387          op._set_attr(key, value)  # pylint: disable=protected-access
3388
3389    # Apply a kernel label if one has been specified for this op type.
3390    try:
3391      kernel_label = self._op_to_kernel_label_map[op.type]
3392      op._set_attr("_kernel",  # pylint: disable=protected-access
3393                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3394    except KeyError:
3395      pass
3396
3397    op._gradient_function = self._gradient_function_map.get(op.type)  # pylint: disable=protected-access
3398
3399    # Apply the overriding op type for gradients if one has been specified for
3400    # this op type.
3401    try:
3402      mapped_op_type = self._gradient_override_map[op.type]
3403      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3404                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3405    except KeyError:
3406      pass
3407
3408    self._record_op_seen_by_control_dependencies(op)
3409
3410    if compute_device:
3411      self._apply_device_functions(op)
3412
3413    # Snapshot the colocation stack metadata before we might generate error
3414    # messages using it.  Note that this snapshot depends on the actual stack
3415    # and is independent of the op's _class attribute.
3416    # pylint: disable=protected-access
3417    op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3418    # pylint: enable=protected-access
3419
3420    if self._colocation_stack:
3421      all_colocation_groups = []
3422      for colocation_op in self._colocation_stack.peek_objs():
3423        all_colocation_groups.extend(colocation_op.colocation_groups())
3424        if colocation_op.device:
3425          # pylint: disable=protected-access
3426          op._set_device(colocation_op.device)
3427          # pylint: enable=protected-access
3428
3429      all_colocation_groups = sorted(set(all_colocation_groups))
3430      # pylint: disable=protected-access
3431      op._set_attr(
3432          "_class",
3433          attr_value_pb2.AttrValue(
3434              list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3435      # pylint: enable=protected-access
3436
3437    # Sets "container" attribute if
3438    # (1) self._container is not None
3439    # (2) "is_stateful" is set in OpDef
3440    # (3) "container" attribute is in OpDef
3441    # (4) "container" attribute is None
3442    if self._container and op._is_stateful:  # pylint: disable=protected-access
3443      try:
3444        container_attr = op.get_attr("container")
3445      except ValueError:
3446        # "container" attribute is not in OpDef
3447        pass
3448      else:
3449        if not container_attr:
3450          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3451              s=compat.as_bytes(self._container)))
3452
3453  def _add_new_tf_operations(self, compute_devices=True):
3454    """Creates `Operations` in this graph for any new TF_Operations.
3455
3456    This is useful for when TF_Operations are indirectly created by the C API
3457    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3458    TF_FinishWhile). This ensures there are corresponding Operations for all
3459    TF_Operations in the underlying TF_Graph.
3460
3461    Args:
3462      compute_devices: (Optional.) If True, device functions will be executed to
3463        compute the device properties of each new Operation.
3464
3465    Returns:
3466      A list of the new `Operation` objects.
3467    """
3468    # Create all Operation objects before accessing their inputs since an op may
3469    # be created before its inputs.
3470    new_ops = [
3471        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3472        for c_op in c_api_util.new_tf_operations(self)
3473    ]
3474
3475    # pylint: disable=protected-access
3476    for op in new_ops:
3477      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3478      op._add_control_inputs(new_control_inputs)
3479      op._control_flow_post_processing()
3480    # pylint: enable=protected-access
3481
3482    return new_ops
3483
3484  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3485    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3486
3487    This function validates that `obj` represents an element of this
3488    graph, and gives an informative error message if it is not.
3489
3490    This function is the canonical way to get/validate an object of
3491    one of the allowed types from an external argument reference in the
3492    Session API.
3493
3494    This method may be called concurrently from multiple threads.
3495
3496    Args:
3497      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can
3498        also be any object with an `_as_graph_element()` method that returns a
3499        value of one of these types. Note: `_as_graph_element` will be called
3500        inside the graph's lock and so may not modify the graph.
3501      allow_tensor: If true, `obj` may refer to a `Tensor`.
3502      allow_operation: If true, `obj` may refer to an `Operation`.
3503
3504    Returns:
3505      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3506
3507    Raises:
3508      TypeError: If `obj` is not a type we support attempting to convert
3509        to types.
3510      ValueError: If `obj` is of an appropriate type but invalid. For
3511        example, an invalid string.
3512      KeyError: If `obj` is not an object in the graph.
3513    """
3514    if self._finalized:
3515      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3516
3517    with self._lock:
3518      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3519
3520  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3521    """See `Graph.as_graph_element()` for details."""
3522    # The vast majority of this function is figuring
3523    # out what an API user might be doing wrong, so
3524    # that we can give helpful error messages.
3525    #
3526    # Ideally, it would be nice to split it up, but we
3527    # need context to generate nice error messages.
3528
3529    if allow_tensor and allow_operation:
3530      types_str = "Tensor or Operation"
3531    elif allow_tensor:
3532      types_str = "Tensor"
3533    elif allow_operation:
3534      types_str = "Operation"
3535    else:
3536      raise ValueError("allow_tensor and allow_operation can't both be False.")
3537
3538    temp_obj = _as_graph_element(obj)
3539    if temp_obj is not None:
3540      obj = temp_obj
3541
3542    # If obj appears to be a name...
3543    if isinstance(obj, compat.bytes_or_text_types):
3544      name = compat.as_str(obj)
3545
3546      if ":" in name and allow_tensor:
3547        # Looks like a Tensor name and can be a Tensor.
3548        try:
3549          op_name, out_n = name.split(":")
3550          out_n = int(out_n)
3551        except:
3552          raise ValueError("The name %s looks a like a Tensor name, but is "
3553                           "not a valid one. Tensor names must be of the "
3554                           "form \"<op_name>:<output_index>\"." % repr(name))
3555        if op_name in self._nodes_by_name:
3556          op = self._nodes_by_name[op_name]
3557        else:
3558          raise KeyError("The name %s refers to a Tensor which does not "
3559                         "exist. The operation, %s, does not exist in the "
3560                         "graph." % (repr(name), repr(op_name)))
3561        try:
3562          return op.outputs[out_n]
3563        except:
3564          raise KeyError("The name %s refers to a Tensor which does not "
3565                         "exist. The operation, %s, exists but only has "
3566                         "%s outputs." %
3567                         (repr(name), repr(op_name), len(op.outputs)))
3568
3569      elif ":" in name and not allow_tensor:
3570        # Looks like a Tensor name but can't be a Tensor.
3571        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3572                         (repr(name), types_str))
3573
3574      elif ":" not in name and allow_operation:
3575        # Looks like an Operation name and can be an Operation.
3576        if name not in self._nodes_by_name:
3577          raise KeyError("The name %s refers to an Operation not in the "
3578                         "graph." % repr(name))
3579        return self._nodes_by_name[name]
3580
3581      elif ":" not in name and not allow_operation:
3582        # Looks like an Operation name but can't be an Operation.
3583        if name in self._nodes_by_name:
3584          # Yep, it's an Operation name
3585          err_msg = ("The name %s refers to an Operation, not a %s." %
3586                     (repr(name), types_str))
3587        else:
3588          err_msg = ("The name %s looks like an (invalid) Operation name, "
3589                     "not a %s." % (repr(name), types_str))
3590        err_msg += (" Tensor names must be of the form "
3591                    "\"<op_name>:<output_index>\".")
3592        raise ValueError(err_msg)
3593
3594    elif isinstance(obj, Tensor) and allow_tensor:
3595      # Actually obj is just the object it's referring to.
3596      if obj.graph is not self:
3597        raise ValueError("Tensor %s is not an element of this graph." % obj)
3598      return obj
3599    elif isinstance(obj, Operation) and allow_operation:
3600      # Actually obj is just the object it's referring to.
3601      if obj.graph is not self:
3602        raise ValueError("Operation %s is not an element of this graph." % obj)
3603      return obj
3604    else:
3605      # We give up!
3606      raise TypeError("Can not convert a %s into a %s." %
3607                      (type(obj).__name__, types_str))
3608
3609  def get_operations(self):
3610    """Return the list of operations in the graph.
3611
3612    You can modify the operations in place, but modifications
3613    to the list such as inserts/delete have no effect on the
3614    list of operations known to the graph.
3615
3616    This method may be called concurrently from multiple threads.
3617
3618    Returns:
3619      A list of Operations.
3620    """
3621    if self._finalized:
3622      return list(self._nodes_by_id.values())
3623
3624    with self._lock:
3625      return list(self._nodes_by_id.values())
3626
3627  def get_operation_by_name(self, name):
3628    """Returns the `Operation` with the given `name`.
3629
3630    This method may be called concurrently from multiple threads.
3631
3632    Args:
3633      name: The name of the `Operation` to return.
3634
3635    Returns:
3636      The `Operation` with the given `name`.
3637
3638    Raises:
3639      TypeError: If `name` is not a string.
3640      KeyError: If `name` does not correspond to an operation in this graph.
3641    """
3642
3643    if not isinstance(name, six.string_types):
3644      raise TypeError("Operation names are strings (or similar), not %s." %
3645                      type(name).__name__)
3646    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3647
3648  def _get_operation_by_name_unsafe(self, name):
3649    """Returns the `Operation` with the given `name`.
3650
3651    This is a internal unsafe version of get_operation_by_name. It skips many
3652    checks and does not have user friedly error messages but runs considerably
3653    faster. This method may be called concurrently from multiple threads.
3654
3655    Args:
3656      name: The name of the `Operation` to return.
3657
3658    Returns:
3659      The `Operation` with the given `name`.
3660
3661    Raises:
3662      KeyError: If `name` does not correspond to an operation in this graph.
3663    """
3664
3665    if self._finalized:
3666      return self._nodes_by_name[name]
3667
3668    with self._lock:
3669      return self._nodes_by_name[name]
3670
3671  def _get_operation_by_tf_operation(self, tf_oper):
3672    op_name = pywrap_tf_session.TF_OperationName(tf_oper)
3673    return self._get_operation_by_name_unsafe(op_name)
3674
3675  def get_tensor_by_name(self, name):
3676    """Returns the `Tensor` with the given `name`.
3677
3678    This method may be called concurrently from multiple threads.
3679
3680    Args:
3681      name: The name of the `Tensor` to return.
3682
3683    Returns:
3684      The `Tensor` with the given `name`.
3685
3686    Raises:
3687      TypeError: If `name` is not a string.
3688      KeyError: If `name` does not correspond to a tensor in this graph.
3689    """
3690    # Names should be strings.
3691    if not isinstance(name, six.string_types):
3692      raise TypeError("Tensor names are strings (or similar), not %s." %
3693                      type(name).__name__)
3694    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
3695
3696  def _get_tensor_by_tf_output(self, tf_output):
3697    """Returns the `Tensor` representing `tf_output`.
3698
3699    Note that there is only one such `Tensor`, i.e. multiple calls to this
3700    function with the same TF_Output value will always return the same `Tensor`
3701    object.
3702
3703    Args:
3704      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
3705
3706    Returns:
3707      The `Tensor` that represents `tf_output`.
3708    """
3709    op = self._get_operation_by_tf_operation(tf_output.oper)
3710    return op.outputs[tf_output.index]
3711
3712  @property
3713  def _last_id(self):
3714    return self._next_id_counter
3715
3716  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
3717    """Returns the `OpDef` proto for `type`. `type` is a string."""
3718    # NOTE: No locking is required because the lookup and insertion operations
3719    # on Python dictionaries are atomic.
3720    try:
3721      return self._op_def_cache[type]
3722    except KeyError:
3723      with c_api_util.tf_buffer() as buf:
3724        # pylint: disable=protected-access
3725        pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
3726                                           buf)
3727        # pylint: enable=protected-access
3728        data = pywrap_tf_session.TF_GetBuffer(buf)
3729      op_def = op_def_pb2.OpDef()
3730      op_def.ParseFromString(compat.as_bytes(data))
3731      self._op_def_cache[type] = op_def
3732      return op_def
3733
3734  def as_default(self):
3735    """Returns a context manager that makes this `Graph` the default graph.
3736
3737    This method should be used if you want to create multiple graphs
3738    in the same process. For convenience, a global default graph is
3739    provided, and all ops will be added to this graph if you do not
3740    create a new graph explicitly.
3741
3742    Use this method with the `with` keyword to specify that ops created within
3743    the scope of a block should be added to this graph. In this case, once
3744    the scope of the `with` is exited, the previous default graph is set again
3745    as default. There is a stack, so it's ok to have multiple nested levels
3746    of `as_default` calls.
3747
3748    The default graph is a property of the current thread. If you
3749    create a new thread, and wish to use the default graph in that
3750    thread, you must explicitly add a `with g.as_default():` in that
3751    thread's function.
3752
3753    The following code examples are equivalent:
3754
3755    ```python
3756    # 1. Using Graph.as_default():
3757    g = tf.Graph()
3758    with g.as_default():
3759      c = tf.constant(5.0)
3760      assert c.graph is g
3761
3762    # 2. Constructing and making default:
3763    with tf.Graph().as_default() as g:
3764      c = tf.constant(5.0)
3765      assert c.graph is g
3766    ```
3767
3768    If eager execution is enabled ops created under this context manager will be
3769    added to the graph instead of executed eagerly.
3770
3771    Returns:
3772      A context manager for using this graph as the default graph.
3773    """
3774    return _default_graph_stack.get_controller(self)
3775
3776  @property
3777  def collections(self):
3778    """Returns the names of the collections known to this graph."""
3779    return list(self._collections)
3780
3781  def add_to_collection(self, name, value):
3782    """Stores `value` in the collection with the given `name`.
3783
3784    Note that collections are not sets, so it is possible to add a value to
3785    a collection several times.
3786
3787    Args:
3788      name: The key for the collection. The `GraphKeys` class contains many
3789        standard names for collections.
3790      value: The value to add to the collection.
3791    """  # pylint: disable=g-doc-exception
3792    self._check_not_finalized()
3793    with self._lock:
3794      if name not in self._collections:
3795        self._collections[name] = [value]
3796      else:
3797        self._collections[name].append(value)
3798
3799  def add_to_collections(self, names, value):
3800    """Stores `value` in the collections given by `names`.
3801
3802    Note that collections are not sets, so it is possible to add a value to
3803    a collection several times. This function makes sure that duplicates in
3804    `names` are ignored, but it will not check for pre-existing membership of
3805    `value` in any of the collections in `names`.
3806
3807    `names` can be any iterable, but if `names` is a string, it is treated as a
3808    single collection name.
3809
3810    Args:
3811      names: The keys for the collections to add to. The `GraphKeys` class
3812        contains many standard names for collections.
3813      value: The value to add to the collections.
3814    """
3815    # Make sure names are unique, but treat strings as a single collection name
3816    names = (names,) if isinstance(names, six.string_types) else set(names)
3817    for name in names:
3818      self.add_to_collection(name, value)
3819
3820  def get_collection_ref(self, name):
3821    """Returns a list of values in the collection with the given `name`.
3822
3823    If the collection exists, this returns the list itself, which can
3824    be modified in place to change the collection.  If the collection does
3825    not exist, it is created as an empty list and the list is returned.
3826
3827    This is different from `get_collection()` which always returns a copy of
3828    the collection list if it exists and never creates an empty collection.
3829
3830    Args:
3831      name: The key for the collection. For example, the `GraphKeys` class
3832        contains many standard names for collections.
3833
3834    Returns:
3835      The list of values in the collection with the given `name`, or an empty
3836      list if no value has been added to that collection.
3837    """  # pylint: disable=g-doc-exception
3838    with self._lock:
3839      coll_list = self._collections.get(name, None)
3840      if coll_list is None:
3841        coll_list = []
3842        self._collections[name] = coll_list
3843      return coll_list
3844
3845  def get_collection(self, name, scope=None):
3846    """Returns a list of values in the collection with the given `name`.
3847
3848    This is different from `get_collection_ref()` which always returns the
3849    actual collection list if it exists in that it returns a new list each time
3850    it is called.
3851
3852    Args:
3853      name: The key for the collection. For example, the `GraphKeys` class
3854        contains many standard names for collections.
3855      scope: (Optional.) A string. If supplied, the resulting list is filtered
3856        to include only items whose `name` attribute matches `scope` using
3857        `re.match`. Items without a `name` attribute are never returned if a
3858        scope is supplied. The choice of `re.match` means that a `scope` without
3859        special tokens filters by prefix.
3860
3861    Returns:
3862      The list of values in the collection with the given `name`, or
3863      an empty list if no value has been added to that collection. The
3864      list contains the values in the order under which they were
3865      collected.
3866    """  # pylint: disable=g-doc-exception
3867    with self._lock:
3868      collection = self._collections.get(name, None)
3869      if collection is None:
3870        return []
3871      if scope is None:
3872        return list(collection)
3873      else:
3874        c = []
3875        regex = re.compile(scope)
3876        for item in collection:
3877          try:
3878            if regex.match(item.name):
3879              c.append(item)
3880          except AttributeError:
3881            # Collection items with no name are ignored.
3882            pass
3883        return c
3884
3885  def get_all_collection_keys(self):
3886    """Returns a list of collections used in this graph."""
3887    with self._lock:
3888      return [x for x in self._collections if isinstance(x, six.string_types)]
3889
3890  def clear_collection(self, name):
3891    """Clears all values in a collection.
3892
3893    Args:
3894      name: The key for the collection. The `GraphKeys` class contains many
3895        standard names for collections.
3896    """
3897    self._check_not_finalized()
3898    with self._lock:
3899      if name in self._collections:
3900        del self._collections[name]
3901
3902  @tf_contextlib.contextmanager
3903  def _original_op(self, op):
3904    """Python 'with' handler to help annotate ops with their originator.
3905
3906    An op may have an 'original_op' property that indicates the op on which
3907    it was based. For example a replica op is based on the op that was
3908    replicated and a gradient op is based on the op that was differentiated.
3909
3910    All ops created in the scope of this 'with' handler will have
3911    the given 'op' as their original op.
3912
3913    Args:
3914      op: The Operation that all ops created in this scope will have as their
3915        original op.
3916
3917    Yields:
3918      Nothing.
3919    """
3920    old_original_op = self._default_original_op
3921    self._default_original_op = op
3922    try:
3923      yield
3924    finally:
3925      self._default_original_op = old_original_op
3926
3927  @property
3928  def _name_stack(self):
3929    # This may be called from a thread where name_stack doesn't yet exist.
3930    if not hasattr(self._thread_local, "_name_stack"):
3931      self._thread_local._name_stack = ""
3932    return self._thread_local._name_stack
3933
3934  @_name_stack.setter
3935  def _name_stack(self, name_stack):
3936    self._thread_local._name_stack = name_stack
3937
3938  # pylint: disable=g-doc-return-or-yield,line-too-long
3939  @tf_contextlib.contextmanager
3940  def name_scope(self, name):
3941    """Returns a context manager that creates hierarchical names for operations.
3942
3943    A graph maintains a stack of name scopes. A `with name_scope(...):`
3944    statement pushes a new name onto the stack for the lifetime of the context.
3945
3946    The `name` argument will be interpreted as follows:
3947
3948    * A string (not ending with '/') will create a new name scope, in which
3949      `name` is appended to the prefix of all operations created in the
3950      context. If `name` has been used before, it will be made unique by
3951      calling `self.unique_name(name)`.
3952    * A scope previously captured from a `with g.name_scope(...) as
3953      scope:` statement will be treated as an "absolute" name scope, which
3954      makes it possible to re-enter existing scopes.
3955    * A value of `None` or the empty string will reset the current name scope
3956      to the top-level (empty) name scope.
3957
3958    For example:
3959
3960    ```python
3961    with tf.Graph().as_default() as g:
3962      c = tf.constant(5.0, name="c")
3963      assert c.op.name == "c"
3964      c_1 = tf.constant(6.0, name="c")
3965      assert c_1.op.name == "c_1"
3966
3967      # Creates a scope called "nested"
3968      with g.name_scope("nested") as scope:
3969        nested_c = tf.constant(10.0, name="c")
3970        assert nested_c.op.name == "nested/c"
3971
3972        # Creates a nested scope called "inner".
3973        with g.name_scope("inner"):
3974          nested_inner_c = tf.constant(20.0, name="c")
3975          assert nested_inner_c.op.name == "nested/inner/c"
3976
3977        # Create a nested scope called "inner_1".
3978        with g.name_scope("inner"):
3979          nested_inner_1_c = tf.constant(30.0, name="c")
3980          assert nested_inner_1_c.op.name == "nested/inner_1/c"
3981
3982          # Treats `scope` as an absolute name scope, and
3983          # switches to the "nested/" scope.
3984          with g.name_scope(scope):
3985            nested_d = tf.constant(40.0, name="d")
3986            assert nested_d.op.name == "nested/d"
3987
3988            with g.name_scope(""):
3989              e = tf.constant(50.0, name="e")
3990              assert e.op.name == "e"
3991    ```
3992
3993    The name of the scope itself can be captured by `with
3994    g.name_scope(...) as scope:`, which stores the name of the scope
3995    in the variable `scope`. This value can be used to name an
3996    operation that represents the overall result of executing the ops
3997    in a scope. For example:
3998
3999    ```python
4000    inputs = tf.constant(...)
4001    with g.name_scope('my_layer') as scope:
4002      weights = tf.Variable(..., name="weights")
4003      biases = tf.Variable(..., name="biases")
4004      affine = tf.matmul(inputs, weights) + biases
4005      output = tf.nn.relu(affine, name=scope)
4006    ```
4007
4008    NOTE: This constructor validates the given `name`. Valid scope
4009    names match one of the following regular expressions:
4010
4011        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4012        [A-Za-z0-9_.\\-/]* (for other scopes)
4013
4014    Args:
4015      name: A name for the scope.
4016
4017    Returns:
4018      A context manager that installs `name` as a new name scope.
4019
4020    Raises:
4021      ValueError: If `name` is not a valid scope name, according to the rules
4022        above.
4023    """
4024    if name:
4025      if isinstance(name, compat.bytes_or_text_types):
4026        name = compat.as_str(name)
4027
4028      if self._name_stack:
4029        # Scopes created in a nested scope may have initial characters
4030        # that are illegal as the initial character of an op name
4031        # (viz. '-', '\', '/', and '_').
4032        if not _VALID_SCOPE_NAME_REGEX.match(name):
4033          raise ValueError("'%s' is not a valid scope name" % name)
4034      else:
4035        # Scopes created in the root must match the more restrictive
4036        # op name regex, which constrains the initial character.
4037        if not _VALID_OP_NAME_REGEX.match(name):
4038          raise ValueError("'%s' is not a valid scope name" % name)
4039    old_stack = self._name_stack
4040    if not name:  # Both for name=None and name="" we re-set to empty scope.
4041      new_stack = None
4042    elif name[-1] == "/":
4043      new_stack = name_from_scope_name(name)
4044    else:
4045      new_stack = self.unique_name(name)
4046    self._name_stack = new_stack
4047    try:
4048      yield "" if new_stack is None else new_stack + "/"
4049    finally:
4050      self._name_stack = old_stack
4051
4052  # pylint: enable=g-doc-return-or-yield,line-too-long
4053
4054  def unique_name(self, name, mark_as_used=True):
4055    """Return a unique operation name for `name`.
4056
4057    Note: You rarely need to call `unique_name()` directly.  Most of
4058    the time you just need to create `with g.name_scope()` blocks to
4059    generate structured names.
4060
4061    `unique_name` is used to generate structured names, separated by
4062    `"/"`, to help identify operations when debugging a graph.
4063    Operation names are displayed in error messages reported by the
4064    TensorFlow runtime, and in various visualization tools such as
4065    TensorBoard.
4066
4067    If `mark_as_used` is set to `True`, which is the default, a new
4068    unique name is created and marked as in use. If it's set to `False`,
4069    the unique name is returned without actually being marked as used.
4070    This is useful when the caller simply wants to know what the name
4071    to be created will be.
4072
4073    Args:
4074      name: The name for an operation.
4075      mark_as_used: Whether to mark this name as being used.
4076
4077    Returns:
4078      A string to be passed to `create_op()` that will be used
4079      to name the operation being created.
4080    """
4081    if self._name_stack:
4082      name = self._name_stack + "/" + name
4083
4084    # For the sake of checking for names in use, we treat names as case
4085    # insensitive (e.g. foo = Foo).
4086    name_key = name.lower()
4087    i = self._names_in_use.get(name_key, 0)
4088    # Increment the number for "name_key".
4089    if mark_as_used:
4090      self._names_in_use[name_key] = i + 1
4091    if i > 0:
4092      base_name_key = name_key
4093      # Make sure the composed name key is not already used.
4094      while name_key in self._names_in_use:
4095        name_key = "%s_%d" % (base_name_key, i)
4096        i += 1
4097      # Mark the composed name_key as used in case someone wants
4098      # to call unique_name("name_1").
4099      if mark_as_used:
4100        self._names_in_use[name_key] = 1
4101
4102      # Return the new name with the original capitalization of the given name.
4103      name = "%s_%d" % (name, i - 1)
4104    return name
4105
4106  def get_name_scope(self):
4107    """Returns the current name scope.
4108
4109    For example:
4110
4111    ```python
4112    with tf.name_scope('scope1'):
4113      with tf.name_scope('scope2'):
4114        print(tf.compat.v1.get_default_graph().get_name_scope())
4115    ```
4116    would print the string `scope1/scope2`.
4117
4118    Returns:
4119      A string representing the current name scope.
4120    """
4121    return self._name_stack
4122
4123  @tf_contextlib.contextmanager
4124  def _colocate_with_for_gradient(self, op, gradient_uid,
4125                                  ignore_existing=False):
4126    with self.colocate_with(op, ignore_existing):
4127      if gradient_uid is not None and self._control_flow_context is not None:
4128        self._control_flow_context.EnterGradientColocation(op, gradient_uid)
4129        try:
4130          yield
4131        finally:
4132          self._control_flow_context.ExitGradientColocation(op, gradient_uid)
4133      else:
4134        yield
4135
4136  @tf_contextlib.contextmanager
4137  def colocate_with(self, op, ignore_existing=False):
4138    """Returns a context manager that specifies an op to colocate with.
4139
4140    Note: this function is not for public use, only for internal libraries.
4141
4142    For example:
4143
4144    ```python
4145    a = tf.Variable([1.0])
4146    with g.colocate_with(a):
4147      b = tf.constant(1.0)
4148      c = tf.add(a, b)
4149    ```
4150
4151    `b` and `c` will always be colocated with `a`, no matter where `a`
4152    is eventually placed.
4153
4154    **NOTE** Using a colocation scope resets any existing device constraints.
4155
4156    If `op` is `None` then `ignore_existing` must be `True` and the new
4157    scope resets all colocation and device constraints.
4158
4159    Args:
4160      op: The op to colocate all created ops with, or `None`.
4161      ignore_existing: If true, only applies colocation of this op within the
4162        context, rather than applying all colocation properties on the stack.
4163        If `op` is `None`, this value must be `True`.
4164
4165    Raises:
4166      ValueError: if op is None but ignore_existing is False.
4167
4168    Yields:
4169      A context manager that specifies the op with which to colocate
4170      newly created ops.
4171    """
4172    if op is None and not ignore_existing:
4173      raise ValueError("Trying to reset colocation (op is None) but "
4174                       "ignore_existing is not True")
4175    op = _op_to_colocate_with(op, self)
4176
4177    # By default, colocate_with resets the device function stack,
4178    # since colocate_with is typically used in specific internal
4179    # library functions where colocation is intended to be "stronger"
4180    # than device functions.
4181    #
4182    # In the future, a caller may specify that device_functions win
4183    # over colocation, in which case we can add support.
4184    device_fn_tmp = self._device_function_stack
4185    self._device_function_stack = traceable_stack.TraceableStack()
4186
4187    if ignore_existing:
4188      current_stack = self._colocation_stack
4189      self._colocation_stack = traceable_stack.TraceableStack()
4190
4191    if op is not None:
4192      # offset refers to the stack frame used for storing code location.
4193      # We use 4, the sum of 1 to use our caller's stack frame and 3
4194      # to jump over layers of context managers above us.
4195      self._colocation_stack.push_obj(op, offset=4)
4196
4197    try:
4198      yield
4199    finally:
4200      # Restore device function stack
4201      self._device_function_stack = device_fn_tmp
4202      if op is not None:
4203        self._colocation_stack.pop_obj()
4204
4205      # Reset the colocation stack if requested.
4206      if ignore_existing:
4207        self._colocation_stack = current_stack
4208
4209  def _add_device_to_stack(self, device_name_or_function, offset=0):
4210    """Add device to stack manually, separate from a context manager."""
4211    total_offset = 1 + offset
4212    spec = _UserDeviceSpec(device_name_or_function)
4213    self._device_function_stack.push_obj(spec, offset=total_offset)
4214    return spec
4215
4216  @tf_contextlib.contextmanager
4217  def device(self, device_name_or_function):
4218    # pylint: disable=line-too-long
4219    """Returns a context manager that specifies the default device to use.
4220
4221    The `device_name_or_function` argument may either be a device name
4222    string, a device function, or None:
4223
4224    * If it is a device name string, all operations constructed in
4225      this context will be assigned to the device with that name, unless
4226      overridden by a nested `device()` context.
4227    * If it is a function, it will be treated as a function from
4228      Operation objects to device name strings, and invoked each time
4229      a new Operation is created. The Operation will be assigned to
4230      the device with the returned name.
4231    * If it is None, all `device()` invocations from the enclosing context
4232      will be ignored.
4233
4234    For information about the valid syntax of device name strings, see
4235    the documentation in
4236    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4237
4238    For example:
4239
4240    ```python
4241    with g.device('/device:GPU:0'):
4242      # All operations constructed in this context will be placed
4243      # on GPU 0.
4244      with g.device(None):
4245        # All operations constructed in this context will have no
4246        # assigned device.
4247
4248    # Defines a function from `Operation` to device string.
4249    def matmul_on_gpu(n):
4250      if n.type == "MatMul":
4251        return "/device:GPU:0"
4252      else:
4253        return "/cpu:0"
4254
4255    with g.device(matmul_on_gpu):
4256      # All operations of type "MatMul" constructed in this context
4257      # will be placed on GPU 0; all other operations will be placed
4258      # on CPU 0.
4259    ```
4260
4261    **N.B.** The device scope may be overridden by op wrappers or
4262    other library code. For example, a variable assignment op
4263    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4264    incompatible device scopes will be ignored.
4265
4266    Args:
4267      device_name_or_function: The device name or function to use in the
4268        context.
4269
4270    Yields:
4271      A context manager that specifies the default device to use for newly
4272      created ops.
4273
4274    Raises:
4275      RuntimeError: If device scopes are not properly nested.
4276    """
4277    self._add_device_to_stack(device_name_or_function, offset=2)
4278    old_top_of_stack = self._device_function_stack.peek_top_obj()
4279    try:
4280      yield
4281    finally:
4282      new_top_of_stack = self._device_function_stack.peek_top_obj()
4283      if old_top_of_stack is not new_top_of_stack:
4284        raise RuntimeError("Exiting device scope without proper scope nesting.")
4285      self._device_function_stack.pop_obj()
4286
4287  def _apply_device_functions(self, op):
4288    """Applies the current device function stack to the given operation."""
4289    # Apply any device functions in LIFO order, so that the most recently
4290    # pushed function has the first chance to apply a device to the op.
4291    # We apply here because the result can depend on the Operation's
4292    # signature, which is computed in the Operation constructor.
4293    # pylint: disable=protected-access
4294    prior_device_string = None
4295    for device_spec in self._device_function_stack.peek_objs():
4296      if device_spec.is_null_merge:
4297        continue
4298
4299      if device_spec.function is None:
4300        break
4301
4302      device_string = device_spec.string_merge(op)
4303
4304      # Take advantage of the fact that None is a singleton and Python interns
4305      # strings, since identity checks are faster than equality checks.
4306      if device_string is not prior_device_string:
4307        op._set_device_from_string(device_string)
4308        prior_device_string = device_string
4309    op._device_code_locations = self._snapshot_device_function_stack_metadata()
4310    # pylint: enable=protected-access
4311
4312  # pylint: disable=g-doc-return-or-yield
4313  @tf_contextlib.contextmanager
4314  def container(self, container_name):
4315    """Returns a context manager that specifies the resource container to use.
4316
4317    Stateful operations, such as variables and queues, can maintain their
4318    states on devices so that they can be shared by multiple processes.
4319    A resource container is a string name under which these stateful
4320    operations are tracked. These resources can be released or cleared
4321    with `tf.Session.reset()`.
4322
4323    For example:
4324
4325    ```python
4326    with g.container('experiment0'):
4327      # All stateful Operations constructed in this context will be placed
4328      # in resource container "experiment0".
4329      v1 = tf.Variable([1.0])
4330      v2 = tf.Variable([2.0])
4331      with g.container("experiment1"):
4332        # All stateful Operations constructed in this context will be
4333        # placed in resource container "experiment1".
4334        v3 = tf.Variable([3.0])
4335        q1 = tf.queue.FIFOQueue(10, tf.float32)
4336      # All stateful Operations constructed in this context will be
4337      # be created in the "experiment0".
4338      v4 = tf.Variable([4.0])
4339      q1 = tf.queue.FIFOQueue(20, tf.float32)
4340      with g.container(""):
4341        # All stateful Operations constructed in this context will be
4342        # be placed in the default resource container.
4343        v5 = tf.Variable([5.0])
4344        q3 = tf.queue.FIFOQueue(30, tf.float32)
4345
4346    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4347    # will become undefined (such as uninitialized).
4348    tf.Session.reset(target, ["experiment0"])
4349    ```
4350
4351    Args:
4352      container_name: container name string.
4353
4354    Returns:
4355      A context manager for defining resource containers for stateful ops,
4356        yields the container name.
4357    """
4358    original_container = self._container
4359    self._container = container_name
4360    try:
4361      yield self._container
4362    finally:
4363      self._container = original_container
4364
4365  # pylint: enable=g-doc-return-or-yield
4366
4367  class _ControlDependenciesController(object):
4368    """Context manager for `control_dependencies()`."""
4369
4370    def __init__(self, graph, control_inputs):
4371      """Create a new `_ControlDependenciesController`.
4372
4373      A `_ControlDependenciesController` is the context manager for
4374      `with tf.control_dependencies()` blocks.  These normally nest,
4375      as described in the documentation for `control_dependencies()`.
4376
4377      The `control_inputs` argument list control dependencies that must be
4378      added to the current set of control dependencies.  Because of
4379      uniquification the set can be empty even if the caller passed a list of
4380      ops.  The special value `None` indicates that we want to start a new
4381      empty set of control dependencies instead of extending the current set.
4382
4383      In that case we also clear the current control flow context, which is an
4384      additional mechanism to add control dependencies.
4385
4386      Args:
4387        graph: The graph that this controller is managing.
4388        control_inputs: List of ops to use as control inputs in addition to the
4389          current control dependencies.  None to indicate that the dependencies
4390          should be cleared.
4391      """
4392      self._graph = graph
4393      if control_inputs is None:
4394        self._control_inputs_val = []
4395        self._new_stack = True
4396      else:
4397        self._control_inputs_val = control_inputs
4398        self._new_stack = False
4399      self._seen_nodes = set()
4400      self._old_stack = None
4401      self._old_control_flow_context = None
4402
4403# pylint: disable=protected-access
4404
4405    def __enter__(self):
4406      if self._new_stack:
4407        # Clear the control_dependencies graph.
4408        self._old_stack = self._graph._control_dependencies_stack
4409        self._graph._control_dependencies_stack = []
4410        # Clear the control_flow_context too.
4411        self._old_control_flow_context = self._graph._get_control_flow_context()
4412        self._graph._set_control_flow_context(None)
4413      self._graph._push_control_dependencies_controller(self)
4414
4415    def __exit__(self, unused_type, unused_value, unused_traceback):
4416      self._graph._pop_control_dependencies_controller(self)
4417      if self._new_stack:
4418        self._graph._control_dependencies_stack = self._old_stack
4419        self._graph._set_control_flow_context(self._old_control_flow_context)
4420
4421# pylint: enable=protected-access
4422
4423    @property
4424    def control_inputs(self):
4425      return self._control_inputs_val
4426
4427    def add_op(self, op):
4428      if isinstance(op, Tensor):
4429        op = op.experimental_ref()
4430      self._seen_nodes.add(op)
4431
4432    def op_in_group(self, op):
4433      if isinstance(op, Tensor):
4434        op = op.experimental_ref()
4435      return op in self._seen_nodes
4436
4437  def _push_control_dependencies_controller(self, controller):
4438    self._control_dependencies_stack.append(controller)
4439
4440  def _pop_control_dependencies_controller(self, controller):
4441    assert self._control_dependencies_stack[-1] is controller
4442    self._control_dependencies_stack.pop()
4443
4444  def _current_control_dependencies(self):
4445    ret = set()
4446    for controller in self._control_dependencies_stack:
4447      for op in controller.control_inputs:
4448        ret.add(op)
4449    return ret
4450
4451  def _control_dependencies_for_inputs(self, input_ops):
4452    """For an op that takes `input_ops` as inputs, compute control inputs.
4453
4454    The returned control dependencies should yield an execution that
4455    is equivalent to adding all control inputs in
4456    self._control_dependencies_stack to a newly created op. However,
4457    this function attempts to prune the returned control dependencies
4458    by observing that nodes created within the same `with
4459    control_dependencies(...):` block may have data dependencies that make
4460    the explicit approach redundant.
4461
4462    Args:
4463      input_ops: The data input ops for an op to be created.
4464
4465    Returns:
4466      A list of control inputs for the op to be created.
4467    """
4468    ret = []
4469    for controller in self._control_dependencies_stack:
4470      # If any of the input_ops already depends on the inputs from controller,
4471      # we say that the new op is dominated (by that input), and we therefore
4472      # do not need to add control dependencies for this controller's inputs.
4473      dominated = False
4474      for op in input_ops:
4475        if controller.op_in_group(op):
4476          dominated = True
4477          break
4478      if not dominated:
4479        # Don't add a control input if we already have a data dependency on i.
4480        # NOTE(mrry): We do not currently track transitive data dependencies,
4481        #   so we may add redundant control inputs.
4482        ret.extend(c for c in controller.control_inputs if c not in input_ops)
4483    return ret
4484
4485  def _record_op_seen_by_control_dependencies(self, op):
4486    """Record that the given op depends on all registered control dependencies.
4487
4488    Args:
4489      op: An Operation.
4490    """
4491    for controller in self._control_dependencies_stack:
4492      controller.add_op(op)
4493
4494  def control_dependencies(self, control_inputs):
4495    """Returns a context manager that specifies control dependencies.
4496
4497    Use with the `with` keyword to specify that all operations constructed
4498    within the context should have control dependencies on
4499    `control_inputs`. For example:
4500
4501    ```python
4502    with g.control_dependencies([a, b, c]):
4503      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4504      d = ...
4505      e = ...
4506    ```
4507
4508    Multiple calls to `control_dependencies()` can be nested, and in
4509    that case a new `Operation` will have control dependencies on the union
4510    of `control_inputs` from all active contexts.
4511
4512    ```python
4513    with g.control_dependencies([a, b]):
4514      # Ops constructed here run after `a` and `b`.
4515      with g.control_dependencies([c, d]):
4516        # Ops constructed here run after `a`, `b`, `c`, and `d`.
4517    ```
4518
4519    You can pass None to clear the control dependencies:
4520
4521    ```python
4522    with g.control_dependencies([a, b]):
4523      # Ops constructed here run after `a` and `b`.
4524      with g.control_dependencies(None):
4525        # Ops constructed here run normally, not waiting for either `a` or `b`.
4526        with g.control_dependencies([c, d]):
4527          # Ops constructed here run after `c` and `d`, also not waiting
4528          # for either `a` or `b`.
4529    ```
4530
4531    *N.B.* The control dependencies context applies *only* to ops that
4532    are constructed within the context. Merely using an op or tensor
4533    in the context does not add a control dependency. The following
4534    example illustrates this point:
4535
4536    ```python
4537    # WRONG
4538    def my_func(pred, tensor):
4539      t = tf.matmul(tensor, tensor)
4540      with tf.control_dependencies([pred]):
4541        # The matmul op is created outside the context, so no control
4542        # dependency will be added.
4543        return t
4544
4545    # RIGHT
4546    def my_func(pred, tensor):
4547      with tf.control_dependencies([pred]):
4548        # The matmul op is created in the context, so a control dependency
4549        # will be added.
4550        return tf.matmul(tensor, tensor)
4551    ```
4552
4553    Also note that though execution of ops created under this scope will trigger
4554    execution of the dependencies, the ops created under this scope might still
4555    be pruned from a normal tensorflow graph. For example, in the following
4556    snippet of code the dependencies are never executed:
4557
4558    ```python
4559      loss = model.loss()
4560      with tf.control_dependencies(dependencies):
4561        loss = loss + tf.constant(1)  # note: dependencies ignored in the
4562                                      # backward pass
4563      return tf.gradients(loss, model.variables)
4564    ```
4565
4566    This is because evaluating the gradient graph does not require evaluating
4567    the constant(1) op created in the forward pass.
4568
4569    Args:
4570      control_inputs: A list of `Operation` or `Tensor` objects which must be
4571        executed or computed before running the operations defined in the
4572        context.  Can also be `None` to clear the control dependencies.
4573
4574    Returns:
4575     A context manager that specifies control dependencies for all
4576     operations constructed within the context.
4577
4578    Raises:
4579      TypeError: If `control_inputs` is not a list of `Operation` or
4580        `Tensor` objects.
4581    """
4582    if control_inputs is None:
4583      return self._ControlDependenciesController(self, None)
4584    # First convert the inputs to ops, and deduplicate them.
4585    # NOTE(mrry): Other than deduplication, we do not currently track direct
4586    #   or indirect dependencies between control_inputs, which may result in
4587    #   redundant control inputs.
4588    control_ops = []
4589    current = self._current_control_dependencies()
4590    for c in control_inputs:
4591      # The hasattr(handle) is designed to match ResourceVariables. This is so
4592      # control dependencies on a variable or on an unread variable don't
4593      # trigger reads.
4594      if (isinstance(c, IndexedSlices) or
4595          (hasattr(c, "_handle") and hasattr(c, "op"))):
4596        c = c.op
4597      c = self.as_graph_element(c)
4598      if isinstance(c, Tensor):
4599        c = c.op
4600      elif not isinstance(c, Operation):
4601        raise TypeError("Control input must be Operation or Tensor: %s" % c)
4602      if c not in current:
4603        control_ops.append(c)
4604        current.add(c)
4605    return self._ControlDependenciesController(self, control_ops)
4606
4607  # pylint: disable=g-doc-return-or-yield
4608  @tf_contextlib.contextmanager
4609  def _attr_scope(self, attr_map):
4610    """EXPERIMENTAL: A context manager for setting attributes on operators.
4611
4612    This context manager can be used to add additional
4613    attributes to operators within the scope of the context.
4614
4615    For example:
4616
4617       with ops.Graph().as_default() as g:
4618         f_1 = Foo()  # No extra attributes
4619         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4620           f_2 = Foo()  # Additional attribute _a=False
4621           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4622             f_3 = Foo()  # Additional attribute _a=False
4623             with g._attr_scope({"_a": None}):
4624               f_4 = Foo()  # No additional attributes.
4625
4626    Args:
4627      attr_map: A dictionary mapping attr name strings to AttrValue protocol
4628        buffers or None.
4629
4630    Returns:
4631      A context manager that sets the kernel label to be used for one or more
4632      ops created in that context.
4633
4634    Raises:
4635      TypeError: If attr_map is not a dictionary mapping
4636        strings to AttrValue protobufs.
4637    """
4638    if not isinstance(attr_map, dict):
4639      raise TypeError("attr_map must be a dictionary mapping "
4640                      "strings to AttrValue protocol buffers")
4641    # The saved_attrs dictionary stores any currently-set labels that
4642    # will be overridden by this context manager.
4643    saved_attrs = {}
4644    # Install the given attribute
4645    for name, attr in attr_map.items():
4646      if not (isinstance(name, six.string_types) and
4647              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4648               callable(attr))):
4649        raise TypeError("attr_map must be a dictionary mapping "
4650                        "strings to AttrValue protocol buffers or "
4651                        "callables that emit AttrValue protocol buffers")
4652      try:
4653        saved_attrs[name] = self._attr_scope_map[name]
4654      except KeyError:
4655        pass
4656      if attr is None:
4657        del self._attr_scope_map[name]
4658      else:
4659        self._attr_scope_map[name] = attr
4660    try:
4661      yield  # The code within the context runs here.
4662    finally:
4663      # Remove the attributes set for this context, and restore any saved
4664      # attributes.
4665      for name, attr in attr_map.items():
4666        try:
4667          self._attr_scope_map[name] = saved_attrs[name]
4668        except KeyError:
4669          del self._attr_scope_map[name]
4670
4671  # pylint: enable=g-doc-return-or-yield
4672
4673  # pylint: disable=g-doc-return-or-yield
4674  @tf_contextlib.contextmanager
4675  def _kernel_label_map(self, op_to_kernel_label_map):
4676    """EXPERIMENTAL: A context manager for setting kernel labels.
4677
4678    This context manager can be used to select particular
4679    implementations of kernels within the scope of the context.
4680
4681    For example:
4682
4683        with ops.Graph().as_default() as g:
4684          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
4685          with g.kernel_label_map({"Foo": "v_2"}):
4686            f_2 = Foo()  # Uses the registered kernel with label "v_2"
4687                         # for the Foo op.
4688            with g.kernel_label_map({"Foo": "v_3"}):
4689              f_3 = Foo()  # Uses the registered kernel with label "v_3"
4690                           # for the Foo op.
4691              with g.kernel_label_map({"Foo": ""}):
4692                f_4 = Foo()  # Uses the default registered kernel
4693                             # for the Foo op.
4694
4695    Args:
4696      op_to_kernel_label_map: A dictionary mapping op type strings to kernel
4697        label strings.
4698
4699    Returns:
4700      A context manager that sets the kernel label to be used for one or more
4701      ops created in that context.
4702
4703    Raises:
4704      TypeError: If op_to_kernel_label_map is not a dictionary mapping
4705        strings to strings.
4706    """
4707    if not isinstance(op_to_kernel_label_map, dict):
4708      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4709                      "strings to strings")
4710    # The saved_labels dictionary stores any currently-set labels that
4711    # will be overridden by this context manager.
4712    saved_labels = {}
4713    # Install the given label
4714    for op_type, label in op_to_kernel_label_map.items():
4715      if not (isinstance(op_type, six.string_types) and
4716              isinstance(label, six.string_types)):
4717        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4718                        "strings to strings")
4719      try:
4720        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
4721      except KeyError:
4722        pass
4723      self._op_to_kernel_label_map[op_type] = label
4724    try:
4725      yield  # The code within the context runs here.
4726    finally:
4727      # Remove the labels set for this context, and restore any saved labels.
4728      for op_type, label in op_to_kernel_label_map.items():
4729        try:
4730          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
4731        except KeyError:
4732          del self._op_to_kernel_label_map[op_type]
4733
4734  # pylint: enable=g-doc-return-or-yield
4735
4736  @tf_contextlib.contextmanager
4737  def _override_gradient_function(self, gradient_function_map):
4738    """Specify gradient function for the given op type."""
4739
4740    # This is an internal API and we don't need nested context for this.
4741    assert not self._gradient_function_map
4742    self._gradient_function_map = gradient_function_map
4743    yield
4744    self._gradient_function_map = {}
4745
4746  # pylint: disable=g-doc-return-or-yield
4747  @tf_contextlib.contextmanager
4748  def gradient_override_map(self, op_type_map):
4749    """EXPERIMENTAL: A context manager for overriding gradient functions.
4750
4751    This context manager can be used to override the gradient function
4752    that will be used for ops within the scope of the context.
4753
4754    For example:
4755
4756    ```python
4757    @tf.RegisterGradient("CustomSquare")
4758    def _custom_square_grad(op, grad):
4759      # ...
4760
4761    with tf.Graph().as_default() as g:
4762      c = tf.constant(5.0)
4763      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
4764      with g.gradient_override_map({"Square": "CustomSquare"}):
4765        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
4766                              # gradient of s_2.
4767    ```
4768
4769    Args:
4770      op_type_map: A dictionary mapping op type strings to alternative op type
4771        strings.
4772
4773    Returns:
4774      A context manager that sets the alternative op type to be used for one
4775      or more ops created in that context.
4776
4777    Raises:
4778      TypeError: If `op_type_map` is not a dictionary mapping strings to
4779        strings.
4780    """
4781    if not isinstance(op_type_map, dict):
4782      raise TypeError("op_type_map must be a dictionary mapping "
4783                      "strings to strings")
4784    # The saved_mappings dictionary stores any currently-set mappings that
4785    # will be overridden by this context manager.
4786    saved_mappings = {}
4787    # Install the given label
4788    for op_type, mapped_op_type in op_type_map.items():
4789      if not (isinstance(op_type, six.string_types) and
4790              isinstance(mapped_op_type, six.string_types)):
4791        raise TypeError("op_type_map must be a dictionary mapping "
4792                        "strings to strings")
4793      try:
4794        saved_mappings[op_type] = self._gradient_override_map[op_type]
4795      except KeyError:
4796        pass
4797      self._gradient_override_map[op_type] = mapped_op_type
4798    try:
4799      yield  # The code within the context runs here.
4800    finally:
4801      # Remove the labels set for this context, and restore any saved labels.
4802      for op_type, mapped_op_type in op_type_map.items():
4803        try:
4804          self._gradient_override_map[op_type] = saved_mappings[op_type]
4805        except KeyError:
4806          del self._gradient_override_map[op_type]
4807
4808  # pylint: enable=g-doc-return-or-yield
4809
4810  def prevent_feeding(self, tensor):
4811    """Marks the given `tensor` as unfeedable in this graph."""
4812    self._unfeedable_tensors.add(tensor)
4813
4814  def is_feedable(self, tensor):
4815    """Returns `True` if and only if `tensor` is feedable."""
4816    return tensor not in self._unfeedable_tensors
4817
4818  def prevent_fetching(self, op):
4819    """Marks the given `op` as unfetchable in this graph."""
4820    self._unfetchable_ops.add(op)
4821
4822  def is_fetchable(self, tensor_or_op):
4823    """Returns `True` if and only if `tensor_or_op` is fetchable."""
4824    if isinstance(tensor_or_op, Tensor):
4825      return tensor_or_op.op not in self._unfetchable_ops
4826    else:
4827      return tensor_or_op not in self._unfetchable_ops
4828
4829  def switch_to_thread_local(self):
4830    """Make device, colocation and dependencies stacks thread-local.
4831
4832    Device, colocation and dependencies stacks are not thread-local be default.
4833    If multiple threads access them, then the state is shared.  This means that
4834    one thread may affect the behavior of another thread.
4835
4836    After this method is called, the stacks become thread-local.  If multiple
4837    threads access them, then the state is not shared.  Each thread uses its own
4838    value; a thread doesn't affect other threads by mutating such a stack.
4839
4840    The initial value for every thread's stack is set to the current value
4841    of the stack when `switch_to_thread_local()` was first called.
4842    """
4843    if not self._stack_state_is_thread_local:
4844      self._stack_state_is_thread_local = True
4845
4846  @property
4847  def _device_function_stack(self):
4848    if self._stack_state_is_thread_local:
4849      # This may be called from a thread where device_function_stack doesn't yet
4850      # exist.
4851      # pylint: disable=protected-access
4852      if not hasattr(self._thread_local, "_device_function_stack"):
4853        stack_copy_for_this_thread = self._graph_device_function_stack.copy()
4854        self._thread_local._device_function_stack = stack_copy_for_this_thread
4855      return self._thread_local._device_function_stack
4856      # pylint: enable=protected-access
4857    else:
4858      return self._graph_device_function_stack
4859
4860  @property
4861  def _device_functions_outer_to_inner(self):
4862    user_device_specs = self._device_function_stack.peek_objs()
4863    device_functions = [spec.function for spec in user_device_specs]
4864    device_functions_outer_to_inner = list(reversed(device_functions))
4865    return device_functions_outer_to_inner
4866
4867  def _snapshot_device_function_stack_metadata(self):
4868    """Return device function stack as a list of TraceableObjects.
4869
4870    Returns:
4871      [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
4872      member is a displayable name for the user's argument to Graph.device, and
4873      the filename and lineno members point to the code location where
4874      Graph.device was called directly or indirectly by the user.
4875    """
4876    snapshot = []
4877    for obj in self._device_function_stack.peek_traceable_objs():
4878      obj_copy = obj.copy_metadata()
4879      obj_copy.obj = obj.obj.display_name
4880      snapshot.append(obj_copy)
4881    return snapshot
4882
4883  @_device_function_stack.setter
4884  def _device_function_stack(self, device_function_stack):
4885    if self._stack_state_is_thread_local:
4886      # pylint: disable=protected-access
4887      self._thread_local._device_function_stack = device_function_stack
4888      # pylint: enable=protected-access
4889    else:
4890      self._graph_device_function_stack = device_function_stack
4891
4892  @property
4893  def _colocation_stack(self):
4894    """Return thread-local copy of colocation stack."""
4895    if self._stack_state_is_thread_local:
4896      # This may be called from a thread where colocation_stack doesn't yet
4897      # exist.
4898      # pylint: disable=protected-access
4899      if not hasattr(self._thread_local, "_colocation_stack"):
4900        stack_copy_for_this_thread = self._graph_colocation_stack.copy()
4901        self._thread_local._colocation_stack = stack_copy_for_this_thread
4902      return self._thread_local._colocation_stack
4903      # pylint: enable=protected-access
4904    else:
4905      return self._graph_colocation_stack
4906
4907  def _snapshot_colocation_stack_metadata(self):
4908    """Return colocation stack metadata as a dictionary."""
4909    return {
4910        traceable_obj.obj.name: traceable_obj.copy_metadata()
4911        for traceable_obj in self._colocation_stack.peek_traceable_objs()
4912    }
4913
4914  @_colocation_stack.setter
4915  def _colocation_stack(self, colocation_stack):
4916    if self._stack_state_is_thread_local:
4917      # pylint: disable=protected-access
4918      self._thread_local._colocation_stack = colocation_stack
4919      # pylint: enable=protected-access
4920    else:
4921      self._graph_colocation_stack = colocation_stack
4922
4923  @property
4924  def _control_dependencies_stack(self):
4925    if self._stack_state_is_thread_local:
4926      # This may be called from a thread where control_dependencies_stack
4927      # doesn't yet exist.
4928      if not hasattr(self._thread_local, "_control_dependencies_stack"):
4929        self._thread_local._control_dependencies_stack = (
4930            self._graph_control_dependencies_stack[:])
4931      return self._thread_local._control_dependencies_stack
4932    else:
4933      return self._graph_control_dependencies_stack
4934
4935  @_control_dependencies_stack.setter
4936  def _control_dependencies_stack(self, control_dependencies):
4937    if self._stack_state_is_thread_local:
4938      self._thread_local._control_dependencies_stack = control_dependencies
4939    else:
4940      self._graph_control_dependencies_stack = control_dependencies
4941
4942  @property
4943  def _distribution_strategy_stack(self):
4944    """A stack to maintain distribution strategy context for each thread."""
4945    if not hasattr(self._thread_local, "_distribution_strategy_stack"):
4946      self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
4947    return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
4948
4949  @_distribution_strategy_stack.setter
4950  def _distribution_strategy_stack(self, _distribution_strategy_stack):
4951    self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
4952        _distribution_strategy_stack)
4953
4954  @property
4955  def _global_distribute_strategy_scope(self):
4956    """For implementing `tf.distribute.set_strategy()`."""
4957    if not hasattr(self._thread_local, "distribute_strategy_scope"):
4958      self._thread_local.distribute_strategy_scope = None
4959    return self._thread_local.distribute_strategy_scope
4960
4961  @_global_distribute_strategy_scope.setter
4962  def _global_distribute_strategy_scope(self, distribute_strategy_scope):
4963    self._thread_local.distribute_strategy_scope = (distribute_strategy_scope)
4964
4965  @property
4966  def _auto_cast_variable_read_dtype(self):
4967    """The dtype that instances of `AutoCastVariable` will be casted to.
4968
4969    This is None if `AutoCastVariables` should not be casted.
4970
4971    See `AutoCastVariable` for more information.
4972
4973    Returns:
4974      The dtype that instances of `AutoCastVariable` will be casted to.
4975    """
4976    if not hasattr(self._thread_local, "_auto_cast_variable_read_dtype"):
4977      self._thread_local._auto_cast_variable_read_dtype = None  # pylint: disable=protected-access
4978    return self._thread_local._auto_cast_variable_read_dtype  # pylint: disable=protected-access
4979
4980  @_auto_cast_variable_read_dtype.setter
4981  def _auto_cast_variable_read_dtype(self, dtype):
4982    if dtype:
4983      dtype = dtypes.as_dtype(dtype)
4984    self._thread_local._auto_cast_variable_read_dtype = dtype  # pylint: disable=protected-access
4985
4986  @tf_contextlib.contextmanager
4987  def _enable_auto_casting_variables(self, dtype):
4988    """Context manager to automatically cast AutoCastVariables.
4989
4990    If an AutoCastVariable `var` is used under this context manager, it will be
4991    casted to `dtype` before being used.
4992
4993    See `AutoCastVariable` for more information.
4994
4995    Args:
4996      dtype: The dtype that AutoCastVariables should be casted to.
4997
4998    Yields:
4999      Nothing.
5000    """
5001    prev_read_dtype = self._auto_cast_variable_read_dtype
5002    try:
5003      self._auto_cast_variable_read_dtype = dtype
5004      yield
5005    finally:
5006      self._auto_cast_variable_read_dtype = prev_read_dtype
5007
5008  def _mutation_lock(self):
5009    """Returns a lock to guard code that creates & mutates ops.
5010
5011    See the comment for self._group_lock for more info.
5012    """
5013    return self._group_lock.group(_MUTATION_LOCK_GROUP)
5014
5015  def _session_run_lock(self):
5016    """Returns a lock to guard code for Session.run.
5017
5018    See the comment for self._group_lock for more info.
5019    """
5020    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5021
5022
5023# TODO(agarwal): currently device directives in an outer eager scope will not
5024# apply to inner graph mode code. Fix that.
5025
5026
5027@tf_export(v1=["device"])
5028def device(device_name_or_function):
5029  """Wrapper for `Graph.device()` using the default graph.
5030
5031  See `tf.Graph.device` for more details.
5032
5033  Args:
5034    device_name_or_function: The device name or function to use in the context.
5035
5036  Returns:
5037    A context manager that specifies the default device to use for newly
5038    created ops.
5039
5040  Raises:
5041    RuntimeError: If eager execution is enabled and a function is passed in.
5042  """
5043  if context.executing_eagerly():
5044    if callable(device_name_or_function):
5045      raise RuntimeError(
5046          "tf.device does not support functions when eager execution "
5047          "is enabled.")
5048    return context.device(device_name_or_function)
5049  elif executing_eagerly_outside_functions():
5050    @tf_contextlib.contextmanager
5051    def combined(device_name_or_function):
5052      with get_default_graph().device(device_name_or_function):
5053        if not callable(device_name_or_function):
5054          with context.device(device_name_or_function):
5055            yield
5056        else:
5057          yield
5058    return combined(device_name_or_function)
5059  else:
5060    return get_default_graph().device(device_name_or_function)
5061
5062
5063@tf_export("device", v1=[])
5064def device_v2(device_name):
5065  """Specifies the device for ops created/executed in this context.
5066
5067  This function specifies the device to be used for ops created/executed in a
5068  particular context. Nested contexts will inherit and also create/execute
5069  their ops on the specified device. If a specific device is not required,
5070  consider not using this function so that a device can be automatically
5071  assigned.  In general the use of this function is optional. `device_name` can
5072  be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially
5073  specified, containing only a subset of the "/"-separated fields. Any fields
5074  which are specified will override device annotations from outer scopes.
5075
5076  For example:
5077
5078  ```python
5079  with tf.device('/job:foo'):
5080    # ops created here have devices with /job:foo
5081    with tf.device('/job:bar/task:0/device:gpu:2'):
5082      # ops created here have the fully specified device above
5083    with tf.device('/device:gpu:1'):
5084      # ops created here have the device '/job:foo/device:gpu:1'
5085  ```
5086
5087  Args:
5088    device_name: The device name to use in the context.
5089
5090  Returns:
5091    A context manager that specifies the default device to use for newly
5092    created ops.
5093
5094  Raises:
5095    RuntimeError: If a function is passed in.
5096  """
5097  if callable(device_name):
5098    raise RuntimeError("tf.device does not support functions.")
5099  return device(device_name)
5100
5101
5102@tf_export(v1=["container"])
5103def container(container_name):
5104  """Wrapper for `Graph.container()` using the default graph.
5105
5106  Args:
5107    container_name: The container string to use in the context.
5108
5109  Returns:
5110    A context manager that specifies the default container to use for newly
5111    created stateful ops.
5112  """
5113  return get_default_graph().container(container_name)
5114
5115
5116def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5117  if context.executing_eagerly():
5118    if op is not None:
5119      if not hasattr(op, "device"):
5120        op = internal_convert_to_tensor_or_indexed_slices(op)
5121      return device(op.device)
5122    else:
5123      return NullContextmanager()
5124  else:
5125    default_graph = get_default_graph()
5126    if isinstance(op, EagerTensor):
5127      if default_graph.building_function:
5128        return default_graph.device(op.device)
5129      else:
5130        raise ValueError("Encountered an Eager-defined Tensor during graph "
5131                         "construction, but a function was not being built.")
5132    return default_graph._colocate_with_for_gradient(
5133        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5134
5135
5136# Internal interface to colocate_with. colocate_with has been deprecated from
5137# public API. There are still a few internal uses of colocate_with. Add internal
5138# only API for those uses to avoid deprecation warning.
5139def colocate_with(op, ignore_existing=False):
5140  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5141
5142
5143@deprecation.deprecated(
5144    date=None, instructions="Colocations handled automatically by placer.")
5145@tf_export(v1=["colocate_with"])
5146def _colocate_with(op, ignore_existing=False):
5147  return colocate_with(op, ignore_existing)
5148
5149
5150@tf_export("control_dependencies")
5151def control_dependencies(control_inputs):
5152  """Wrapper for `Graph.control_dependencies()` using the default graph.
5153
5154  See `tf.Graph.control_dependencies`
5155  for more details.
5156
5157  When eager execution is enabled, any callable object in the `control_inputs`
5158  list will be called.
5159
5160  Args:
5161    control_inputs: A list of `Operation` or `Tensor` objects which must be
5162      executed or computed before running the operations defined in the context.
5163      Can also be `None` to clear the control dependencies. If eager execution
5164      is enabled, any callable object in the `control_inputs` list will be
5165      called.
5166
5167  Returns:
5168   A context manager that specifies control dependencies for all
5169   operations constructed within the context.
5170  """
5171  if context.executing_eagerly():
5172    if control_inputs:
5173      # Excute any pending callables.
5174      for control in control_inputs:
5175        if callable(control):
5176          control()
5177    return NullContextmanager()
5178  else:
5179    return get_default_graph().control_dependencies(control_inputs)
5180
5181
5182class _DefaultStack(threading.local):
5183  """A thread-local stack of objects for providing implicit defaults."""
5184
5185  def __init__(self):
5186    super(_DefaultStack, self).__init__()
5187    self._enforce_nesting = True
5188    self.stack = []
5189
5190  def get_default(self):
5191    return self.stack[-1] if len(self.stack) >= 1 else None
5192
5193  def reset(self):
5194    self.stack = []
5195
5196  def is_cleared(self):
5197    return not self.stack
5198
5199  @property
5200  def enforce_nesting(self):
5201    return self._enforce_nesting
5202
5203  @enforce_nesting.setter
5204  def enforce_nesting(self, value):
5205    self._enforce_nesting = value
5206
5207  @tf_contextlib.contextmanager
5208  def get_controller(self, default):
5209    """A context manager for manipulating a default stack."""
5210    self.stack.append(default)
5211    try:
5212      yield default
5213    finally:
5214      # stack may be empty if reset() was called
5215      if self.stack:
5216        if self._enforce_nesting:
5217          if self.stack[-1] is not default:
5218            raise AssertionError(
5219                "Nesting violated for default stack of %s objects" %
5220                type(default))
5221          self.stack.pop()
5222        else:
5223          self.stack.remove(default)
5224
5225
5226_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
5227
5228
5229def default_session(session):
5230  """Python "with" handler for defining a default session.
5231
5232  This function provides a means of registering a session for handling
5233  Tensor.eval() and Operation.run() calls. It is primarily intended for use
5234  by session.Session, but can be used with any object that implements
5235  the Session.run() interface.
5236
5237  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
5238  invocations within the scope of a block should be executed by a particular
5239  session.
5240
5241  The default session applies to the current thread only, so it is always
5242  possible to inspect the call stack and determine the scope of a default
5243  session. If you create a new thread, and wish to use the default session
5244  in that thread, you must explicitly add a "with ops.default_session(sess):"
5245  block in that thread's function.
5246
5247  Example:
5248    The following code examples are equivalent:
5249
5250    # 1. Using the Session object directly:
5251    sess = ...
5252    c = tf.constant(5.0)
5253    sess.run(c)
5254
5255    # 2. Using default_session():
5256    sess = ...
5257    with ops.default_session(sess):
5258      c = tf.constant(5.0)
5259      result = c.eval()
5260
5261    # 3. Overriding default_session():
5262    sess = ...
5263    with ops.default_session(sess):
5264      c = tf.constant(5.0)
5265      with ops.default_session(...):
5266        c.eval(session=sess)
5267
5268  Args:
5269    session: The session to be installed as the default session.
5270
5271  Returns:
5272    A context manager for the default session.
5273  """
5274  return _default_session_stack.get_controller(session)
5275
5276
5277@tf_export(v1=["get_default_session"])
5278def get_default_session():
5279  """Returns the default session for the current thread.
5280
5281  The returned `Session` will be the innermost session on which a
5282  `Session` or `Session.as_default()` context has been entered.
5283
5284  NOTE: The default session is a property of the current thread. If you
5285  create a new thread, and wish to use the default session in that
5286  thread, you must explicitly add a `with sess.as_default():` in that
5287  thread's function.
5288
5289  Returns:
5290    The default `Session` being used in the current thread.
5291  """
5292  return _default_session_stack.get_default()
5293
5294
5295def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5296  """Uses the default session to evaluate one or more tensors.
5297
5298  Args:
5299    tensors: A single Tensor, or a list of Tensor objects.
5300    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5301      numpy ndarrays, TensorProtos, or strings.
5302    graph: The graph in which the tensors are defined.
5303    session: (Optional) A different session to use to evaluate "tensors".
5304
5305  Returns:
5306    Either a single numpy ndarray if "tensors" is a single tensor; or a list
5307    of numpy ndarrays that each correspond to the respective element in
5308    "tensors".
5309
5310  Raises:
5311    ValueError: If no default session is available; the default session
5312      does not have "graph" as its graph; or if "session" is specified,
5313      and it does not have "graph" as its graph.
5314  """
5315  if session is None:
5316    session = get_default_session()
5317    if session is None:
5318      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5319                       "session is registered. Use `with "
5320                       "sess.as_default()` or pass an explicit session to "
5321                       "`eval(session=sess)`")
5322    if session.graph is not graph:
5323      raise ValueError("Cannot use the default session to evaluate tensor: "
5324                       "the tensor's graph is different from the session's "
5325                       "graph. Pass an explicit session to "
5326                       "`eval(session=sess)`.")
5327  else:
5328    if session.graph is not graph:
5329      raise ValueError("Cannot use the given session to evaluate tensor: "
5330                       "the tensor's graph is different from the session's "
5331                       "graph.")
5332  return session.run(tensors, feed_dict)
5333
5334
5335def _run_using_default_session(operation, feed_dict, graph, session=None):
5336  """Uses the default session to run "operation".
5337
5338  Args:
5339    operation: The Operation to be run.
5340    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5341      numpy ndarrays, TensorProtos, or strings.
5342    graph: The graph in which "operation" is defined.
5343    session: (Optional) A different session to use to run "operation".
5344
5345  Raises:
5346    ValueError: If no default session is available; the default session
5347      does not have "graph" as its graph; or if "session" is specified,
5348      and it does not have "graph" as its graph.
5349  """
5350  if session is None:
5351    session = get_default_session()
5352    if session is None:
5353      raise ValueError("Cannot execute operation using `run()`: No default "
5354                       "session is registered. Use `with "
5355                       "sess.as_default():` or pass an explicit session to "
5356                       "`run(session=sess)`")
5357    if session.graph is not graph:
5358      raise ValueError("Cannot use the default session to execute operation: "
5359                       "the operation's graph is different from the "
5360                       "session's graph. Pass an explicit session to "
5361                       "run(session=sess).")
5362  else:
5363    if session.graph is not graph:
5364      raise ValueError("Cannot use the given session to execute operation: "
5365                       "the operation's graph is different from the session's "
5366                       "graph.")
5367  session.run(operation, feed_dict)
5368
5369
5370class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
5371  """A thread-local stack of objects for providing an implicit default graph."""
5372
5373  def __init__(self):
5374    super(_DefaultGraphStack, self).__init__()
5375    self._global_default_graph = None
5376
5377  def get_default(self):
5378    """Override that returns a global default if the stack is empty."""
5379    ret = super(_DefaultGraphStack, self).get_default()
5380    if ret is None:
5381      ret = self._GetGlobalDefaultGraph()
5382    return ret
5383
5384  def _GetGlobalDefaultGraph(self):
5385    if self._global_default_graph is None:
5386      # TODO(mrry): Perhaps log that the default graph is being used, or set
5387      #   provide some other feedback to prevent confusion when a mixture of
5388      #   the global default graph and an explicit graph are combined in the
5389      #   same process.
5390      self._global_default_graph = Graph()
5391    return self._global_default_graph
5392
5393  def reset(self):
5394    super(_DefaultGraphStack, self).reset()
5395    self._global_default_graph = None
5396
5397  @tf_contextlib.contextmanager
5398  def get_controller(self, default):
5399    context.context().context_switches.push(default.building_function,
5400                                            default.as_default,
5401                                            default._device_function_stack)
5402    try:
5403      with super(_DefaultGraphStack,
5404                 self).get_controller(default) as g, context.graph_mode():
5405        yield g
5406    finally:
5407      # If an exception is raised here it may be hiding a related exception in
5408      # the try-block (just above).
5409      context.context().context_switches.pop()
5410
5411
5412_default_graph_stack = _DefaultGraphStack()
5413
5414
5415# Shared helper used in init_scope and executing_eagerly_outside_functions
5416# to obtain the outermost context that is not building a function, and the
5417# innermost non empty device stack.
5418def _get_outer_context_and_inner_device_stack():
5419  """Get the outermost context not building a function."""
5420  default_graph = get_default_graph()
5421  outer_context = None
5422  innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
5423
5424  if not _default_graph_stack.stack:
5425    # If the default graph stack is empty, then we cannot be building a
5426    # function. Install the global graph (which, in this case, is also the
5427    # default graph) as the outer context.
5428    if default_graph.building_function:
5429      raise RuntimeError("The global graph is building a function.")
5430    outer_context = default_graph.as_default
5431  else:
5432    # Find a context that is not building a function.
5433    for stack_entry in reversed(context.context().context_switches.stack):
5434      if not innermost_nonempty_device_stack:
5435        innermost_nonempty_device_stack = stack_entry.device_stack
5436      if not stack_entry.is_building_function:
5437        outer_context = stack_entry.enter_context_fn
5438        break
5439
5440    if outer_context is None:
5441      # As a last resort, obtain the global default graph; this graph doesn't
5442      # necessarily live on the graph stack (and hence it doesn't necessarily
5443      # live on the context stack), but it is stored in the graph stack's
5444      # encapsulating object.
5445      outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
5446
5447  if outer_context is None:
5448    # Sanity check; this shouldn't be triggered.
5449    raise RuntimeError("All graphs are building functions, and no "
5450                       "eager context was previously active.")
5451
5452  return outer_context, innermost_nonempty_device_stack
5453
5454
5455# pylint: disable=g-doc-return-or-yield,line-too-long
5456@tf_export("init_scope")
5457@tf_contextlib.contextmanager
5458def init_scope():
5459  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5460
5461  There is often a need to lift variable initialization ops out of control-flow
5462  scopes, function-building graphs, and gradient tapes. Entering an
5463  `init_scope` is a mechanism for satisfying these desiderata. In particular,
5464  entering an `init_scope` has three effects:
5465
5466    (1) All control dependencies are cleared the moment the scope is entered;
5467        this is equivalent to entering the context manager returned from
5468        `control_dependencies(None)`, which has the side-effect of exiting
5469        control-flow scopes like `tf.cond` and `tf.while_loop`.
5470
5471    (2) All operations that are created while the scope is active are lifted
5472        into the lowest context on the `context_stack` that is not building a
5473        graph function. Here, a context is defined as either a graph or an eager
5474        context. Every context switch, i.e., every installation of a graph as
5475        the default graph and every switch into eager mode, is logged in a
5476        thread-local stack called `context_switches`; the log entry for a
5477        context switch is popped from the stack when the context is exited.
5478        Entering an `init_scope` is equivalent to crawling up
5479        `context_switches`, finding the first context that is not building a
5480        graph function, and entering it. A caveat is that if graph mode is
5481        enabled but the default graph stack is empty, then entering an
5482        `init_scope` will simply install a fresh graph as the default one.
5483
5484    (3) The gradient tape is paused while the scope is active.
5485
5486  When eager execution is enabled, code inside an init_scope block runs with
5487  eager execution enabled even when tracing a `tf.function`. For example:
5488
5489  ```python
5490  tf.compat.v1.enable_eager_execution()
5491
5492  @tf.function
5493  def func():
5494    # A function constructs TensorFlow graphs,
5495    # it does not execute eagerly.
5496    assert not tf.executing_eagerly()
5497    with tf.init_scope():
5498      # Initialization runs with eager execution enabled
5499      assert tf.executing_eagerly()
5500  ```
5501
5502  Raises:
5503    RuntimeError: if graph state is incompatible with this initialization.
5504  """
5505  # pylint: enable=g-doc-return-or-yield,line-too-long
5506
5507  if context.executing_eagerly():
5508    # Fastpath.
5509    with tape.stop_recording():
5510      yield
5511  else:
5512    # Retrieve the active name scope: entering an `init_scope` preserves
5513    # the name scope of the current context.
5514    scope = get_default_graph().get_name_scope()
5515    if scope and scope[-1] != "/":
5516      # Names that end with trailing slashes are treated by `name_scope` as
5517      # absolute.
5518      scope = scope + "/"
5519
5520    outer_context, innermost_nonempty_device_stack = (
5521        _get_outer_context_and_inner_device_stack())
5522
5523    outer_graph = None
5524    outer_device_stack = None
5525    try:
5526      with outer_context(), name_scope(
5527          scope, skip_on_eager=False), control_dependencies(
5528              None), tape.stop_recording():
5529        context_manager = NullContextmanager
5530        context_manager_input = None
5531        if not context.executing_eagerly():
5532          # The device stack is preserved when lifting into a graph. Eager
5533          # execution doesn't implement device stacks and in particular it
5534          # doesn't support device functions, so in general it's not possible
5535          # to do the same when lifting into the eager context.
5536          outer_graph = get_default_graph()
5537          outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
5538          outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
5539        elif innermost_nonempty_device_stack is not None:
5540          for device_spec in innermost_nonempty_device_stack.peek_objs():
5541            if device_spec.function is None:
5542              break
5543            if device_spec.raw_string:
5544              context_manager = context.device
5545              context_manager_input = device_spec.raw_string
5546              break
5547            # It is currently not possible to have a device function in V2,
5548            # but in V1 we are unable to apply device functions in eager mode.
5549            # This means that we will silently skip some of the entries on the
5550            # device stack in V1 + eager mode.
5551
5552        with context_manager(context_manager_input):
5553          yield
5554    finally:
5555      # If an exception is raised here it may be hiding a related exception in
5556      # try-block (just above).
5557      if outer_graph is not None:
5558        outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
5559
5560
5561def executing_eagerly_outside_functions():
5562  """Returns True if executing eagerly, even if inside a graph function."""
5563  if context.executing_eagerly():
5564    return True
5565  else:
5566    outer_context, _ = _get_outer_context_and_inner_device_stack()
5567    with outer_context():
5568      return context.executing_eagerly()
5569
5570
5571def inside_function():
5572  return get_default_graph().building_function
5573
5574
5575@tf_export(v1=["enable_eager_execution"])
5576def enable_eager_execution(config=None, device_policy=None,
5577                           execution_mode=None):
5578  """Enables eager execution for the lifetime of this program.
5579
5580  Eager execution provides an imperative interface to TensorFlow. With eager
5581  execution enabled, TensorFlow functions execute operations immediately (as
5582  opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`)
5583  and
5584  return concrete values (as opposed to symbolic references to a node in a
5585  computational graph).
5586
5587  For example:
5588
5589  ```python
5590  tf.compat.v1.enable_eager_execution()
5591
5592  # After eager execution is enabled, operations are executed as they are
5593  # defined and Tensor objects hold concrete values, which can be accessed as
5594  # numpy.ndarray`s through the numpy() method.
5595  assert tf.multiply(6, 7).numpy() == 42
5596  ```
5597
5598  Eager execution cannot be enabled after TensorFlow APIs have been used to
5599  create or execute graphs. It is typically recommended to invoke this function
5600  at program startup and not in a library (as most libraries should be usable
5601  both with and without eager execution).
5602
5603  Args:
5604    config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the
5605      environment in which operations are executed. Note that
5606      `tf.compat.v1.ConfigProto` is also used to configure graph execution (via
5607      `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto`
5608      are not implemented (or are irrelevant) when eager execution is enabled.
5609    device_policy: (Optional.) Policy controlling how operations requiring
5610      inputs on a specific device (e.g., a GPU 0) handle inputs on a different
5611      device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will
5612      be picked automatically. The value picked may change between TensorFlow
5613      releases.
5614      Valid values:
5615      - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
5616        placement is not correct.
5617      - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
5618        on the right device but logs a warning.
5619      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
5620        Note that this may hide performance problems as there is no notification
5621        provided when operations are blocked on the tensor being copied between
5622        devices.
5623      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
5624        int32 tensors, raising errors on the other ones.
5625    execution_mode: (Optional.) Policy controlling how operations dispatched are
5626      actually executed. When set to None, an appropriate value will be picked
5627      automatically. The value picked may change between TensorFlow releases.
5628      Valid values:
5629      - tf.contrib.eager.SYNC: executes each operation synchronously.
5630      - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
5631        operations may return "non-ready" handles.
5632
5633  Raises:
5634    ValueError: If eager execution is enabled after creating/executing a
5635     TensorFlow graph, or if options provided conflict with a previous call
5636     to this function.
5637  """
5638  _api_usage_gauge.get_cell().set(True)
5639  if context.default_execution_mode != context.EAGER_MODE:
5640    return enable_eager_execution_internal(
5641        config=config,
5642        device_policy=device_policy,
5643        execution_mode=execution_mode,
5644        server_def=None)
5645
5646
5647@tf_export(v1=["disable_eager_execution"])
5648def disable_eager_execution():
5649  """Disables eager execution.
5650
5651  This function can only be called before any Graphs, Ops, or Tensors have been
5652  created. It can be used at the beginning of the program for complex migration
5653  projects from TensorFlow 1.x to 2.x.
5654  """
5655  _api_usage_gauge.get_cell().set(False)
5656  context.default_execution_mode = context.GRAPH_MODE
5657  c = context.context_safe()
5658  if c is not None:
5659    c._thread_local_data.is_eager = False  # pylint: disable=protected-access
5660
5661
5662def enable_eager_execution_internal(config=None,
5663                                    device_policy=None,
5664                                    execution_mode=None,
5665                                    server_def=None):
5666  """Enables eager execution for the lifetime of this program.
5667
5668  Most of the doc string for enable_eager_execution is relevant here as well.
5669
5670  Args:
5671    config: See enable_eager_execution doc string
5672    device_policy: See enable_eager_execution doc string
5673    execution_mode: See enable_eager_execution doc string
5674    server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on
5675      remote devices. GrpcServers need to be started by creating an identical
5676      server_def to this, and setting the appropriate task_indexes, so that the
5677      servers can communicate. It will then be possible to execute operations on
5678      remote devices.
5679
5680  Raises:
5681    ValueError
5682
5683  """
5684  if config is not None and not isinstance(config, config_pb2.ConfigProto):
5685    raise TypeError("config must be a tf.ConfigProto, but got %s" %
5686                    type(config))
5687  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
5688                           context.DEVICE_PLACEMENT_WARN,
5689                           context.DEVICE_PLACEMENT_SILENT,
5690                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
5691    raise ValueError(
5692        "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
5693    )
5694  if execution_mode not in (None, context.SYNC, context.ASYNC):
5695    raise ValueError(
5696        "execution_mode must be one of None, tf.contrib.eager.SYNC, "
5697        "tf.contrib.eager.ASYNC")
5698  if context.default_execution_mode == context.GRAPH_MODE:
5699    graph_mode_has_been_used = (
5700        _default_graph_stack._global_default_graph is not None)  # pylint: disable=protected-access
5701    if graph_mode_has_been_used:
5702      raise ValueError(
5703          "tf.enable_eager_execution must be called at program startup.")
5704  context.default_execution_mode = context.EAGER_MODE
5705  # pylint: disable=protected-access
5706  with context._context_lock:
5707    if context._context is None:
5708      context._set_context_locked(context.Context(
5709          config=config,
5710          device_policy=device_policy,
5711          execution_mode=execution_mode,
5712          server_def=server_def))
5713    elif ((config is not None and config is not context._context._config) or
5714          (device_policy is not None and
5715           device_policy is not context._context._device_policy) or
5716          (execution_mode is not None and
5717           execution_mode is not context._context._execution_mode)):
5718      raise ValueError(
5719          "Trying to change the options of an active eager"
5720          " execution. Context config: %s, specified config:"
5721          " %s. Context device policy: %s, specified device"
5722          " policy: %s. Context execution mode: %s, "
5723          " specified execution mode %s." %
5724          (context._context._config, config, context._context._device_policy,
5725           device_policy, context._context._execution_mode, execution_mode))
5726    else:
5727      # We already created everything, so update the thread local data.
5728      context._context._thread_local_data.is_eager = True
5729
5730  # Monkey patch to get rid of an unnecessary conditional since the context is
5731  # now initialized.
5732  context.context = context.context_safe
5733
5734
5735def eager_run(main=None, argv=None):
5736  """Runs the program with an optional main function and argv list.
5737
5738  The program will run with eager execution enabled.
5739
5740  Example:
5741  ```python
5742  import tensorflow as tf
5743  # Import subject to future changes:
5744  from tensorflow.contrib.eager.python import tfe
5745
5746  def main(_):
5747    u = tf.constant(6.0)
5748    v = tf.constant(7.0)
5749    print(u * v)
5750
5751  if __name__ == "__main__":
5752    tfe.run()
5753  ```
5754
5755  Args:
5756    main: the main function to run.
5757    argv: the arguments to pass to it.
5758  """
5759  enable_eager_execution()
5760  app.run(main, argv)
5761
5762
5763@tf_export(v1=["reset_default_graph"])
5764def reset_default_graph():
5765  """Clears the default graph stack and resets the global default graph.
5766
5767  NOTE: The default graph is a property of the current thread. This
5768  function applies only to the current thread.  Calling this function while
5769  a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will
5770  result in undefined
5771  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
5772  after calling this function will result in undefined behavior.
5773  Raises:
5774    AssertionError: If this function is called within a nested graph.
5775  """
5776  if not _default_graph_stack.is_cleared():
5777    raise AssertionError("Do not use tf.reset_default_graph() to clear "
5778                         "nested graphs. If you need a cleared graph, "
5779                         "exit the nesting and create a new graph.")
5780  _default_graph_stack.reset()
5781
5782
5783@tf_export(v1=["get_default_graph"])
5784def get_default_graph():
5785  """Returns the default graph for the current thread.
5786
5787  The returned graph will be the innermost graph on which a
5788  `Graph.as_default()` context has been entered, or a global default
5789  graph if none has been explicitly created.
5790
5791  NOTE: The default graph is a property of the current thread. If you
5792  create a new thread, and wish to use the default graph in that
5793  thread, you must explicitly add a `with g.as_default():` in that
5794  thread's function.
5795
5796  Returns:
5797    The default `Graph` being used in the current thread.
5798  """
5799  return _default_graph_stack.get_default()
5800
5801
5802def has_default_graph():
5803  """Returns True if there is a default graph."""
5804  return len(_default_graph_stack.stack) >= 1
5805
5806
5807def get_name_scope():
5808  """Returns the current name scope in the default_graph.
5809
5810  For example:
5811
5812  ```python
5813  with tf.name_scope('scope1'):
5814    with tf.name_scope('scope2'):
5815      print(tf.get_name_scope())
5816  ```
5817  would print the string `scope1/scope2`.
5818
5819  Returns:
5820    A string representing the current name scope.
5821  """
5822  if context.executing_eagerly():
5823    return context.context().scope_name.rstrip("/")
5824  return get_default_graph().get_name_scope()
5825
5826
5827def _assert_same_graph(original_item, item):
5828  """Fail if the 2 items are from different graphs.
5829
5830  Args:
5831    original_item: Original item to check against.
5832    item: Item to check.
5833
5834  Raises:
5835    ValueError: if graphs do not match.
5836  """
5837  if original_item.graph is not item.graph:
5838    raise ValueError("%s must be from the same graph as %s." %
5839                     (item, original_item))
5840
5841
5842def _get_graph_from_inputs(op_input_list, graph=None):
5843  """Returns the appropriate graph to use for the given inputs.
5844
5845  This library method provides a consistent algorithm for choosing the graph
5846  in which an Operation should be constructed:
5847
5848  1. If the default graph is being used to construct a function, we
5849     use the default graph.
5850  2. If the "graph" is specified explicitly, we validate that all of the inputs
5851     in "op_input_list" are compatible with that graph.
5852  3. Otherwise, we attempt to select a graph from the first Operation-
5853     or Tensor-valued input in "op_input_list", and validate that all other
5854     such inputs are in the same graph.
5855  4. If the graph was not specified and it could not be inferred from
5856     "op_input_list", we attempt to use the default graph.
5857
5858  Args:
5859    op_input_list: A list of inputs to an operation, which may include `Tensor`,
5860      `Operation`, and other objects that may be converted to a graph element.
5861    graph: (Optional) The explicit graph to use.
5862
5863  Raises:
5864    TypeError: If op_input_list is not a list or tuple, or if graph is not a
5865      Graph.
5866    ValueError: If a graph is explicitly passed and not all inputs are from it,
5867      or if the inputs are from multiple graphs, or we could not find a graph
5868      and there was no default graph.
5869
5870  Returns:
5871    The appropriate graph to use for the given inputs.
5872
5873  """
5874  current_default_graph = get_default_graph()
5875  if current_default_graph.building_function:
5876    return current_default_graph
5877
5878  op_input_list = tuple(op_input_list)  # Handle generators correctly
5879  if graph and not isinstance(graph, Graph):
5880    raise TypeError("Input graph needs to be a Graph: %s" % graph)
5881
5882  # 1. We validate that all of the inputs are from the same graph. This is
5883  #    either the supplied graph parameter, or the first one selected from one
5884  #    the graph-element-valued inputs. In the latter case, we hold onto
5885  #    that input in original_graph_element so we can provide a more
5886  #    informative error if a mismatch is found.
5887  original_graph_element = None
5888  for op_input in op_input_list:
5889    # Determine if this is a valid graph_element.
5890    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
5891    # up.
5892    graph_element = None
5893    if (isinstance(op_input, (Operation, _TensorLike)) and
5894        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
5895      graph_element = op_input
5896    else:
5897      graph_element = _as_graph_element(op_input)
5898
5899    if graph_element is not None:
5900      if not graph:
5901        original_graph_element = graph_element
5902        graph = graph_element.graph
5903      elif original_graph_element is not None:
5904        _assert_same_graph(original_graph_element, graph_element)
5905      elif graph_element.graph is not graph:
5906        raise ValueError("%s is not from the passed-in graph." % graph_element)
5907
5908  # 2. If all else fails, we use the default graph, which is always there.
5909  return graph or current_default_graph
5910
5911
5912@tf_export(v1=["GraphKeys"])
5913class GraphKeys(object):
5914  """Standard names to use for graph collections.
5915
5916  The standard library uses various well-known names to collect and
5917  retrieve values associated with a graph. For example, the
5918  `tf.Optimizer` subclasses default to optimizing the variables
5919  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
5920  specified, but it is also possible to pass an explicit list of
5921  variables.
5922
5923  The following standard keys are defined:
5924
5925  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
5926    across distributed environment (model variables are subset of these). See
5927    `tf.compat.v1.global_variables`
5928    for more details.
5929    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
5930    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
5931  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
5932    machine. Usually used for temporarily variables, like counters.
5933    Note: use `tf.contrib.framework.local_variable` to add to this collection.
5934  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
5935    model for inference (feed forward). Note: use
5936    `tf.contrib.framework.model_variable` to add to this collection.
5937  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
5938    be trained by an optimizer. See
5939    `tf.compat.v1.trainable_variables`
5940    for more details.
5941  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
5942    graph. See
5943    `tf.compat.v1.summary.merge_all`
5944    for more details.
5945  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
5946    produce input for a computation. See
5947    `tf.compat.v1.train.start_queue_runners`
5948    for more details.
5949  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
5950    keep moving averages.  See
5951    `tf.compat.v1.moving_average_variables`
5952    for more details.
5953  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
5954    construction.
5955
5956  The following standard keys are _defined_, but their collections are **not**
5957  automatically populated as many of the others are:
5958
5959  * `WEIGHTS`
5960  * `BIASES`
5961  * `ACTIVATIONS`
5962  """
5963
5964  # Key to collect Variable objects that are global (shared across machines).
5965  # Default collection for all variables, except local ones.
5966  GLOBAL_VARIABLES = "variables"
5967  # Key to collect local variables that are local to the machine and are not
5968  # saved/restored.
5969  LOCAL_VARIABLES = "local_variables"
5970  # Key to collect local variables which are used to accumulate interal state
5971  # to be used in tf.metrics.*.
5972  METRIC_VARIABLES = "metric_variables"
5973  # Key to collect model variables defined by layers.
5974  MODEL_VARIABLES = "model_variables"
5975  # Key to collect Variable objects that will be trained by the
5976  # optimizers.
5977  TRAINABLE_VARIABLES = "trainable_variables"
5978  # Key to collect summaries.
5979  SUMMARIES = "summaries"
5980  # Key to collect QueueRunners.
5981  QUEUE_RUNNERS = "queue_runners"
5982  # Key to collect table initializers.
5983  TABLE_INITIALIZERS = "table_initializer"
5984  # Key to collect asset filepaths. An asset represents an external resource
5985  # like a vocabulary file.
5986  ASSET_FILEPATHS = "asset_filepaths"
5987  # Key to collect Variable objects that keep moving averages.
5988  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
5989  # Key to collect regularization losses at graph construction.
5990  REGULARIZATION_LOSSES = "regularization_losses"
5991  # Key to collect concatenated sharded variables.
5992  CONCATENATED_VARIABLES = "concatenated_variables"
5993  # Key to collect savers.
5994  SAVERS = "savers"
5995  # Key to collect weights
5996  WEIGHTS = "weights"
5997  # Key to collect biases
5998  BIASES = "biases"
5999  # Key to collect activations
6000  ACTIVATIONS = "activations"
6001  # Key to collect update_ops
6002  UPDATE_OPS = "update_ops"
6003  # Key to collect losses
6004  LOSSES = "losses"
6005  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6006  SAVEABLE_OBJECTS = "saveable_objects"
6007  # Key to collect all shared resources used by the graph which need to be
6008  # initialized once per cluster.
6009  RESOURCES = "resources"
6010  # Key to collect all shared resources used in this graph which need to be
6011  # initialized once per session.
6012  LOCAL_RESOURCES = "local_resources"
6013  # Trainable resource-style variables.
6014  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6015
6016  # Key to indicate various ops.
6017  INIT_OP = "init_op"
6018  LOCAL_INIT_OP = "local_init_op"
6019  READY_OP = "ready_op"
6020  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6021  SUMMARY_OP = "summary_op"
6022  GLOBAL_STEP = "global_step"
6023
6024  # Used to count the number of evaluations performed during a single evaluation
6025  # run.
6026  EVAL_STEP = "eval_step"
6027  TRAIN_OP = "train_op"
6028
6029  # Key for control flow context.
6030  COND_CONTEXT = "cond_context"
6031  WHILE_CONTEXT = "while_context"
6032
6033  # Used to store v2 summary names.
6034  _SUMMARY_COLLECTION = "_SUMMARY_V2"
6035
6036  # List of all collections that keep track of variables.
6037  _VARIABLE_COLLECTIONS = [
6038      GLOBAL_VARIABLES,
6039      LOCAL_VARIABLES,
6040      METRIC_VARIABLES,
6041      MODEL_VARIABLES,
6042      TRAINABLE_VARIABLES,
6043      MOVING_AVERAGE_VARIABLES,
6044      CONCATENATED_VARIABLES,
6045      TRAINABLE_RESOURCE_VARIABLES,
6046  ]
6047
6048  # Key for streaming model ports.
6049  # NOTE(yuanbyu): internal and experimental.
6050  _STREAMING_MODEL_PORTS = "streaming_model_ports"
6051
6052  @decorator_utils.classproperty
6053  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6054  def VARIABLES(cls):  # pylint: disable=no-self-argument
6055    return cls.GLOBAL_VARIABLES
6056
6057
6058def dismantle_graph(graph):
6059  """Cleans up reference cycles from a `Graph`.
6060
6061  Helpful for making sure the garbage collector doesn't need to run after a
6062  temporary `Graph` is no longer needed.
6063
6064  Args:
6065    graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6066      after this function runs.
6067  """
6068  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
6069
6070  # Now clean up Operation<->Graph reference cycles by clearing all of the
6071  # attributes for the Graph and its ops.
6072  graph_operations = graph.get_operations()
6073  for op in graph_operations:
6074    op.__dict__ = {}
6075  graph.__dict__ = {}
6076
6077
6078@tf_export(v1=["add_to_collection"])
6079def add_to_collection(name, value):
6080  """Wrapper for `Graph.add_to_collection()` using the default graph.
6081
6082  See `tf.Graph.add_to_collection`
6083  for more details.
6084
6085  Args:
6086    name: The key for the collection. For example, the `GraphKeys` class
6087      contains many standard names for collections.
6088    value: The value to add to the collection.  @compatibility(eager)
6089      Collections are only supported in eager when variables are created inside
6090      an EagerVariableStore (e.g. as part of a layer or template).
6091      @end_compatibility
6092  """
6093  get_default_graph().add_to_collection(name, value)
6094
6095
6096@tf_export(v1=["add_to_collections"])
6097def add_to_collections(names, value):
6098  """Wrapper for `Graph.add_to_collections()` using the default graph.
6099
6100  See `tf.Graph.add_to_collections`
6101  for more details.
6102
6103  Args:
6104    names: The key for the collections. The `GraphKeys` class contains many
6105      standard names for collections.
6106    value: The value to add to the collections.  @compatibility(eager)
6107      Collections are only supported in eager when variables are created inside
6108      an EagerVariableStore (e.g. as part of a layer or template).
6109      @end_compatibility
6110  """
6111  get_default_graph().add_to_collections(names, value)
6112
6113
6114@tf_export(v1=["get_collection_ref"])
6115def get_collection_ref(key):
6116  """Wrapper for `Graph.get_collection_ref()` using the default graph.
6117
6118  See `tf.Graph.get_collection_ref`
6119  for more details.
6120
6121  Args:
6122    key: The key for the collection. For example, the `GraphKeys` class contains
6123      many standard names for collections.
6124
6125  Returns:
6126    The list of values in the collection with the given `name`, or an empty
6127    list if no value has been added to that collection.  Note that this returns
6128    the collection list itself, which can be modified in place to change the
6129    collection.
6130
6131  @compatibility(eager)
6132  Collections are not supported when eager execution is enabled.
6133  @end_compatibility
6134  """
6135  return get_default_graph().get_collection_ref(key)
6136
6137
6138@tf_export(v1=["get_collection"])
6139def get_collection(key, scope=None):
6140  """Wrapper for `Graph.get_collection()` using the default graph.
6141
6142  See `tf.Graph.get_collection`
6143  for more details.
6144
6145  Args:
6146    key: The key for the collection. For example, the `GraphKeys` class contains
6147      many standard names for collections.
6148    scope: (Optional.) If supplied, the resulting list is filtered to include
6149      only items whose `name` attribute matches using `re.match`. Items without
6150      a `name` attribute are never returned if a scope is supplied and the
6151      choice or `re.match` means that a `scope` without special tokens filters
6152      by prefix.
6153
6154  Returns:
6155    The list of values in the collection with the given `name`, or
6156    an empty list if no value has been added to that collection. The
6157    list contains the values in the order under which they were
6158    collected.
6159
6160  @compatibility(eager)
6161  Collections are not supported when eager execution is enabled.
6162  @end_compatibility
6163  """
6164  return get_default_graph().get_collection(key, scope)
6165
6166
6167def get_all_collection_keys():
6168  """Returns a list of collections used in the default graph."""
6169  return get_default_graph().get_all_collection_keys()
6170
6171
6172def name_scope(name, default_name=None, values=None, skip_on_eager=True):
6173  """Internal-only entry point for `name_scope*`.
6174
6175  Internal ops do not use the public API and instead rely on
6176  `ops.name_scope` regardless of the execution mode. This function
6177  dispatches to the correct `name_scope*` implementation based on
6178  the arguments provided and the current mode. Specifically,
6179
6180  * if `values` contains a graph tensor `Graph.name_scope` is used;
6181  * `name_scope_v1` is used in graph mode;
6182  * `name_scope_v2` -- in eager mode.
6183
6184  Args:
6185    name: The name argument that is passed to the op function.
6186    default_name: The default name to use if the `name` argument is `None`.
6187    values: The list of `Tensor` arguments that are passed to the op function.
6188    skip_on_eager: Indicates to return NullContextmanager if executing eagerly.
6189      By default this is True since naming tensors and operations in eager mode
6190      have little use and cause unecessary performance overhead. However, it is
6191      important to preseve variable names since they are often useful for
6192      debugging and saved models.
6193
6194  Returns:
6195    `name_scope*` context manager.
6196  """
6197  ctx = context.context()
6198  in_eager_mode = ctx.executing_eagerly()
6199  if not in_eager_mode:
6200    return internal_name_scope_v1(name, default_name, values)
6201
6202  if skip_on_eager:
6203    return NullContextmanager()
6204
6205  name = default_name if name is None else name
6206  if values:
6207    # The presence of a graph tensor in `values` overrides the context.
6208    # TODO(slebedev): this is Keras-specific and should be removed.
6209    # pylint: disable=unidiomatic-typecheck
6210    graph_value = next((value for value in values if type(value) == Tensor),
6211                       None)
6212    # pylint: enable=unidiomatic-typecheck
6213    if graph_value is not None:
6214      return graph_value.graph.name_scope(name)
6215
6216  return name_scope_v2(name or "")
6217
6218
6219class internal_name_scope_v1(object):  # pylint: disable=invalid-name
6220  """Graph-only version of `name_scope_v1`."""
6221
6222  @property
6223  def name(self):
6224    return self._name
6225
6226  def __init__(self, name, default_name=None, values=None):
6227    """Initialize the context manager.
6228
6229    Args:
6230      name: The name argument that is passed to the op function.
6231      default_name: The default name to use if the `name` argument is `None`.
6232      values: The list of `Tensor` arguments that are passed to the op function.
6233
6234    Raises:
6235      TypeError: if `default_name` is passed in but not a string.
6236    """
6237    if not (default_name is None or isinstance(default_name, six.string_types)):
6238      raise TypeError(
6239          "`default_name` type (%s) is not a string type. You likely meant to "
6240          "pass this into the `values` kwarg." % type(default_name))
6241    self._name = default_name if name is None else name
6242    self._default_name = default_name
6243    self._values = values
6244
6245  def __enter__(self):
6246    """Start the scope block.
6247
6248    Returns:
6249      The scope name.
6250
6251    Raises:
6252      ValueError: if neither `name` nor `default_name` is provided
6253        but `values` are.
6254    """
6255    if self._name is None and self._values is not None:
6256      # We only raise an error if values is not None (provided) because
6257      # currently tf.name_scope(None) (values=None then) is sometimes used as
6258      # an idiom to reset to top scope.
6259      raise ValueError(
6260          "At least one of name (%s) and default_name (%s) must be provided."
6261          % (self._name, self._default_name))
6262
6263    g = get_default_graph()
6264    if self._values and not g.building_function:
6265      # Specialize based on the knowledge that `_get_graph_from_inputs()`
6266      # ignores `inputs` when building a function.
6267      g_from_inputs = _get_graph_from_inputs(self._values)
6268      if g_from_inputs is not g:
6269        g = g_from_inputs
6270        self._g_manager = g.as_default()
6271        self._g_manager.__enter__()
6272      else:
6273        self._g_manager = None
6274    else:
6275      self._g_manager = None
6276
6277    try:
6278      self._name_scope = g.name_scope(self._name)
6279      return self._name_scope.__enter__()
6280    except:
6281      if self._g_manager is not None:
6282        self._g_manager.__exit__(*sys.exc_info())
6283      raise
6284
6285  def __exit__(self, *exc_info):
6286    self._name_scope.__exit__(*exc_info)
6287    if self._g_manager is not None:
6288      self._g_manager.__exit__(*exc_info)
6289
6290
6291# Named like a function for backwards compatibility with the
6292# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6293# some object creation overhead.
6294@tf_export(v1=["name_scope"])
6295class name_scope_v1(object):  # pylint: disable=invalid-name
6296  """A context manager for use when defining a Python op.
6297
6298  This context manager validates that the given `values` are from the
6299  same graph, makes that graph the default graph, and pushes a
6300  name scope in that graph (see
6301  `tf.Graph.name_scope`
6302  for more details on that).
6303
6304  For example, to define a new Python op called `my_op`:
6305
6306  ```python
6307  def my_op(a, b, c, name=None):
6308    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6309      a = tf.convert_to_tensor(a, name="a")
6310      b = tf.convert_to_tensor(b, name="b")
6311      c = tf.convert_to_tensor(c, name="c")
6312      # Define some computation that uses `a`, `b`, and `c`.
6313      return foo_op(..., name=scope)
6314  ```
6315  """
6316
6317  @property
6318  def name(self):
6319    return self._name
6320
6321  def __init__(self, name, default_name=None, values=None):
6322    """Initialize the context manager.
6323
6324    Args:
6325      name: The name argument that is passed to the op function.
6326      default_name: The default name to use if the `name` argument is `None`.
6327      values: The list of `Tensor` arguments that are passed to the op function.
6328
6329    Raises:
6330      TypeError: if `default_name` is passed in but not a string.
6331    """
6332    self._name_scope = name_scope(
6333        name, default_name, values, skip_on_eager=False)
6334    self._name = default_name if name is None else name
6335
6336  def __enter__(self):
6337    return self._name_scope.__enter__()
6338
6339  def __exit__(self, *exc_info):
6340    return self._name_scope.__exit__(*exc_info)
6341
6342
6343def enter_eager_name_scope(ctx, name):
6344  """Updates the eager context to enter the given name scope."""
6345  old_name = ctx.scope_name
6346  if not name:
6347    scope_name = ""
6348  else:
6349    if name.endswith("/"):
6350      # A trailing slash breaks out of nested name scopes, indicating a
6351      # fully specified scope name, for compatibility with Graph.name_scope.
6352      scope_name = name
6353    else:
6354      scope_name = name + "/"
6355      if old_name:
6356        scope_name = old_name + scope_name
6357  ctx.scope_name = scope_name
6358  return scope_name, old_name
6359
6360
6361@tf_export("name_scope", v1=[])
6362class name_scope_v2(object):
6363  """A context manager for use when defining a Python op.
6364
6365  This context manager pushes a name scope, which will make the name of all
6366  operations added within it have a prefix.
6367
6368  For example, to define a new Python op called `my_op`:
6369
6370  ```python
6371  def my_op(a, b, c, name=None):
6372    with tf.name_scope("MyOp") as scope:
6373      a = tf.convert_to_tensor(a, name="a")
6374      b = tf.convert_to_tensor(b, name="b")
6375      c = tf.convert_to_tensor(c, name="c")
6376      # Define some computation that uses `a`, `b`, and `c`.
6377      return foo_op(..., name=scope)
6378  ```
6379
6380  When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6381  and `MyOp/c`.
6382
6383  If the scope name already exists, the name will be made unique by appending
6384  `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`,
6385  etc.
6386  """
6387
6388  def __init__(self, name):
6389    """Initialize the context manager.
6390
6391    Args:
6392      name: The prefix to use on all names created within the name scope.
6393
6394    Raises:
6395      ValueError: If name is None, or not a string.
6396    """
6397    if name is None or not isinstance(name, six.string_types):
6398      raise ValueError("name for name_scope must be a string.")
6399    self._name = name
6400    self._exit_fns = []
6401
6402  @property
6403  def name(self):
6404    return self._name
6405
6406  def __enter__(self):
6407    """Start the scope block.
6408
6409    Returns:
6410      The scope name.
6411
6412    Raises:
6413      ValueError: if neither `name` nor `default_name` is provided
6414        but `values` are.
6415    """
6416    ctx = context.context()
6417    if ctx.executing_eagerly():
6418      scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name)
6419      self._exit_fns.append(
6420          lambda *a: setattr(ctx, "scope_name", old_scope_name))
6421    else:
6422      scope = get_default_graph().name_scope(self._name)
6423      scope_name = scope.__enter__()
6424      self._exit_fns.append(scope.__exit__)
6425    return scope_name
6426
6427  def __exit__(self, type_arg, value_arg, traceback_arg):
6428    exit_fn = self._exit_fns.pop()
6429    exit_fn(type_arg, value_arg, traceback_arg)
6430    return False  # False values do not suppress exceptions
6431
6432
6433def strip_name_scope(name, export_scope):
6434  """Removes name scope from a name.
6435
6436  Args:
6437    name: A `string` name.
6438    export_scope: Optional `string`. Name scope to remove.
6439
6440  Returns:
6441    Name with name scope removed, or the original name if export_scope
6442    is None.
6443  """
6444  if export_scope:
6445    if export_scope[-1] == "/":
6446      export_scope = export_scope[:-1]
6447
6448    try:
6449      # Strips export_scope/, export_scope///,
6450      # ^export_scope/, loc:@export_scope/.
6451      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
6452      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
6453    except TypeError as e:
6454      # If the name is not of a type we can process, simply return it.
6455      logging.warning(e)
6456      return name
6457  else:
6458    return name
6459
6460
6461def prepend_name_scope(name, import_scope):
6462  """Prepends name scope to a name.
6463
6464  Args:
6465    name: A `string` name.
6466    import_scope: Optional `string`. Name scope to add.
6467
6468  Returns:
6469    Name with name scope added, or the original name if import_scope
6470    is None.
6471  """
6472  if import_scope:
6473    if import_scope[-1] == "/":
6474      import_scope = import_scope[:-1]
6475
6476    try:
6477      str_to_replace = r"([\^]|loc:@|^)(.*)"
6478      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
6479                    compat.as_str(name))
6480    except TypeError as e:
6481      # If the name is not of a type we can process, simply return it.
6482      logging.warning(e)
6483      return name
6484  else:
6485    return name
6486
6487
6488# pylint: disable=g-doc-return-or-yield
6489# pylint: disable=not-context-manager
6490@tf_export(v1=["op_scope"])
6491@tf_contextlib.contextmanager
6492def op_scope(values, name, default_name=None):
6493  """DEPRECATED. Same as name_scope above, just different argument order."""
6494  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
6495               " use tf.name_scope(name, default_name, values)")
6496  with name_scope(name, default_name=default_name, values=values) as scope:
6497    yield scope
6498
6499
6500_proto_function_registry = registry.Registry("proto functions")
6501
6502
6503def register_proto_function(collection_name,
6504                            proto_type=None,
6505                            to_proto=None,
6506                            from_proto=None):
6507  """Registers `to_proto` and `from_proto` functions for collection_name.
6508
6509  `to_proto` function converts a Python object to the corresponding protocol
6510  buffer, and returns the protocol buffer.
6511
6512  `from_proto` function converts protocol buffer into a Python object, and
6513  returns the object..
6514
6515  Args:
6516    collection_name: Name of the collection.
6517    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
6518      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
6519    to_proto: Function that implements Python object to protobuf conversion.
6520    from_proto: Function that implements protobuf to Python object conversion.
6521  """
6522  if to_proto and not callable(to_proto):
6523    raise TypeError("to_proto must be callable.")
6524  if from_proto and not callable(from_proto):
6525    raise TypeError("from_proto must be callable.")
6526
6527  _proto_function_registry.register((proto_type, to_proto, from_proto),
6528                                    collection_name)
6529
6530
6531def get_collection_proto_type(collection_name):
6532  """Returns the proto_type for collection_name."""
6533  try:
6534    return _proto_function_registry.lookup(collection_name)[0]
6535  except LookupError:
6536    return None
6537
6538
6539def get_to_proto_function(collection_name):
6540  """Returns the to_proto function for collection_name."""
6541  try:
6542    return _proto_function_registry.lookup(collection_name)[1]
6543  except LookupError:
6544    return None
6545
6546
6547def get_from_proto_function(collection_name):
6548  """Returns the from_proto function for collection_name."""
6549  try:
6550    return _proto_function_registry.lookup(collection_name)[2]
6551  except LookupError:
6552    return None
6553
6554
6555def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
6556  """Produce a nice error if someone converts an Operation to a Tensor."""
6557  raise TypeError(("Can't convert Operation '%s' to Tensor "
6558                   "(target dtype=%r, name=%r, as_ref=%r)") %
6559                  (op.name, dtype, name, as_ref))
6560
6561
6562def _op_to_colocate_with(v, graph):
6563  """Operation object corresponding to v to use for colocation constraints."""
6564  if v is None:
6565    return None
6566  if isinstance(v, Operation):
6567    return v
6568  # We always want to colocate with the reference op.
6569  # When 'v' is a ResourceVariable, the reference op is the handle creating op.
6570  #
6571  # What this should be is:
6572  # if isinstance(v, ResourceVariable):
6573  #   return v.handle.op
6574  # However, that would require a circular import dependency.
6575  # As of October 2018, there were attempts underway to remove
6576  # colocation constraints altogether. Assuming that will
6577  # happen soon, perhaps this hack to work around the circular
6578  # import dependency is acceptable.
6579  if hasattr(v, "handle") and isinstance(v.handle, Tensor):
6580    if graph.building_function:
6581      return graph.capture(v.handle).op
6582    else:
6583      return v.handle.op
6584  return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
6585
6586
6587def _is_keras_symbolic_tensor(x):
6588  return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
6589
6590
6591tensor_conversion_registry.register_tensor_conversion_function(
6592    Operation, _operation_conversion_error)
6593
6594
6595# These symbols were originally defined in this module; import them for
6596# backwards compatibility until all references have been updated to access
6597# them from the indexed_slices.py module.
6598IndexedSlices = indexed_slices.IndexedSlices
6599IndexedSlicesValue = indexed_slices.IndexedSlicesValue
6600convert_to_tensor_or_indexed_slices = \
6601    indexed_slices.convert_to_tensor_or_indexed_slices
6602convert_n_to_tensor_or_indexed_slices = \
6603    indexed_slices.convert_n_to_tensor_or_indexed_slices
6604internal_convert_to_tensor_or_indexed_slices = \
6605    indexed_slices.internal_convert_to_tensor_or_indexed_slices
6606internal_convert_n_to_tensor_or_indexed_slices = \
6607    indexed_slices.internal_convert_n_to_tensor_or_indexed_slices
6608register_tensor_conversion_function = \
6609    tensor_conversion_registry.register_tensor_conversion_function
6610
6611
6612# Helper functions for op wrapper modules generated by `python_op_gen`.
6613
6614
6615def to_raw_op(f):
6616  """Make a given op wrapper function `f` raw.
6617
6618  Raw op wrappers are not included in the docs, and can only be called
6619  with keyword arguments.
6620
6621  Args:
6622    f: An op wrapper function to make raw.
6623
6624  Returns:
6625    Raw `f`.
6626  """
6627  # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail
6628  # due to double-registration.
6629  f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__,
6630                         f.__closure__)
6631  return kwarg_only(do_not_generate_docs(f))
6632
6633
6634def raise_from_not_ok_status(e, name):
6635  message = e.message + (" name: " + name if name is not None else "")
6636  # pylint: disable=protected-access
6637  six.raise_from(core._status_to_exception(e.code, message), None)
6638  # pylint: enable=protected-access
6639
6640
6641def add_exit_callback_to_default_func_graph(fn):
6642  """Add a callback to run when the default function graph goes out of scope.
6643
6644  Usage:
6645
6646  ```python
6647  @tf.function
6648  def fn(x, v):
6649    expensive = expensive_object(v)
6650    add_exit_callback_to_default_func_graph(lambda: expensive.release())
6651    return g(x, expensive)
6652
6653  fn(x=tf.constant(...), v=...)
6654  # `expensive` has been released.
6655  ```
6656
6657  Args:
6658    fn: A callable that takes no arguments and whose output is ignored.
6659      To be executed when exiting func graph scope.
6660
6661  Raises:
6662    RuntimeError: If executed when the current defualt graph is not a FuncGraph,
6663      or not currently executing in function creation mode (e.g., if inside
6664      an init_scope).
6665  """
6666  default_graph = get_default_graph()
6667  if not default_graph._building_function:  # pylint: disable=protected-access
6668    raise RuntimeError(
6669        "Cannot add scope exit callbacks when not building a function.  "
6670        "Default graph: {}".format(default_graph))
6671  default_graph._add_scope_exit_callback(fn)  # pylint: disable=protected-access
6672
6673
6674def _reconstruct_sequence_inputs(op_def, inputs, attrs):
6675  """Regroups a flat list of input tensors into scalar and sequence inputs.
6676
6677  Args:
6678    op_def: The `op_def_pb2.OpDef` (for knowing the input types)
6679    inputs: a list of input `Tensor`s to the op.
6680    attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
6681      how long each sequence is)
6682
6683  Returns:
6684    A list of `Tensor`s (corresponding to scalar inputs) and lists of
6685    `Tensor`s (corresponding to sequence inputs).
6686  """
6687  grouped_inputs = []
6688  i = 0
6689  for input_arg in op_def.input_arg:
6690    if input_arg.number_attr:
6691      input_len = attrs[input_arg.number_attr].i
6692      is_sequence = True
6693    elif input_arg.type_list_attr:
6694      input_len = len(attrs[input_arg.type_list_attr].list.type)
6695      is_sequence = True
6696    else:
6697      input_len = 1
6698      is_sequence = False
6699
6700    if is_sequence:
6701      grouped_inputs.append(inputs[i:i + input_len])
6702    else:
6703      grouped_inputs.append(inputs[i])
6704    i += input_len
6705
6706  assert i == len(inputs)
6707  return grouped_inputs
6708