• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import re
24import sys
25import threading
26import types
27from absl import app
28
29import numpy as np
30import six
31from six.moves import map  # pylint: disable=redefined-builtin
32from six.moves import xrange  # pylint: disable=redefined-builtin
33
34from tensorflow.core.framework import attr_value_pb2
35from tensorflow.core.framework import function_pb2
36from tensorflow.core.framework import graph_pb2
37from tensorflow.core.framework import node_def_pb2
38from tensorflow.core.framework import op_def_pb2
39from tensorflow.core.framework import versions_pb2
40from tensorflow.core.protobuf import config_pb2
41# pywrap_tensorflow must be imported first to avoid protobuf issues.
42# (b/143110113)
43# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
44from tensorflow.python import pywrap_tensorflow
45from tensorflow.python import pywrap_tfe
46# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
47from tensorflow.python import tf2
48from tensorflow.python.client import pywrap_tf_session
49from tensorflow.python.eager import context
50from tensorflow.python.eager import core
51from tensorflow.python.eager import monitoring
52from tensorflow.python.eager import tape
53from tensorflow.python.framework import c_api_util
54from tensorflow.python.framework import composite_tensor
55from tensorflow.python.framework import cpp_shape_inference_pb2
56from tensorflow.python.framework import device as pydev
57from tensorflow.python.framework import dtypes
58from tensorflow.python.framework import errors
59from tensorflow.python.framework import indexed_slices
60from tensorflow.python.framework import registry
61from tensorflow.python.framework import tensor_conversion_registry
62from tensorflow.python.framework import tensor_shape
63from tensorflow.python.framework import traceable_stack
64from tensorflow.python.framework import versions
65from tensorflow.python.ops import control_flow_util
66from tensorflow.python.platform import tf_logging as logging
67from tensorflow.python.profiler import trace
68from tensorflow.python.types import core as core_tf_types
69from tensorflow.python.types import internal
70from tensorflow.python.util import compat
71from tensorflow.python.util import decorator_utils
72from tensorflow.python.util import deprecation
73from tensorflow.python.util import dispatch
74from tensorflow.python.util import function_utils
75from tensorflow.python.util import lock_util
76from tensorflow.python.util import memory
77from tensorflow.python.util import object_identity
78from tensorflow.python.util import tf_contextlib
79from tensorflow.python.util import tf_stack
80from tensorflow.python.util import traceback_utils
81from tensorflow.python.util.compat import collections_abc
82from tensorflow.python.util.deprecation import deprecated_args
83from tensorflow.python.util.lazy_loader import LazyLoader
84from tensorflow.python.util.tf_export import kwarg_only
85from tensorflow.python.util.tf_export import tf_export
86
87ag_ctx = LazyLoader(
88    "ag_ctx", globals(),
89    "tensorflow.python.autograph.core.ag_ctx")
90
91
92# Temporary global switches determining if we should enable the work-in-progress
93# calls to the C API. These will be removed once all functionality is supported.
94_USE_C_API = True
95_USE_C_SHAPES = True
96
97_api_usage_gauge = monitoring.BoolGauge(
98    "/tensorflow/api/ops_eager_execution",
99    "Whether ops.enable_eager_execution() is called.")
100
101_tensor_equality_api_usage_gauge = monitoring.BoolGauge(
102    "/tensorflow/api/enable_tensor_equality",
103    "Whether ops.enable_tensor_equality() is called.")
104
105_control_flow_api_gauge = monitoring.BoolGauge(
106    "/tensorflow/api/enable_control_flow_v2",
107    "Whether enable_control_flow_v2() is called.")
108
109_tf_function_api_guage = monitoring.BoolGauge(
110    "/tensorflow/api/tf_function",
111    "Whether tf.function() is used.")
112
113# pylint: disable=protected-access
114_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
115# pylint: enable=protected-access
116
117
118def tensor_id(tensor):
119  """Returns a unique identifier for this Tensor."""
120  return tensor._id  # pylint: disable=protected-access
121
122
123class _UserDeviceSpec(object):
124  """Store user-specified device and provide computation of merged device."""
125
126  def __init__(self, device_name_or_function):
127    self._device_name_or_function = device_name_or_function
128    self.display_name = str(self._device_name_or_function)
129    self.function = device_name_or_function
130    self.raw_string = None
131
132    if isinstance(device_name_or_function, pydev.MergeDevice):
133      self.is_null_merge = device_name_or_function.is_null_merge
134
135    elif callable(device_name_or_function):
136      self.is_null_merge = False
137      dev_func = self._device_name_or_function
138      func_name = function_utils.get_func_name(dev_func)
139      func_code = function_utils.get_func_code(dev_func)
140      if func_code:
141        fname = func_code.co_filename
142        lineno = func_code.co_firstlineno
143      else:
144        fname = "unknown"
145        lineno = -1
146      self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
147
148    elif device_name_or_function is None:
149      # NOTE(taylorrobie): This MUST be False. None signals a break in the
150      #   device stack, so `is_null_merge` must be False for such a case to
151      #   allow callers to safely skip over null merges without missing a None.
152      self.is_null_merge = False
153
154    else:
155      self.raw_string = device_name_or_function
156      self.function = pydev.merge_device(device_name_or_function)
157      self.is_null_merge = self.function.is_null_merge
158
159    # We perform this check in __init__ because it is of non-trivial cost,
160    # and self.string_merge is typically called many times.
161    self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
162
163  def string_merge(self, node_def):
164    if self.fast_string_merge:
165      return self.function.shortcut_string_merge(node_def)
166
167    return compat.as_str(_device_string(self.function(node_def)))
168
169
170class NullContextmanager(object):
171
172  def __init__(self, *args, **kwargs):
173    pass
174
175  def __enter__(self):
176    pass
177
178  def __exit__(self, type_arg, value_arg, traceback_arg):
179    return False  # False values do not suppress exceptions
180
181
182def _override_helper(clazz_object, operator, func):
183  """Overrides (string) operator on Tensors to call func.
184
185  Args:
186    clazz_object: the class to override for; either Tensor or SparseTensor.
187    operator: the string name of the operator to override.
188    func: the function that replaces the overridden operator.
189
190  Raises:
191    ValueError: If operator is not allowed to be overwritten.
192  """
193  if operator not in Tensor.OVERLOADABLE_OPERATORS:
194    raise ValueError("Overriding %s is disallowed" % operator)
195  setattr(clazz_object, operator, func)
196
197
198def _as_graph_element(obj):
199  """Convert `obj` to a graph element if possible, otherwise return `None`.
200
201  Args:
202    obj: Object to convert.
203
204  Returns:
205    The result of `obj._as_graph_element()` if that method is available;
206        otherwise `None`.
207  """
208  conv_fn = getattr(obj, "_as_graph_element", None)
209  if conv_fn and callable(conv_fn):
210    return conv_fn()
211  return None
212
213
214# Deprecated - do not use.
215# This API to avoid breaking estimator and tensorflow-mesh which depend on this
216# internal API. The stub should be safe to use after TF 2.3 is released.
217def is_dense_tensor_like(t):
218  return isinstance(t, core_tf_types.Tensor)
219
220
221def uid():
222  """A unique (within this program execution) integer."""
223  return pywrap_tfe.TFE_Py_UID()
224
225
226def numpy_text(tensor, is_repr=False):
227  """Human readable representation of a tensor's numpy value."""
228  if tensor.dtype.is_numpy_compatible:
229    # pylint: disable=protected-access
230    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
231    # pylint: enable=protected-access
232  else:
233    text = "<unprintable>"
234  if "\n" in text:
235    text = "\n" + text
236  return text
237
238
239@tf_export(v1=["enable_tensor_equality"])
240def enable_tensor_equality():
241  """Compare Tensors with element-wise comparison and thus be unhashable.
242
243  Comparing tensors with element-wise allows comparisons such as
244  tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are
245  unhashable. Thus tensors can no longer be directly used in sets or as a key in
246  a dictionary.
247  """
248  logging.vlog(1, "Enabling tensor equality")
249  _tensor_equality_api_usage_gauge.get_cell().set(True)
250  Tensor._USE_EQUALITY = True  # pylint: disable=protected-access
251
252
253@tf_export(v1=["disable_tensor_equality"])
254def disable_tensor_equality():
255  """Compare Tensors by their id and be hashable.
256
257  This is a legacy behaviour of TensorFlow and is highly discouraged.
258  """
259  logging.vlog(1, "Disabling tensor equality")
260  _tensor_equality_api_usage_gauge.get_cell().set(False)
261  Tensor._USE_EQUALITY = False  # pylint: disable=protected-access
262
263
264# TODO(mdan): This object should subclass Symbol, not just Tensor.
265@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
266class Tensor(internal.NativeObject, core_tf_types.Tensor):
267  """A `tf.Tensor` represents a multidimensional array of elements.
268
269  All elements are of a single known data type.
270
271  When writing a TensorFlow program, the main object that is
272  manipulated and passed around is the `tf.Tensor`.
273
274  A `tf.Tensor` has the following properties:
275
276  * a single data type (float32, int32, or string, for example)
277  * a shape
278
279  TensorFlow supports eager execution and graph execution.  In eager
280  execution, operations are evaluated immediately.  In graph
281  execution, a computational graph is constructed for later
282  evaluation.
283
284  TensorFlow defaults to eager execution.  In the example below, the
285  matrix multiplication results are calculated immediately.
286
287  >>> # Compute some values using a Tensor
288  >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
289  >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
290  >>> e = tf.matmul(c, d)
291  >>> print(e)
292  tf.Tensor(
293  [[1. 3.]
294   [3. 7.]], shape=(2, 2), dtype=float32)
295
296  Note that during eager execution, you may discover your `Tensors` are actually
297  of type `EagerTensor`.  This is an internal detail, but it does give you
298  access to a useful function, `numpy`:
299
300  >>> type(e)
301  <class '...ops.EagerTensor'>
302  >>> print(e.numpy())
303    [[1. 3.]
304     [3. 7.]]
305
306  In TensorFlow, `tf.function`s are a common way to define graph execution.
307
308  A Tensor's shape (that is, the rank of the Tensor and the size of
309  each dimension) may not always be fully known.  In `tf.function`
310  definitions, the shape may only be partially known.
311
312  Most operations produce tensors of fully-known shapes if the shapes of their
313  inputs are also fully known, but in some cases it's only possible to find the
314  shape of a tensor at execution time.
315
316  A number of specialized tensors are available: see `tf.Variable`,
317  `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and
318  `tf.RaggedTensor`.
319
320  For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor).
321
322  """
323
324  # List of Python operators that we allow to override.
325  OVERLOADABLE_OPERATORS = {
326      # Binary.
327      "__add__",
328      "__radd__",
329      "__sub__",
330      "__rsub__",
331      "__mul__",
332      "__rmul__",
333      "__div__",
334      "__rdiv__",
335      "__truediv__",
336      "__rtruediv__",
337      "__floordiv__",
338      "__rfloordiv__",
339      "__mod__",
340      "__rmod__",
341      "__lt__",
342      "__le__",
343      "__gt__",
344      "__ge__",
345      "__ne__",
346      "__eq__",
347      "__and__",
348      "__rand__",
349      "__or__",
350      "__ror__",
351      "__xor__",
352      "__rxor__",
353      "__getitem__",
354      "__pow__",
355      "__rpow__",
356      # Unary.
357      "__invert__",
358      "__neg__",
359      "__abs__",
360      "__matmul__",
361      "__rmatmul__"
362  }
363
364  # Whether to allow hashing or numpy-style equality
365  _USE_EQUALITY = tf2.enabled()
366
367  def __init__(self, op, value_index, dtype):
368    """Creates a new `Tensor`.
369
370    Args:
371      op: An `Operation`. `Operation` that computes this tensor.
372      value_index: An `int`. Index of the operation's endpoint that produces
373        this tensor.
374      dtype: A `DType`. Type of elements stored in this tensor.
375
376    Raises:
377      TypeError: If the op is not an `Operation`.
378    """
379    if not isinstance(op, Operation):
380      raise TypeError("op needs to be an Operation: %s" % (op,))
381    self._op = op
382    self._value_index = value_index
383    self._dtype = dtypes.as_dtype(dtype)
384    # This will be set by self._as_tf_output().
385    self._tf_output = None
386    # This will be set by self.shape().
387    self._shape_val = None
388    # List of operations that use this Tensor as input.  We maintain this list
389    # to easily navigate a computation graph.
390    self._consumers = []
391    self._id = uid()
392    self._name = None
393
394  def __getattr__(self, name):
395    if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
396                "tolist", "data"}:
397      # TODO(wangpeng): Export the enable_numpy_behavior knob
398      raise AttributeError("""
399        '{}' object has no attribute '{}'.
400        If you are looking for numpy-related methods, please run the following:
401        from tensorflow.python.ops.numpy_ops import np_config
402        np_config.enable_numpy_behavior()""".format(type(self).__name__, name))
403    self.__getattribute__(name)
404
405  @staticmethod
406  def _create_with_tf_output(op, value_index, dtype, tf_output):
407    ret = Tensor(op, value_index, dtype)
408    ret._tf_output = tf_output
409    return ret
410
411  @property
412  def op(self):
413    """The `Operation` that produces this tensor as an output."""
414    return self._op
415
416  @property
417  def dtype(self):
418    """The `DType` of elements in this tensor."""
419    return self._dtype
420
421  @property
422  def graph(self):
423    """The `Graph` that contains this tensor."""
424    return self._op.graph
425
426  @property
427  def name(self):
428    """The string name of this tensor."""
429    if self._name is None:
430      if not self._op.name:
431        raise ValueError("Operation was not named: %s" % self._op)
432      self._name = "%s:%d" % (self._op.name, self._value_index)
433    return self._name
434
435  @property
436  def device(self):
437    """The name of the device on which this tensor will be produced, or None."""
438    return self._op.device
439
440  @property
441  def shape(self):
442    """Returns a `tf.TensorShape` that represents the shape of this tensor.
443
444    >>> t = tf.constant([1,2,3,4,5])
445    >>> t.shape
446    TensorShape([5])
447
448    `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`.
449
450    In a `tf.function` or when building a model using
451    `tf.keras.Input`, they return the build-time shape of the
452    tensor, which may be partially unknown.
453
454    A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor
455    containing the shape, calculated at runtime.
456
457    See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples.
458    """
459    if self._shape_val is None:
460      self._shape_val = self._c_api_shape()
461    return self._shape_val
462
463  def _c_api_shape(self):
464    """Returns the TensorShape of this tensor according to the C API."""
465    c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
466    shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
467        c_graph, self._as_tf_output())
468    if unknown_shape:
469      return tensor_shape.unknown_shape()
470    else:
471      shape_vec = [None if d == -1 else d for d in shape_vec]
472      return tensor_shape.TensorShape(shape_vec)
473
474  @property
475  def _shape(self):
476    logging.warning("Tensor._shape is private, use Tensor.shape "
477                    "instead. Tensor._shape will eventually be removed.")
478    return self.shape
479
480  @_shape.setter
481  def _shape(self, value):
482    raise ValueError(
483        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
484
485  def _disallow_when_autograph_disabled(self, task):
486    raise errors.OperatorNotAllowedInGraphError(
487        "{} is not allowed: AutoGraph is disabled in this function."
488        " Try decorating it directly with @tf.function.".format(task))
489
490  def _disallow_when_autograph_enabled(self, task):
491    raise errors.OperatorNotAllowedInGraphError(
492        "{} is not allowed: AutoGraph did convert this function. This might"
493        " indicate you are trying to use an unsupported feature.".format(task))
494
495  def _disallow_in_graph_mode(self, task):
496    raise errors.OperatorNotAllowedInGraphError(
497        "{} is not allowed in Graph execution. Use Eager execution or decorate"
498        " this function with @tf.function.".format(task))
499
500  def _disallow_bool_casting(self):
501    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
502      self._disallow_when_autograph_disabled(
503          "using a `tf.Tensor` as a Python `bool`")
504    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
505      self._disallow_when_autograph_enabled(
506          "using a `tf.Tensor` as a Python `bool`")
507    else:
508      # Default: V1-style Graph execution.
509      self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
510
511  def _disallow_iteration(self):
512    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
513      self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
514    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
515      self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
516    else:
517      # Default: V1-style Graph execution.
518      self._disallow_in_graph_mode("iterating over `tf.Tensor`")
519
520  def __iter__(self):
521    if not context.executing_eagerly():
522      self._disallow_iteration()
523
524    shape = self._shape_tuple()
525    if shape is None:
526      raise TypeError("Cannot iterate over a tensor with unknown shape.")
527    if not shape:
528      raise TypeError("Cannot iterate over a scalar tensor.")
529    if shape[0] is None:
530      raise TypeError(
531          "Cannot iterate over a tensor with unknown first dimension.")
532    return _TensorIterator(self, shape[0])
533
534  def _shape_as_list(self):
535    if self.shape.ndims is not None:
536      return [dim.value for dim in self.shape.dims]
537    else:
538      return None
539
540  def _shape_tuple(self):
541    shape = self._shape_as_list()
542    if shape is None:
543      return None
544    return tuple(shape)
545
546  def _rank(self):
547    """Integer rank of this Tensor, if known, else None.
548
549    Returns:
550      Integer rank or None
551    """
552    return self.shape.ndims
553
554  def get_shape(self):
555    """Returns a `tf.TensorShape` that represents the shape of this tensor.
556
557    In eager execution the shape is always fully-known.
558
559    >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
560    >>> print(a.shape)
561    (2, 3)
562
563    `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`.
564
565
566    When executing in a `tf.function` or building a model using
567    `tf.keras.Input`, `Tensor.shape` may return a partial shape (including
568    `None` for unknown dimensions). See `tf.TensorShape` for more details.
569
570    >>> inputs = tf.keras.Input(shape = [10])
571    >>> # Unknown batch size
572    >>> print(inputs.shape)
573    (None, 10)
574
575    The shape is computed using shape inference functions that are
576    registered for each `tf.Operation`.
577
578    The returned `tf.TensorShape` is determined at *build* time, without
579    executing the underlying kernel. It is not a `tf.Tensor`. If you need a
580    shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or
581    use the `tf.shape(tensor)` function, which returns the tensor's shape at
582    *execution* time.
583
584    This is useful for debugging and providing early errors. For
585    example, when tracing a `tf.function`, no ops are being executed, shapes
586    may be unknown (See the [Concrete Functions
587    Guide](https://www.tensorflow.org/guide/concrete_function) for details).
588
589    >>> @tf.function
590    ... def my_matmul(a, b):
591    ...   result = a@b
592    ...   # the `print` executes during tracing.
593    ...   print("Result shape: ", result.shape)
594    ...   return result
595
596    The shape inference functions propagate shapes to the extent possible:
597
598    >>> f = my_matmul.get_concrete_function(
599    ...   tf.TensorSpec([None,3]),
600    ...   tf.TensorSpec([3,5]))
601    Result shape: (None, 5)
602
603    Tracing may fail if a shape missmatch can be detected:
604
605    >>> cf = my_matmul.get_concrete_function(
606    ...   tf.TensorSpec([None,3]),
607    ...   tf.TensorSpec([4,5]))
608    Traceback (most recent call last):
609    ...
610    ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
611    'MatMul') with input shapes: [?,3], [4,5].
612
613    In some cases, the inferred shape may have unknown dimensions. If
614    the caller has additional information about the values of these
615    dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment
616    the inferred shape.
617
618    >>> @tf.function
619    ... def my_fun(a):
620    ...   a = tf.ensure_shape(a, [5, 5])
621    ...   # the `print` executes during tracing.
622    ...   print("Result shape: ", a.shape)
623    ...   return a
624
625    >>> cf = my_fun.get_concrete_function(
626    ...   tf.TensorSpec([None, None]))
627    Result shape: (5, 5)
628
629    Returns:
630      A `tf.TensorShape` representing the shape of this tensor.
631
632    """
633    return self.shape
634
635  def set_shape(self, shape):
636    """Updates the shape of this tensor.
637
638    Note: It is recommended to use `tf.ensure_shape` instead of
639    `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for
640    programming errors and can create guarantees for compiler
641    optimization.
642
643    With eager execution this operates as a shape assertion.
644    Here the shapes match:
645
646    >>> t = tf.constant([[1,2,3]])
647    >>> t.set_shape([1, 3])
648
649    Passing a `None` in the new shape allows any value for that axis:
650
651    >>> t.set_shape([1,None])
652
653    An error is raised if an incompatible shape is passed.
654
655    >>> t.set_shape([1,5])
656    Traceback (most recent call last):
657    ...
658    ValueError: Tensor's shape (1, 3) is not compatible with supplied
659    shape [1, 5]
660
661    When executing in a `tf.function`, or building a model using
662    `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with
663    the current shape of this tensor, and set the tensor's shape to the
664    merged value (see `tf.TensorShape.merge_with` for details):
665
666    >>> t = tf.keras.Input(shape=[None, None, 3])
667    >>> print(t.shape)
668    (None, None, None, 3)
669
670    Dimensions set to `None` are not updated:
671
672    >>> t.set_shape([None, 224, 224, None])
673    >>> print(t.shape)
674    (None, 224, 224, 3)
675
676    The main use case for this is to provide additional shape information
677    that cannot be inferred from the graph alone.
678
679    For example if you know all the images in a dataset have shape [28,28,3] you
680    can set it with `tf.set_shape`:
681
682    >>> @tf.function
683    ... def load_image(filename):
684    ...   raw = tf.io.read_file(filename)
685    ...   image = tf.image.decode_png(raw, channels=3)
686    ...   # the `print` executes during tracing.
687    ...   print("Initial shape: ", image.shape)
688    ...   image.set_shape([28, 28, 3])
689    ...   print("Final shape: ", image.shape)
690    ...   return image
691
692    Trace the function, see the [Concrete Functions
693    Guide](https://www.tensorflow.org/guide/concrete_function) for details.
694
695    >>> cf = load_image.get_concrete_function(
696    ...     tf.TensorSpec([], dtype=tf.string))
697    Initial shape:  (None, None, 3)
698    Final shape: (28, 28, 3)
699
700    Similarly the `tf.io.parse_tensor` function could return a tensor with
701    any shape, even the `tf.rank` is unknown. If you know that all your
702    serialized tensors will be 2d, set it with `set_shape`:
703
704    >>> @tf.function
705    ... def my_parse(string_tensor):
706    ...   result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
707    ...   # the `print` executes during tracing.
708    ...   print("Initial shape: ", result.shape)
709    ...   result.set_shape([None, None])
710    ...   print("Final shape: ", result.shape)
711    ...   return result
712
713    Trace the function
714
715    >>> concrete_parse = my_parse.get_concrete_function(
716    ...     tf.TensorSpec([], dtype=tf.string))
717    Initial shape:  <unknown>
718    Final shape:  (None, None)
719
720    Make sure it works:
721
722    >>> t = tf.ones([5,3], dtype=tf.float32)
723    >>> serialized = tf.io.serialize_tensor(t)
724    >>> print(serialized.dtype)
725    <dtype: 'string'>
726    >>> print(serialized.shape)
727    ()
728    >>> t2 = concrete_parse(serialized)
729    >>> print(t2.shape)
730    (5, 3)
731
732    Caution: `set_shape` ensures that the applied shape is compatible with
733    the existing shape, but it does not check at runtime. Setting
734    incorrect shapes can result in inconsistencies between the
735    statically-known graph and the runtime value of tensors. For runtime
736    validation of the shape, use `tf.ensure_shape` instead. It also modifies
737    the `shape` of the tensor.
738
739    >>> # Serialize a rank-3 tensor
740    >>> t = tf.ones([5,5,5], dtype=tf.float32)
741    >>> serialized = tf.io.serialize_tensor(t)
742    >>> # The function still runs, even though it `set_shape([None,None])`
743    >>> t2 = concrete_parse(serialized)
744    >>> print(t2.shape)
745    (5, 5, 5)
746
747    Args:
748      shape: A `TensorShape` representing the shape of this tensor, a
749        `TensorShapeProto`, a list, a tuple, or None.
750
751    Raises:
752      ValueError: If `shape` is not compatible with the current shape of
753        this tensor.
754    """
755    # Reset cached shape.
756    self._shape_val = None
757
758    # We want set_shape to be reflected in the C API graph for when we run it.
759    if not isinstance(shape, tensor_shape.TensorShape):
760      shape = tensor_shape.TensorShape(shape)
761    dim_list = []
762    if shape.dims is None:
763      unknown_shape = True
764    else:
765      unknown_shape = False
766      for dim in shape.dims:
767        if dim.value is None:
768          dim_list.append(-1)
769        else:
770          dim_list.append(dim.value)
771    try:
772      pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
773          self._op._graph._c_graph,  # pylint: disable=protected-access
774          self._as_tf_output(),
775          dim_list,
776          unknown_shape)
777    except errors.InvalidArgumentError as e:
778      # Convert to ValueError for backwards compatibility.
779      raise ValueError(e.message)
780
781  @property
782  def value_index(self):
783    """The index of this tensor in the outputs of its `Operation`."""
784    return self._value_index
785
786  def consumers(self):
787    """Returns a list of `Operation`s that consume this tensor.
788
789    Returns:
790      A list of `Operation`s.
791    """
792    consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
793        self._as_tf_output())
794    # pylint: disable=protected-access
795    return [
796        self.graph._get_operation_by_name_unsafe(name)
797        for name in consumer_names
798    ]
799    # pylint: enable=protected-access
800
801  def _as_node_def_input(self):
802    """Return a value to use for the NodeDef "input" attribute.
803
804    The returned string can be used in a NodeDef "input" attribute
805    to indicate that the NodeDef uses this Tensor as input.
806
807    Raises:
808      ValueError: if this Tensor's Operation does not have a name.
809
810    Returns:
811      a string.
812    """
813    if not self._op.name:
814      raise ValueError("Operation was not named: %s" % self._op)
815    if self._value_index == 0:
816      return self._op.name
817    else:
818      return "%s:%d" % (self._op.name, self._value_index)
819
820  def _as_tf_output(self):
821    # pylint: disable=protected-access
822    # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
823    # also guarantees that a dictionary of tf_output objects will retain a
824    # deterministic (yet unsorted) order which prevents memory blowup in the
825    # cache of executor(s) stored for every session.
826    if self._tf_output is None:
827      self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
828    return self._tf_output
829    # pylint: enable=protected-access
830
831  def __str__(self):
832    return "Tensor(\"%s\"%s%s%s)" % (
833        self.name,
834        (", shape=%s" %
835         self.get_shape()) if self.get_shape().ndims is not None else "",
836        (", dtype=%s" % self._dtype.name) if self._dtype else "",
837        (", device=%s" % self.device) if self.device else "")
838
839  def __repr__(self):
840    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
841                                                   self._dtype.name)
842
843  def __hash__(self):
844    g = getattr(self, "graph", None)
845    if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
846        (g is None or g.building_function)):
847      raise TypeError("Tensor is unhashable. "
848                      "Instead, use tensor.ref() as the key.")
849    else:
850      return id(self)
851
852  def __copy__(self):
853    # TODO(b/77597810): get rid of Tensor copies.
854    cls = self.__class__
855    result = cls.__new__(cls)
856    result.__dict__.update(self.__dict__)
857    return result
858
859  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
860  # operators to run when the left operand is an ndarray, because it
861  # accords the Tensor class higher priority than an ndarray, or a
862  # numpy matrix.
863  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
864  # mechanism, which allows more control over how Tensors interact
865  # with ndarrays.
866  __array_priority__ = 100
867
868  def __array__(self):
869    raise NotImplementedError(
870        "Cannot convert a symbolic Tensor ({}) to a numpy array."
871        " This error may indicate that you're trying to pass a Tensor to"
872        " a NumPy call, which is not supported".format(self.name))
873
874  def __len__(self):
875    raise TypeError("len is not well defined for symbolic Tensors. ({}) "
876                    "Please call `x.shape` rather than `len(x)` for "
877                    "shape information.".format(self.name))
878
879  # TODO(mdan): This convoluted machinery is hard to maintain. Clean up.
880  @staticmethod
881  def _override_operator(operator, func):
882    _override_helper(Tensor, operator, func)
883
884  def __bool__(self):
885    """Dummy method to prevent a tensor from being used as a Python `bool`.
886
887    This overload raises a `TypeError` when the user inadvertently
888    treats a `Tensor` as a boolean (most commonly in an `if` or `while`
889    statement), in code that was not converted by AutoGraph. For example:
890
891    ```python
892    if tf.constant(True):  # Will raise.
893      # ...
894
895    if tf.constant(5) < tf.constant(7):  # Will raise.
896      # ...
897    ```
898
899    Raises:
900      `TypeError`.
901    """
902    self._disallow_bool_casting()
903
904  def __nonzero__(self):
905    """Dummy method to prevent a tensor from being used as a Python `bool`.
906
907    This is the Python 2.x counterpart to `__bool__()` above.
908
909    Raises:
910      `TypeError`.
911    """
912    self._disallow_bool_casting()
913
914  def eval(self, feed_dict=None, session=None):
915    """Evaluates this tensor in a `Session`.
916
917    Note: If you are not using `compat.v1` libraries, you should not need this,
918    (or `feed_dict` or `Session`).  In eager execution (or within `tf.function`)
919    you do not need to call `eval`.
920
921    Calling this method will execute all preceding operations that
922    produce the inputs needed for the operation that produces this
923    tensor.
924
925    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
926    launched in a session, and either a default session must be
927    available, or `session` must be specified explicitly.
928
929    Args:
930      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
931        `tf.Session.run` for a description of the valid feed values.
932      session: (Optional.) The `Session` to be used to evaluate this tensor. If
933        none, the default session will be used.
934
935    Returns:
936      A numpy array corresponding to the value of this tensor.
937    """
938    return _eval_using_default_session(self, feed_dict, self.graph, session)
939
940  @deprecation.deprecated(None, "Use ref() instead.")
941  def experimental_ref(self):
942    return self.ref()
943
944  def ref(self):
945    # tf.Variable also has the same ref() API.  If you update the
946    # documentation here, please update tf.Variable.ref() as well.
947    """Returns a hashable reference object to this Tensor.
948
949    The primary use case for this API is to put tensors in a set/dictionary.
950    We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
951    available starting Tensorflow 2.0.
952
953    The following will raise an exception starting 2.0
954
955    >>> x = tf.constant(5)
956    >>> y = tf.constant(10)
957    >>> z = tf.constant(10)
958    >>> tensor_set = {x, y, z}
959    Traceback (most recent call last):
960      ...
961    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
962    >>> tensor_dict = {x: 'five', y: 'ten'}
963    Traceback (most recent call last):
964      ...
965    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
966
967    Instead, we can use `tensor.ref()`.
968
969    >>> tensor_set = {x.ref(), y.ref(), z.ref()}
970    >>> x.ref() in tensor_set
971    True
972    >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
973    >>> tensor_dict[y.ref()]
974    'ten'
975
976    Also, the reference object provides `.deref()` function that returns the
977    original Tensor.
978
979    >>> x = tf.constant(5)
980    >>> x.ref().deref()
981    <tf.Tensor: shape=(), dtype=int32, numpy=5>
982    """
983    return object_identity.Reference(self)
984
985
986# TODO(agarwal): consider getting rid of this.
987# TODO(mdan): This object should not subclass ops.Tensor.
988class _EagerTensorBase(Tensor):
989  """Base class for EagerTensor."""
990
991  # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and
992  # only work for scalars; values are cast as per numpy.
993  def __complex__(self):
994    return complex(self._numpy())
995
996  def __int__(self):
997    return int(self._numpy())
998
999  def __long__(self):
1000    return long(self._numpy())
1001
1002  def __float__(self):
1003    return float(self._numpy())
1004
1005  def __index__(self):
1006    return self._numpy().__index__()
1007
1008  def __bool__(self):
1009    return bool(self._numpy())
1010
1011  __nonzero__ = __bool__
1012
1013  def __format__(self, format_spec):
1014    if self._prefer_custom_summarizer():
1015      return self._summarize_value().__format__(format_spec)
1016    elif self.dtype.is_numpy_compatible:
1017      # Not numpy_text here, otherwise the __format__ behaves differently.
1018      return self._numpy().__format__(format_spec)
1019    else:
1020      return "<unprintable>".__format__(format_spec)
1021
1022  def __reduce__(self):
1023    return convert_to_tensor, (self._numpy(),)
1024
1025  def __copy__(self):
1026    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1027    return self
1028
1029  def __deepcopy__(self, memo):
1030    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1031    del memo
1032    return self
1033
1034  def __str__(self):
1035    if self._prefer_custom_summarizer():
1036      value_text = self._summarize_value()
1037    else:
1038      value_text = numpy_text(self)
1039    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (value_text, self.shape,
1040                                                  self.dtype.name)
1041
1042  def __repr__(self):
1043    if self._prefer_custom_summarizer():
1044      value_text = "value=" + self._summarize_value()
1045    else:
1046      value_text = "numpy=" + numpy_text(self, is_repr=True)
1047    return "<tf.Tensor: shape=%s, dtype=%s, %s>" % (self.shape, self.dtype.name,
1048                                                    value_text)
1049
1050  def __len__(self):
1051    """Returns the length of the first dimension in the Tensor."""
1052    if not self.shape.ndims:
1053      raise TypeError("Scalar tensor has no `len()`")
1054    # pylint: disable=protected-access
1055    try:
1056      return self._shape_tuple()[0]
1057    except core._NotOkStatusException as e:
1058      raise core._status_to_exception(e.code, e.message) from None
1059
1060  def __array__(self):
1061    return self._numpy()
1062
1063  def _numpy_internal(self):
1064    raise NotImplementedError()
1065
1066  def _numpy(self):
1067    try:
1068      return self._numpy_internal()
1069    except core._NotOkStatusException as e:  # pylint: disable=protected-access
1070      raise core._status_to_exception(e.code, e.message) from None  # pylint: disable=protected-access
1071
1072  @property
1073  def dtype(self):
1074    # Note: using the intern table directly here as this is
1075    # performance-sensitive in some models.
1076    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
1077
1078  def numpy(self):
1079    """Copy of the contents of this Tensor into a NumPy array or scalar.
1080
1081    Unlike NumPy arrays, Tensors are immutable, so this method has to copy
1082    the contents to ensure safety. Use `memoryview` to get a readonly
1083    view of the contents without doing a copy:
1084
1085    >>> t = tf.constant([42])
1086    >>> np.array(memoryview(t))
1087    array([42], dtype=int32)
1088
1089    Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor
1090    is on GPU, it will have to be transferred to CPU first in order for
1091    `memoryview` to work.
1092
1093    Returns:
1094      A NumPy array of the same shape and dtype or a NumPy scalar, if this
1095      Tensor has rank 0.
1096
1097    Raises:
1098      ValueError: If the dtype of this Tensor does not have a compatible
1099        NumPy dtype.
1100    """
1101    # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
1102    maybe_arr = self._numpy()  # pylint: disable=protected-access
1103    return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1104
1105  @property
1106  def backing_device(self):
1107    """Returns the name of the device holding this tensor's memory.
1108
1109    `.backing_device` is usually the same as `.device`, which returns
1110    the device on which the kernel of the operation that produced this tensor
1111    ran. However, some operations can produce tensors on a different device
1112    (e.g., an operation that executes on the GPU but produces output tensors
1113    in host memory).
1114    """
1115    raise NotImplementedError()
1116
1117  def _datatype_enum(self):
1118    raise NotImplementedError()
1119
1120  def _shape_tuple(self):
1121    """The shape of this Tensor, as a tuple.
1122
1123    This is more performant than tuple(shape().as_list()) as it avoids
1124    two list and one object creation. Marked private for now as from an API
1125    perspective, it would be better to have a single performant way of
1126    getting a shape rather than exposing shape() and shape_tuple()
1127    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
1128    but ideally one would work things out and remove the need for this method.
1129
1130    Returns:
1131      tuple with the shape.
1132    """
1133    raise NotImplementedError()
1134
1135  def _rank(self):
1136    """Integer rank of this Tensor.
1137
1138    Unlike regular Tensors, the rank is always known for EagerTensors.
1139
1140    This is more performant than len(self._shape_tuple())
1141
1142    Returns:
1143      Integer rank
1144    """
1145    raise NotImplementedError()
1146
1147  def _num_elements(self):
1148    """Number of elements of this Tensor.
1149
1150    Unlike regular Tensors, the number of elements is always known for
1151    EagerTensors.
1152
1153    This is more performant than tensor.shape.num_elements
1154
1155    Returns:
1156      Long - num elements in the tensor
1157    """
1158    raise NotImplementedError()
1159
1160  def _copy_to_device(self, device_name):  # pylint: disable=redefined-outer-name
1161    raise NotImplementedError()
1162
1163  @staticmethod
1164  def _override_operator(name, func):
1165    setattr(_EagerTensorBase, name, func)
1166
1167  def _copy_nograd(self, ctx=None, device_name=None):
1168    """Copies tensor to dest device, but doesn't record the operation."""
1169    # Creates a new tensor on the dest device.
1170    if ctx is None:
1171      ctx = context.context()
1172    if device_name is None:
1173      device_name = ctx.device_name
1174    # pylint: disable=protected-access
1175    try:
1176      ctx.ensure_initialized()
1177      new_tensor = self._copy_to_device(device_name)
1178    except core._NotOkStatusException as e:
1179      raise core._status_to_exception(e.code, e.message) from None
1180    return new_tensor
1181
1182  def _copy(self, ctx=None, device_name=None):
1183    """Copies tensor to dest device."""
1184    new_tensor = self._copy_nograd(ctx, device_name)
1185    # Record the copy on tape and define backprop copy as well.
1186    if context.executing_eagerly():
1187      self_device = self.device
1188
1189      def grad_fun(dresult):
1190        return [
1191            dresult._copy(device_name=self_device)
1192            if hasattr(dresult, "_copy") else dresult
1193        ]
1194
1195      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
1196    return new_tensor
1197    # pylint: enable=protected-access
1198
1199  @property
1200  def shape(self):
1201    if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
1202      # pylint: disable=protected-access
1203      try:
1204        # `_tensor_shape` is declared and defined in the definition of
1205        # `EagerTensor`, in C.
1206        self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
1207      except core._NotOkStatusException as e:
1208        raise core._status_to_exception(e.code, e.message) from None
1209
1210    return self._tensor_shape
1211
1212  def get_shape(self):
1213    """Alias of Tensor.shape."""
1214    return self.shape
1215
1216  def _shape_as_list(self):
1217    """The shape of the tensor as a list."""
1218    return list(self._shape_tuple())
1219
1220  @property
1221  def ndim(self):
1222    """Returns the number of Tensor dimensions."""
1223    return self.shape.ndims
1224
1225  @deprecation.deprecated(None, "Use tf.identity instead.")
1226  def cpu(self):
1227    """A copy of this Tensor with contents backed by host memory."""
1228    return self._copy(context.context(), "CPU:0")
1229
1230  @deprecation.deprecated(None, "Use tf.identity instead.")
1231  def gpu(self, gpu_index=0):
1232    """A copy of this Tensor with contents backed by memory on the GPU.
1233
1234    Args:
1235      gpu_index: Identifies which GPU to place the contents on the returned
1236        Tensor in.
1237
1238    Returns:
1239      A GPU-memory backed Tensor object initialized with the same contents
1240      as this Tensor.
1241    """
1242    return self._copy(context.context(), "GPU:" + str(gpu_index))
1243
1244  def set_shape(self, shape):
1245    if not self.shape.is_compatible_with(shape):
1246      raise ValueError(
1247          "Tensor's shape %s is not compatible with supplied shape %s" %
1248          (self.shape, shape))
1249
1250  # Methods not supported / implemented for Eager Tensors.
1251  @property
1252  def op(self):
1253    raise AttributeError(
1254        "Tensor.op is meaningless when eager execution is enabled.")
1255
1256  @property
1257  def graph(self):
1258    raise AttributeError(
1259        "Tensor.graph is meaningless when eager execution is enabled.")
1260
1261  @property
1262  def name(self):
1263    raise AttributeError(
1264        "Tensor.name is meaningless when eager execution is enabled.")
1265
1266  @property
1267  def value_index(self):
1268    raise AttributeError(
1269        "Tensor.value_index is meaningless when eager execution is enabled.")
1270
1271  def consumers(self):
1272    raise NotImplementedError(
1273        "Tensor.consumers is meaningless when eager execution is enabled.")
1274
1275  def _add_consumer(self, consumer):
1276    raise NotImplementedError(
1277        "_add_consumer not supported when eager execution is enabled.")
1278
1279  def _as_node_def_input(self):
1280    raise NotImplementedError(
1281        "_as_node_def_input not supported when eager execution is enabled.")
1282
1283  def _as_tf_output(self):
1284    raise NotImplementedError(
1285        "_as_tf_output not supported when eager execution is enabled.")
1286
1287  def eval(self, feed_dict=None, session=None):
1288    raise NotImplementedError(
1289        "eval is not supported when eager execution is enabled, "
1290        "is .numpy() what you're looking for?")
1291
1292
1293# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
1294# registers it with the current module.
1295# It is exposed as an __internal__ api for now (b/171081052), though we
1296# expect it to be eventually covered by tf Tensor types and typing.
1297EagerTensor = tf_export("__internal__.EagerTensor", v1=[])(
1298    pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase))
1299
1300
1301@tf_export(v1=["convert_to_tensor"])
1302@dispatch.add_dispatch_support
1303def convert_to_tensor_v1_with_dispatch(
1304    value,
1305    dtype=None,
1306    name=None,
1307    preferred_dtype=None,
1308    dtype_hint=None):
1309  """Converts the given `value` to a `Tensor`.
1310
1311  This function converts Python objects of various types to `Tensor`
1312  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1313  and Python scalars. For example:
1314
1315  ```python
1316  import numpy as np
1317
1318  def my_func(arg):
1319    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1320    return tf.matmul(arg, arg) + arg
1321
1322  # The following calls are equivalent.
1323  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1324  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1325  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1326  ```
1327
1328  This function can be useful when composing a new operation in Python
1329  (such as `my_func` in the example above). All standard Python op
1330  constructors apply this function to each of their Tensor-valued
1331  inputs, which allows those ops to accept numpy arrays, Python lists,
1332  and scalars in addition to `Tensor` objects.
1333
1334  Note: This function diverges from default Numpy behavior for `float` and
1335    `string` types when `None` is present in a Python list or scalar. Rather
1336    than silently converting `None` values, an error will be thrown.
1337
1338  Args:
1339    value: An object whose type has a registered `Tensor` conversion function.
1340    dtype: Optional element type for the returned tensor. If missing, the type
1341      is inferred from the type of `value`.
1342    name: Optional name to use if a new `Tensor` is created.
1343    preferred_dtype: Optional element type for the returned tensor, used when
1344      dtype is None. In some cases, a caller may not have a dtype in mind when
1345      converting to a tensor, so preferred_dtype can be used as a soft
1346      preference.  If the conversion to `preferred_dtype` is not possible, this
1347      argument has no effect.
1348    dtype_hint: same meaning as preferred_dtype, and overrides it.
1349
1350  Returns:
1351    A `Tensor` based on `value`.
1352
1353  Raises:
1354    TypeError: If no conversion function is registered for `value` to `dtype`.
1355    RuntimeError: If a registered conversion function returns an invalid value.
1356    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1357  """
1358  return convert_to_tensor_v1(value, dtype=dtype, name=name,
1359                              preferred_dtype=preferred_dtype,
1360                              dtype_hint=dtype_hint)
1361
1362
1363def convert_to_tensor_v1(value,
1364                         dtype=None,
1365                         name=None,
1366                         preferred_dtype=None,
1367                         dtype_hint=None):
1368  """Converts the given `value` to a `Tensor` (with the TF1 API)."""
1369  preferred_dtype = deprecation.deprecated_argument_lookup(
1370      "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
1371  return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
1372
1373
1374@tf_export("convert_to_tensor", v1=[])
1375@dispatch.add_dispatch_support
1376def convert_to_tensor_v2_with_dispatch(
1377    value, dtype=None, dtype_hint=None, name=None):
1378  """Converts the given `value` to a `Tensor`.
1379
1380  This function converts Python objects of various types to `Tensor`
1381  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1382  and Python scalars.
1383
1384  For example:
1385
1386  >>> import numpy as np
1387  >>> def my_func(arg):
1388  ...   arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1389  ...   return arg
1390
1391  >>> # The following calls are equivalent.
1392  ...
1393  >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1394  >>> print(value_1)
1395  tf.Tensor(
1396    [[1. 2.]
1397     [3. 4.]], shape=(2, 2), dtype=float32)
1398  >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1399  >>> print(value_2)
1400  tf.Tensor(
1401    [[1. 2.]
1402     [3. 4.]], shape=(2, 2), dtype=float32)
1403  >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1404  >>> print(value_3)
1405  tf.Tensor(
1406    [[1. 2.]
1407     [3. 4.]], shape=(2, 2), dtype=float32)
1408
1409  This function can be useful when composing a new operation in Python
1410  (such as `my_func` in the example above). All standard Python op
1411  constructors apply this function to each of their Tensor-valued
1412  inputs, which allows those ops to accept numpy arrays, Python lists,
1413  and scalars in addition to `Tensor` objects.
1414
1415  Note: This function diverges from default Numpy behavior for `float` and
1416    `string` types when `None` is present in a Python list or scalar. Rather
1417    than silently converting `None` values, an error will be thrown.
1418
1419  Args:
1420    value: An object whose type has a registered `Tensor` conversion function.
1421    dtype: Optional element type for the returned tensor. If missing, the type
1422      is inferred from the type of `value`.
1423    dtype_hint: Optional element type for the returned tensor, used when dtype
1424      is None. In some cases, a caller may not have a dtype in mind when
1425      converting to a tensor, so dtype_hint can be used as a soft preference.
1426      If the conversion to `dtype_hint` is not possible, this argument has no
1427      effect.
1428    name: Optional name to use if a new `Tensor` is created.
1429
1430  Returns:
1431    A `Tensor` based on `value`.
1432
1433  Raises:
1434    TypeError: If no conversion function is registered for `value` to `dtype`.
1435    RuntimeError: If a registered conversion function returns an invalid value.
1436    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1437  """
1438  return convert_to_tensor_v2(
1439      value, dtype=dtype, dtype_hint=dtype_hint, name=name)
1440
1441
1442def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
1443  """Converts the given `value` to a `Tensor`."""
1444  return convert_to_tensor(
1445      value=value,
1446      dtype=dtype,
1447      name=name,
1448      preferred_dtype=dtype_hint,
1449      as_ref=False)
1450
1451
1452def _error_prefix(name):
1453  return "" if name is None else "%s: " % name
1454
1455
1456def pack_eager_tensors(tensors, ctx=None):
1457  """Pack multiple `EagerTensor`s of the same dtype and shape.
1458
1459  Args:
1460    tensors: a list of EagerTensors to pack.
1461    ctx: context.context().
1462
1463  Returns:
1464    A packed EagerTensor.
1465  """
1466  if not isinstance(tensors, list):
1467    raise TypeError("tensors must be a list or a tuple: %s" % tensors)
1468
1469  if not tensors:
1470    raise ValueError("Empty tensors is unexpected for packing.")
1471
1472  dtype = tensors[0].dtype
1473  shape = tensors[0].shape
1474  handle_data = tensors[0]._handle_data  # pylint: disable=protected-access
1475  is_resource = dtype == dtypes.resource
1476  for i in range(len(tensors)):
1477    t = tensors[i]
1478    if not isinstance(t, EagerTensor):
1479      raise TypeError("tensors must be a list of EagerTensors: %s" % t)
1480
1481    if t.dtype != dtype:
1482      raise ValueError(
1483          "All tensors being packed should have the same dtype %s, "
1484          "but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype))
1485    if t.shape != shape:
1486      raise ValueError(
1487          "All tensors being packed should have the same shape %s, "
1488          "but the %d-th tensor is of shape %s" % (shape, i, t.shape))
1489    # pylint: disable=protected-access
1490    if is_resource and t._handle_data != handle_data:
1491      raise ValueError(
1492          "All tensors being packed should have the same handle data %s, "
1493          "but the %d-th tensor is of handle data %s" %
1494          (handle_data, i, t._handle_data))
1495    # pylint: enable=protected-access
1496
1497  if ctx is None:
1498    ctx = context.context()
1499
1500  # Propogate handle data for resource variables
1501  packed_tensor = ctx.pack_eager_tensors(tensors)
1502  if handle_data is not None:
1503    packed_tensor._handle_data = handle_data  # pylint: disable=protected-access
1504
1505  def grad_fun(_):
1506    raise ValueError(
1507        "Gradients through pack_eager_tensors are not supported yet.")
1508
1509  tape.record_operation("pack_eager_tensors", [packed_tensor], tensors,
1510                        grad_fun)
1511
1512  return packed_tensor
1513
1514
1515@trace.trace_wrapper("convert_to_tensor")
1516def convert_to_tensor(value,
1517                      dtype=None,
1518                      name=None,
1519                      as_ref=False,
1520                      preferred_dtype=None,
1521                      dtype_hint=None,
1522                      ctx=None,
1523                      accepted_result_types=(Tensor,)):
1524  """Implementation of the public convert_to_tensor."""
1525  # TODO(b/142518781): Fix all call-sites and remove redundant arg
1526  preferred_dtype = preferred_dtype or dtype_hint
1527  if isinstance(value, EagerTensor):
1528    if ctx is None:
1529      ctx = context.context()
1530    if not ctx.executing_eagerly():
1531      graph = get_default_graph()
1532      if not graph.building_function:
1533        raise RuntimeError("Attempting to capture an EagerTensor without "
1534                           "building a function.")
1535      return graph.capture(value, name=name)
1536
1537  if dtype is not None:
1538    dtype = dtypes.as_dtype(dtype)
1539  if isinstance(value, Tensor):
1540    if dtype is not None and not dtype.is_compatible_with(value.dtype):
1541      raise ValueError(
1542          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1543          (dtype.name, value.dtype.name, value))
1544    return value
1545
1546  if preferred_dtype is not None:
1547    preferred_dtype = dtypes.as_dtype(preferred_dtype)
1548
1549  # See below for the reason why it's `type(value)` and not just `value`.
1550  # https://docs.python.org/3.8/reference/datamodel.html#special-lookup
1551  overload = getattr(type(value), "__tf_tensor__", None)
1552  if overload is not None:
1553    return overload(value, dtype, name)  #  pylint: disable=not-callable
1554
1555  for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
1556    # If dtype is None but preferred_dtype is not None, we try to
1557    # cast to preferred_dtype first.
1558    ret = None
1559    if dtype is None and preferred_dtype is not None:
1560      try:
1561        ret = conversion_func(
1562            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1563      except (TypeError, ValueError):
1564        # Could not coerce the conversion to use the preferred dtype.
1565        pass
1566      else:
1567        if (ret is not NotImplemented and
1568            ret.dtype.base_dtype != preferred_dtype.base_dtype):
1569          raise TypeError("convert_to_tensor did not convert to "
1570                          "the preferred dtype: %s vs %s " %
1571                          (ret.dtype.base_dtype, preferred_dtype.base_dtype))
1572
1573    if ret is None:
1574      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1575
1576    if ret is NotImplemented:
1577      continue
1578
1579    if not isinstance(ret, accepted_result_types):
1580      raise RuntimeError(
1581          "%sConversion function %r for type %s returned non-Tensor: %r" %
1582          (_error_prefix(name), conversion_func, base_type, ret))
1583    if dtype and not dtype.is_compatible_with(ret.dtype):
1584      raise RuntimeError(
1585          "%sConversion function %r for type %s returned incompatible "
1586          "dtype: requested = %s, actual = %s" %
1587          (_error_prefix(name), conversion_func, base_type, dtype.name,
1588           ret.dtype.name))
1589    return ret
1590  raise TypeError("%sCannot convert %r with type %s to Tensor: "
1591                  "no conversion function registered." %
1592                  (_error_prefix(name), value, type(value)))
1593
1594
1595internal_convert_to_tensor = convert_to_tensor
1596
1597
1598def internal_convert_n_to_tensor(values,
1599                                 dtype=None,
1600                                 name=None,
1601                                 as_ref=False,
1602                                 preferred_dtype=None,
1603                                 ctx=None):
1604  """Converts `values` to a list of `Tensor` objects.
1605
1606  Args:
1607    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1608    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1609    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1610      which case element `i` will be given the name `name + '_' + i`.
1611    as_ref: True if the caller wants the results as ref tensors.
1612    preferred_dtype: Optional element type for the returned tensors, used when
1613      dtype is None. In some cases, a caller may not have a dtype in mind when
1614      converting to a tensor, so preferred_dtype can be used as a soft
1615      preference.  If the conversion to `preferred_dtype` is not possible, this
1616      argument has no effect.
1617    ctx: The value of context.context().
1618
1619  Returns:
1620    A list of `Tensor` and/or `IndexedSlices` objects.
1621
1622  Raises:
1623    TypeError: If no conversion function is registered for an element in
1624      `values`.
1625    RuntimeError: If a registered conversion function returns an invalid
1626      value.
1627  """
1628  if not isinstance(values, collections_abc.Sequence):
1629    raise TypeError("values must be a sequence.")
1630  ret = []
1631  if ctx is None:
1632    ctx = context.context()
1633  for i, value in enumerate(values):
1634    n = None if name is None else "%s_%d" % (name, i)
1635    ret.append(
1636        convert_to_tensor(
1637            value,
1638            dtype=dtype,
1639            name=n,
1640            as_ref=as_ref,
1641            preferred_dtype=preferred_dtype,
1642            ctx=ctx))
1643  return ret
1644
1645
1646def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1647  """Converts `values` to a list of `Tensor` objects.
1648
1649  Args:
1650    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1651    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1652    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1653      which case element `i` will be given the name `name + '_' + i`.
1654    preferred_dtype: Optional element type for the returned tensors, used when
1655      dtype is None. In some cases, a caller may not have a dtype in mind when
1656      converting to a tensor, so preferred_dtype can be used as a soft
1657      preference.  If the conversion to `preferred_dtype` is not possible, this
1658      argument has no effect.
1659
1660  Returns:
1661    A list of `Tensor` and/or `IndexedSlices` objects.
1662
1663  Raises:
1664    TypeError: If no conversion function is registered for an element in
1665      `values`.
1666    RuntimeError: If a registered conversion function returns an invalid
1667      value.
1668  """
1669  return internal_convert_n_to_tensor(
1670      values=values,
1671      dtype=dtype,
1672      name=name,
1673      preferred_dtype=preferred_dtype,
1674      as_ref=False)
1675
1676
1677def convert_to_tensor_or_composite(value, dtype=None, name=None):
1678  """Converts the given object to a `Tensor` or `CompositeTensor`.
1679
1680  If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1681  is converted to a `Tensor` using `convert_to_tensor()`.
1682
1683  Args:
1684    value: A `CompositeTensor` or an object that can be consumed by
1685      `convert_to_tensor()`.
1686    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1687      `CompositeTensor`.
1688    name: (Optional.) A name to use if a new `Tensor` is created.
1689
1690  Returns:
1691    A `Tensor` or `CompositeTensor`, based on `value`.
1692
1693  Raises:
1694    ValueError: If `dtype` does not match the element type of `value`.
1695  """
1696  return internal_convert_to_tensor_or_composite(
1697      value=value, dtype=dtype, name=name, as_ref=False)
1698
1699
1700def internal_convert_to_tensor_or_composite(value,
1701                                            dtype=None,
1702                                            name=None,
1703                                            as_ref=False):
1704  """Converts the given object to a `Tensor` or `CompositeTensor`.
1705
1706  If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
1707  is converted to a `Tensor` using `convert_to_tensor()`.
1708
1709  Args:
1710    value: A `CompositeTensor`, or an object that can be consumed by
1711      `convert_to_tensor()`.
1712    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1713      `CompositeTensor`.
1714    name: (Optional.) A name to use if a new `Tensor` is created.
1715    as_ref: True if the caller wants the results as ref tensors.
1716
1717  Returns:
1718    A `Tensor` or `CompositeTensor`, based on `value`.
1719
1720  Raises:
1721    ValueError: If `dtype` does not match the element type of `value`.
1722  """
1723  if isinstance(value, composite_tensor.CompositeTensor):
1724    value_dtype = getattr(value, "dtype", None)
1725    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1726      raise ValueError(
1727          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1728          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1729    return value
1730  else:
1731    return convert_to_tensor(
1732        value,
1733        dtype=dtype,
1734        name=name,
1735        as_ref=as_ref,
1736        accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
1737
1738
1739def internal_convert_n_to_tensor_or_composite(values,
1740                                              dtype=None,
1741                                              name=None,
1742                                              as_ref=False):
1743  """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1744
1745  Any `CompositeTensor` objects in `values` are returned unmodified.
1746
1747  Args:
1748    values: A list of `None`, `CompositeTensor`, or objects that can be consumed
1749      by `convert_to_tensor()`.
1750    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1751      `CompositeTensor`s.
1752    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1753      which case element `i` will be given the name `name + '_' + i`.
1754    as_ref: True if the caller wants the results as ref tensors.
1755
1756  Returns:
1757    A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1758
1759  Raises:
1760    TypeError: If no conversion function is registered for an element in
1761      `values`.
1762    RuntimeError: If a registered conversion function returns an invalid
1763      value.
1764  """
1765  if not isinstance(values, collections_abc.Sequence):
1766    raise TypeError("values must be a sequence.")
1767  ret = []
1768  for i, value in enumerate(values):
1769    if value is None:
1770      ret.append(value)
1771    else:
1772      n = None if name is None else "%s_%d" % (name, i)
1773      ret.append(
1774          internal_convert_to_tensor_or_composite(
1775              value, dtype=dtype, name=n, as_ref=as_ref))
1776  return ret
1777
1778
1779def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1780  """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1781
1782  Any `CompositeTensor` objects in `values` are returned unmodified.
1783
1784  Args:
1785    values: A list of `None`, `CompositeTensor``, or objects that can be
1786      consumed by `convert_to_tensor()`.
1787    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1788      `CompositeTensor`s.
1789    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1790      which case element `i` will be given the name `name + '_' + i`.
1791
1792  Returns:
1793    A list of `Tensor` and/or `CompositeTensor` objects.
1794
1795  Raises:
1796    TypeError: If no conversion function is registered for an element in
1797      `values`.
1798    RuntimeError: If a registered conversion function returns an invalid
1799      value.
1800  """
1801  return internal_convert_n_to_tensor_or_composite(
1802      values=values, dtype=dtype, name=name, as_ref=False)
1803
1804
1805def _device_string(dev_spec):
1806  if pydev.is_device_spec(dev_spec):
1807    return dev_spec.to_string()
1808  else:
1809    return dev_spec
1810
1811
1812def _NodeDef(op_type, name, attrs=None):
1813  """Create a NodeDef proto.
1814
1815  Args:
1816    op_type: Value for the "op" attribute of the NodeDef proto.
1817    name: Value for the "name" attribute of the NodeDef proto.
1818    attrs: Dictionary where the key is the attribute name (a string)
1819      and the value is the respective "attr" attribute of the NodeDef proto (an
1820      AttrValue).
1821
1822  Returns:
1823    A node_def_pb2.NodeDef protocol buffer.
1824  """
1825  node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type),
1826                                  name=compat.as_bytes(name))
1827  if attrs:
1828    for k, v in six.iteritems(attrs):
1829      node_def.attr[k].CopyFrom(v)
1830  return node_def
1831
1832
1833# Copied from core/framework/node_def_util.cc
1834# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1835_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
1836_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$")
1837
1838
1839@tf_export("__internal__.create_c_op", v1=[])
1840@traceback_utils.filter_traceback
1841def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
1842  """Creates a TF_Operation.
1843
1844  Args:
1845    graph: a `Graph`.
1846    node_def: `node_def_pb2.NodeDef` for the operation to create.
1847    inputs: A flattened list of `Tensor`s. This function handles grouping
1848      tensors into lists as per attributes in the `node_def`.
1849    control_inputs: A list of `Operation`s to set as control dependencies.
1850    op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not
1851      specified, is looked up from the `graph` using `node_def.op`.
1852
1853  Returns:
1854    A wrapped TF_Operation*.
1855  """
1856  if op_def is None:
1857    op_def = graph._get_op_def(node_def.op)  # pylint: disable=protected-access
1858  # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1859  # Refactor so we don't have to do this here.
1860  inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
1861  # pylint: disable=protected-access
1862  op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph,
1863                                              compat.as_str(node_def.op),
1864                                              compat.as_str(node_def.name))
1865  if node_def.device:
1866    pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1867  # Add inputs
1868  for op_input in inputs:
1869    if isinstance(op_input, (list, tuple)):
1870      pywrap_tf_session.TF_AddInputList(op_desc,
1871                                        [t._as_tf_output() for t in op_input])
1872    else:
1873      pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
1874
1875  # Add control inputs
1876  for control_input in control_inputs:
1877    pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
1878  # pylint: enable=protected-access
1879
1880  # Add attrs
1881  for name, attr_value in node_def.attr.items():
1882    serialized = attr_value.SerializeToString()
1883    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1884    # It might be worth creating a convenient way to re-use the same status.
1885    pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
1886                                           serialized)
1887
1888  try:
1889    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1890  except errors.InvalidArgumentError as e:
1891    # Convert to ValueError for backwards compatibility.
1892    raise ValueError(e.message)
1893
1894  return c_op
1895
1896
1897@tf_export("Operation")
1898class Operation(object):
1899  """Represents a graph node that performs computation on tensors.
1900
1901  An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
1902  objects as input, and produces zero or more `Tensor` objects as output.
1903  Objects of type `Operation` are created by calling a Python op constructor
1904  (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default`
1905  context manager.
1906
1907  For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an
1908  `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and
1909  produces `c` as output.
1910
1911  If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be
1912  executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for
1913  calling `tf.compat.v1.get_default_session().run(op)`.
1914  """
1915
1916  def __init__(self,
1917               node_def,
1918               g,
1919               inputs=None,
1920               output_types=None,
1921               control_inputs=None,
1922               input_types=None,
1923               original_op=None,
1924               op_def=None):
1925    r"""Creates an `Operation`.
1926
1927    NOTE: This constructor validates the name of the `Operation` (passed
1928    as `node_def.name`). Valid `Operation` names match the following
1929    regular expression:
1930
1931        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1932
1933    Args:
1934      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`. Used for
1935        attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and
1936        `device`.  The `input` attribute is irrelevant here as it will be
1937        computed when generating the model.
1938      g: `Graph`. The parent graph.
1939      inputs: list of `Tensor` objects. The inputs to this `Operation`.
1940      output_types: list of `DType` objects.  List of the types of the `Tensors`
1941        computed by this operation.  The length of this list indicates the
1942        number of output endpoints of the `Operation`.
1943      control_inputs: list of operations or tensors from which to have a control
1944        dependency.
1945      input_types: List of `DType` objects representing the types of the tensors
1946        accepted by the `Operation`.  By default uses `[x.dtype.base_dtype for x
1947        in inputs]`.  Operations that expect reference-typed inputs must specify
1948        these explicitly.
1949      original_op: Optional. Used to associate the new `Operation` with an
1950        existing `Operation` (for example, a replica with the op that was
1951        replicated).
1952      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type
1953        that this `Operation` represents.
1954
1955    Raises:
1956      TypeError: if control inputs are not Operations or Tensors,
1957        or if `node_def` is not a `NodeDef`,
1958        or if `g` is not a `Graph`,
1959        or if `inputs` are not tensors,
1960        or if `inputs` and `input_types` are incompatible.
1961      ValueError: if the `node_def` name is not valid.
1962    """
1963    # For internal use only: `node_def` can be set to a TF_Operation to create
1964    # an Operation for that op. This is useful for creating Operations for ops
1965    # indirectly created by C API methods, e.g. the ops created by
1966    # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
1967    # should be None.
1968
1969    if isinstance(node_def, node_def_pb2.NodeDef):
1970      if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1971        raise ValueError(
1972            "Cannot create a tensor proto whose content is larger than 2GB.")
1973      if not _VALID_OP_NAME_REGEX.match(node_def.name):
1974        raise ValueError("'%s' is not a valid node name" % node_def.name)
1975      c_op = None
1976    elif type(node_def).__name__ == "TF_Operation":
1977      assert inputs is None
1978      assert output_types is None
1979      assert control_inputs is None
1980      assert input_types is None
1981      assert original_op is None
1982      assert op_def is None
1983      c_op = node_def
1984    else:
1985      raise TypeError("node_def needs to be a NodeDef: %s" % (node_def,))
1986
1987    if not isinstance(g, Graph):
1988      raise TypeError("g needs to be a Graph: %s" % (g,))
1989    self._graph = g
1990
1991    if inputs is None:
1992      inputs = []
1993    elif not isinstance(inputs, list):
1994      raise TypeError("inputs needs to be a list of Tensors: %s" % (inputs,))
1995    for a in inputs:
1996      if not isinstance(a, Tensor):
1997        raise TypeError("input needs to be a Tensor: %s" % (a,))
1998    if input_types is None:
1999      input_types = [i.dtype.base_dtype for i in inputs]
2000    else:
2001      if not all(
2002          x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)):
2003        raise TypeError("In op '%s', input types (%s) are not compatible "
2004                        "with expected types (%s)" %
2005                        (node_def.name, [i.dtype for i in inputs], input_types))
2006
2007    # Build the list of control inputs.
2008    control_input_ops = []
2009    if control_inputs:
2010      for c in control_inputs:
2011        control_op = None
2012        if isinstance(c, Operation):
2013          control_op = c
2014        elif isinstance(c, (Tensor, IndexedSlices)):
2015          control_op = c.op
2016        else:
2017          raise TypeError("Control input must be an Operation, "
2018                          "a Tensor, or IndexedSlices: %s" % c)
2019        control_input_ops.append(control_op)
2020
2021    # This will be set by self.inputs.
2022    self._inputs_val = None
2023
2024    # pylint: disable=protected-access
2025    self._original_op = original_op
2026
2027    # List of _UserDevSpecs holding code location of device context manager
2028    # invocations and the users original argument to them.
2029    self._device_code_locations = None
2030    # Dict mapping op name to file and line information for op colocation
2031    # context managers.
2032    self._colocation_code_locations = None
2033    self._control_flow_context = self.graph._get_control_flow_context()
2034
2035    # Gradient function for this op. There are three ways to specify gradient
2036    # function, and first available gradient gets used, in the following order.
2037    # 1. self._gradient_function
2038    # 2. Gradient name registered by "_gradient_op_type" attribute.
2039    # 3. Gradient name registered by op.type.
2040    self._gradient_function = None
2041
2042    # Initialize self._c_op.
2043    if c_op:
2044      self._c_op = c_op
2045      op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))
2046      name = self.name
2047    else:
2048      if op_def is None:
2049        op_def = self._graph._get_op_def(node_def.op)
2050      self._c_op = _create_c_op(self._graph, node_def, inputs,
2051                                control_input_ops, op_def)
2052      name = compat.as_str(node_def.name)
2053
2054    self._traceback = tf_stack.extract_stack_for_node(self._c_op)
2055
2056    # pylint: enable=protected-access
2057
2058    self._is_stateful = op_def.is_stateful
2059
2060    # Initialize self._outputs.
2061    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2062    self._outputs = []
2063    for i in range(num_outputs):
2064      tf_output = c_api_util.tf_output(self._c_op, i)
2065      output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
2066      tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output)  # pylint: disable=protected-access
2067      self._outputs.append(tensor)
2068
2069    self._id_value = self._graph._add_op(self, name)  # pylint: disable=protected-access
2070
2071    if not c_op:
2072      self._control_flow_post_processing(input_tensors=inputs)
2073
2074  def _control_flow_post_processing(self, input_tensors=None):
2075    """Add this op to its control flow context.
2076
2077    This may add new ops and change this op's inputs. self.inputs must be
2078    available before calling this method.
2079
2080    Args:
2081      input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
2082        of this op, which should be equivalent to `self.inputs`. Pass this
2083        argument to avoid evaluating `self.inputs` unnecessarily.
2084    """
2085    if input_tensors is None:
2086      input_tensors = self.inputs
2087    for input_tensor in input_tensors:
2088      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
2089    if self._control_flow_context is not None:
2090      self._control_flow_context.AddOp(self)
2091
2092  def colocation_groups(self):
2093    """Returns the list of colocation groups of the op."""
2094    default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)]
2095    try:
2096      class_attr = self.get_attr("_class")
2097    except ValueError:
2098      # This op has no explicit colocation group, so it is itself its
2099      # own root of a colocation group.
2100      return default_colocation_group
2101
2102    attr_groups = [
2103        class_name for class_name in class_attr
2104        if class_name.startswith(b"loc:@")
2105    ]
2106
2107    # If there are no colocation groups in the explicit _class field,
2108    # return the default colocation group.
2109    return attr_groups if attr_groups else default_colocation_group
2110
2111  def values(self):
2112    """DEPRECATED: Use outputs."""
2113    return tuple(self.outputs)
2114
2115  def _get_control_flow_context(self):
2116    """Returns the control flow context of this op.
2117
2118    Returns:
2119      A context object.
2120    """
2121    return self._control_flow_context
2122
2123  def _set_control_flow_context(self, ctx):
2124    """Sets the current control flow context of this op.
2125
2126    Args:
2127      ctx: a context object.
2128    """
2129    self._control_flow_context = ctx
2130
2131  @property
2132  def name(self):
2133    """The full name of this operation."""
2134    return pywrap_tf_session.TF_OperationName(self._c_op)
2135
2136  @property
2137  def _id(self):
2138    """The unique integer id of this operation."""
2139    return self._id_value
2140
2141  @property
2142  def device(self):
2143    """The name of the device to which this op has been assigned, if any.
2144
2145    Returns:
2146      The string name of the device to which this op has been
2147      assigned, or an empty string if it has not been assigned to a
2148      device.
2149    """
2150    return pywrap_tf_session.TF_OperationDevice(self._c_op)
2151
2152  @property
2153  def _device_assignments(self):
2154    """Code locations for device context managers active at op creation.
2155
2156    This property will return a list of traceable_stack.TraceableObject
2157    instances where .obj is a string representing the assigned device
2158    (or information about the function that would be applied to this op
2159    to compute the desired device) and the filename and lineno members
2160    record the location of the relevant device context manager.
2161
2162    For example, suppose file_a contained these lines:
2163
2164      file_a.py:
2165        15: with tf.device('/gpu:0'):
2166        16:   node_b = tf.constant(4, name='NODE_B')
2167
2168    Then a TraceableObject t_obj representing the device context manager
2169    would have these member values:
2170
2171      t_obj.obj -> '/gpu:0'
2172      t_obj.filename = 'file_a.py'
2173      t_obj.lineno = 15
2174
2175    and node_b.op._device_assignments would return the list [t_obj].
2176
2177    Returns:
2178      [str: traceable_stack.TraceableObject, ...] as per this method's
2179      description, above.
2180    """
2181    return self._device_code_locations or []
2182
2183  @property
2184  def _colocation_dict(self):
2185    """Code locations for colocation context managers active at op creation.
2186
2187    This property will return a dictionary for which the keys are nodes with
2188    which this Operation is colocated, and for which the values are
2189    traceable_stack.TraceableObject instances.  The TraceableObject instances
2190    record the location of the relevant colocation context manager but have the
2191    "obj" field set to None to prevent leaking private data.
2192
2193    For example, suppose file_a contained these lines:
2194
2195      file_a.py:
2196        14: node_a = tf.constant(3, name='NODE_A')
2197        15: with tf.compat.v1.colocate_with(node_a):
2198        16:   node_b = tf.constant(4, name='NODE_B')
2199
2200    Then a TraceableObject t_obj representing the colocation context manager
2201    would have these member values:
2202
2203      t_obj.obj -> None
2204      t_obj.filename = 'file_a.py'
2205      t_obj.lineno = 15
2206
2207    and node_b.op._colocation_dict would return the dictionary
2208
2209      { 'NODE_A': t_obj }
2210
2211    Returns:
2212      {str: traceable_stack.TraceableObject} as per this method's description,
2213      above.
2214    """
2215    locations_dict = self._colocation_code_locations or {}
2216    return locations_dict.copy()
2217
2218  @property
2219  def _output_types(self):
2220    """List this operation's output types.
2221
2222    Returns:
2223      List of the types of the Tensors computed by this operation.
2224      Each element in the list is an integer whose value is one of
2225      the TF_DataType enums defined in pywrap_tf_session.h
2226      The length of this list indicates the number of output endpoints
2227      of the operation.
2228    """
2229    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2230    output_types = [
2231        int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
2232        for i in xrange(num_outputs)
2233    ]
2234
2235    return output_types
2236
2237  def _tf_output(self, output_idx):
2238    """Create and return a new TF_Output for output_idx'th output of this op."""
2239    tf_output = pywrap_tf_session.TF_Output()
2240    tf_output.oper = self._c_op
2241    tf_output.index = output_idx
2242    return tf_output
2243
2244  def _tf_input(self, input_idx):
2245    """Create and return a new TF_Input for input_idx'th input of this op."""
2246    tf_input = pywrap_tf_session.TF_Input()
2247    tf_input.oper = self._c_op
2248    tf_input.index = input_idx
2249    return tf_input
2250
2251  def _set_device(self, device):  # pylint: disable=redefined-outer-name
2252    """Set the device of this operation.
2253
2254    Args:
2255      device: string or device..  The device to set.
2256    """
2257    self._set_device_from_string(compat.as_str(_device_string(device)))
2258
2259  def _set_device_from_string(self, device_str):
2260    """Fast path to set device if the type is known to be a string.
2261
2262    This function is called frequently enough during graph construction that
2263    there are non-trivial performance gains if the caller can guarantee that
2264    the specified device is already a string.
2265
2266    Args:
2267      device_str: A string specifying where to place this op.
2268    """
2269    pywrap_tf_session.SetRequestedDevice(
2270        self._graph._c_graph,  # pylint: disable=protected-access
2271        self._c_op,  # pylint: disable=protected-access
2272        device_str)
2273
2274  def _update_input(self, index, tensor):
2275    """Update the input to this operation at the given index.
2276
2277    NOTE: This is for TF internal use only. Please don't use it.
2278
2279    Args:
2280      index: the index of the input to update.
2281      tensor: the Tensor to be used as the input at the given index.
2282
2283    Raises:
2284      TypeError: if tensor is not a Tensor,
2285        or if input tensor type is not convertible to dtype.
2286      ValueError: if the Tensor is from a different graph.
2287    """
2288    if not isinstance(tensor, Tensor):
2289      raise TypeError("tensor must be a Tensor: %s" % tensor)
2290    _assert_same_graph(self, tensor)
2291
2292    # Reset cached inputs.
2293    self._inputs_val = None
2294    pywrap_tf_session.UpdateEdge(
2295        self._graph._c_graph,  # pylint: disable=protected-access
2296        tensor._as_tf_output(),  # pylint: disable=protected-access
2297        self._tf_input(index))
2298
2299  def _add_while_inputs(self, tensors):
2300    """See AddWhileInputHack in python_api.h.
2301
2302    NOTE: This is for TF internal use only. Please don't use it.
2303
2304    Args:
2305      tensors: list of Tensors
2306
2307    Raises:
2308      TypeError: if tensor is not a Tensor,
2309        or if input tensor type is not convertible to dtype.
2310      ValueError: if the Tensor is from a different graph.
2311    """
2312    for tensor in tensors:
2313      if not isinstance(tensor, Tensor):
2314        raise TypeError("tensor must be a Tensor: %s" % tensor)
2315      _assert_same_graph(self, tensor)
2316
2317      # Reset cached inputs.
2318      self._inputs_val = None
2319      pywrap_tf_session.AddWhileInputHack(
2320          self._graph._c_graph,  # pylint: disable=protected-access
2321          tensor._as_tf_output(),  # pylint: disable=protected-access
2322          self._c_op)
2323
2324  def _add_control_inputs(self, ops):
2325    """Add a list of new control inputs to this operation.
2326
2327    Args:
2328      ops: the list of Operations to add as control input.
2329
2330    Raises:
2331      TypeError: if ops is not a list of Operations.
2332      ValueError: if any op in ops is from a different graph.
2333    """
2334    for op in ops:
2335      if not isinstance(op, Operation):
2336        raise TypeError("op must be an Operation: %s" % op)
2337      pywrap_tf_session.AddControlInput(
2338          self._graph._c_graph,  # pylint: disable=protected-access
2339          self._c_op,  # pylint: disable=protected-access
2340          op._c_op)  # pylint: disable=protected-access
2341
2342  def _add_control_input(self, op):
2343    """Add a new control input to this operation.
2344
2345    Args:
2346      op: the Operation to add as control input.
2347
2348    Raises:
2349      TypeError: if op is not an Operation.
2350      ValueError: if op is from a different graph.
2351    """
2352    if not isinstance(op, Operation):
2353      raise TypeError("op must be an Operation: %s" % op)
2354    pywrap_tf_session.AddControlInput(
2355        self._graph._c_graph,  # pylint: disable=protected-access
2356        self._c_op,  # pylint: disable=protected-access
2357        op._c_op)  # pylint: disable=protected-access
2358
2359  def _remove_all_control_inputs(self):
2360    """Removes any control inputs to this operation."""
2361    pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
2362
2363  def _add_outputs(self, types, shapes):
2364    """Adds new Tensors to self.outputs.
2365
2366    Note: this is generally unsafe to use. This is used in certain situations in
2367    conjunction with _set_type_list_attr.
2368
2369    Args:
2370      types: list of DTypes
2371      shapes: list of TensorShapes
2372    """
2373    assert len(types) == len(shapes)
2374    orig_num_outputs = len(self.outputs)
2375    for i in range(len(types)):
2376      t = Tensor(self, orig_num_outputs + i, types[i])
2377      self._outputs.append(t)
2378      t.set_shape(shapes[i])
2379
2380  def __str__(self):
2381    return str(self.node_def)
2382
2383  def __repr__(self):
2384    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2385
2386  def __tf_tensor__(self, dtype=None, name=None):
2387    """Raises a helpful error."""
2388    raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
2389
2390  @property
2391  def outputs(self):
2392    """The list of `Tensor` objects representing the outputs of this op."""
2393    return self._outputs
2394
2395  @property
2396  def inputs(self):
2397    """The sequence of `Tensor` objects representing the data inputs of this op."""
2398    if self._inputs_val is None:
2399      # pylint: disable=protected-access
2400      self._inputs_val = tuple(
2401          map(self.graph._get_tensor_by_tf_output,
2402              pywrap_tf_session.GetOperationInputs(self._c_op)))
2403      # pylint: enable=protected-access
2404    return self._inputs_val
2405
2406  @property
2407  def _input_types(self):
2408    num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
2409    input_types = [
2410        dtypes.as_dtype(
2411            pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
2412        for i in xrange(num_inputs)
2413    ]
2414    return input_types
2415
2416  @property
2417  def control_inputs(self):
2418    """The `Operation` objects on which this op has a control dependency.
2419
2420    Before this op is executed, TensorFlow will ensure that the
2421    operations in `self.control_inputs` have finished executing. This
2422    mechanism can be used to run ops sequentially for performance
2423    reasons, or to ensure that the side effects of an op are observed
2424    in the correct order.
2425
2426    Returns:
2427      A list of `Operation` objects.
2428
2429    """
2430    control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
2431        self._c_op)
2432    # pylint: disable=protected-access
2433    return [
2434        self.graph._get_operation_by_name_unsafe(
2435            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2436    ]
2437    # pylint: enable=protected-access
2438
2439  @property
2440  def _control_outputs(self):
2441    """The `Operation` objects which have a control dependency on this op.
2442
2443    Before any of the ops in self._control_outputs can execute tensorflow will
2444    ensure self has finished executing.
2445
2446    Returns:
2447      A list of `Operation` objects.
2448
2449    """
2450    control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
2451        self._c_op)
2452    # pylint: disable=protected-access
2453    return [
2454        self.graph._get_operation_by_name_unsafe(
2455            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2456    ]
2457    # pylint: enable=protected-access
2458
2459  @property
2460  def type(self):
2461    """The type of the op (e.g. `"MatMul"`)."""
2462    return pywrap_tf_session.TF_OperationOpType(self._c_op)
2463
2464  @property
2465  def graph(self):
2466    """The `Graph` that contains this operation."""
2467    return self._graph
2468
2469  @property
2470  def node_def(self):
2471    # pylint: disable=line-too-long
2472    """Returns the `NodeDef` representation of this operation.
2473
2474    Returns:
2475      A
2476      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2477      protocol buffer.
2478    """
2479    # pylint: enable=line-too-long
2480    with c_api_util.tf_buffer() as buf:
2481      pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
2482      data = pywrap_tf_session.TF_GetBuffer(buf)
2483    node_def = node_def_pb2.NodeDef()
2484    node_def.ParseFromString(compat.as_bytes(data))
2485    return node_def
2486
2487  @property
2488  def op_def(self):
2489    # pylint: disable=line-too-long
2490    """Returns the `OpDef` proto that represents the type of this op.
2491
2492    Returns:
2493      An
2494      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2495      protocol buffer.
2496    """
2497    # pylint: enable=line-too-long
2498    return self._graph._get_op_def(self.type)
2499
2500  @property
2501  def traceback(self):
2502    """Returns the call stack from when this operation was constructed."""
2503    return self._traceback
2504
2505  def _set_attr(self, attr_name, attr_value):
2506    """Private method used to set an attribute in the node_def."""
2507    buf = pywrap_tf_session.TF_NewBufferFromString(
2508        compat.as_bytes(attr_value.SerializeToString()))
2509    try:
2510      self._set_attr_with_buf(attr_name, buf)
2511    finally:
2512      pywrap_tf_session.TF_DeleteBuffer(buf)
2513
2514  def _set_attr_with_buf(self, attr_name, attr_buf):
2515    """Set an attr in the node_def with a pre-allocated buffer."""
2516    # pylint: disable=protected-access
2517    pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name,
2518                              attr_buf)
2519    # pylint: enable=protected-access
2520
2521  def _set_func_attr(self, attr_name, func_name):
2522    """Private method used to set a function attribute in the node_def."""
2523    func = attr_value_pb2.NameAttrList(name=func_name)
2524    self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2525
2526  def _set_func_list_attr(self, attr_name, func_names):
2527    """Private method used to set a list(function) attribute in the node_def."""
2528    funcs = [attr_value_pb2.NameAttrList(name=func_name)
2529             for func_name in func_names]
2530    funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
2531    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
2532
2533  def _set_type_list_attr(self, attr_name, types):
2534    """Private method used to set a list(type) attribute in the node_def."""
2535    if not types:
2536      return
2537    if isinstance(types[0], dtypes.DType):
2538      types = [dt.as_datatype_enum for dt in types]
2539    types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2540    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2541
2542  def _set_shape_list_attr(self, attr_name, shapes):
2543    """Private method used to set a list(shape) attribute in the node_def."""
2544    shapes = [s.as_proto() for s in shapes]
2545    shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2546    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2547
2548  def _clear_attr(self, attr_name):
2549    """Private method used to clear an attribute in the node_def."""
2550    # pylint: disable=protected-access
2551    pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
2552    # pylint: enable=protected-access
2553
2554  def get_attr(self, name):
2555    """Returns the value of the attr of this op with the given `name`.
2556
2557    Args:
2558      name: The name of the attr to fetch.
2559
2560    Returns:
2561      The value of the attr, as a Python object.
2562
2563    Raises:
2564      ValueError: If this op does not have an attr with the given `name`.
2565    """
2566    fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2567    try:
2568      with c_api_util.tf_buffer() as buf:
2569        pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2570        data = pywrap_tf_session.TF_GetBuffer(buf)
2571    except errors.InvalidArgumentError as e:
2572      # Convert to ValueError for backwards compatibility.
2573      raise ValueError(e.message)
2574    x = attr_value_pb2.AttrValue()
2575    x.ParseFromString(data)
2576
2577    oneof_value = x.WhichOneof("value")
2578    if oneof_value is None:
2579      return []
2580    if oneof_value == "list":
2581      for f in fields:
2582        if getattr(x.list, f):
2583          if f == "type":
2584            return [dtypes.as_dtype(t) for t in x.list.type]
2585          else:
2586            return list(getattr(x.list, f))
2587      return []
2588    if oneof_value == "type":
2589      return dtypes.as_dtype(x.type)
2590    assert oneof_value in fields, "Unsupported field type in " + str(x)
2591    return getattr(x, oneof_value)
2592
2593  def _get_attr_type(self, name):
2594    """Returns the `DType` value of the attr of this op with the given `name`."""
2595    try:
2596      dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
2597      return _DTYPES_INTERN_TABLE[dtype_enum]
2598    except errors.InvalidArgumentError as e:
2599      # Convert to ValueError for backwards compatibility.
2600      raise ValueError(e.message)
2601
2602  def _get_attr_bool(self, name):
2603    """Returns the `bool` value of the attr of this op with the given `name`."""
2604    try:
2605      return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
2606    except errors.InvalidArgumentError as e:
2607      # Convert to ValueError for backwards compatibility.
2608      raise ValueError(e.message)
2609
2610  def _get_attr_int(self, name):
2611    """Returns the `int` value of the attr of this op with the given `name`."""
2612    try:
2613      return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
2614    except errors.InvalidArgumentError as e:
2615      # Convert to ValueError for backwards compatibility.
2616      raise ValueError(e.message)
2617
2618  def run(self, feed_dict=None, session=None):
2619    """Runs this operation in a `Session`.
2620
2621    Calling this method will execute all preceding operations that
2622    produce the inputs needed for this operation.
2623
2624    *N.B.* Before invoking `Operation.run()`, its graph must have been
2625    launched in a session, and either a default session must be
2626    available, or `session` must be specified explicitly.
2627
2628    Args:
2629      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
2630        `tf.Session.run` for a description of the valid feed values.
2631      session: (Optional.) The `Session` to be used to run to this operation. If
2632        none, the default session will be used.
2633    """
2634    _run_using_default_session(self, feed_dict, self.graph, session)
2635
2636# TODO(b/185395742): Clean up usages of _gradient_registry
2637gradient_registry = _gradient_registry = registry.Registry("gradient")
2638
2639
2640@tf_export("RegisterGradient")
2641class RegisterGradient(object):
2642  """A decorator for registering the gradient function for an op type.
2643
2644  This decorator is only used when defining a new op type. For an op
2645  with `m` inputs and `n` outputs, the gradient function is a function
2646  that takes the original `Operation` and `n` `Tensor` objects
2647  (representing the gradients with respect to each output of the op),
2648  and returns `m` `Tensor` objects (representing the partial gradients
2649  with respect to each input of the op).
2650
2651  For example, assuming that operations of type `"Sub"` take two
2652  inputs `x` and `y`, and return a single output `x - y`, the
2653  following gradient function would be registered:
2654
2655  ```python
2656  @tf.RegisterGradient("Sub")
2657  def _sub_grad(unused_op, grad):
2658    return grad, tf.negative(grad)
2659  ```
2660
2661  The decorator argument `op_type` is the string type of an
2662  operation. This corresponds to the `OpDef.name` field for the proto
2663  that defines the operation.
2664  """
2665
2666  __slots__ = ["_op_type"]
2667
2668  def __init__(self, op_type):
2669    """Creates a new decorator with `op_type` as the Operation type.
2670
2671    Args:
2672      op_type: The string type of an operation. This corresponds to the
2673        `OpDef.name` field for the proto that defines the operation.
2674
2675    Raises:
2676      TypeError: If `op_type` is not string.
2677    """
2678    if not isinstance(op_type, six.string_types):
2679      raise TypeError("op_type must be a string")
2680    self._op_type = op_type
2681
2682  def __call__(self, f):
2683    """Registers the function `f` as gradient function for `op_type`."""
2684    gradient_registry.register(f, self._op_type)
2685    return f
2686
2687
2688@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2689@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2690def no_gradient(op_type):
2691  """Specifies that ops of type `op_type` is not differentiable.
2692
2693  This function should *not* be used for operations that have a
2694  well-defined gradient that is not yet implemented.
2695
2696  This function is only used when defining a new op type. It may be
2697  used for ops such as `tf.size()` that are not differentiable.  For
2698  example:
2699
2700  ```python
2701  tf.no_gradient("Size")
2702  ```
2703
2704  The gradient computed for 'op_type' will then propagate zeros.
2705
2706  For ops that have a well-defined gradient but are not yet implemented,
2707  no declaration should be made, and an error *must* be thrown if
2708  an attempt to request its gradient is made.
2709
2710  Args:
2711    op_type: The string type of an operation. This corresponds to the
2712      `OpDef.name` field for the proto that defines the operation.
2713
2714  Raises:
2715    TypeError: If `op_type` is not a string.
2716
2717  """
2718  if not isinstance(op_type, six.string_types):
2719    raise TypeError("op_type must be a string")
2720  gradient_registry.register(None, op_type)
2721
2722
2723# Aliases for the old names, will be eventually removed.
2724NoGradient = no_gradient
2725NotDifferentiable = no_gradient
2726
2727
2728def get_gradient_function(op):
2729  """Returns the function that computes gradients for "op"."""
2730  if not op.inputs:
2731    return None
2732
2733  gradient_function = op._gradient_function  # pylint: disable=protected-access
2734  if gradient_function:
2735    return gradient_function
2736
2737  try:
2738    op_type = op.get_attr("_gradient_op_type")
2739  except ValueError:
2740    op_type = op.type
2741  return gradient_registry.lookup(op_type)
2742
2743
2744def set_shape_and_handle_data_for_outputs(_):
2745  """No op. TODO(b/74620627): Remove this."""
2746  pass
2747
2748
2749class OpStats(object):
2750  """A holder for statistics about an operator.
2751
2752  This class holds information about the resource requirements for an op,
2753  including the size of its weight parameters on-disk and how many FLOPS it
2754  requires to execute forward inference.
2755
2756  If you define a new operation, you can create a function that will return a
2757  set of information about its usage of the CPU and disk space when serialized.
2758  The function itself takes a Graph object that's been set up so you can call
2759  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2760  argument.
2761
2762  """
2763
2764  __slots__ = ["_statistic_type", "_value"]
2765
2766  def __init__(self, statistic_type, value=None):
2767    """Sets up the initial placeholders for the statistics."""
2768    self.statistic_type = statistic_type
2769    self.value = value
2770
2771  @property
2772  def statistic_type(self):
2773    return self._statistic_type
2774
2775  @statistic_type.setter
2776  def statistic_type(self, statistic_type):
2777    self._statistic_type = statistic_type
2778
2779  @property
2780  def value(self):
2781    return self._value
2782
2783  @value.setter
2784  def value(self, value):
2785    self._value = value
2786
2787  def __iadd__(self, other):
2788    if other.statistic_type != self.statistic_type:
2789      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2790                       (self.statistic_type, other.statistic_type))
2791    if self.value is None:
2792      self.value = other.value
2793    elif other.value is not None:
2794      self._value += other.value
2795    return self
2796
2797
2798_stats_registry = registry.Registry("statistical functions")
2799
2800
2801class RegisterStatistics(object):
2802  """A decorator for registering the statistics function for an op type.
2803
2804  This decorator can be defined for an op type so that it gives a
2805  report on the resources used by an instance of an operator, in the
2806  form of an OpStats object.
2807
2808  Well-known types of statistics include these so far:
2809
2810  - flops: When running a graph, the bulk of the computation happens doing
2811    numerical calculations like matrix multiplications. This type allows a node
2812    to return how many floating-point operations it takes to complete. The
2813    total number of FLOPs for a graph is a good guide to its expected latency.
2814
2815  You can add your own statistics just by picking a new type string, registering
2816  functions for the ops you care about, and then calling get_stats_for_node_def.
2817
2818  If a statistic for an op is registered multiple times, a KeyError will be
2819  raised.
2820
2821  Since the statistics is counted on a per-op basis. It is not suitable for
2822  model parameters (capacity), which is expected to be counted only once, even
2823  if it is shared by multiple ops. (e.g. RNN)
2824
2825  For example, you can define a new metric called doohickey for a Foo operation
2826  by placing this in your code:
2827
2828  ```python
2829  @ops.RegisterStatistics("Foo", "doohickey")
2830  def _calc_foo_bojangles(unused_graph, unused_node_def):
2831    return ops.OpStats("doohickey", 20)
2832  ```
2833
2834  Then in client code you can retrieve the value by making this call:
2835
2836  ```python
2837  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2838  ```
2839
2840  If the NodeDef is for an op with a registered doohickey function, you'll get
2841  back the calculated amount in doohickey.value, or None if it's not defined.
2842
2843  """
2844
2845  __slots__ = ["_op_type", "_statistic_type"]
2846
2847  def __init__(self, op_type, statistic_type):
2848    """Saves the `op_type` as the `Operation` type."""
2849    if not isinstance(op_type, six.string_types):
2850      raise TypeError("op_type must be a string.")
2851    if "," in op_type:
2852      raise TypeError("op_type must not contain a comma.")
2853    self._op_type = op_type
2854    if not isinstance(statistic_type, six.string_types):
2855      raise TypeError("statistic_type must be a string.")
2856    if "," in statistic_type:
2857      raise TypeError("statistic_type must not contain a comma.")
2858    self._statistic_type = statistic_type
2859
2860  def __call__(self, f):
2861    """Registers "f" as the statistics function for "op_type"."""
2862    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2863    return f
2864
2865
2866def get_stats_for_node_def(graph, node, statistic_type):
2867  """Looks up the node's statistics function in the registry and calls it.
2868
2869  This function takes a Graph object and a NodeDef from a GraphDef, and if
2870  there's an associated statistics method, calls it and returns a result. If no
2871  function has been registered for the particular node type, it returns an empty
2872  statistics object.
2873
2874  Args:
2875    graph: A Graph object that's been set up with the node's graph.
2876    node: A NodeDef describing the operator.
2877    statistic_type: A string identifying the statistic we're interested in.
2878
2879  Returns:
2880    An OpStats object containing information about resource usage.
2881  """
2882
2883  try:
2884    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2885    result = stats_func(graph, node)
2886  except LookupError:
2887    result = OpStats(statistic_type)
2888  return result
2889
2890
2891def name_from_scope_name(name):
2892  """Returns the name of an op given the name of its scope.
2893
2894  Args:
2895    name: the name of the scope.
2896
2897  Returns:
2898    the name of the op (equal to scope name minus any trailing slash).
2899  """
2900  return name[:-1] if (name and name[-1] == "/") else name
2901
2902
2903_MUTATION_LOCK_GROUP = 0
2904_SESSION_RUN_LOCK_GROUP = 1
2905
2906
2907@tf_contextlib.contextmanager
2908def resource_creator_scope(resource_type, resource_creator):
2909  with get_default_graph()._resource_creator_scope(resource_type,  # pylint: disable=protected-access
2910                                                   resource_creator):
2911    yield
2912
2913
2914@tf_export("Graph")
2915class Graph(object):
2916  """A TensorFlow computation, represented as a dataflow graph.
2917
2918  Graphs are used by `tf.function`s to represent the function's computations.
2919  Each graph contains a set of `tf.Operation` objects, which represent units of
2920  computation; and `tf.Tensor` objects, which represent the units of data that
2921  flow between operations.
2922
2923  ### Using graphs directly (deprecated)
2924
2925  A `tf.Graph` can be constructed and used directly without a `tf.function`, as
2926  was required in TensorFlow 1, but this is deprecated and it is recommended to
2927  use a `tf.function` instead. If a graph is directly used, other deprecated
2928  TensorFlow 1 classes are also required to execute the graph, such as a
2929  `tf.compat.v1.Session`.
2930
2931  A default graph can be registered with the `tf.Graph.as_default` context
2932  manager. Then, operations will be added to the graph instead of being executed
2933  eagerly. For example:
2934
2935  ```python
2936  g = tf.Graph()
2937  with g.as_default():
2938    # Define operations and tensors in `g`.
2939    c = tf.constant(30.0)
2940    assert c.graph is g
2941  ```
2942
2943  `tf.compat.v1.get_default_graph()` can be used to obtain the default graph.
2944
2945  Important note: This class *is not* thread-safe for graph construction. All
2946  operations should be created from a single thread, or external
2947  synchronization must be provided. Unless otherwise specified, all methods
2948  are not thread-safe.
2949
2950  A `Graph` instance supports an arbitrary number of "collections"
2951  that are identified by name. For convenience when building a large
2952  graph, collections can store groups of related objects: for
2953  example, the `tf.Variable` uses a collection (named
2954  `tf.GraphKeys.GLOBAL_VARIABLES`) for
2955  all variables that are created during the construction of a graph. The caller
2956  may define additional collections by specifying a new name.
2957  """
2958
2959  def __init__(self):
2960    """Creates a new, empty Graph."""
2961    # Protects core state that can be returned via public accessors.
2962    # Thread-safety is provided on a best-effort basis to support buggy
2963    # programs, and is not guaranteed by the public `tf.Graph` API.
2964    #
2965    # NOTE(mrry): This does not protect the various stacks. A warning will
2966    # be reported if these are used from multiple threads
2967    self._lock = threading.RLock()
2968    # The group lock synchronizes Session.run calls with methods that create
2969    # and mutate ops (e.g. Graph.create_op()). This synchronization is
2970    # necessary because it's illegal to modify an operation after it's been run.
2971    # The group lock allows any number of threads to mutate ops at the same time
2972    # but if any modification is going on, all Session.run calls have to wait.
2973    # Similarly, if one or more Session.run calls are going on, all mutate ops
2974    # have to wait until all Session.run calls have finished.
2975    self._group_lock = lock_util.GroupLock(num_groups=2)
2976    self._nodes_by_id = {}  # GUARDED_BY(self._lock)
2977    self._next_id_counter = 0  # GUARDED_BY(self._lock)
2978    self._nodes_by_name = {}  # GUARDED_BY(self._lock)
2979    self._version = 0  # GUARDED_BY(self._lock)
2980    # Maps a name used in the graph to the next id to use for that name.
2981    self._names_in_use = {}
2982    self._stack_state_is_thread_local = False
2983    self._thread_local = threading.local()
2984    # Functions that will be applied to choose a device if none is specified.
2985    # In TF2.x or after switch_to_thread_local(),
2986    # self._thread_local._device_function_stack is used instead.
2987    self._graph_device_function_stack = traceable_stack.TraceableStack()
2988    # Default original_op applied to new ops.
2989    self._default_original_op = None
2990    # Current control flow context. It could be either CondContext or
2991    # WhileContext defined in ops/control_flow_ops.py
2992    self._control_flow_context = None
2993    # A new node will depend of the union of all of the nodes in the stack.
2994    # In TF2.x or after switch_to_thread_local(),
2995    # self._thread_local._control_dependencies_stack is used instead.
2996    self._graph_control_dependencies_stack = []
2997    # Arbitrary collections of objects.
2998    self._collections = {}
2999    # The graph-level random seed
3000    self._seed = None
3001    # A dictionary of attributes that should be applied to all ops.
3002    self._attr_scope_map = {}
3003    # A map from op type to the kernel label that should be used.
3004    self._op_to_kernel_label_map = {}
3005    # A map from op type to an alternative op type that should be used when
3006    # computing gradients.
3007    self._gradient_override_map = {}
3008    # A map from op type to a gradient function that should be used instead.
3009    self._gradient_function_map = {}
3010    # True if the graph is considered "finalized".  In that case no
3011    # new operations can be added.
3012    self._finalized = False
3013    # Functions defined in the graph
3014    self._functions = collections.OrderedDict()
3015    # Default GraphDef versions
3016    self._graph_def_versions = versions_pb2.VersionDef(
3017        producer=versions.GRAPH_DEF_VERSION,
3018        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
3019    self._building_function = False
3020    # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
3021    # self._thread_local._colocation_stack is used instead.
3022    self._graph_colocation_stack = traceable_stack.TraceableStack()
3023    # Set of tensors that are dangerous to feed!
3024    self._unfeedable_tensors = object_identity.ObjectIdentitySet()
3025    # Set of operations that are dangerous to fetch!
3026    self._unfetchable_ops = set()
3027    # A map of tensor handle placeholder to tensor dtype.
3028    self._handle_feeders = {}
3029    # A map from tensor handle to its read op.
3030    self._handle_readers = {}
3031    # A map from tensor handle to its move op.
3032    self._handle_movers = {}
3033    # A map from tensor handle to its delete op.
3034    self._handle_deleters = {}
3035    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
3036    # will be shared when defining function graphs, for example, so optimizers
3037    # being called inside function definitions behave as if they were seeing the
3038    # actual outside graph).
3039    self._graph_key = "grap-key-%d/" % (uid(),)
3040    # A string with the last reduction method passed to
3041    # losses.compute_weighted_loss(), or None. This is required only for
3042    # backward compatibility with Estimator and optimizer V1 use cases.
3043    self._last_loss_reduction = None
3044    # Flag that is used to indicate whether loss has been scaled by optimizer.
3045    # If this flag has been set, then estimator uses it to scale losss back
3046    # before reporting. This is required only for backward compatibility with
3047    # Estimator and optimizer V1 use cases.
3048    self._is_loss_scaled_by_optimizer = False
3049    self._container = ""
3050    # Set to True if this graph is being built in an
3051    # AutomaticControlDependencies context.
3052    self._add_control_dependencies = False
3053    # Cache for OpDef protobufs retrieved via the C API.
3054    self._op_def_cache = {}
3055    # Cache for constant results of `broadcast_gradient_args()`. The keys are
3056    # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
3057    # values are tuples of reduction indices: (rx, ry).
3058    self._bcast_grad_args_cache = {}
3059    # Cache for constant results of `reduced_shape()`. The keys are pairs of
3060    # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
3061    # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
3062    self._reduced_shape_cache = {}
3063
3064    # TODO(skyewm): fold as much of the above as possible into the C
3065    # implementation
3066    self._scoped_c_graph = c_api_util.ScopedTFGraph()
3067    # The C API requires all ops to have shape functions. Disable this
3068    # requirement (many custom ops do not have shape functions, and we don't
3069    # want to break these existing cases).
3070    pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False)
3071    if tf2.enabled():
3072      self.switch_to_thread_local()
3073
3074  # Note: this method is private because the API of tf.Graph() is public and
3075  # frozen, and this functionality is still not ready for public visibility.
3076  @tf_contextlib.contextmanager
3077  def _variable_creator_scope(self, creator, priority=100):
3078    """Scope which defines a variable creation function.
3079
3080    Args:
3081      creator: A callable taking `next_creator` and `kwargs`. See the
3082        `tf.variable_creator_scope` docstring.
3083      priority: Creators with a higher `priority` are called first. Within the
3084        same priority, creators are called inner-to-outer.
3085
3086    Yields:
3087      `_variable_creator_scope` is a context manager with a side effect, but
3088      doesn't return a value.
3089
3090    Raises:
3091      RuntimeError: If variable creator scopes are not properly nested.
3092    """
3093    # This step keeps a reference to the existing stack, and it also initializes
3094    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3095    old = self._variable_creator_stack
3096    new = list(old)
3097    new.append((priority, creator))
3098    # Sorting is stable, so we'll put higher-priority creators later in the list
3099    # but otherwise maintain registration order.
3100    new.sort(key=lambda item: item[0])
3101    self._thread_local._variable_creator_stack = new  # pylint: disable=protected-access
3102    try:
3103      yield
3104    finally:
3105      if self._thread_local._variable_creator_stack is not new:  # pylint: disable=protected-access
3106        raise RuntimeError(
3107            "Exiting variable_creator_scope without proper nesting.")
3108      self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
3109
3110  # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope.
3111  # pylint: disable=protected-access
3112  @tf_contextlib.contextmanager
3113  def _resource_creator_scope(self, resource_type, creator):
3114    """Scope which defines a resource creation function used by some resource.
3115
3116    The resource should be a subclass of CachableResource with a class method
3117    `cls._resource_type`, the output of which is what the `resource_type`
3118    argument should be. By default, `cls._resource_type` returns the class name,
3119    `cls.__name__`. Given a scope, creators being added with the same
3120    `resource_type` argument will be composed together to apply to all classes
3121    with this `_resource_type`.
3122
3123
3124    `creator` is expected to be a function with the following signature:
3125
3126    ```
3127      def resource_creator(next_creator, *a, **kwargs)
3128    ```
3129
3130    The creator is supposed to eventually call the next_creator to create an
3131    instance if it does want to create an instance and not call
3132    the class initialization method directly. This helps make creators
3133    composable. A creator may choose to create multiple instances, return
3134    already existing instances, or simply register that an instance was created
3135    and defer to the next creator in line. Creators can also modify keyword
3136    arguments seen by the next creators.
3137
3138    Valid keyword arguments in `kwargs` depends on the specific resource
3139    class. For StaticHashTable, this may be:
3140    * initializer: The table initializer to use.
3141    * default_value: The value to use if a key is missing in the table.
3142    * name: Optional name for the table, default to None.
3143
3144
3145    Args:
3146      resource_type: the output of the resource class's `_resource_type` method.
3147      creator: the passed creator for the resource.
3148
3149    Yields:
3150      A scope in which the creator is active
3151
3152    Raises:
3153      RuntimeError: If resource_creator_scope is existed without proper nesting.
3154    """
3155    # This step keeps a reference to the existing stack, and it also initializes
3156    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3157    old = self._resource_creator_stack
3158    new = copy.deepcopy(old)
3159    if isinstance(resource_type, (list, tuple)):
3160      for r in resource_type:
3161        new[r].append(creator)
3162    else:
3163      new[resource_type].append(creator)
3164    self._thread_local._resource_creator_stack = new
3165    try:
3166      yield
3167    finally:
3168      if self._thread_local._resource_creator_stack is not new:
3169        raise RuntimeError(
3170            "Exiting resource_creator_scope without proper nesting.")
3171      self._thread_local._resource_creator_stack = old
3172
3173  @property
3174  def _resource_creator_stack(self):
3175    if not hasattr(self._thread_local, "_resource_creator_stack"):
3176      self._thread_local._resource_creator_stack = collections.defaultdict(list)
3177    return self._thread_local._resource_creator_stack
3178
3179  @_resource_creator_stack.setter
3180  def _resource_creator_stack(self, resource_creator_stack):
3181    self._thread_local._resource_creator_stack = resource_creator_stack
3182  # pylint: enable=protected-access
3183
3184  # Note: this method is private because the API of tf.Graph() is public and
3185  # frozen, and this functionality is still not ready for public visibility.
3186  @property
3187  def _variable_creator_stack(self):
3188    if not hasattr(self._thread_local, "_variable_creator_stack"):
3189      self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
3190
3191    # This previously returned a copy of the stack instead of the stack itself,
3192    # to guard against accidental mutation. Consider, however, code that wants
3193    # to save and restore the variable creator stack:
3194    #     def f():
3195    #       original_stack = graph._variable_creator_stack
3196    #       graph._variable_creator_stack = new_stack
3197    #       ...  # Some code
3198    #       graph._variable_creator_stack = original_stack
3199    #
3200    # And lets say you have some code that calls this function with some
3201    # variable_creator:
3202    #     def g():
3203    #       with variable_scope.variable_creator_scope(creator):
3204    #         f()
3205    # When exiting the variable creator scope, it would see a different stack
3206    # object than it expected leading to a "Exiting variable_creator_scope
3207    # without proper nesting" error.
3208    return self._thread_local._variable_creator_stack  # pylint: disable=protected-access
3209
3210  @_variable_creator_stack.setter
3211  def _variable_creator_stack(self, variable_creator_stack):
3212    self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
3213
3214  def _check_not_finalized(self):
3215    """Check if the graph is finalized.
3216
3217    Raises:
3218      RuntimeError: If the graph finalized.
3219    """
3220    if self._finalized:
3221      raise RuntimeError("Graph is finalized and cannot be modified.")
3222
3223  def _add_op(self, op, op_name):
3224    """Adds 'op' to the graph and returns the unique ID for the added Operation.
3225
3226    Args:
3227      op: the Operation to add.
3228      op_name: the name of the Operation.
3229
3230    Returns:
3231      An integer that is a unique ID for the added Operation.
3232    """
3233    self._check_not_finalized()
3234    with self._lock:
3235      self._next_id_counter += 1
3236      op_id = self._next_id_counter
3237      self._nodes_by_id[op_id] = op
3238      self._nodes_by_name[op_name] = op
3239      self._version = max(self._version, op_id)
3240      return op_id
3241
3242  @property
3243  def _c_graph(self):
3244    return self._scoped_c_graph.graph
3245
3246  @property
3247  def version(self):
3248    """Returns a version number that increases as ops are added to the graph.
3249
3250    Note that this is unrelated to the
3251    `tf.Graph.graph_def_versions`.
3252
3253    Returns:
3254       An integer version that increases as ops are added to the graph.
3255    """
3256    if self._finalized:
3257      return self._version
3258
3259    with self._lock:
3260      return self._version
3261
3262  @property
3263  def graph_def_versions(self):
3264    # pylint: disable=line-too-long
3265    """The GraphDef version information of this graph.
3266
3267    For details on the meaning of each version, see
3268    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
3269
3270    Returns:
3271      A `VersionDef`.
3272    """
3273    # pylint: enable=line-too-long
3274    with c_api_util.tf_buffer() as buf:
3275      pywrap_tf_session.TF_GraphVersions(self._c_graph, buf)
3276      data = pywrap_tf_session.TF_GetBuffer(buf)
3277    version_def = versions_pb2.VersionDef()
3278    version_def.ParseFromString(compat.as_bytes(data))
3279    return version_def
3280
3281  @property
3282  def seed(self):
3283    """The graph-level random seed of this graph."""
3284    return self._seed
3285
3286  @seed.setter
3287  def seed(self, seed):
3288    self._seed = seed
3289
3290  @property
3291  def finalized(self):
3292    """True if this graph has been finalized."""
3293    return self._finalized
3294
3295  def finalize(self):
3296    """Finalizes this graph, making it read-only.
3297
3298    After calling `g.finalize()`, no new operations can be added to
3299    `g`.  This method is used to ensure that no operations are added
3300    to a graph when it is shared between multiple threads, for example
3301    when using a `tf.compat.v1.train.QueueRunner`.
3302    """
3303    self._finalized = True
3304
3305  def _unsafe_unfinalize(self):
3306    """Opposite of `finalize`.
3307
3308    Internal interface.
3309
3310    NOTE: Unfinalizing a graph could have negative impact on performance,
3311    especially in a multi-threaded environment.  Unfinalizing a graph
3312    when it is in use by a Session may lead to undefined behavior. Ensure
3313    that all sessions using a graph are closed before calling this method.
3314    """
3315    self._finalized = False
3316
3317  def _get_control_flow_context(self):
3318    """Returns the current control flow context.
3319
3320    Returns:
3321      A context object.
3322    """
3323    return self._control_flow_context
3324
3325  def _set_control_flow_context(self, ctx):
3326    """Sets the current control flow context.
3327
3328    Args:
3329      ctx: a context object.
3330    """
3331    self._control_flow_context = ctx
3332
3333  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3334    """If this graph contains functions, copy them to `graph_def`."""
3335    bytesize = starting_bytesize
3336    for f in self._functions.values():
3337      bytesize += f.definition.ByteSize()
3338      if bytesize >= (1 << 31) or bytesize < 0:
3339        raise ValueError("GraphDef cannot be larger than 2GB.")
3340      graph_def.library.function.extend([f.definition])
3341      if f.grad_func_name:
3342        grad_def = function_pb2.GradientDef()
3343        grad_def.function_name = f.name
3344        grad_def.gradient_func = f.grad_func_name
3345        graph_def.library.gradient.extend([grad_def])
3346
3347  def _as_graph_def(self, from_version=None, add_shapes=False):
3348    # pylint: disable=line-too-long
3349    """Returns a serialized `GraphDef` representation of this graph.
3350
3351    The serialized `GraphDef` can be imported into another `Graph`
3352    (using `tf.import_graph_def`) or used with the
3353    [C++ Session API](../../../../api_docs/cc/index.md).
3354
3355    This method is thread-safe.
3356
3357    Args:
3358      from_version: Optional.  If this is set, returns a `GraphDef` containing
3359        only the nodes that were added to this graph since its `version`
3360        property had the given value.
3361      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3362        the inferred shapes of each of its outputs.
3363
3364    Returns:
3365      A tuple containing a
3366      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3367      protocol buffer, and the version of the graph to which that
3368      `GraphDef` corresponds.
3369
3370    Raises:
3371      ValueError: If the `graph_def` would be too large.
3372
3373    """
3374    # pylint: enable=line-too-long
3375    with self._lock:
3376      with c_api_util.tf_buffer() as buf:
3377        pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf)
3378        data = pywrap_tf_session.TF_GetBuffer(buf)
3379      graph = graph_pb2.GraphDef()
3380      graph.ParseFromString(compat.as_bytes(data))
3381      # Strip the experimental library field iff it's empty.
3382      if not graph.library.function:
3383        graph.ClearField("library")
3384
3385      if add_shapes:
3386        for node in graph.node:
3387          op = self._nodes_by_name[node.name]
3388          if op.outputs:
3389            node.attr["_output_shapes"].list.shape.extend(
3390                [output.get_shape().as_proto() for output in op.outputs])
3391        for function_def in graph.library.function:
3392          defined_function = self._functions[function_def.signature.name]
3393          try:
3394            func_graph = defined_function.graph
3395          except AttributeError:
3396            # _DefinedFunction doesn't have a graph, _EagerDefinedFunction
3397            # does. Both rely on ops.py, so we can't really isinstance check
3398            # them.
3399            continue
3400          input_shapes = function_def.attr["_input_shapes"]
3401          try:
3402            func_graph_inputs = func_graph.inputs
3403          except AttributeError:
3404            continue
3405          # TODO(b/141471245): Fix the inconsistency when inputs of func graph
3406          # are appended during gradient computation of while/cond.
3407          assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)]
3408          # If the function_def has inputs already filled out, skip this step.
3409          if not input_shapes.list.shape:
3410            for input_tensor, arg_def in zip(func_graph_inputs,
3411                                             function_def.signature.input_arg):
3412              input_shapes.list.shape.add().CopyFrom(
3413                  input_tensor.get_shape().as_proto())
3414              if input_tensor.dtype == dtypes.resource:
3415                _copy_handle_data_to_arg_def(input_tensor, arg_def)
3416
3417          for output_tensor, arg_def in zip(func_graph.outputs,
3418                                            function_def.signature.output_arg):
3419            if output_tensor.dtype == dtypes.resource:
3420              _copy_handle_data_to_arg_def(output_tensor, arg_def)
3421
3422          for node in function_def.node_def:
3423            try:
3424              op = func_graph.get_operation_by_name(node.name)
3425            except KeyError:
3426              continue
3427            outputs = op.outputs
3428
3429            if op.type == "StatefulPartitionedCall":
3430              # Filter out any extra outputs (possibly added by function
3431              # backpropagation rewriting).
3432              num_outputs = len(node.attr["Tout"].list.type)
3433              outputs = outputs[:num_outputs]
3434
3435            node.attr["_output_shapes"].list.shape.extend(
3436                [output.get_shape().as_proto() for output in outputs])
3437
3438    return graph, self._version
3439
3440  def as_graph_def(self, from_version=None, add_shapes=False):
3441    # pylint: disable=line-too-long
3442    """Returns a serialized `GraphDef` representation of this graph.
3443
3444    The serialized `GraphDef` can be imported into another `Graph`
3445    (using `tf.import_graph_def`) or used with the
3446    [C++ Session API](../../api_docs/cc/index.md).
3447
3448    This method is thread-safe.
3449
3450    Args:
3451      from_version: Optional.  If this is set, returns a `GraphDef` containing
3452        only the nodes that were added to this graph since its `version`
3453        property had the given value.
3454      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3455        the inferred shapes of each of its outputs.
3456
3457    Returns:
3458      A
3459      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3460      protocol buffer.
3461
3462    Raises:
3463      ValueError: If the `graph_def` would be too large.
3464    """
3465    # pylint: enable=line-too-long
3466    result, _ = self._as_graph_def(from_version, add_shapes)
3467    return result
3468
3469  def _is_function(self, name):
3470    """Tests whether 'name' is registered in this graph's function library.
3471
3472    Args:
3473      name: string op name.
3474
3475    Returns:
3476      bool indicating whether or not 'name' is registered in function library.
3477    """
3478    return compat.as_str(name) in self._functions
3479
3480  def _get_function(self, name):
3481    """Returns the function definition for 'name'.
3482
3483    Args:
3484      name: string function name.
3485
3486    Returns:
3487      The function def proto.
3488    """
3489    return self._functions.get(compat.as_str(name), None)
3490
3491  def _add_function(self, function):
3492    """Adds a function to the graph.
3493
3494    After the function has been added, you can call to the function by
3495    passing the function name in place of an op name to
3496    `Graph.create_op()`.
3497
3498    Args:
3499      function: A `_DefinedFunction` object.
3500
3501    Raises:
3502      ValueError: if another function is defined with the same name.
3503    """
3504    self._check_not_finalized()
3505
3506    name = function.name
3507    # Sanity checks on gradient definition.
3508    if (function.grad_func_name is not None) and (function.python_grad_func is
3509                                                  not None):
3510      raise ValueError("Gradient defined twice for function %s" % name)
3511
3512    # Add function to graph
3513    # pylint: disable=protected-access
3514    gradient = (
3515        function._grad_func._c_func.func if function._grad_func else None)
3516    pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
3517                                           gradient)
3518    # pylint: enable=protected-access
3519
3520    self._functions[compat.as_str(name)] = function
3521
3522    # Need a new-enough consumer to support the functions we add to the graph.
3523    if self._graph_def_versions.min_consumer < 12:
3524      self._graph_def_versions.min_consumer = 12
3525
3526  @property
3527  def building_function(self):
3528    """Returns True iff this graph represents a function."""
3529    return self._building_function
3530
3531  # Helper functions to create operations.
3532  @deprecated_args(None,
3533                   "Shapes are always computed; don't use the compute_shapes "
3534                   "as it has no effect.", "compute_shapes")
3535  @traceback_utils.filter_traceback
3536  def create_op(
3537      self,
3538      op_type,
3539      inputs,
3540      dtypes=None,  # pylint: disable=redefined-outer-name
3541      input_types=None,
3542      name=None,
3543      attrs=None,
3544      op_def=None,
3545      compute_shapes=True,
3546      compute_device=True):
3547    """Creates an `Operation` in this graph.
3548
3549    This is a low-level interface for creating an `Operation`. Most
3550    programs will not call this method directly, and instead use the
3551    Python op constructors, such as `tf.constant()`, which add ops to
3552    the default graph.
3553
3554    Args:
3555      op_type: The `Operation` type to create. This corresponds to the
3556        `OpDef.name` field for the proto that defines the operation.
3557      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3558      dtypes: (Optional) A list of `DType` objects that will be the types of the
3559        tensors that the operation produces.
3560      input_types: (Optional.) A list of `DType`s that will be the types of the
3561        tensors that the operation consumes. By default, uses the base `DType`
3562        of each input in `inputs`. Operations that expect reference-typed inputs
3563        must specify `input_types` explicitly.
3564      name: (Optional.) A string name for the operation. If not specified, a
3565        name is generated based on `op_type`.
3566      attrs: (Optional.) A dictionary where the key is the attribute name (a
3567        string) and the value is the respective `attr` attribute of the
3568        `NodeDef` proto that will represent the operation (an `AttrValue`
3569        proto).
3570      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3571        the operation will have.
3572      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3573        computed).
3574      compute_device: (Optional.) If True, device functions will be executed to
3575        compute the device property of the Operation.
3576
3577    Raises:
3578      TypeError: if any of the inputs is not a `Tensor`.
3579      ValueError: if colocation conflicts with existing device assignment.
3580
3581    Returns:
3582      An `Operation` object.
3583    """
3584    del compute_shapes
3585    for idx, a in enumerate(inputs):
3586      if not isinstance(a, Tensor):
3587        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3588    return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
3589                                    attrs, op_def, compute_device)
3590
3591  def _create_op_internal(
3592      self,
3593      op_type,
3594      inputs,
3595      dtypes=None,  # pylint: disable=redefined-outer-name
3596      input_types=None,
3597      name=None,
3598      attrs=None,
3599      op_def=None,
3600      compute_device=True):
3601    """Creates an `Operation` in this graph.
3602
3603    Implements `Graph.create_op()` without the overhead of the deprecation
3604    wrapper.
3605
3606    Args:
3607      op_type: The `Operation` type to create. This corresponds to the
3608        `OpDef.name` field for the proto that defines the operation.
3609      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3610      dtypes: (Optional) A list of `DType` objects that will be the types of the
3611        tensors that the operation produces.
3612      input_types: (Optional.) A list of `DType`s that will be the types of the
3613        tensors that the operation consumes. By default, uses the base `DType`
3614        of each input in `inputs`. Operations that expect reference-typed inputs
3615        must specify `input_types` explicitly.
3616      name: (Optional.) A string name for the operation. If not specified, a
3617        name is generated based on `op_type`.
3618      attrs: (Optional.) A dictionary where the key is the attribute name (a
3619        string) and the value is the respective `attr` attribute of the
3620        `NodeDef` proto that will represent the operation (an `AttrValue`
3621        proto).
3622      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3623        the operation will have.
3624      compute_device: (Optional.) If True, device functions will be executed to
3625        compute the device property of the Operation.
3626
3627    Raises:
3628      ValueError: if colocation conflicts with existing device assignment.
3629
3630    Returns:
3631      An `Operation` object.
3632    """
3633    self._check_not_finalized()
3634    if name is None:
3635      name = op_type
3636    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3637    # after removing the trailing '/'.
3638    if name and name[-1] == "/":
3639      name = name_from_scope_name(name)
3640    else:
3641      name = self.unique_name(name)
3642
3643    node_def = _NodeDef(op_type, name, attrs)
3644
3645    input_ops = set(t.op for t in inputs)
3646    control_inputs = self._control_dependencies_for_inputs(input_ops)
3647    # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3648    # Session.run call cannot occur between creating and mutating the op.
3649    with self._mutation_lock():
3650      ret = Operation(
3651          node_def,
3652          self,
3653          inputs=inputs,
3654          output_types=dtypes,
3655          control_inputs=control_inputs,
3656          input_types=input_types,
3657          original_op=self._default_original_op,
3658          op_def=op_def)
3659      self._create_op_helper(ret, compute_device=compute_device)
3660    return ret
3661
3662  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3663    """Creates an `Operation` in this graph from the supplied TF_Operation.
3664
3665    This method is like create_op() except the new Operation is constructed
3666    using `c_op`. The returned Operation will have `c_op` as its _c_op
3667    field. This is used to create Operation objects around TF_Operations created
3668    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3669
3670    This function does not call Operation._control_flow_post_processing or
3671    Graph._control_dependencies_for_inputs (since the inputs may not be
3672    available yet). The caller is responsible for calling these methods.
3673
3674    Args:
3675      c_op: a wrapped TF_Operation
3676      compute_device: (Optional.) If True, device functions will be executed to
3677        compute the device property of the Operation.
3678
3679    Returns:
3680      An `Operation` object.
3681    """
3682    self._check_not_finalized()
3683    ret = Operation(c_op, self)
3684    # If a name_scope was created with ret.name but no nodes were created in it,
3685    # the name will still appear in _names_in_use even though the name hasn't
3686    # been used. This is ok, just leave _names_in_use as-is in this case.
3687    # TODO(skyewm): make the C API guarantee no name conflicts.
3688    name_key = ret.name.lower()
3689    if name_key not in self._names_in_use:
3690      self._names_in_use[name_key] = 1
3691    self._create_op_helper(ret, compute_device=compute_device)
3692    return ret
3693
3694  def _create_op_helper(self, op, compute_device=True):
3695    """Common logic for creating an op in this graph."""
3696    # Apply any additional attributes requested. Do not overwrite any existing
3697    # attributes.
3698    for key, value in self._attr_scope_map.items():
3699      try:
3700        op.get_attr(key)
3701      except ValueError:
3702        if callable(value):
3703          value = value(op.node_def)
3704          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3705            raise TypeError(
3706                "Callable for scope map key '%s' must return either None or "
3707                "an AttrValue protocol buffer; but it returned: %s" %
3708                (key, value))
3709        if value:
3710          op._set_attr(key, value)  # pylint: disable=protected-access
3711
3712    # Apply a kernel label if one has been specified for this op type.
3713    try:
3714      kernel_label = self._op_to_kernel_label_map[op.type]
3715      op._set_attr("_kernel",  # pylint: disable=protected-access
3716                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3717    except KeyError:
3718      pass
3719
3720    op._gradient_function = self._gradient_function_map.get(op.type)  # pylint: disable=protected-access
3721
3722    # Apply the overriding op type for gradients if one has been specified for
3723    # this op type.
3724    try:
3725      mapped_op_type = self._gradient_override_map[op.type]
3726      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3727                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3728    except KeyError:
3729      pass
3730
3731    self._record_op_seen_by_control_dependencies(op)
3732
3733    if compute_device:
3734      self._apply_device_functions(op)
3735
3736    # Snapshot the colocation stack metadata before we might generate error
3737    # messages using it.  Note that this snapshot depends on the actual stack
3738    # and is independent of the op's _class attribute.
3739    # pylint: disable=protected-access
3740    op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3741    # pylint: enable=protected-access
3742
3743    if self._colocation_stack:
3744      all_colocation_groups = []
3745      is_device_set = False
3746      for colocation_op in self._colocation_stack.peek_objs():
3747        try:
3748          all_colocation_groups.extend(colocation_op.colocation_groups())
3749        except AttributeError:
3750          pass
3751        if colocation_op.device and not is_device_set:
3752          # pylint: disable=protected-access
3753          op._set_device(colocation_op.device)
3754          # pylint: enable=protected-access
3755          is_device_set = True
3756
3757      all_colocation_groups = sorted(set(all_colocation_groups))
3758      # pylint: disable=protected-access
3759      op._set_attr(
3760          "_class",
3761          attr_value_pb2.AttrValue(
3762              list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3763      # pylint: enable=protected-access
3764
3765    # Sets "container" attribute if
3766    # (1) self._container is not None
3767    # (2) "is_stateful" is set in OpDef
3768    # (3) "container" attribute is in OpDef
3769    # (4) "container" attribute is None
3770    if self._container and op._is_stateful:  # pylint: disable=protected-access
3771      try:
3772        container_attr = op.get_attr("container")
3773      except ValueError:
3774        # "container" attribute is not in OpDef
3775        pass
3776      else:
3777        if not container_attr:
3778          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3779              s=compat.as_bytes(self._container)))
3780
3781  def _add_new_tf_operations(self, compute_devices=True):
3782    """Creates `Operations` in this graph for any new TF_Operations.
3783
3784    This is useful for when TF_Operations are indirectly created by the C API
3785    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3786    TF_FinishWhile). This ensures there are corresponding Operations for all
3787    TF_Operations in the underlying TF_Graph.
3788
3789    Args:
3790      compute_devices: (Optional.) If True, device functions will be executed to
3791        compute the device properties of each new Operation.
3792
3793    Returns:
3794      A list of the new `Operation` objects.
3795    """
3796    self._check_not_finalized()
3797
3798    # Create all Operation objects before accessing their inputs since an op may
3799    # be created before its inputs.
3800    new_ops = [
3801        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3802        for c_op in c_api_util.new_tf_operations(self)
3803    ]
3804
3805    # pylint: disable=protected-access
3806    for op in new_ops:
3807      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3808      op._add_control_inputs(new_control_inputs)
3809      op._control_flow_post_processing()
3810    # pylint: enable=protected-access
3811
3812    return new_ops
3813
3814  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3815    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3816
3817    This function validates that `obj` represents an element of this
3818    graph, and gives an informative error message if it is not.
3819
3820    This function is the canonical way to get/validate an object of
3821    one of the allowed types from an external argument reference in the
3822    Session API.
3823
3824    This method may be called concurrently from multiple threads.
3825
3826    Args:
3827      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can
3828        also be any object with an `_as_graph_element()` method that returns a
3829        value of one of these types. Note: `_as_graph_element` will be called
3830        inside the graph's lock and so may not modify the graph.
3831      allow_tensor: If true, `obj` may refer to a `Tensor`.
3832      allow_operation: If true, `obj` may refer to an `Operation`.
3833
3834    Returns:
3835      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3836
3837    Raises:
3838      TypeError: If `obj` is not a type we support attempting to convert
3839        to types.
3840      ValueError: If `obj` is of an appropriate type but invalid. For
3841        example, an invalid string.
3842      KeyError: If `obj` is not an object in the graph.
3843    """
3844    if self._finalized:
3845      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3846
3847    with self._lock:
3848      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3849
3850  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3851    """See `Graph.as_graph_element()` for details."""
3852    # The vast majority of this function is figuring
3853    # out what an API user might be doing wrong, so
3854    # that we can give helpful error messages.
3855    #
3856    # Ideally, it would be nice to split it up, but we
3857    # need context to generate nice error messages.
3858
3859    if allow_tensor and allow_operation:
3860      types_str = "Tensor or Operation"
3861    elif allow_tensor:
3862      types_str = "Tensor"
3863    elif allow_operation:
3864      types_str = "Operation"
3865    else:
3866      raise ValueError("allow_tensor and allow_operation can't both be False.")
3867
3868    temp_obj = _as_graph_element(obj)
3869    if temp_obj is not None:
3870      obj = temp_obj
3871
3872    # If obj appears to be a name...
3873    if isinstance(obj, compat.bytes_or_text_types):
3874      name = compat.as_str(obj)
3875
3876      if ":" in name and allow_tensor:
3877        # Looks like a Tensor name and can be a Tensor.
3878        try:
3879          op_name, out_n = name.split(":")
3880          out_n = int(out_n)
3881        except:
3882          raise ValueError("The name %s looks a like a Tensor name, but is "
3883                           "not a valid one. Tensor names must be of the "
3884                           "form \"<op_name>:<output_index>\"." % repr(name))
3885        if op_name in self._nodes_by_name:
3886          op = self._nodes_by_name[op_name]
3887        else:
3888          raise KeyError("The name %s refers to a Tensor which does not "
3889                         "exist. The operation, %s, does not exist in the "
3890                         "graph." % (repr(name), repr(op_name)))
3891        try:
3892          return op.outputs[out_n]
3893        except:
3894          raise KeyError("The name %s refers to a Tensor which does not "
3895                         "exist. The operation, %s, exists but only has "
3896                         "%s outputs." %
3897                         (repr(name), repr(op_name), len(op.outputs)))
3898
3899      elif ":" in name and not allow_tensor:
3900        # Looks like a Tensor name but can't be a Tensor.
3901        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3902                         (repr(name), types_str))
3903
3904      elif ":" not in name and allow_operation:
3905        # Looks like an Operation name and can be an Operation.
3906        if name not in self._nodes_by_name:
3907          raise KeyError("The name %s refers to an Operation not in the "
3908                         "graph." % repr(name))
3909        return self._nodes_by_name[name]
3910
3911      elif ":" not in name and not allow_operation:
3912        # Looks like an Operation name but can't be an Operation.
3913        if name in self._nodes_by_name:
3914          # Yep, it's an Operation name
3915          err_msg = ("The name %s refers to an Operation, not a %s." %
3916                     (repr(name), types_str))
3917        else:
3918          err_msg = ("The name %s looks like an (invalid) Operation name, "
3919                     "not a %s." % (repr(name), types_str))
3920        err_msg += (" Tensor names must be of the form "
3921                    "\"<op_name>:<output_index>\".")
3922        raise ValueError(err_msg)
3923
3924    elif isinstance(obj, Tensor) and allow_tensor:
3925      # Actually obj is just the object it's referring to.
3926      if obj.graph is not self:
3927        raise ValueError("Tensor %s is not an element of this graph." % obj)
3928      return obj
3929    elif isinstance(obj, Operation) and allow_operation:
3930      # Actually obj is just the object it's referring to.
3931      if obj.graph is not self:
3932        raise ValueError("Operation %s is not an element of this graph." % obj)
3933      return obj
3934    else:
3935      # We give up!
3936      raise TypeError("Can not convert a %s into a %s." %
3937                      (type(obj).__name__, types_str))
3938
3939  def get_operations(self):
3940    """Return the list of operations in the graph.
3941
3942    You can modify the operations in place, but modifications
3943    to the list such as inserts/delete have no effect on the
3944    list of operations known to the graph.
3945
3946    This method may be called concurrently from multiple threads.
3947
3948    Returns:
3949      A list of Operations.
3950    """
3951    if self._finalized:
3952      return list(self._nodes_by_id.values())
3953
3954    with self._lock:
3955      return list(self._nodes_by_id.values())
3956
3957  def get_operation_by_name(self, name):
3958    """Returns the `Operation` with the given `name`.
3959
3960    This method may be called concurrently from multiple threads.
3961
3962    Args:
3963      name: The name of the `Operation` to return.
3964
3965    Returns:
3966      The `Operation` with the given `name`.
3967
3968    Raises:
3969      TypeError: If `name` is not a string.
3970      KeyError: If `name` does not correspond to an operation in this graph.
3971    """
3972
3973    if not isinstance(name, six.string_types):
3974      raise TypeError("Operation names are strings (or similar), not %s." %
3975                      type(name).__name__)
3976    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3977
3978  def _get_operation_by_name_unsafe(self, name):
3979    """Returns the `Operation` with the given `name`.
3980
3981    This is a internal unsafe version of get_operation_by_name. It skips many
3982    checks and does not have user friendly error messages but runs considerably
3983    faster. This method may be called concurrently from multiple threads.
3984
3985    Args:
3986      name: The name of the `Operation` to return.
3987
3988    Returns:
3989      The `Operation` with the given `name`.
3990
3991    Raises:
3992      KeyError: If `name` does not correspond to an operation in this graph.
3993    """
3994
3995    if self._finalized:
3996      return self._nodes_by_name[name]
3997
3998    with self._lock:
3999      return self._nodes_by_name[name]
4000
4001  def _get_operation_by_tf_operation(self, tf_oper):
4002    op_name = pywrap_tf_session.TF_OperationName(tf_oper)
4003    return self._get_operation_by_name_unsafe(op_name)
4004
4005  def get_tensor_by_name(self, name):
4006    """Returns the `Tensor` with the given `name`.
4007
4008    This method may be called concurrently from multiple threads.
4009
4010    Args:
4011      name: The name of the `Tensor` to return.
4012
4013    Returns:
4014      The `Tensor` with the given `name`.
4015
4016    Raises:
4017      TypeError: If `name` is not a string.
4018      KeyError: If `name` does not correspond to a tensor in this graph.
4019    """
4020    # Names should be strings.
4021    if not isinstance(name, six.string_types):
4022      raise TypeError("Tensor names are strings (or similar), not %s." %
4023                      type(name).__name__)
4024    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
4025
4026  def _get_tensor_by_tf_output(self, tf_output):
4027    """Returns the `Tensor` representing `tf_output`.
4028
4029    Note that there is only one such `Tensor`, i.e. multiple calls to this
4030    function with the same TF_Output value will always return the same `Tensor`
4031    object.
4032
4033    Args:
4034      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
4035
4036    Returns:
4037      The `Tensor` that represents `tf_output`.
4038    """
4039    op = self._get_operation_by_tf_operation(tf_output.oper)
4040    return op.outputs[tf_output.index]
4041
4042  @property
4043  def _last_id(self):
4044    return self._next_id_counter
4045
4046  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
4047    """Returns the `OpDef` proto for `type`. `type` is a string."""
4048    # NOTE: No locking is required because the lookup and insertion operations
4049    # on Python dictionaries are atomic.
4050    try:
4051      return self._op_def_cache[type]
4052    except KeyError:
4053      with c_api_util.tf_buffer() as buf:
4054        # pylint: disable=protected-access
4055        pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
4056                                           buf)
4057        # pylint: enable=protected-access
4058        data = pywrap_tf_session.TF_GetBuffer(buf)
4059      op_def = op_def_pb2.OpDef()
4060      op_def.ParseFromString(compat.as_bytes(data))
4061      self._op_def_cache[type] = op_def
4062      return op_def
4063
4064  def as_default(self):
4065    """Returns a context manager that makes this `Graph` the default graph.
4066
4067    This method should be used if you want to create multiple graphs
4068    in the same process. For convenience, a global default graph is
4069    provided, and all ops will be added to this graph if you do not
4070    create a new graph explicitly.
4071
4072    Use this method with the `with` keyword to specify that ops created within
4073    the scope of a block should be added to this graph. In this case, once
4074    the scope of the `with` is exited, the previous default graph is set again
4075    as default. There is a stack, so it's ok to have multiple nested levels
4076    of `as_default` calls.
4077
4078    The default graph is a property of the current thread. If you
4079    create a new thread, and wish to use the default graph in that
4080    thread, you must explicitly add a `with g.as_default():` in that
4081    thread's function.
4082
4083    The following code examples are equivalent:
4084
4085    ```python
4086    # 1. Using Graph.as_default():
4087    g = tf.Graph()
4088    with g.as_default():
4089      c = tf.constant(5.0)
4090      assert c.graph is g
4091
4092    # 2. Constructing and making default:
4093    with tf.Graph().as_default() as g:
4094      c = tf.constant(5.0)
4095      assert c.graph is g
4096    ```
4097
4098    If eager execution is enabled ops created under this context manager will be
4099    added to the graph instead of executed eagerly.
4100
4101    Returns:
4102      A context manager for using this graph as the default graph.
4103    """
4104    return _default_graph_stack.get_controller(self)
4105
4106  @property
4107  def collections(self):
4108    """Returns the names of the collections known to this graph."""
4109    return list(self._collections)
4110
4111  def add_to_collection(self, name, value):
4112    """Stores `value` in the collection with the given `name`.
4113
4114    Note that collections are not sets, so it is possible to add a value to
4115    a collection several times.
4116
4117    Args:
4118      name: The key for the collection. The `GraphKeys` class contains many
4119        standard names for collections.
4120      value: The value to add to the collection.
4121    """  # pylint: disable=g-doc-exception
4122    self._check_not_finalized()
4123    with self._lock:
4124      if name not in self._collections:
4125        self._collections[name] = [value]
4126      else:
4127        self._collections[name].append(value)
4128
4129  def add_to_collections(self, names, value):
4130    """Stores `value` in the collections given by `names`.
4131
4132    Note that collections are not sets, so it is possible to add a value to
4133    a collection several times. This function makes sure that duplicates in
4134    `names` are ignored, but it will not check for pre-existing membership of
4135    `value` in any of the collections in `names`.
4136
4137    `names` can be any iterable, but if `names` is a string, it is treated as a
4138    single collection name.
4139
4140    Args:
4141      names: The keys for the collections to add to. The `GraphKeys` class
4142        contains many standard names for collections.
4143      value: The value to add to the collections.
4144    """
4145    # Make sure names are unique, but treat strings as a single collection name
4146    names = (names,) if isinstance(names, six.string_types) else set(names)
4147    for name in names:
4148      self.add_to_collection(name, value)
4149
4150  def get_collection_ref(self, name):
4151    """Returns a list of values in the collection with the given `name`.
4152
4153    If the collection exists, this returns the list itself, which can
4154    be modified in place to change the collection.  If the collection does
4155    not exist, it is created as an empty list and the list is returned.
4156
4157    This is different from `get_collection()` which always returns a copy of
4158    the collection list if it exists and never creates an empty collection.
4159
4160    Args:
4161      name: The key for the collection. For example, the `GraphKeys` class
4162        contains many standard names for collections.
4163
4164    Returns:
4165      The list of values in the collection with the given `name`, or an empty
4166      list if no value has been added to that collection.
4167    """  # pylint: disable=g-doc-exception
4168    with self._lock:
4169      coll_list = self._collections.get(name, None)
4170      if coll_list is None:
4171        coll_list = []
4172        self._collections[name] = coll_list
4173      return coll_list
4174
4175  def get_collection(self, name, scope=None):
4176    """Returns a list of values in the collection with the given `name`.
4177
4178    This is different from `get_collection_ref()` which always returns the
4179    actual collection list if it exists in that it returns a new list each time
4180    it is called.
4181
4182    Args:
4183      name: The key for the collection. For example, the `GraphKeys` class
4184        contains many standard names for collections.
4185      scope: (Optional.) A string. If supplied, the resulting list is filtered
4186        to include only items whose `name` attribute matches `scope` using
4187        `re.match`. Items without a `name` attribute are never returned if a
4188        scope is supplied. The choice of `re.match` means that a `scope` without
4189        special tokens filters by prefix.
4190
4191    Returns:
4192      The list of values in the collection with the given `name`, or
4193      an empty list if no value has been added to that collection. The
4194      list contains the values in the order under which they were
4195      collected.
4196    """  # pylint: disable=g-doc-exception
4197    with self._lock:
4198      collection = self._collections.get(name, None)
4199      if collection is None:
4200        return []
4201      if scope is None:
4202        return list(collection)
4203      else:
4204        c = []
4205        regex = re.compile(scope)
4206        for item in collection:
4207          try:
4208            if regex.match(item.name):
4209              c.append(item)
4210          except AttributeError:
4211            # Collection items with no name are ignored.
4212            pass
4213        return c
4214
4215  def get_all_collection_keys(self):
4216    """Returns a list of collections used in this graph."""
4217    with self._lock:
4218      return [x for x in self._collections if isinstance(x, six.string_types)]
4219
4220  def clear_collection(self, name):
4221    """Clears all values in a collection.
4222
4223    Args:
4224      name: The key for the collection. The `GraphKeys` class contains many
4225        standard names for collections.
4226    """
4227    self._check_not_finalized()
4228    with self._lock:
4229      if name in self._collections:
4230        del self._collections[name]
4231
4232  @tf_contextlib.contextmanager
4233  def _original_op(self, op):
4234    """Python 'with' handler to help annotate ops with their originator.
4235
4236    An op may have an 'original_op' property that indicates the op on which
4237    it was based. For example a replica op is based on the op that was
4238    replicated and a gradient op is based on the op that was differentiated.
4239
4240    All ops created in the scope of this 'with' handler will have
4241    the given 'op' as their original op.
4242
4243    Args:
4244      op: The Operation that all ops created in this scope will have as their
4245        original op.
4246
4247    Yields:
4248      Nothing.
4249    """
4250    old_original_op = self._default_original_op
4251    self._default_original_op = op
4252    try:
4253      yield
4254    finally:
4255      self._default_original_op = old_original_op
4256
4257  @property
4258  def _name_stack(self):
4259    # This may be called from a thread where name_stack doesn't yet exist.
4260    if not hasattr(self._thread_local, "_name_stack"):
4261      self._thread_local._name_stack = ""
4262    return self._thread_local._name_stack
4263
4264  @_name_stack.setter
4265  def _name_stack(self, name_stack):
4266    self._thread_local._name_stack = name_stack
4267
4268  # pylint: disable=g-doc-return-or-yield,line-too-long
4269  @tf_contextlib.contextmanager
4270  def name_scope(self, name):
4271    """Returns a context manager that creates hierarchical names for operations.
4272
4273    A graph maintains a stack of name scopes. A `with name_scope(...):`
4274    statement pushes a new name onto the stack for the lifetime of the context.
4275
4276    The `name` argument will be interpreted as follows:
4277
4278    * A string (not ending with '/') will create a new name scope, in which
4279      `name` is appended to the prefix of all operations created in the
4280      context. If `name` has been used before, it will be made unique by
4281      calling `self.unique_name(name)`.
4282    * A scope previously captured from a `with g.name_scope(...) as
4283      scope:` statement will be treated as an "absolute" name scope, which
4284      makes it possible to re-enter existing scopes.
4285    * A value of `None` or the empty string will reset the current name scope
4286      to the top-level (empty) name scope.
4287
4288    For example:
4289
4290    ```python
4291    with tf.Graph().as_default() as g:
4292      c = tf.constant(5.0, name="c")
4293      assert c.op.name == "c"
4294      c_1 = tf.constant(6.0, name="c")
4295      assert c_1.op.name == "c_1"
4296
4297      # Creates a scope called "nested"
4298      with g.name_scope("nested") as scope:
4299        nested_c = tf.constant(10.0, name="c")
4300        assert nested_c.op.name == "nested/c"
4301
4302        # Creates a nested scope called "inner".
4303        with g.name_scope("inner"):
4304          nested_inner_c = tf.constant(20.0, name="c")
4305          assert nested_inner_c.op.name == "nested/inner/c"
4306
4307        # Create a nested scope called "inner_1".
4308        with g.name_scope("inner"):
4309          nested_inner_1_c = tf.constant(30.0, name="c")
4310          assert nested_inner_1_c.op.name == "nested/inner_1/c"
4311
4312          # Treats `scope` as an absolute name scope, and
4313          # switches to the "nested/" scope.
4314          with g.name_scope(scope):
4315            nested_d = tf.constant(40.0, name="d")
4316            assert nested_d.op.name == "nested/d"
4317
4318            with g.name_scope(""):
4319              e = tf.constant(50.0, name="e")
4320              assert e.op.name == "e"
4321    ```
4322
4323    The name of the scope itself can be captured by `with
4324    g.name_scope(...) as scope:`, which stores the name of the scope
4325    in the variable `scope`. This value can be used to name an
4326    operation that represents the overall result of executing the ops
4327    in a scope. For example:
4328
4329    ```python
4330    inputs = tf.constant(...)
4331    with g.name_scope('my_layer') as scope:
4332      weights = tf.Variable(..., name="weights")
4333      biases = tf.Variable(..., name="biases")
4334      affine = tf.matmul(inputs, weights) + biases
4335      output = tf.nn.relu(affine, name=scope)
4336    ```
4337
4338    NOTE: This constructor validates the given `name`. Valid scope
4339    names match one of the following regular expressions:
4340
4341        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4342        [A-Za-z0-9_.\\-/]* (for other scopes)
4343
4344    Args:
4345      name: A name for the scope.
4346
4347    Returns:
4348      A context manager that installs `name` as a new name scope.
4349
4350    Raises:
4351      ValueError: If `name` is not a valid scope name, according to the rules
4352        above.
4353    """
4354    if name:
4355      if isinstance(name, compat.bytes_or_text_types):
4356        name = compat.as_str(name)
4357
4358      if self._name_stack:
4359        # Scopes created in a nested scope may have initial characters
4360        # that are illegal as the initial character of an op name
4361        # (viz. '-', '\', '/', and '_').
4362        if not _VALID_SCOPE_NAME_REGEX.match(name):
4363          raise ValueError("'%s' is not a valid scope name" % name)
4364      else:
4365        # Scopes created in the root must match the more restrictive
4366        # op name regex, which constrains the initial character.
4367        if not _VALID_OP_NAME_REGEX.match(name):
4368          raise ValueError("'%s' is not a valid scope name" % name)
4369    old_stack = self._name_stack
4370    if not name:  # Both for name=None and name="" we re-set to empty scope.
4371      new_stack = ""
4372      returned_scope = ""
4373    elif name[-1] == "/":
4374      new_stack = name_from_scope_name(name)
4375      returned_scope = name
4376    else:
4377      new_stack = self.unique_name(name)
4378      returned_scope = new_stack + "/"
4379    self._name_stack = new_stack
4380    try:
4381      yield returned_scope
4382    finally:
4383      self._name_stack = old_stack
4384
4385  # pylint: enable=g-doc-return-or-yield,line-too-long
4386
4387  def unique_name(self, name, mark_as_used=True):
4388    """Return a unique operation name for `name`.
4389
4390    Note: You rarely need to call `unique_name()` directly.  Most of
4391    the time you just need to create `with g.name_scope()` blocks to
4392    generate structured names.
4393
4394    `unique_name` is used to generate structured names, separated by
4395    `"/"`, to help identify operations when debugging a graph.
4396    Operation names are displayed in error messages reported by the
4397    TensorFlow runtime, and in various visualization tools such as
4398    TensorBoard.
4399
4400    If `mark_as_used` is set to `True`, which is the default, a new
4401    unique name is created and marked as in use. If it's set to `False`,
4402    the unique name is returned without actually being marked as used.
4403    This is useful when the caller simply wants to know what the name
4404    to be created will be.
4405
4406    Args:
4407      name: The name for an operation.
4408      mark_as_used: Whether to mark this name as being used.
4409
4410    Returns:
4411      A string to be passed to `create_op()` that will be used
4412      to name the operation being created.
4413    """
4414    if self._name_stack:
4415      name = self._name_stack + "/" + name
4416
4417    # For the sake of checking for names in use, we treat names as case
4418    # insensitive (e.g. foo = Foo).
4419    name_key = name.lower()
4420    i = self._names_in_use.get(name_key, 0)
4421    # Increment the number for "name_key".
4422    if mark_as_used:
4423      self._names_in_use[name_key] = i + 1
4424    if i > 0:
4425      base_name_key = name_key
4426      # Make sure the composed name key is not already used.
4427      while name_key in self._names_in_use:
4428        name_key = "%s_%d" % (base_name_key, i)
4429        i += 1
4430      # Mark the composed name_key as used in case someone wants
4431      # to call unique_name("name_1").
4432      if mark_as_used:
4433        self._names_in_use[name_key] = 1
4434
4435      # Return the new name with the original capitalization of the given name.
4436      name = "%s_%d" % (name, i - 1)
4437    return name
4438
4439  def get_name_scope(self):
4440    """Returns the current name scope.
4441
4442    For example:
4443
4444    ```python
4445    with tf.name_scope('scope1'):
4446      with tf.name_scope('scope2'):
4447        print(tf.compat.v1.get_default_graph().get_name_scope())
4448    ```
4449    would print the string `scope1/scope2`.
4450
4451    Returns:
4452      A string representing the current name scope.
4453    """
4454    return self._name_stack
4455
4456  @tf_contextlib.contextmanager
4457  def _colocate_with_for_gradient(self, op, gradient_uid,
4458                                  ignore_existing=False):
4459    with self.colocate_with(op, ignore_existing):
4460      if gradient_uid is not None:
4461        ctx = _get_enclosing_context(self)
4462        if ctx is not None:
4463          ctx.EnterGradientColocation(op, gradient_uid)
4464          try:
4465            yield
4466          finally:
4467            ctx.ExitGradientColocation(op, gradient_uid)
4468        else:
4469          yield
4470      else:
4471        yield
4472
4473  @tf_contextlib.contextmanager
4474  def colocate_with(self, op, ignore_existing=False):
4475    """Returns a context manager that specifies an op to colocate with.
4476
4477    Note: this function is not for public use, only for internal libraries.
4478
4479    For example:
4480
4481    ```python
4482    a = tf.Variable([1.0])
4483    with g.colocate_with(a):
4484      b = tf.constant(1.0)
4485      c = tf.add(a, b)
4486    ```
4487
4488    `b` and `c` will always be colocated with `a`, no matter where `a`
4489    is eventually placed.
4490
4491    **NOTE** Using a colocation scope resets any existing device constraints.
4492
4493    If `op` is `None` then `ignore_existing` must be `True` and the new
4494    scope resets all colocation and device constraints.
4495
4496    Args:
4497      op: The op to colocate all created ops with, or `None`.
4498      ignore_existing: If true, only applies colocation of this op within the
4499        context, rather than applying all colocation properties on the stack.
4500        If `op` is `None`, this value must be `True`.
4501
4502    Raises:
4503      ValueError: if op is None but ignore_existing is False.
4504
4505    Yields:
4506      A context manager that specifies the op with which to colocate
4507      newly created ops.
4508    """
4509    if op is None and not ignore_existing:
4510      raise ValueError("Trying to reset colocation (op is None) but "
4511                       "ignore_existing is not True")
4512    op, device_only_candidate = _op_to_colocate_with(op, self)
4513
4514    # By default, colocate_with resets the device function stack,
4515    # since colocate_with is typically used in specific internal
4516    # library functions where colocation is intended to be "stronger"
4517    # than device functions.
4518    #
4519    # In the future, a caller may specify that device_functions win
4520    # over colocation, in which case we can add support.
4521    device_fn_tmp = self._device_function_stack
4522    self._device_function_stack = traceable_stack.TraceableStack()
4523
4524    if ignore_existing:
4525      current_stack = self._colocation_stack
4526      self._colocation_stack = traceable_stack.TraceableStack()
4527
4528    if op is not None:
4529      # offset refers to the stack frame used for storing code location.
4530      # We use 4, the sum of 1 to use our caller's stack frame and 3
4531      # to jump over layers of context managers above us.
4532      if device_only_candidate is not None:
4533        self._colocation_stack.push_obj(device_only_candidate, offset=4)
4534      self._colocation_stack.push_obj(op, offset=4)
4535    elif not ignore_existing:
4536      raise ValueError("Trying to reset colocation (op is None) but "
4537                       "ignore_existing is not True")
4538    try:
4539      yield
4540    finally:
4541      # Restore device function stack
4542      self._device_function_stack = device_fn_tmp
4543      if op is not None:
4544        self._colocation_stack.pop_obj()
4545        if device_only_candidate is not None:
4546          self._colocation_stack.pop_obj()
4547
4548      # Reset the colocation stack if requested.
4549      if ignore_existing:
4550        self._colocation_stack = current_stack
4551
4552  def _add_device_to_stack(self, device_name_or_function, offset=0):
4553    """Add device to stack manually, separate from a context manager."""
4554    total_offset = 1 + offset
4555    spec = _UserDeviceSpec(device_name_or_function)
4556    self._device_function_stack.push_obj(spec, offset=total_offset)
4557    return spec
4558
4559  @tf_contextlib.contextmanager
4560  def device(self, device_name_or_function):
4561    # pylint: disable=line-too-long
4562    """Returns a context manager that specifies the default device to use.
4563
4564    The `device_name_or_function` argument may either be a device name
4565    string, a device function, or None:
4566
4567    * If it is a device name string, all operations constructed in
4568      this context will be assigned to the device with that name, unless
4569      overridden by a nested `device()` context.
4570    * If it is a function, it will be treated as a function from
4571      Operation objects to device name strings, and invoked each time
4572      a new Operation is created. The Operation will be assigned to
4573      the device with the returned name.
4574    * If it is None, all `device()` invocations from the enclosing context
4575      will be ignored.
4576
4577    For information about the valid syntax of device name strings, see
4578    the documentation in
4579    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4580
4581    For example:
4582
4583    ```python
4584    with g.device('/device:GPU:0'):
4585      # All operations constructed in this context will be placed
4586      # on GPU 0.
4587      with g.device(None):
4588        # All operations constructed in this context will have no
4589        # assigned device.
4590
4591    # Defines a function from `Operation` to device string.
4592    def matmul_on_gpu(n):
4593      if n.type == "MatMul":
4594        return "/device:GPU:0"
4595      else:
4596        return "/cpu:0"
4597
4598    with g.device(matmul_on_gpu):
4599      # All operations of type "MatMul" constructed in this context
4600      # will be placed on GPU 0; all other operations will be placed
4601      # on CPU 0.
4602    ```
4603
4604    **N.B.** The device scope may be overridden by op wrappers or
4605    other library code. For example, a variable assignment op
4606    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4607    incompatible device scopes will be ignored.
4608
4609    Args:
4610      device_name_or_function: The device name or function to use in the
4611        context.
4612
4613    Yields:
4614      A context manager that specifies the default device to use for newly
4615      created ops.
4616
4617    Raises:
4618      RuntimeError: If device scopes are not properly nested.
4619    """
4620    self._add_device_to_stack(device_name_or_function, offset=2)
4621    old_top_of_stack = self._device_function_stack.peek_top_obj()
4622    try:
4623      yield
4624    finally:
4625      new_top_of_stack = self._device_function_stack.peek_top_obj()
4626      if old_top_of_stack is not new_top_of_stack:
4627        raise RuntimeError("Exiting device scope without proper scope nesting.")
4628      self._device_function_stack.pop_obj()
4629
4630  def _apply_device_functions(self, op):
4631    """Applies the current device function stack to the given operation."""
4632    # Apply any device functions in LIFO order, so that the most recently
4633    # pushed function has the first chance to apply a device to the op.
4634    # We apply here because the result can depend on the Operation's
4635    # signature, which is computed in the Operation constructor.
4636    # pylint: disable=protected-access
4637    prior_device_string = None
4638    for device_spec in self._device_function_stack.peek_objs():
4639      if device_spec.is_null_merge:
4640        continue
4641
4642      if device_spec.function is None:
4643        break
4644
4645      device_string = device_spec.string_merge(op)
4646
4647      # Take advantage of the fact that None is a singleton and Python interns
4648      # strings, since identity checks are faster than equality checks.
4649      if device_string is not prior_device_string:
4650        op._set_device_from_string(device_string)
4651        prior_device_string = device_string
4652    op._device_code_locations = self._snapshot_device_function_stack_metadata()
4653    # pylint: enable=protected-access
4654
4655  # pylint: disable=g-doc-return-or-yield
4656  @tf_contextlib.contextmanager
4657  def container(self, container_name):
4658    """Returns a context manager that specifies the resource container to use.
4659
4660    Stateful operations, such as variables and queues, can maintain their
4661    states on devices so that they can be shared by multiple processes.
4662    A resource container is a string name under which these stateful
4663    operations are tracked. These resources can be released or cleared
4664    with `tf.Session.reset()`.
4665
4666    For example:
4667
4668    ```python
4669    with g.container('experiment0'):
4670      # All stateful Operations constructed in this context will be placed
4671      # in resource container "experiment0".
4672      v1 = tf.Variable([1.0])
4673      v2 = tf.Variable([2.0])
4674      with g.container("experiment1"):
4675        # All stateful Operations constructed in this context will be
4676        # placed in resource container "experiment1".
4677        v3 = tf.Variable([3.0])
4678        q1 = tf.queue.FIFOQueue(10, tf.float32)
4679      # All stateful Operations constructed in this context will be
4680      # be created in the "experiment0".
4681      v4 = tf.Variable([4.0])
4682      q1 = tf.queue.FIFOQueue(20, tf.float32)
4683      with g.container(""):
4684        # All stateful Operations constructed in this context will be
4685        # be placed in the default resource container.
4686        v5 = tf.Variable([5.0])
4687        q3 = tf.queue.FIFOQueue(30, tf.float32)
4688
4689    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4690    # will become undefined (such as uninitialized).
4691    tf.Session.reset(target, ["experiment0"])
4692    ```
4693
4694    Args:
4695      container_name: container name string.
4696
4697    Returns:
4698      A context manager for defining resource containers for stateful ops,
4699        yields the container name.
4700    """
4701    original_container = self._container
4702    self._container = container_name
4703    try:
4704      yield self._container
4705    finally:
4706      self._container = original_container
4707
4708  # pylint: enable=g-doc-return-or-yield
4709
4710  class _ControlDependenciesController(object):
4711    """Context manager for `control_dependencies()`."""
4712
4713    def __init__(self, graph, control_inputs):
4714      """Create a new `_ControlDependenciesController`.
4715
4716      A `_ControlDependenciesController` is the context manager for
4717      `with tf.control_dependencies()` blocks.  These normally nest,
4718      as described in the documentation for `control_dependencies()`.
4719
4720      The `control_inputs` argument list control dependencies that must be
4721      added to the current set of control dependencies.  Because of
4722      uniquification the set can be empty even if the caller passed a list of
4723      ops.  The special value `None` indicates that we want to start a new
4724      empty set of control dependencies instead of extending the current set.
4725
4726      In that case we also clear the current control flow context, which is an
4727      additional mechanism to add control dependencies.
4728
4729      Args:
4730        graph: The graph that this controller is managing.
4731        control_inputs: List of ops to use as control inputs in addition to the
4732          current control dependencies.  None to indicate that the dependencies
4733          should be cleared.
4734      """
4735      self._graph = graph
4736      if control_inputs is None:
4737        self._control_inputs_val = []
4738        self._new_stack = True
4739      else:
4740        self._control_inputs_val = control_inputs
4741        self._new_stack = False
4742      self._seen_nodes = set()
4743      self._old_stack = None
4744      self._old_control_flow_context = None
4745
4746# pylint: disable=protected-access
4747
4748    def __enter__(self):
4749      if self._new_stack:
4750        # Clear the control_dependencies graph.
4751        self._old_stack = self._graph._control_dependencies_stack
4752        self._graph._control_dependencies_stack = []
4753        # Clear the control_flow_context too.
4754        self._old_control_flow_context = self._graph._get_control_flow_context()
4755        self._graph._set_control_flow_context(None)
4756      self._graph._push_control_dependencies_controller(self)
4757
4758    def __exit__(self, unused_type, unused_value, unused_traceback):
4759      self._graph._pop_control_dependencies_controller(self)
4760      if self._new_stack:
4761        self._graph._control_dependencies_stack = self._old_stack
4762        self._graph._set_control_flow_context(self._old_control_flow_context)
4763
4764# pylint: enable=protected-access
4765
4766    @property
4767    def control_inputs(self):
4768      return self._control_inputs_val
4769
4770    def add_op(self, op):
4771      if isinstance(op, Tensor):
4772        op = op.ref()
4773      self._seen_nodes.add(op)
4774
4775    def op_in_group(self, op):
4776      if isinstance(op, Tensor):
4777        op = op.ref()
4778      return op in self._seen_nodes
4779
4780  def _push_control_dependencies_controller(self, controller):
4781    self._control_dependencies_stack.append(controller)
4782
4783  def _pop_control_dependencies_controller(self, controller):
4784    assert self._control_dependencies_stack[-1] is controller
4785    self._control_dependencies_stack.pop()
4786
4787  def _current_control_dependencies(self):
4788    ret = set()
4789    for controller in self._control_dependencies_stack:
4790      for op in controller.control_inputs:
4791        ret.add(op)
4792    return ret
4793
4794  def _control_dependencies_for_inputs(self, input_ops):
4795    """For an op that takes `input_ops` as inputs, compute control inputs.
4796
4797    The returned control dependencies should yield an execution that
4798    is equivalent to adding all control inputs in
4799    self._control_dependencies_stack to a newly created op. However,
4800    this function attempts to prune the returned control dependencies
4801    by observing that nodes created within the same `with
4802    control_dependencies(...):` block may have data dependencies that make
4803    the explicit approach redundant.
4804
4805    Args:
4806      input_ops: The data input ops for an op to be created.
4807
4808    Returns:
4809      A list of control inputs for the op to be created.
4810    """
4811    ret = []
4812    for controller in self._control_dependencies_stack:
4813      # If any of the input_ops already depends on the inputs from controller,
4814      # we say that the new op is dominated (by that input), and we therefore
4815      # do not need to add control dependencies for this controller's inputs.
4816      dominated = False
4817      for op in input_ops:
4818        if controller.op_in_group(op):
4819          dominated = True
4820          break
4821      if not dominated:
4822        # Don't add a control input if we already have a data dependency on i.
4823        # NOTE(mrry): We do not currently track transitive data dependencies,
4824        #   so we may add redundant control inputs.
4825        ret.extend(c for c in controller.control_inputs if c not in input_ops)
4826    return ret
4827
4828  def _record_op_seen_by_control_dependencies(self, op):
4829    """Record that the given op depends on all registered control dependencies.
4830
4831    Args:
4832      op: An Operation.
4833    """
4834    for controller in self._control_dependencies_stack:
4835      controller.add_op(op)
4836
4837  def control_dependencies(self, control_inputs):
4838    """Returns a context manager that specifies control dependencies.
4839
4840    Use with the `with` keyword to specify that all operations constructed
4841    within the context should have control dependencies on
4842    `control_inputs`. For example:
4843
4844    ```python
4845    with g.control_dependencies([a, b, c]):
4846      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4847      d = ...
4848      e = ...
4849    ```
4850
4851    Multiple calls to `control_dependencies()` can be nested, and in
4852    that case a new `Operation` will have control dependencies on the union
4853    of `control_inputs` from all active contexts.
4854
4855    ```python
4856    with g.control_dependencies([a, b]):
4857      # Ops constructed here run after `a` and `b`.
4858      with g.control_dependencies([c, d]):
4859        # Ops constructed here run after `a`, `b`, `c`, and `d`.
4860    ```
4861
4862    You can pass None to clear the control dependencies:
4863
4864    ```python
4865    with g.control_dependencies([a, b]):
4866      # Ops constructed here run after `a` and `b`.
4867      with g.control_dependencies(None):
4868        # Ops constructed here run normally, not waiting for either `a` or `b`.
4869        with g.control_dependencies([c, d]):
4870          # Ops constructed here run after `c` and `d`, also not waiting
4871          # for either `a` or `b`.
4872    ```
4873
4874    *N.B.* The control dependencies context applies *only* to ops that
4875    are constructed within the context. Merely using an op or tensor
4876    in the context does not add a control dependency. The following
4877    example illustrates this point:
4878
4879    ```python
4880    # WRONG
4881    def my_func(pred, tensor):
4882      t = tf.matmul(tensor, tensor)
4883      with tf.control_dependencies([pred]):
4884        # The matmul op is created outside the context, so no control
4885        # dependency will be added.
4886        return t
4887
4888    # RIGHT
4889    def my_func(pred, tensor):
4890      with tf.control_dependencies([pred]):
4891        # The matmul op is created in the context, so a control dependency
4892        # will be added.
4893        return tf.matmul(tensor, tensor)
4894    ```
4895
4896    Also note that though execution of ops created under this scope will trigger
4897    execution of the dependencies, the ops created under this scope might still
4898    be pruned from a normal tensorflow graph. For example, in the following
4899    snippet of code the dependencies are never executed:
4900
4901    ```python
4902      loss = model.loss()
4903      with tf.control_dependencies(dependencies):
4904        loss = loss + tf.constant(1)  # note: dependencies ignored in the
4905                                      # backward pass
4906      return tf.gradients(loss, model.variables)
4907    ```
4908
4909    This is because evaluating the gradient graph does not require evaluating
4910    the constant(1) op created in the forward pass.
4911
4912    Args:
4913      control_inputs: A list of `Operation` or `Tensor` objects which must be
4914        executed or computed before running the operations defined in the
4915        context.  Can also be `None` to clear the control dependencies.
4916
4917    Returns:
4918     A context manager that specifies control dependencies for all
4919     operations constructed within the context.
4920
4921    Raises:
4922      TypeError: If `control_inputs` is not a list of `Operation` or
4923        `Tensor` objects.
4924    """
4925    if control_inputs is None:
4926      return self._ControlDependenciesController(self, None)
4927    # First convert the inputs to ops, and deduplicate them.
4928    # NOTE(mrry): Other than deduplication, we do not currently track direct
4929    #   or indirect dependencies between control_inputs, which may result in
4930    #   redundant control inputs.
4931    control_ops = []
4932    current = self._current_control_dependencies()
4933    for c in control_inputs:
4934      # The hasattr(handle) is designed to match ResourceVariables. This is so
4935      # control dependencies on a variable or on an unread variable don't
4936      # trigger reads.
4937      if (isinstance(c, IndexedSlices) or
4938          (hasattr(c, "_handle") and hasattr(c, "op"))):
4939        c = c.op
4940      c = self.as_graph_element(c)
4941      if isinstance(c, Tensor):
4942        c = c.op
4943      elif not isinstance(c, Operation):
4944        raise TypeError("Control input must be Operation or Tensor: %s" % c)
4945      if c not in current:
4946        control_ops.append(c)
4947        current.add(c)
4948    return self._ControlDependenciesController(self, control_ops)
4949
4950  # pylint: disable=g-doc-return-or-yield
4951  @tf_contextlib.contextmanager
4952  def _attr_scope(self, attr_map):
4953    """EXPERIMENTAL: A context manager for setting attributes on operators.
4954
4955    This context manager can be used to add additional
4956    attributes to operators within the scope of the context.
4957
4958    For example:
4959
4960       with ops.Graph().as_default() as g:
4961         f_1 = Foo()  # No extra attributes
4962         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4963           f_2 = Foo()  # Additional attribute _a=False
4964           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4965             f_3 = Foo()  # Additional attribute _a=False
4966             with g._attr_scope({"_a": None}):
4967               f_4 = Foo()  # No additional attributes.
4968
4969    Args:
4970      attr_map: A dictionary mapping attr name strings to AttrValue protocol
4971        buffers or None.
4972
4973    Returns:
4974      A context manager that sets the kernel label to be used for one or more
4975      ops created in that context.
4976
4977    Raises:
4978      TypeError: If attr_map is not a dictionary mapping
4979        strings to AttrValue protobufs.
4980    """
4981    if not isinstance(attr_map, dict):
4982      raise TypeError("attr_map must be a dictionary mapping "
4983                      "strings to AttrValue protocol buffers")
4984    # The saved_attrs dictionary stores any currently-set labels that
4985    # will be overridden by this context manager.
4986    saved_attrs = {}
4987    # Install the given attribute
4988    for name, attr in attr_map.items():
4989      if not (isinstance(name, six.string_types) and
4990              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4991               callable(attr))):
4992        raise TypeError("attr_map must be a dictionary mapping "
4993                        "strings to AttrValue protocol buffers or "
4994                        "callables that emit AttrValue protocol buffers")
4995      try:
4996        saved_attrs[name] = self._attr_scope_map[name]
4997      except KeyError:
4998        pass
4999      if attr is None:
5000        del self._attr_scope_map[name]
5001      else:
5002        self._attr_scope_map[name] = attr
5003    try:
5004      yield  # The code within the context runs here.
5005    finally:
5006      # Remove the attributes set for this context, and restore any saved
5007      # attributes.
5008      for name, attr in attr_map.items():
5009        try:
5010          self._attr_scope_map[name] = saved_attrs[name]
5011        except KeyError:
5012          del self._attr_scope_map[name]
5013
5014  # pylint: enable=g-doc-return-or-yield
5015
5016  # pylint: disable=g-doc-return-or-yield
5017  @tf_contextlib.contextmanager
5018  def _kernel_label_map(self, op_to_kernel_label_map):
5019    """EXPERIMENTAL: A context manager for setting kernel labels.
5020
5021    This context manager can be used to select particular
5022    implementations of kernels within the scope of the context.
5023
5024    For example:
5025
5026        with ops.Graph().as_default() as g:
5027          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
5028          with g.kernel_label_map({"Foo": "v_2"}):
5029            f_2 = Foo()  # Uses the registered kernel with label "v_2"
5030                         # for the Foo op.
5031            with g.kernel_label_map({"Foo": "v_3"}):
5032              f_3 = Foo()  # Uses the registered kernel with label "v_3"
5033                           # for the Foo op.
5034              with g.kernel_label_map({"Foo": ""}):
5035                f_4 = Foo()  # Uses the default registered kernel
5036                             # for the Foo op.
5037
5038    Args:
5039      op_to_kernel_label_map: A dictionary mapping op type strings to kernel
5040        label strings.
5041
5042    Returns:
5043      A context manager that sets the kernel label to be used for one or more
5044      ops created in that context.
5045
5046    Raises:
5047      TypeError: If op_to_kernel_label_map is not a dictionary mapping
5048        strings to strings.
5049    """
5050    if not isinstance(op_to_kernel_label_map, dict):
5051      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
5052                      "strings to strings")
5053    # The saved_labels dictionary stores any currently-set labels that
5054    # will be overridden by this context manager.
5055    saved_labels = {}
5056    # Install the given label
5057    for op_type, label in op_to_kernel_label_map.items():
5058      if not (isinstance(op_type, six.string_types) and
5059              isinstance(label, six.string_types)):
5060        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
5061                        "strings to strings")
5062      try:
5063        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
5064      except KeyError:
5065        pass
5066      self._op_to_kernel_label_map[op_type] = label
5067    try:
5068      yield  # The code within the context runs here.
5069    finally:
5070      # Remove the labels set for this context, and restore any saved labels.
5071      for op_type, label in op_to_kernel_label_map.items():
5072        try:
5073          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
5074        except KeyError:
5075          del self._op_to_kernel_label_map[op_type]
5076
5077  # pylint: enable=g-doc-return-or-yield
5078
5079  @tf_contextlib.contextmanager
5080  def _override_gradient_function(self, gradient_function_map):
5081    """Specify gradient function for the given op type."""
5082
5083    # This is an internal API and we don't need nested context for this.
5084    # TODO(mdan): make it a proper context manager.
5085    assert not self._gradient_function_map
5086    self._gradient_function_map = gradient_function_map
5087    try:
5088      yield
5089    finally:
5090      self._gradient_function_map = {}
5091
5092  # pylint: disable=g-doc-return-or-yield
5093  @tf_contextlib.contextmanager
5094  def gradient_override_map(self, op_type_map):
5095    """EXPERIMENTAL: A context manager for overriding gradient functions.
5096
5097    This context manager can be used to override the gradient function
5098    that will be used for ops within the scope of the context.
5099
5100    For example:
5101
5102    ```python
5103    @tf.RegisterGradient("CustomSquare")
5104    def _custom_square_grad(op, grad):
5105      # ...
5106
5107    with tf.Graph().as_default() as g:
5108      c = tf.constant(5.0)
5109      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
5110      with g.gradient_override_map({"Square": "CustomSquare"}):
5111        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
5112                              # gradient of s_2.
5113    ```
5114
5115    Args:
5116      op_type_map: A dictionary mapping op type strings to alternative op type
5117        strings.
5118
5119    Returns:
5120      A context manager that sets the alternative op type to be used for one
5121      or more ops created in that context.
5122
5123    Raises:
5124      TypeError: If `op_type_map` is not a dictionary mapping strings to
5125        strings.
5126    """
5127    if not isinstance(op_type_map, dict):
5128      raise TypeError("op_type_map must be a dictionary mapping "
5129                      "strings to strings")
5130    # The saved_mappings dictionary stores any currently-set mappings that
5131    # will be overridden by this context manager.
5132    saved_mappings = {}
5133    # Install the given label
5134    for op_type, mapped_op_type in op_type_map.items():
5135      if not (isinstance(op_type, six.string_types) and
5136              isinstance(mapped_op_type, six.string_types)):
5137        raise TypeError("op_type_map must be a dictionary mapping "
5138                        "strings to strings")
5139      try:
5140        saved_mappings[op_type] = self._gradient_override_map[op_type]
5141      except KeyError:
5142        pass
5143      self._gradient_override_map[op_type] = mapped_op_type
5144    try:
5145      yield  # The code within the context runs here.
5146    finally:
5147      # Remove the labels set for this context, and restore any saved labels.
5148      for op_type, mapped_op_type in op_type_map.items():
5149        try:
5150          self._gradient_override_map[op_type] = saved_mappings[op_type]
5151        except KeyError:
5152          del self._gradient_override_map[op_type]
5153
5154  # pylint: enable=g-doc-return-or-yield
5155
5156  def prevent_feeding(self, tensor):
5157    """Marks the given `tensor` as unfeedable in this graph."""
5158    self._unfeedable_tensors.add(tensor)
5159
5160  def is_feedable(self, tensor):
5161    """Returns `True` if and only if `tensor` is feedable."""
5162    return tensor not in self._unfeedable_tensors
5163
5164  def prevent_fetching(self, op):
5165    """Marks the given `op` as unfetchable in this graph."""
5166    self._unfetchable_ops.add(op)
5167
5168  def is_fetchable(self, tensor_or_op):
5169    """Returns `True` if and only if `tensor_or_op` is fetchable."""
5170    if isinstance(tensor_or_op, Tensor):
5171      return tensor_or_op.op not in self._unfetchable_ops
5172    else:
5173      return tensor_or_op not in self._unfetchable_ops
5174
5175  def switch_to_thread_local(self):
5176    """Make device, colocation and dependencies stacks thread-local.
5177
5178    Device, colocation and dependencies stacks are not thread-local be default.
5179    If multiple threads access them, then the state is shared.  This means that
5180    one thread may affect the behavior of another thread.
5181
5182    After this method is called, the stacks become thread-local.  If multiple
5183    threads access them, then the state is not shared.  Each thread uses its own
5184    value; a thread doesn't affect other threads by mutating such a stack.
5185
5186    The initial value for every thread's stack is set to the current value
5187    of the stack when `switch_to_thread_local()` was first called.
5188    """
5189    if not self._stack_state_is_thread_local:
5190      self._stack_state_is_thread_local = True
5191
5192  @property
5193  def _device_function_stack(self):
5194    if self._stack_state_is_thread_local:
5195      # This may be called from a thread where device_function_stack doesn't yet
5196      # exist.
5197      # pylint: disable=protected-access
5198      if not hasattr(self._thread_local, "_device_function_stack"):
5199        stack_copy_for_this_thread = self._graph_device_function_stack.copy()
5200        self._thread_local._device_function_stack = stack_copy_for_this_thread
5201      return self._thread_local._device_function_stack
5202      # pylint: enable=protected-access
5203    else:
5204      return self._graph_device_function_stack
5205
5206  @property
5207  def _device_functions_outer_to_inner(self):
5208    user_device_specs = self._device_function_stack.peek_objs()
5209    device_functions = [spec.function for spec in user_device_specs]
5210    device_functions_outer_to_inner = list(reversed(device_functions))
5211    return device_functions_outer_to_inner
5212
5213  def _snapshot_device_function_stack_metadata(self):
5214    """Return device function stack as a list of TraceableObjects.
5215
5216    Returns:
5217      [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
5218      member is a displayable name for the user's argument to Graph.device, and
5219      the filename and lineno members point to the code location where
5220      Graph.device was called directly or indirectly by the user.
5221    """
5222    snapshot = []
5223    for obj in self._device_function_stack.peek_traceable_objs():
5224      obj_copy = obj.copy_metadata()
5225      obj_copy.obj = obj.obj.display_name
5226      snapshot.append(obj_copy)
5227    return snapshot
5228
5229  @_device_function_stack.setter
5230  def _device_function_stack(self, device_function_stack):
5231    if self._stack_state_is_thread_local:
5232      # pylint: disable=protected-access
5233      self._thread_local._device_function_stack = device_function_stack
5234      # pylint: enable=protected-access
5235    else:
5236      self._graph_device_function_stack = device_function_stack
5237
5238  @property
5239  def _colocation_stack(self):
5240    """Return thread-local copy of colocation stack."""
5241    if self._stack_state_is_thread_local:
5242      # This may be called from a thread where colocation_stack doesn't yet
5243      # exist.
5244      # pylint: disable=protected-access
5245      if not hasattr(self._thread_local, "_colocation_stack"):
5246        stack_copy_for_this_thread = self._graph_colocation_stack.copy()
5247        self._thread_local._colocation_stack = stack_copy_for_this_thread
5248      return self._thread_local._colocation_stack
5249      # pylint: enable=protected-access
5250    else:
5251      return self._graph_colocation_stack
5252
5253  def _snapshot_colocation_stack_metadata(self):
5254    """Return colocation stack metadata as a dictionary."""
5255    return {
5256        traceable_obj.obj.name: traceable_obj.copy_metadata()
5257        for traceable_obj in self._colocation_stack.peek_traceable_objs()
5258    }
5259
5260  @_colocation_stack.setter
5261  def _colocation_stack(self, colocation_stack):
5262    if self._stack_state_is_thread_local:
5263      # pylint: disable=protected-access
5264      self._thread_local._colocation_stack = colocation_stack
5265      # pylint: enable=protected-access
5266    else:
5267      self._graph_colocation_stack = colocation_stack
5268
5269  @property
5270  def _control_dependencies_stack(self):
5271    if self._stack_state_is_thread_local:
5272      # This may be called from a thread where control_dependencies_stack
5273      # doesn't yet exist.
5274      if not hasattr(self._thread_local, "_control_dependencies_stack"):
5275        self._thread_local._control_dependencies_stack = (
5276            self._graph_control_dependencies_stack[:])
5277      return self._thread_local._control_dependencies_stack
5278    else:
5279      return self._graph_control_dependencies_stack
5280
5281  @_control_dependencies_stack.setter
5282  def _control_dependencies_stack(self, control_dependencies):
5283    if self._stack_state_is_thread_local:
5284      self._thread_local._control_dependencies_stack = control_dependencies
5285    else:
5286      self._graph_control_dependencies_stack = control_dependencies
5287
5288  @property
5289  def _distribution_strategy_stack(self):
5290    """A stack to maintain distribution strategy context for each thread."""
5291    if not hasattr(self._thread_local, "_distribution_strategy_stack"):
5292      self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
5293    return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
5294
5295  @_distribution_strategy_stack.setter
5296  def _distribution_strategy_stack(self, _distribution_strategy_stack):
5297    self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
5298        _distribution_strategy_stack)
5299
5300  @property
5301  def _global_distribute_strategy_scope(self):
5302    """For implementing `tf.distribute.set_strategy()`."""
5303    if not hasattr(self._thread_local, "distribute_strategy_scope"):
5304      self._thread_local.distribute_strategy_scope = None
5305    return self._thread_local.distribute_strategy_scope
5306
5307  @_global_distribute_strategy_scope.setter
5308  def _global_distribute_strategy_scope(self, distribute_strategy_scope):
5309    self._thread_local.distribute_strategy_scope = (distribute_strategy_scope)
5310
5311  def _mutation_lock(self):
5312    """Returns a lock to guard code that creates & mutates ops.
5313
5314    See the comment for self._group_lock for more info.
5315    """
5316    return self._group_lock.group(_MUTATION_LOCK_GROUP)
5317
5318  def _session_run_lock(self):
5319    """Returns a lock to guard code for Session.run.
5320
5321    See the comment for self._group_lock for more info.
5322    """
5323    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5324
5325
5326# TODO(agarwal): currently device directives in an outer eager scope will not
5327# apply to inner graph mode code. Fix that.
5328
5329
5330@tf_export(v1=["device"])
5331def device(device_name_or_function):
5332  """Wrapper for `Graph.device()` using the default graph.
5333
5334  See `tf.Graph.device` for more details.
5335
5336  Args:
5337    device_name_or_function: The device name or function to use in the context.
5338
5339  Returns:
5340    A context manager that specifies the default device to use for newly
5341    created ops.
5342
5343  Raises:
5344    RuntimeError: If eager execution is enabled and a function is passed in.
5345  """
5346  if context.executing_eagerly():
5347    if callable(device_name_or_function):
5348      raise RuntimeError(
5349          "tf.device does not support functions when eager execution "
5350          "is enabled.")
5351    return context.device(device_name_or_function)
5352  elif executing_eagerly_outside_functions():
5353    @tf_contextlib.contextmanager
5354    def combined(device_name_or_function):
5355      with get_default_graph().device(device_name_or_function):
5356        if not callable(device_name_or_function):
5357          with context.device(device_name_or_function):
5358            yield
5359        else:
5360          yield
5361    return combined(device_name_or_function)
5362  else:
5363    return get_default_graph().device(device_name_or_function)
5364
5365
5366@tf_export("device", v1=[])
5367def device_v2(device_name):
5368  """Specifies the device for ops created/executed in this context.
5369
5370  This function specifies the device to be used for ops created/executed in a
5371  particular context. Nested contexts will inherit and also create/execute
5372  their ops on the specified device. If a specific device is not required,
5373  consider not using this function so that a device can be automatically
5374  assigned.  In general the use of this function is optional. `device_name` can
5375  be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially
5376  specified, containing only a subset of the "/"-separated fields. Any fields
5377  which are specified will override device annotations from outer scopes.
5378
5379  For example:
5380
5381  ```python
5382  with tf.device('/job:foo'):
5383    # ops created here have devices with /job:foo
5384    with tf.device('/job:bar/task:0/device:gpu:2'):
5385      # ops created here have the fully specified device above
5386    with tf.device('/device:gpu:1'):
5387      # ops created here have the device '/job:foo/device:gpu:1'
5388  ```
5389
5390  Args:
5391    device_name: The device name to use in the context.
5392
5393  Returns:
5394    A context manager that specifies the default device to use for newly
5395    created ops.
5396
5397  Raises:
5398    RuntimeError: If a function is passed in.
5399  """
5400  if callable(device_name):
5401    raise RuntimeError("tf.device does not support functions.")
5402  return device(device_name)
5403
5404
5405@tf_export(v1=["container"])
5406def container(container_name):
5407  """Wrapper for `Graph.container()` using the default graph.
5408
5409  Args:
5410    container_name: The container string to use in the context.
5411
5412  Returns:
5413    A context manager that specifies the default container to use for newly
5414    created stateful ops.
5415  """
5416  return get_default_graph().container(container_name)
5417
5418
5419def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5420  if context.executing_eagerly():
5421    if op is not None:
5422      if not hasattr(op, "device"):
5423        op = internal_convert_to_tensor_or_indexed_slices(op)
5424      return device(op.device)
5425    else:
5426      return NullContextmanager()
5427  else:
5428    default_graph = get_default_graph()
5429    if isinstance(op, EagerTensor):
5430      if default_graph.building_function:
5431        return default_graph.device(op.device)
5432      else:
5433        raise ValueError("Encountered an Eager-defined Tensor during graph "
5434                         "construction, but a function was not being built.")
5435    return default_graph._colocate_with_for_gradient(
5436        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5437
5438
5439# Internal interface to colocate_with. colocate_with has been deprecated from
5440# public API. There are still a few internal uses of colocate_with. Add internal
5441# only API for those uses to avoid deprecation warning.
5442def colocate_with(op, ignore_existing=False):
5443  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5444
5445
5446@deprecation.deprecated(
5447    date=None, instructions="Colocations handled automatically by placer.")
5448@tf_export(v1=["colocate_with"])
5449def _colocate_with(op, ignore_existing=False):
5450  return colocate_with(op, ignore_existing)
5451
5452
5453@tf_export("control_dependencies")
5454def control_dependencies(control_inputs):
5455  """Wrapper for `Graph.control_dependencies()` using the default graph.
5456
5457  See `tf.Graph.control_dependencies` for more details.
5458
5459  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
5460  this method, as ops execute in the expected order thanks to automatic control
5461  dependencies.* Only use `tf.control_dependencies` when working with v1
5462  `tf.Graph` code.
5463
5464  When eager execution is enabled, any callable object in the `control_inputs`
5465  list will be called.
5466
5467  Args:
5468    control_inputs: A list of `Operation` or `Tensor` objects which must be
5469      executed or computed before running the operations defined in the context.
5470      Can also be `None` to clear the control dependencies. If eager execution
5471      is enabled, any callable object in the `control_inputs` list will be
5472      called.
5473
5474  Returns:
5475   A context manager that specifies control dependencies for all
5476   operations constructed within the context.
5477  """
5478  if context.executing_eagerly():
5479    if control_inputs:
5480      # Execute any pending callables.
5481      for control in control_inputs:
5482        if callable(control):
5483          control()
5484    return NullContextmanager()
5485  else:
5486    return get_default_graph().control_dependencies(control_inputs)
5487
5488
5489class _DefaultStack(threading.local):
5490  """A thread-local stack of objects for providing implicit defaults."""
5491
5492  def __init__(self):
5493    super(_DefaultStack, self).__init__()
5494    self._enforce_nesting = True
5495    self.stack = []
5496
5497  def get_default(self):
5498    return self.stack[-1] if self.stack else None
5499
5500  def reset(self):
5501    self.stack = []
5502
5503  def is_cleared(self):
5504    return not self.stack
5505
5506  @property
5507  def enforce_nesting(self):
5508    return self._enforce_nesting
5509
5510  @enforce_nesting.setter
5511  def enforce_nesting(self, value):
5512    self._enforce_nesting = value
5513
5514  @tf_contextlib.contextmanager
5515  def get_controller(self, default):
5516    """A context manager for manipulating a default stack."""
5517    self.stack.append(default)
5518    try:
5519      yield default
5520    finally:
5521      # stack may be empty if reset() was called
5522      if self.stack:
5523        if self._enforce_nesting:
5524          if self.stack[-1] is not default:
5525            raise AssertionError(
5526                "Nesting violated for default stack of %s objects" %
5527                type(default))
5528          self.stack.pop()
5529        else:
5530          self.stack.remove(default)
5531
5532
5533_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
5534
5535
5536def default_session(session):
5537  """Python "with" handler for defining a default session.
5538
5539  This function provides a means of registering a session for handling
5540  Tensor.eval() and Operation.run() calls. It is primarily intended for use
5541  by session.Session, but can be used with any object that implements
5542  the Session.run() interface.
5543
5544  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
5545  invocations within the scope of a block should be executed by a particular
5546  session.
5547
5548  The default session applies to the current thread only, so it is always
5549  possible to inspect the call stack and determine the scope of a default
5550  session. If you create a new thread, and wish to use the default session
5551  in that thread, you must explicitly add a "with ops.default_session(sess):"
5552  block in that thread's function.
5553
5554  Example:
5555    The following code examples are equivalent:
5556
5557    # 1. Using the Session object directly:
5558    sess = ...
5559    c = tf.constant(5.0)
5560    sess.run(c)
5561
5562    # 2. Using default_session():
5563    sess = ...
5564    with ops.default_session(sess):
5565      c = tf.constant(5.0)
5566      result = c.eval()
5567
5568    # 3. Overriding default_session():
5569    sess = ...
5570    with ops.default_session(sess):
5571      c = tf.constant(5.0)
5572      with ops.default_session(...):
5573        c.eval(session=sess)
5574
5575  Args:
5576    session: The session to be installed as the default session.
5577
5578  Returns:
5579    A context manager for the default session.
5580  """
5581  return _default_session_stack.get_controller(session)
5582
5583
5584@tf_export(v1=["get_default_session"])
5585def get_default_session():
5586  """Returns the default session for the current thread.
5587
5588  The returned `Session` will be the innermost session on which a
5589  `Session` or `Session.as_default()` context has been entered.
5590
5591  NOTE: The default session is a property of the current thread. If you
5592  create a new thread, and wish to use the default session in that
5593  thread, you must explicitly add a `with sess.as_default():` in that
5594  thread's function.
5595
5596  Returns:
5597    The default `Session` being used in the current thread.
5598  """
5599  return _default_session_stack.get_default()
5600
5601
5602def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5603  """Uses the default session to evaluate one or more tensors.
5604
5605  Args:
5606    tensors: A single Tensor, or a list of Tensor objects.
5607    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5608      numpy ndarrays, TensorProtos, or strings.
5609    graph: The graph in which the tensors are defined.
5610    session: (Optional) A different session to use to evaluate "tensors".
5611
5612  Returns:
5613    Either a single numpy ndarray if "tensors" is a single tensor; or a list
5614    of numpy ndarrays that each correspond to the respective element in
5615    "tensors".
5616
5617  Raises:
5618    ValueError: If no default session is available; the default session
5619      does not have "graph" as its graph; or if "session" is specified,
5620      and it does not have "graph" as its graph.
5621  """
5622  if session is None:
5623    session = get_default_session()
5624    if session is None:
5625      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5626                       "session is registered. Use `with "
5627                       "sess.as_default()` or pass an explicit session to "
5628                       "`eval(session=sess)`")
5629    if session.graph is not graph:
5630      raise ValueError("Cannot use the default session to evaluate tensor: "
5631                       "the tensor's graph is different from the session's "
5632                       "graph. Pass an explicit session to "
5633                       "`eval(session=sess)`.")
5634  else:
5635    if session.graph is not graph:
5636      raise ValueError("Cannot use the given session to evaluate tensor: "
5637                       "the tensor's graph is different from the session's "
5638                       "graph.")
5639  return session.run(tensors, feed_dict)
5640
5641
5642def _run_using_default_session(operation, feed_dict, graph, session=None):
5643  """Uses the default session to run "operation".
5644
5645  Args:
5646    operation: The Operation to be run.
5647    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5648      numpy ndarrays, TensorProtos, or strings.
5649    graph: The graph in which "operation" is defined.
5650    session: (Optional) A different session to use to run "operation".
5651
5652  Raises:
5653    ValueError: If no default session is available; the default session
5654      does not have "graph" as its graph; or if "session" is specified,
5655      and it does not have "graph" as its graph.
5656  """
5657  if session is None:
5658    session = get_default_session()
5659    if session is None:
5660      raise ValueError("Cannot execute operation using `run()`: No default "
5661                       "session is registered. Use `with "
5662                       "sess.as_default():` or pass an explicit session to "
5663                       "`run(session=sess)`")
5664    if session.graph is not graph:
5665      raise ValueError("Cannot use the default session to execute operation: "
5666                       "the operation's graph is different from the "
5667                       "session's graph. Pass an explicit session to "
5668                       "run(session=sess).")
5669  else:
5670    if session.graph is not graph:
5671      raise ValueError("Cannot use the given session to execute operation: "
5672                       "the operation's graph is different from the session's "
5673                       "graph.")
5674  session.run(operation, feed_dict)
5675
5676
5677class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
5678  """A thread-local stack of objects for providing an implicit default graph."""
5679
5680  def __init__(self):
5681    super(_DefaultGraphStack, self).__init__()
5682    self._global_default_graph = None
5683
5684  def get_default(self):
5685    """Override that returns a global default if the stack is empty."""
5686    if self.stack:
5687      return self.stack[-1]
5688    elif self._global_default_graph:
5689      return self._global_default_graph
5690    else:
5691      self._global_default_graph = Graph()
5692      return self._global_default_graph
5693
5694  def _GetGlobalDefaultGraph(self):
5695    if self._global_default_graph is None:
5696      # TODO(mrry): Perhaps log that the default graph is being used, or set
5697      #   provide some other feedback to prevent confusion when a mixture of
5698      #   the global default graph and an explicit graph are combined in the
5699      #   same process.
5700      self._global_default_graph = Graph()
5701    return self._global_default_graph
5702
5703  def reset(self):
5704    super(_DefaultGraphStack, self).reset()
5705    self._global_default_graph = None
5706
5707  @tf_contextlib.contextmanager
5708  def get_controller(self, default):
5709    context.context().context_switches.push(default.building_function,
5710                                            default.as_default,
5711                                            default._device_function_stack)
5712    try:
5713      with super(_DefaultGraphStack,
5714                 self).get_controller(default) as g, context.graph_mode():
5715        yield g
5716    finally:
5717      # If an exception is raised here it may be hiding a related exception in
5718      # the try-block (just above).
5719      context.context().context_switches.pop()
5720
5721
5722_default_graph_stack = _DefaultGraphStack()
5723
5724
5725# Shared helper used in init_scope and executing_eagerly_outside_functions
5726# to obtain the outermost context that is not building a function, and the
5727# innermost non empty device stack.
5728def _get_outer_context_and_inner_device_stack():
5729  """Get the outermost context not building a function."""
5730  default_graph = get_default_graph()
5731  outer_context = None
5732  innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
5733
5734  if not _default_graph_stack.stack:
5735    # If the default graph stack is empty, then we cannot be building a
5736    # function. Install the global graph (which, in this case, is also the
5737    # default graph) as the outer context.
5738    if default_graph.building_function:
5739      raise RuntimeError("The global graph is building a function.")
5740    outer_context = default_graph.as_default
5741  else:
5742    # Find a context that is not building a function.
5743    for stack_entry in reversed(context.context().context_switches.stack):
5744      if not innermost_nonempty_device_stack:
5745        innermost_nonempty_device_stack = stack_entry.device_stack
5746      if not stack_entry.is_building_function:
5747        outer_context = stack_entry.enter_context_fn
5748        break
5749
5750    if outer_context is None:
5751      # As a last resort, obtain the global default graph; this graph doesn't
5752      # necessarily live on the graph stack (and hence it doesn't necessarily
5753      # live on the context stack), but it is stored in the graph stack's
5754      # encapsulating object.
5755      outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
5756
5757  if outer_context is None:
5758    # Sanity check; this shouldn't be triggered.
5759    raise RuntimeError("All graphs are building functions, and no "
5760                       "eager context was previously active.")
5761
5762  return outer_context, innermost_nonempty_device_stack
5763
5764
5765# pylint: disable=g-doc-return-or-yield,line-too-long
5766@tf_export("init_scope")
5767@tf_contextlib.contextmanager
5768def init_scope():
5769  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5770
5771  There is often a need to lift variable initialization ops out of control-flow
5772  scopes, function-building graphs, and gradient tapes. Entering an
5773  `init_scope` is a mechanism for satisfying these desiderata. In particular,
5774  entering an `init_scope` has three effects:
5775
5776    (1) All control dependencies are cleared the moment the scope is entered;
5777        this is equivalent to entering the context manager returned from
5778        `control_dependencies(None)`, which has the side-effect of exiting
5779        control-flow scopes like `tf.cond` and `tf.while_loop`.
5780
5781    (2) All operations that are created while the scope is active are lifted
5782        into the lowest context on the `context_stack` that is not building a
5783        graph function. Here, a context is defined as either a graph or an eager
5784        context. Every context switch, i.e., every installation of a graph as
5785        the default graph and every switch into eager mode, is logged in a
5786        thread-local stack called `context_switches`; the log entry for a
5787        context switch is popped from the stack when the context is exited.
5788        Entering an `init_scope` is equivalent to crawling up
5789        `context_switches`, finding the first context that is not building a
5790        graph function, and entering it. A caveat is that if graph mode is
5791        enabled but the default graph stack is empty, then entering an
5792        `init_scope` will simply install a fresh graph as the default one.
5793
5794    (3) The gradient tape is paused while the scope is active.
5795
5796  When eager execution is enabled, code inside an init_scope block runs with
5797  eager execution enabled even when tracing a `tf.function`. For example:
5798
5799  ```python
5800  tf.compat.v1.enable_eager_execution()
5801
5802  @tf.function
5803  def func():
5804    # A function constructs TensorFlow graphs,
5805    # it does not execute eagerly.
5806    assert not tf.executing_eagerly()
5807    with tf.init_scope():
5808      # Initialization runs with eager execution enabled
5809      assert tf.executing_eagerly()
5810  ```
5811
5812  Raises:
5813    RuntimeError: if graph state is incompatible with this initialization.
5814  """
5815  # pylint: enable=g-doc-return-or-yield,line-too-long
5816
5817  if context.executing_eagerly():
5818    # Fastpath.
5819    with tape.stop_recording():
5820      yield
5821  else:
5822    # Retrieve the active name scope: entering an `init_scope` preserves
5823    # the name scope of the current context.
5824    scope = get_default_graph().get_name_scope()
5825    if scope and scope[-1] != "/":
5826      # Names that end with trailing slashes are treated by `name_scope` as
5827      # absolute.
5828      scope = scope + "/"
5829
5830    outer_context, innermost_nonempty_device_stack = (
5831        _get_outer_context_and_inner_device_stack())
5832
5833    outer_graph = None
5834    outer_device_stack = None
5835    try:
5836      with outer_context(), name_scope(
5837          scope, skip_on_eager=False), control_dependencies(
5838              None), tape.stop_recording():
5839        context_manager = NullContextmanager
5840        context_manager_input = None
5841        if not context.executing_eagerly():
5842          # The device stack is preserved when lifting into a graph. Eager
5843          # execution doesn't implement device stacks and in particular it
5844          # doesn't support device functions, so in general it's not possible
5845          # to do the same when lifting into the eager context.
5846          outer_graph = get_default_graph()
5847          outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
5848          outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
5849        elif innermost_nonempty_device_stack is not None:
5850          for device_spec in innermost_nonempty_device_stack.peek_objs():
5851            if device_spec.function is None:
5852              break
5853            if device_spec.raw_string:
5854              context_manager = context.device
5855              context_manager_input = device_spec.raw_string
5856              break
5857            # It is currently not possible to have a device function in V2,
5858            # but in V1 we are unable to apply device functions in eager mode.
5859            # This means that we will silently skip some of the entries on the
5860            # device stack in V1 + eager mode.
5861
5862        with context_manager(context_manager_input):
5863          yield
5864    finally:
5865      # If an exception is raised here it may be hiding a related exception in
5866      # try-block (just above).
5867      if outer_graph is not None:
5868        outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
5869
5870
5871@tf_export(v1=["executing_eagerly_outside_functions"])
5872def executing_eagerly_outside_functions():
5873  """Returns True if executing eagerly, even if inside a graph function.
5874
5875  This function will check the outermost context for the program and see if
5876  it is in eager mode. It is useful comparing to `tf.executing_eagerly()`,
5877  which checks the current context and will return `False` within a
5878  `tf.function` body. It can be used to build library that behave differently
5879  in eager runtime and v1 session runtime (deprecated).
5880
5881  Example:
5882
5883  >>> tf.compat.v1.enable_eager_execution()
5884  >>> @tf.function
5885  ... def func():
5886  ...   # A function constructs TensorFlow graphs, it does not execute eagerly,
5887  ...   # but the outer most context is still eager.
5888  ...   assert not tf.executing_eagerly()
5889  ...   return tf.compat.v1.executing_eagerly_outside_functions()
5890  >>> func()
5891  <tf.Tensor: shape=(), dtype=bool, numpy=True>
5892
5893  Returns:
5894    boolean, whether the outermost context is in eager mode.
5895  """
5896  if context.executing_eagerly():
5897    return True
5898  else:
5899    outer_context, _ = _get_outer_context_and_inner_device_stack()
5900    with outer_context():
5901      return context.executing_eagerly()
5902
5903
5904@tf_export("inside_function", v1=[])
5905def inside_function():
5906  """Indicates whether the caller code is executing inside a `tf.function`.
5907
5908  Returns:
5909    Boolean, True if the caller code is executing inside a `tf.function`
5910    rather than eagerly.
5911
5912  Example:
5913
5914  >>> tf.inside_function()
5915  False
5916  >>> @tf.function
5917  ... def f():
5918  ...   print(tf.inside_function())
5919  >>> f()
5920  True
5921  """
5922  return get_default_graph().building_function
5923
5924
5925@tf_export(v1=["enable_eager_execution"])
5926def enable_eager_execution(config=None, device_policy=None,
5927                           execution_mode=None):
5928  """Enables eager execution for the lifetime of this program.
5929
5930  Eager execution provides an imperative interface to TensorFlow. With eager
5931  execution enabled, TensorFlow functions execute operations immediately (as
5932  opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`)
5933  and
5934  return concrete values (as opposed to symbolic references to a node in a
5935  computational graph).
5936
5937  For example:
5938
5939  ```python
5940  tf.compat.v1.enable_eager_execution()
5941
5942  # After eager execution is enabled, operations are executed as they are
5943  # defined and Tensor objects hold concrete values, which can be accessed as
5944  # numpy.ndarray`s through the numpy() method.
5945  assert tf.multiply(6, 7).numpy() == 42
5946  ```
5947
5948  Eager execution cannot be enabled after TensorFlow APIs have been used to
5949  create or execute graphs. It is typically recommended to invoke this function
5950  at program startup and not in a library (as most libraries should be usable
5951  both with and without eager execution).
5952
5953  @compatibility(TF2)
5954  This function is not necessary if you are using TF2. Eager execution is
5955  enabled by default.
5956  @end_compatibility
5957
5958  Args:
5959    config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the
5960      environment in which operations are executed. Note that
5961      `tf.compat.v1.ConfigProto` is also used to configure graph execution (via
5962      `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto`
5963      are not implemented (or are irrelevant) when eager execution is enabled.
5964    device_policy: (Optional.) Policy controlling how operations requiring
5965      inputs on a specific device (e.g., a GPU 0) handle inputs on a different
5966      device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will
5967      be picked automatically. The value picked may change between TensorFlow
5968      releases.
5969      Valid values:
5970      - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
5971        placement is not correct.
5972      - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
5973        on the right device but logs a warning.
5974      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
5975        Note that this may hide performance problems as there is no notification
5976        provided when operations are blocked on the tensor being copied between
5977        devices.
5978      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
5979        int32 tensors, raising errors on the other ones.
5980    execution_mode: (Optional.) Policy controlling how operations dispatched are
5981      actually executed. When set to None, an appropriate value will be picked
5982      automatically. The value picked may change between TensorFlow releases.
5983      Valid values:
5984      - tf.contrib.eager.SYNC: executes each operation synchronously.
5985      - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
5986        operations may return "non-ready" handles.
5987
5988  Raises:
5989    ValueError: If eager execution is enabled after creating/executing a
5990     TensorFlow graph, or if options provided conflict with a previous call
5991     to this function.
5992  """
5993  _api_usage_gauge.get_cell().set(True)
5994  logging.vlog(1, "Enabling eager execution")
5995  if context.default_execution_mode != context.EAGER_MODE:
5996    return enable_eager_execution_internal(
5997        config=config,
5998        device_policy=device_policy,
5999        execution_mode=execution_mode,
6000        server_def=None)
6001
6002
6003@tf_export(v1=["disable_eager_execution"])
6004def disable_eager_execution():
6005  """Disables eager execution.
6006
6007  This function can only be called before any Graphs, Ops, or Tensors have been
6008  created.
6009
6010  @compatibility(TF2)
6011  This function is not necessary if you are using TF2. Eager execution is
6012  enabled by default. If you want to use Graph mode please consider
6013  [tf.function](https://www.tensorflow.org/api_docs/python/tf/function).
6014  @end_compatibility
6015  """
6016  _api_usage_gauge.get_cell().set(False)
6017  logging.vlog(1, "Disabling eager execution")
6018  context.default_execution_mode = context.GRAPH_MODE
6019  c = context.context_safe()
6020  if c is not None:
6021    c._thread_local_data.is_eager = False  # pylint: disable=protected-access
6022
6023
6024def enable_eager_execution_internal(config=None,
6025                                    device_policy=None,
6026                                    execution_mode=None,
6027                                    server_def=None):
6028  """Enables eager execution for the lifetime of this program.
6029
6030  Most of the doc string for enable_eager_execution is relevant here as well.
6031
6032  Args:
6033    config: See enable_eager_execution doc string
6034    device_policy: See enable_eager_execution doc string
6035    execution_mode: See enable_eager_execution doc string
6036    server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on
6037      remote devices. GrpcServers need to be started by creating an identical
6038      server_def to this, and setting the appropriate task_indexes, so that the
6039      servers can communicate. It will then be possible to execute operations on
6040      remote devices.
6041
6042  Raises:
6043    ValueError
6044
6045  """
6046  if config is not None and not isinstance(config, config_pb2.ConfigProto):
6047    raise TypeError("config must be a tf.ConfigProto, but got %s" %
6048                    type(config))
6049  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
6050                           context.DEVICE_PLACEMENT_WARN,
6051                           context.DEVICE_PLACEMENT_SILENT,
6052                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
6053    raise ValueError(
6054        "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
6055    )
6056  if execution_mode not in (None, context.SYNC, context.ASYNC):
6057    raise ValueError(
6058        "execution_mode must be one of None, tf.contrib.eager.SYNC, "
6059        "tf.contrib.eager.ASYNC")
6060  if context.default_execution_mode == context.GRAPH_MODE:
6061    graph_mode_has_been_used = (
6062        _default_graph_stack._global_default_graph is not None)  # pylint: disable=protected-access
6063    if graph_mode_has_been_used:
6064      raise ValueError(
6065          "tf.enable_eager_execution must be called at program startup.")
6066  context.default_execution_mode = context.EAGER_MODE
6067  # pylint: disable=protected-access
6068  with context._context_lock:
6069    if context._context is None:
6070      context._set_context_locked(context.Context(
6071          config=config,
6072          device_policy=device_policy,
6073          execution_mode=execution_mode,
6074          server_def=server_def))
6075    elif ((config is not None and config is not context._context._config) or
6076          (device_policy is not None and
6077           device_policy is not context._context._device_policy) or
6078          (execution_mode is not None and
6079           execution_mode is not context._context._execution_mode)):
6080      raise ValueError(
6081          "Trying to change the options of an active eager"
6082          " execution. Context config: %s, specified config:"
6083          " %s. Context device policy: %s, specified device"
6084          " policy: %s. Context execution mode: %s, "
6085          " specified execution mode %s." %
6086          (context._context._config, config, context._context._device_policy,
6087           device_policy, context._context._execution_mode, execution_mode))
6088    else:
6089      # We already created everything, so update the thread local data.
6090      context._context._thread_local_data.is_eager = True
6091
6092  # Monkey patch to get rid of an unnecessary conditional since the context is
6093  # now initialized.
6094  context.context = context.context_safe
6095
6096
6097def eager_run(main=None, argv=None):
6098  """Runs the program with an optional main function and argv list.
6099
6100  The program will run with eager execution enabled.
6101
6102  Example:
6103  ```python
6104  import tensorflow as tf
6105  # Import subject to future changes:
6106  from tensorflow.contrib.eager.python import tfe
6107
6108  def main(_):
6109    u = tf.constant(6.0)
6110    v = tf.constant(7.0)
6111    print(u * v)
6112
6113  if __name__ == "__main__":
6114    tfe.run()
6115  ```
6116
6117  Args:
6118    main: the main function to run.
6119    argv: the arguments to pass to it.
6120  """
6121  enable_eager_execution()
6122  app.run(main, argv)
6123
6124
6125@tf_export(v1=["reset_default_graph"])
6126def reset_default_graph():
6127  """Clears the default graph stack and resets the global default graph.
6128
6129  NOTE: The default graph is a property of the current thread. This
6130  function applies only to the current thread.  Calling this function while
6131  a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will
6132  result in undefined
6133  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
6134  after calling this function will result in undefined behavior.
6135
6136  @compatibility(TF2)
6137  `reset_default_graph` does not work with either eager execution or
6138  `tf.function`, and you should not invoke it directly. To migrate code that
6139  uses Graph-related functions to TF2, rewrite the code without them. See the
6140  [migration guide](https://www.tensorflow.org/guide/migrate) for more
6141  description about the behavior and semantic changes between Tensorflow 1 and
6142  Tensorflow 2.
6143  @end_compatibility
6144
6145  Raises:
6146    AssertionError: If this function is called within a nested graph.
6147  """
6148  if not _default_graph_stack.is_cleared():
6149    raise AssertionError("Do not use tf.reset_default_graph() to clear "
6150                         "nested graphs. If you need a cleared graph, "
6151                         "exit the nesting and create a new graph.")
6152  _default_graph_stack.reset()
6153
6154
6155@tf_export(v1=["get_default_graph"])
6156def get_default_graph():
6157  """Returns the default graph for the current thread.
6158
6159  The returned graph will be the innermost graph on which a
6160  `Graph.as_default()` context has been entered, or a global default
6161  graph if none has been explicitly created.
6162
6163  NOTE: The default graph is a property of the current thread. If you
6164  create a new thread, and wish to use the default graph in that
6165  thread, you must explicitly add a `with g.as_default():` in that
6166  thread's function.
6167
6168  @compatibility(TF2)
6169  `get_default_graph` does not work with either eager execution or
6170  `tf.function`, and you should not invoke it directly. To migrate code that
6171  uses Graph-related functions to TF2, rewrite the code without them. See the
6172  [migration guide](https://www.tensorflow.org/guide/migrate) for more
6173  description about the behavior and semantic changes between Tensorflow 1 and
6174  Tensorflow 2.
6175  @end_compatibility
6176
6177  Returns:
6178    The default `Graph` being used in the current thread.
6179  """
6180  return _default_graph_stack.get_default()
6181
6182
6183def has_default_graph():
6184  """Returns True if there is a default graph."""
6185  return len(_default_graph_stack.stack) >= 1
6186
6187
6188# Exported due to b/171079555
6189@tf_export("__internal__.get_name_scope", v1=[])
6190def get_name_scope():
6191  """Returns the current name scope in the default_graph.
6192
6193  For example:
6194
6195  ```python
6196  with tf.name_scope('scope1'):
6197    with tf.name_scope('scope2'):
6198      print(tf.get_name_scope())
6199  ```
6200  would print the string `scope1/scope2`.
6201
6202  Returns:
6203    A string representing the current name scope.
6204  """
6205  if context.executing_eagerly():
6206    return context.context().scope_name.rstrip("/")
6207  return get_default_graph().get_name_scope()
6208
6209
6210def _assert_same_graph(original_item, item):
6211  """Fail if the 2 items are from different graphs.
6212
6213  Args:
6214    original_item: Original item to check against.
6215    item: Item to check.
6216
6217  Raises:
6218    ValueError: if graphs do not match.
6219  """
6220  original_graph = getattr(original_item, "graph", None)
6221  graph = getattr(item, "graph", None)
6222  if original_graph and graph and original_graph is not graph:
6223    raise ValueError(
6224        "%s must be from the same graph as %s (graphs are %s and %s)." %
6225        (item, original_item, graph, original_graph))
6226
6227
6228def _get_graph_from_inputs(op_input_list, graph=None):
6229  """Returns the appropriate graph to use for the given inputs.
6230
6231  This library method provides a consistent algorithm for choosing the graph
6232  in which an Operation should be constructed:
6233
6234  1. If the default graph is being used to construct a function, we
6235     use the default graph.
6236  2. If the "graph" is specified explicitly, we validate that all of the inputs
6237     in "op_input_list" are compatible with that graph.
6238  3. Otherwise, we attempt to select a graph from the first Operation-
6239     or Tensor-valued input in "op_input_list", and validate that all other
6240     such inputs are in the same graph.
6241  4. If the graph was not specified and it could not be inferred from
6242     "op_input_list", we attempt to use the default graph.
6243
6244  Args:
6245    op_input_list: A list of inputs to an operation, which may include `Tensor`,
6246      `Operation`, and other objects that may be converted to a graph element.
6247    graph: (Optional) The explicit graph to use.
6248
6249  Raises:
6250    TypeError: If op_input_list is not a list or tuple, or if graph is not a
6251      Graph.
6252    ValueError: If a graph is explicitly passed and not all inputs are from it,
6253      or if the inputs are from multiple graphs, or we could not find a graph
6254      and there was no default graph.
6255
6256  Returns:
6257    The appropriate graph to use for the given inputs.
6258
6259  """
6260  current_default_graph = get_default_graph()
6261  if current_default_graph.building_function:
6262    return current_default_graph
6263
6264  op_input_list = tuple(op_input_list)  # Handle generators correctly
6265  if graph and not isinstance(graph, Graph):
6266    raise TypeError("Input graph needs to be a Graph: %s" % (graph,))
6267
6268  # 1. We validate that all of the inputs are from the same graph. This is
6269  #    either the supplied graph parameter, or the first one selected from one
6270  #    the graph-element-valued inputs. In the latter case, we hold onto
6271  #    that input in original_graph_element so we can provide a more
6272  #    informative error if a mismatch is found.
6273  original_graph_element = None
6274  for op_input in op_input_list:
6275    # Determine if this is a valid graph_element.
6276    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
6277    # up.
6278    graph_element = None
6279    if (isinstance(op_input, (Operation, internal.NativeObject)) and
6280        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
6281      graph_element = op_input
6282    else:
6283      graph_element = _as_graph_element(op_input)
6284
6285    if graph_element is not None:
6286      if not graph:
6287        original_graph_element = graph_element
6288        graph = getattr(graph_element, "graph", None)
6289      elif original_graph_element is not None:
6290        _assert_same_graph(original_graph_element, graph_element)
6291      elif graph_element.graph is not graph:
6292        raise ValueError("%s is not from the passed-in graph." % graph_element)
6293
6294  # 2. If all else fails, we use the default graph, which is always there.
6295  return graph or current_default_graph
6296
6297
6298@tf_export(v1=["GraphKeys"])
6299class GraphKeys(object):
6300  """Standard names to use for graph collections.
6301
6302  The standard library uses various well-known names to collect and
6303  retrieve values associated with a graph. For example, the
6304  `tf.Optimizer` subclasses default to optimizing the variables
6305  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
6306  specified, but it is also possible to pass an explicit list of
6307  variables.
6308
6309  The following standard keys are defined:
6310
6311  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
6312    across distributed environment (model variables are subset of these). See
6313    `tf.compat.v1.global_variables`
6314    for more details.
6315    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
6316    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
6317  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
6318    machine. Usually used for temporarily variables, like counters.
6319    Note: use `tf.contrib.framework.local_variable` to add to this collection.
6320  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
6321    model for inference (feed forward). Note: use
6322    `tf.contrib.framework.model_variable` to add to this collection.
6323  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
6324    be trained by an optimizer. See
6325    `tf.compat.v1.trainable_variables`
6326    for more details.
6327  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
6328    graph. See
6329    `tf.compat.v1.summary.merge_all`
6330    for more details.
6331  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
6332    produce input for a computation. See
6333    `tf.compat.v1.train.start_queue_runners`
6334    for more details.
6335  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
6336    keep moving averages.  See
6337    `tf.compat.v1.moving_average_variables`
6338    for more details.
6339  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
6340    construction.
6341
6342  The following standard keys are _defined_, but their collections are **not**
6343  automatically populated as many of the others are:
6344
6345  * `WEIGHTS`
6346  * `BIASES`
6347  * `ACTIVATIONS`
6348  """
6349
6350  # Key to collect Variable objects that are global (shared across machines).
6351  # Default collection for all variables, except local ones.
6352  GLOBAL_VARIABLES = "variables"
6353  # Key to collect local variables that are local to the machine and are not
6354  # saved/restored.
6355  LOCAL_VARIABLES = "local_variables"
6356  # Key to collect local variables which are used to accumulate interal state
6357  # to be used in tf.metrics.*.
6358  METRIC_VARIABLES = "metric_variables"
6359  # Key to collect model variables defined by layers.
6360  MODEL_VARIABLES = "model_variables"
6361  # Key to collect Variable objects that will be trained by the
6362  # optimizers.
6363  TRAINABLE_VARIABLES = "trainable_variables"
6364  # Key to collect summaries.
6365  SUMMARIES = "summaries"
6366  # Key to collect QueueRunners.
6367  QUEUE_RUNNERS = "queue_runners"
6368  # Key to collect table initializers.
6369  TABLE_INITIALIZERS = "table_initializer"
6370  # Key to collect asset filepaths. An asset represents an external resource
6371  # like a vocabulary file.
6372  ASSET_FILEPATHS = "asset_filepaths"
6373  # Key to collect Variable objects that keep moving averages.
6374  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
6375  # Key to collect regularization losses at graph construction.
6376  REGULARIZATION_LOSSES = "regularization_losses"
6377  # Key to collect concatenated sharded variables.
6378  CONCATENATED_VARIABLES = "concatenated_variables"
6379  # Key to collect savers.
6380  SAVERS = "savers"
6381  # Key to collect weights
6382  WEIGHTS = "weights"
6383  # Key to collect biases
6384  BIASES = "biases"
6385  # Key to collect activations
6386  ACTIVATIONS = "activations"
6387  # Key to collect update_ops
6388  UPDATE_OPS = "update_ops"
6389  # Key to collect losses
6390  LOSSES = "losses"
6391  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6392  SAVEABLE_OBJECTS = "saveable_objects"
6393  # Key to collect all shared resources used by the graph which need to be
6394  # initialized once per cluster.
6395  RESOURCES = "resources"
6396  # Key to collect all shared resources used in this graph which need to be
6397  # initialized once per session.
6398  LOCAL_RESOURCES = "local_resources"
6399  # Trainable resource-style variables.
6400  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6401
6402  # Key to indicate various ops.
6403  INIT_OP = "init_op"
6404  LOCAL_INIT_OP = "local_init_op"
6405  READY_OP = "ready_op"
6406  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6407  SUMMARY_OP = "summary_op"
6408  GLOBAL_STEP = "global_step"
6409
6410  # Used to count the number of evaluations performed during a single evaluation
6411  # run.
6412  EVAL_STEP = "eval_step"
6413  TRAIN_OP = "train_op"
6414
6415  # Key for control flow context.
6416  COND_CONTEXT = "cond_context"
6417  WHILE_CONTEXT = "while_context"
6418
6419  # Used to store v2 summary names.
6420  _SUMMARY_COLLECTION = "_SUMMARY_V2"
6421
6422  # List of all collections that keep track of variables.
6423  _VARIABLE_COLLECTIONS = [
6424      GLOBAL_VARIABLES,
6425      LOCAL_VARIABLES,
6426      METRIC_VARIABLES,
6427      MODEL_VARIABLES,
6428      TRAINABLE_VARIABLES,
6429      MOVING_AVERAGE_VARIABLES,
6430      CONCATENATED_VARIABLES,
6431      TRAINABLE_RESOURCE_VARIABLES,
6432  ]
6433
6434  # Key for streaming model ports.
6435  # NOTE(yuanbyu): internal and experimental.
6436  _STREAMING_MODEL_PORTS = "streaming_model_ports"
6437
6438  @decorator_utils.classproperty
6439  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6440  def VARIABLES(cls):  # pylint: disable=no-self-argument
6441    return cls.GLOBAL_VARIABLES
6442
6443
6444def dismantle_graph(graph):
6445  """Cleans up reference cycles from a `Graph`.
6446
6447  Helpful for making sure the garbage collector doesn't need to run after a
6448  temporary `Graph` is no longer needed.
6449
6450  Args:
6451    graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6452      after this function runs.
6453  """
6454  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
6455
6456  # Now clean up Operation<->Graph reference cycles by clearing all of the
6457  # attributes for the Graph and its ops.
6458  graph_operations = graph.get_operations()
6459  for op in graph_operations:
6460    op.__dict__ = {}
6461  graph.__dict__ = {}
6462
6463
6464@tf_export(v1=["add_to_collection"])
6465def add_to_collection(name, value):
6466  """Wrapper for `Graph.add_to_collection()` using the default graph.
6467
6468  See `tf.Graph.add_to_collection`
6469  for more details.
6470
6471  Args:
6472    name: The key for the collection. For example, the `GraphKeys` class
6473      contains many standard names for collections.
6474    value: The value to add to the collection.
6475
6476  @compatibility(eager)
6477  Collections are only supported in eager when variables are created inside
6478  an EagerVariableStore (e.g. as part of a layer or template).
6479  @end_compatibility
6480  """
6481  get_default_graph().add_to_collection(name, value)
6482
6483
6484@tf_export(v1=["add_to_collections"])
6485def add_to_collections(names, value):
6486  """Wrapper for `Graph.add_to_collections()` using the default graph.
6487
6488  See `tf.Graph.add_to_collections`
6489  for more details.
6490
6491  Args:
6492    names: The key for the collections. The `GraphKeys` class contains many
6493      standard names for collections.
6494    value: The value to add to the collections.
6495
6496  @compatibility(eager)
6497  Collections are only supported in eager when variables are created inside
6498  an EagerVariableStore (e.g. as part of a layer or template).
6499  @end_compatibility
6500  """
6501  get_default_graph().add_to_collections(names, value)
6502
6503
6504@tf_export(v1=["get_collection_ref"])
6505def get_collection_ref(key):
6506  """Wrapper for `Graph.get_collection_ref()` using the default graph.
6507
6508  See `tf.Graph.get_collection_ref`
6509  for more details.
6510
6511  Args:
6512    key: The key for the collection. For example, the `GraphKeys` class contains
6513      many standard names for collections.
6514
6515  Returns:
6516    The list of values in the collection with the given `name`, or an empty
6517    list if no value has been added to that collection.  Note that this returns
6518    the collection list itself, which can be modified in place to change the
6519    collection.
6520
6521  @compatibility(eager)
6522  Collections are not supported when eager execution is enabled.
6523  @end_compatibility
6524  """
6525  return get_default_graph().get_collection_ref(key)
6526
6527
6528@tf_export(v1=["get_collection"])
6529def get_collection(key, scope=None):
6530  """Wrapper for `Graph.get_collection()` using the default graph.
6531
6532  See `tf.Graph.get_collection`
6533  for more details.
6534
6535  Args:
6536    key: The key for the collection. For example, the `GraphKeys` class contains
6537      many standard names for collections.
6538    scope: (Optional.) If supplied, the resulting list is filtered to include
6539      only items whose `name` attribute matches using `re.match`. Items without
6540      a `name` attribute are never returned if a scope is supplied and the
6541      choice or `re.match` means that a `scope` without special tokens filters
6542      by prefix.
6543
6544  Returns:
6545    The list of values in the collection with the given `name`, or
6546    an empty list if no value has been added to that collection. The
6547    list contains the values in the order under which they were
6548    collected.
6549
6550  @compatibility(eager)
6551  Collections are not supported when eager execution is enabled.
6552  @end_compatibility
6553  """
6554  return get_default_graph().get_collection(key, scope)
6555
6556
6557def get_all_collection_keys():
6558  """Returns a list of collections used in the default graph."""
6559  return get_default_graph().get_all_collection_keys()
6560
6561
6562def name_scope(name, default_name=None, values=None, skip_on_eager=True):
6563  """Internal-only entry point for `name_scope*`.
6564
6565  Internal ops do not use the public API and instead rely on
6566  `ops.name_scope` regardless of the execution mode. This function
6567  dispatches to the correct `name_scope*` implementation based on
6568  the arguments provided and the current mode. Specifically,
6569
6570  * if `values` contains a graph tensor `Graph.name_scope` is used;
6571  * `name_scope_v1` is used in graph mode;
6572  * `name_scope_v2` -- in eager mode.
6573
6574  Args:
6575    name: The name argument that is passed to the op function.
6576    default_name: The default name to use if the `name` argument is `None`.
6577    values: The list of `Tensor` arguments that are passed to the op function.
6578    skip_on_eager: Indicates to return NullContextmanager if executing eagerly.
6579      By default this is True since naming tensors and operations in eager mode
6580      have little use and cause unnecessary performance overhead. However, it is
6581      important to preserve variable names since they are often useful for
6582      debugging and saved models.
6583
6584  Returns:
6585    `name_scope*` context manager.
6586  """
6587  if not context.executing_eagerly():
6588    return internal_name_scope_v1(name, default_name, values)
6589
6590  if skip_on_eager:
6591    return NullContextmanager()
6592
6593  name = default_name if name is None else name
6594  if values:
6595    # The presence of a graph tensor in `values` overrides the context.
6596    # TODO(slebedev): this is Keras-specific and should be removed.
6597    # pylint: disable=unidiomatic-typecheck
6598    graph_value = next((value for value in values if type(value) == Tensor),
6599                       None)
6600    # pylint: enable=unidiomatic-typecheck
6601    if graph_value is not None:
6602      return graph_value.graph.name_scope(name)
6603
6604  return name_scope_v2(name or "")
6605
6606
6607class internal_name_scope_v1(object):  # pylint: disable=invalid-name
6608  """Graph-only version of `name_scope_v1`."""
6609
6610  @property
6611  def name(self):
6612    return self._name
6613
6614  def __init__(self, name, default_name=None, values=None):
6615    """Initialize the context manager.
6616
6617    Args:
6618      name: The name argument that is passed to the op function.
6619      default_name: The default name to use if the `name` argument is `None`.
6620      values: The list of `Tensor` arguments that are passed to the op function.
6621
6622    Raises:
6623      TypeError: if `default_name` is passed in but not a string.
6624    """
6625    if not (default_name is None or isinstance(default_name, six.string_types)):
6626      raise TypeError(
6627          "`default_name` type (%s) is not a string type. You likely meant to "
6628          "pass this into the `values` kwarg." % type(default_name))
6629    self._name = default_name if name is None else name
6630    self._default_name = default_name
6631    self._values = values
6632
6633  def __enter__(self):
6634    """Start the scope block.
6635
6636    Returns:
6637      The scope name.
6638
6639    Raises:
6640      ValueError: if neither `name` nor `default_name` is provided
6641        but `values` are.
6642    """
6643    if self._name is None and self._values is not None:
6644      # We only raise an error if values is not None (provided) because
6645      # currently tf.name_scope(None) (values=None then) is sometimes used as
6646      # an idiom to reset to top scope.
6647      raise ValueError(
6648          "At least one of name (%s) and default_name (%s) must be provided."
6649          % (self._name, self._default_name))
6650
6651    g = get_default_graph()
6652    if self._values and not g.building_function:
6653      # Specialize based on the knowledge that `_get_graph_from_inputs()`
6654      # ignores `inputs` when building a function.
6655      g_from_inputs = _get_graph_from_inputs(self._values)
6656      if g_from_inputs is not g:
6657        g = g_from_inputs
6658        self._g_manager = g.as_default()
6659        self._g_manager.__enter__()
6660      else:
6661        self._g_manager = None
6662    else:
6663      self._g_manager = None
6664
6665    try:
6666      self._name_scope = g.name_scope(self._name)
6667      return self._name_scope.__enter__()
6668    except:
6669      if self._g_manager is not None:
6670        self._g_manager.__exit__(*sys.exc_info())
6671      raise
6672
6673  def __exit__(self, *exc_info):
6674    self._name_scope.__exit__(*exc_info)
6675    if self._g_manager is not None:
6676      self._g_manager.__exit__(*exc_info)
6677
6678
6679# Named like a function for backwards compatibility with the
6680# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6681# some object creation overhead.
6682@tf_export(v1=["name_scope"])
6683class name_scope_v1(object):  # pylint: disable=invalid-name
6684  """A context manager for use when defining a Python op.
6685
6686  This context manager validates that the given `values` are from the
6687  same graph, makes that graph the default graph, and pushes a
6688  name scope in that graph (see
6689  `tf.Graph.name_scope`
6690  for more details on that).
6691
6692  For example, to define a new Python op called `my_op`:
6693
6694  ```python
6695  def my_op(a, b, c, name=None):
6696    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6697      a = tf.convert_to_tensor(a, name="a")
6698      b = tf.convert_to_tensor(b, name="b")
6699      c = tf.convert_to_tensor(c, name="c")
6700      # Define some computation that uses `a`, `b`, and `c`.
6701      return foo_op(..., name=scope)
6702  ```
6703  """
6704
6705  __slots__ = ["_name", "_name_scope"]
6706
6707  @property
6708  def name(self):
6709    return self._name
6710
6711  def __init__(self, name, default_name=None, values=None):
6712    """Initialize the context manager.
6713
6714    Args:
6715      name: The name argument that is passed to the op function.
6716      default_name: The default name to use if the `name` argument is `None`.
6717      values: The list of `Tensor` arguments that are passed to the op function.
6718
6719    Raises:
6720      TypeError: if `default_name` is passed in but not a string.
6721    """
6722    self._name_scope = name_scope(
6723        name, default_name, values, skip_on_eager=False)
6724    self._name = default_name if name is None else name
6725
6726  def __enter__(self):
6727    return self._name_scope.__enter__()
6728
6729  def __exit__(self, *exc_info):
6730    return self._name_scope.__exit__(*exc_info)
6731
6732
6733@tf_export("get_current_name_scope", v1=[])
6734def get_current_name_scope():
6735  """Returns current full name scope specified by `tf.name_scope(...)`s.
6736
6737  For example,
6738  ```python
6739  with tf.name_scope("outer"):
6740    tf.get_current_name_scope()  # "outer"
6741
6742    with tf.name_scope("inner"):
6743      tf.get_current_name_scope()  # "outer/inner"
6744  ```
6745
6746  In other words, `tf.get_current_name_scope()` returns the op name prefix that
6747  will be prepended to, if an op is created at that place.
6748
6749  Note that `@tf.function` resets the name scope stack as shown below.
6750
6751  ```
6752  with tf.name_scope("outer"):
6753
6754    @tf.function
6755    def foo(x):
6756      with tf.name_scope("inner"):
6757        return tf.add(x * x)  # Op name is "inner/Add", not "outer/inner/Add"
6758  ```
6759  """
6760
6761  ctx = context.context()
6762  if ctx.executing_eagerly():
6763    return ctx.scope_name.rstrip("/")
6764  else:
6765    return get_default_graph().get_name_scope()
6766
6767
6768@tf_export("name_scope", v1=[])
6769class name_scope_v2(object):
6770  """A context manager for use when defining a Python op.
6771
6772  This context manager pushes a name scope, which will make the name of all
6773  operations added within it have a prefix.
6774
6775  For example, to define a new Python op called `my_op`:
6776
6777  ```python
6778  def my_op(a, b, c, name=None):
6779    with tf.name_scope("MyOp") as scope:
6780      a = tf.convert_to_tensor(a, name="a")
6781      b = tf.convert_to_tensor(b, name="b")
6782      c = tf.convert_to_tensor(c, name="c")
6783      # Define some computation that uses `a`, `b`, and `c`.
6784      return foo_op(..., name=scope)
6785  ```
6786
6787  When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6788  and `MyOp/c`.
6789
6790  Inside a `tf.function`, if the scope name already exists, the name will be
6791  made unique by appending `_n`. For example, calling `my_op` the second time
6792  will generate `MyOp_1/a`, etc.
6793  """
6794
6795  __slots__ = ["_name", "_exit_fns"]
6796
6797  def __init__(self, name):
6798    """Initialize the context manager.
6799
6800    Args:
6801      name: The prefix to use on all names created within the name scope.
6802
6803    Raises:
6804      ValueError: If name is not a string.
6805    """
6806    if not isinstance(name, six.string_types):
6807      raise ValueError("name for name_scope must be a string.")
6808    self._name = name
6809    self._exit_fns = []
6810
6811  @property
6812  def name(self):
6813    return self._name
6814
6815  def __enter__(self):
6816    """Start the scope block.
6817
6818    Returns:
6819      The scope name.
6820    """
6821    ctx = context.context()
6822    if ctx.executing_eagerly():
6823      # Names are not auto-incremented in eager mode.
6824      # A trailing slash breaks out of nested name scopes, indicating a
6825      # fully specified scope name, for compatibility with Graph.name_scope.
6826      # This also prevents auto-incrementing.
6827      old_name = ctx.scope_name
6828      name = self._name
6829      if not name:
6830        scope_name = ""
6831      elif name[-1] == "/":
6832        scope_name = name
6833      elif old_name:
6834        scope_name = old_name + name + "/"
6835      else:
6836        scope_name = name + "/"
6837      ctx.scope_name = scope_name
6838
6839      def _restore_name_scope(*_):
6840        ctx.scope_name = old_name
6841
6842      self._exit_fns.append(_restore_name_scope)
6843    else:
6844      scope = get_default_graph().name_scope(self._name)
6845      scope_name = scope.__enter__()
6846      self._exit_fns.append(scope.__exit__)
6847    return scope_name
6848
6849  def __exit__(self, type_arg, value_arg, traceback_arg):
6850    self._exit_fns.pop()(type_arg, value_arg, traceback_arg)
6851    return False  # False values do not suppress exceptions
6852
6853  def __getstate__(self):
6854    return self._name, self._exit_fns
6855
6856  def __setstate__(self, state):
6857    self._name = state[0]
6858    self._exit_fns = state[1]
6859
6860
6861def strip_name_scope(name, export_scope):
6862  """Removes name scope from a name.
6863
6864  Args:
6865    name: A `string` name.
6866    export_scope: Optional `string`. Name scope to remove.
6867
6868  Returns:
6869    Name with name scope removed, or the original name if export_scope
6870    is None.
6871  """
6872  if export_scope:
6873    if export_scope[-1] == "/":
6874      export_scope = export_scope[:-1]
6875
6876    try:
6877      # Strips export_scope/, export_scope///,
6878      # ^export_scope/, loc:@export_scope/.
6879      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
6880      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
6881    except TypeError as e:
6882      # If the name is not of a type we can process, simply return it.
6883      logging.warning(e)
6884      return name
6885  else:
6886    return name
6887
6888
6889def prepend_name_scope(name, import_scope):
6890  """Prepends name scope to a name.
6891
6892  Args:
6893    name: A `string` name.
6894    import_scope: Optional `string`. Name scope to add.
6895
6896  Returns:
6897    Name with name scope added, or the original name if import_scope
6898    is None.
6899  """
6900  if import_scope:
6901    if import_scope[-1] == "/":
6902      import_scope = import_scope[:-1]
6903
6904    try:
6905      str_to_replace = r"([\^]|loc:@|^)(.*)"
6906      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
6907                    compat.as_str(name))
6908    except TypeError as e:
6909      # If the name is not of a type we can process, simply return it.
6910      logging.warning(e)
6911      return name
6912  else:
6913    return name
6914
6915
6916# pylint: disable=g-doc-return-or-yield
6917# pylint: disable=not-context-manager
6918@tf_export(v1=["op_scope"])
6919@tf_contextlib.contextmanager
6920def op_scope(values, name, default_name=None):
6921  """DEPRECATED. Same as name_scope above, just different argument order."""
6922  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
6923               " use tf.name_scope(name, default_name, values)")
6924  with name_scope(name, default_name=default_name, values=values) as scope:
6925    yield scope
6926
6927
6928_proto_function_registry = registry.Registry("proto functions")
6929
6930
6931def register_proto_function(collection_name,
6932                            proto_type=None,
6933                            to_proto=None,
6934                            from_proto=None):
6935  """Registers `to_proto` and `from_proto` functions for collection_name.
6936
6937  `to_proto` function converts a Python object to the corresponding protocol
6938  buffer, and returns the protocol buffer.
6939
6940  `from_proto` function converts protocol buffer into a Python object, and
6941  returns the object..
6942
6943  Args:
6944    collection_name: Name of the collection.
6945    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
6946      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
6947    to_proto: Function that implements Python object to protobuf conversion.
6948    from_proto: Function that implements protobuf to Python object conversion.
6949  """
6950  if to_proto and not callable(to_proto):
6951    raise TypeError("to_proto must be callable.")
6952  if from_proto and not callable(from_proto):
6953    raise TypeError("from_proto must be callable.")
6954
6955  _proto_function_registry.register((proto_type, to_proto, from_proto),
6956                                    collection_name)
6957
6958
6959def get_collection_proto_type(collection_name):
6960  """Returns the proto_type for collection_name."""
6961  try:
6962    return _proto_function_registry.lookup(collection_name)[0]
6963  except LookupError:
6964    return None
6965
6966
6967def get_to_proto_function(collection_name):
6968  """Returns the to_proto function for collection_name."""
6969  try:
6970    return _proto_function_registry.lookup(collection_name)[1]
6971  except LookupError:
6972    return None
6973
6974
6975def get_from_proto_function(collection_name):
6976  """Returns the from_proto function for collection_name."""
6977  try:
6978    return _proto_function_registry.lookup(collection_name)[2]
6979  except LookupError:
6980    return None
6981
6982
6983def _op_to_colocate_with(v, graph):
6984  """Operation object corresponding to v to use for colocation constraints."""
6985  if v is None:
6986    return None, None
6987  if isinstance(v, Operation):
6988    return v, None
6989
6990  # We always want to colocate with the reference op.
6991  # When 'v' is a ResourceVariable, the reference op is the handle creating op.
6992  #
6993  # What this should be is:
6994  # if isinstance(v, ResourceVariable):
6995  #   return v.handle.op, v
6996  # However, that would require a circular import dependency.
6997  # As of October 2018, there were attempts underway to remove
6998  # colocation constraints altogether. Assuming that will
6999  # happen soon, perhaps this hack to work around the circular
7000  # import dependency is acceptable.
7001  if hasattr(v, "handle") and isinstance(v.handle, Tensor):
7002    device_only_candidate = lambda: None
7003    device_only_candidate.device = v.device
7004    device_only_candidate.name = v.name
7005    if graph.building_function:
7006      return graph.capture(v.handle).op, device_only_candidate
7007    else:
7008      return v.handle.op, device_only_candidate
7009  return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None
7010
7011
7012def _is_keras_symbolic_tensor(x):
7013  return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
7014
7015
7016# These symbols were originally defined in this module; import them for
7017# backwards compatibility until all references have been updated to access
7018# them from the indexed_slices.py module.
7019IndexedSlices = indexed_slices.IndexedSlices
7020IndexedSlicesValue = indexed_slices.IndexedSlicesValue
7021convert_to_tensor_or_indexed_slices = \
7022    indexed_slices.convert_to_tensor_or_indexed_slices
7023convert_n_to_tensor_or_indexed_slices = \
7024    indexed_slices.convert_n_to_tensor_or_indexed_slices
7025internal_convert_to_tensor_or_indexed_slices = \
7026    indexed_slices.internal_convert_to_tensor_or_indexed_slices
7027internal_convert_n_to_tensor_or_indexed_slices = \
7028    indexed_slices.internal_convert_n_to_tensor_or_indexed_slices
7029register_tensor_conversion_function = \
7030    tensor_conversion_registry.register_tensor_conversion_function
7031
7032
7033# Helper functions for op wrapper modules generated by `python_op_gen`.
7034
7035
7036def to_raw_op(f):
7037  """Make a given op wrapper function `f` raw.
7038
7039  Raw op wrappers can only be called with keyword arguments.
7040
7041  Args:
7042    f: An op wrapper function to make raw.
7043
7044  Returns:
7045    Raw `f`.
7046  """
7047  # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail
7048  # due to double-registration.
7049  f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__,
7050                         f.__closure__)
7051  return kwarg_only(f)
7052
7053
7054def raise_from_not_ok_status(e, name):
7055  message = e.message + (" name: " + name if name is not None else "")
7056  raise core._status_to_exception(e.code, message) from None  # pylint: disable=protected-access
7057
7058
7059def add_exit_callback_to_default_func_graph(fn):
7060  """Add a callback to run when the default function graph goes out of scope.
7061
7062  Usage:
7063
7064  ```python
7065  @tf.function
7066  def fn(x, v):
7067    expensive = expensive_object(v)
7068    add_exit_callback_to_default_func_graph(lambda: expensive.release())
7069    return g(x, expensive)
7070
7071  fn(x=tf.constant(...), v=...)
7072  # `expensive` has been released.
7073  ```
7074
7075  Args:
7076    fn: A callable that takes no arguments and whose output is ignored.
7077      To be executed when exiting func graph scope.
7078
7079  Raises:
7080    RuntimeError: If executed when the current default graph is not a FuncGraph,
7081      or not currently executing in function creation mode (e.g., if inside
7082      an init_scope).
7083  """
7084  default_graph = get_default_graph()
7085  if not default_graph._building_function:  # pylint: disable=protected-access
7086    raise RuntimeError(
7087        "Cannot add scope exit callbacks when not building a function.  "
7088        "Default graph: {}".format(default_graph))
7089  default_graph._add_scope_exit_callback(fn)  # pylint: disable=protected-access
7090
7091
7092def _reconstruct_sequence_inputs(op_def, inputs, attrs):
7093  """Regroups a flat list of input tensors into scalar and sequence inputs.
7094
7095  Args:
7096    op_def: The `op_def_pb2.OpDef` (for knowing the input types)
7097    inputs: a list of input `Tensor`s to the op.
7098    attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
7099      how long each sequence is)
7100
7101  Returns:
7102    A list of `Tensor`s (corresponding to scalar inputs) and lists of
7103    `Tensor`s (corresponding to sequence inputs).
7104  """
7105  grouped_inputs = []
7106  i = 0
7107  for input_arg in op_def.input_arg:
7108    if input_arg.number_attr:
7109      input_len = attrs[input_arg.number_attr].i
7110      is_sequence = True
7111    elif input_arg.type_list_attr:
7112      input_len = len(attrs[input_arg.type_list_attr].list.type)
7113      is_sequence = True
7114    else:
7115      input_len = 1
7116      is_sequence = False
7117
7118    if is_sequence:
7119      grouped_inputs.append(inputs[i:i + input_len])
7120    else:
7121      grouped_inputs.append(inputs[i])
7122    i += input_len
7123
7124  assert i == len(inputs)
7125  return grouped_inputs
7126
7127
7128_numpy_style_type_promotion = False
7129
7130
7131def enable_numpy_style_type_promotion():
7132  """If called, follows NumPy's rules for type promotion.
7133
7134  Used for enabling NumPy behavior on methods for TF NumPy.
7135  """
7136  global _numpy_style_type_promotion
7137  _numpy_style_type_promotion = True
7138
7139
7140_numpy_style_slicing = False
7141
7142
7143def enable_numpy_style_slicing():
7144  """If called, follows NumPy's rules for slicing Tensors.
7145
7146  Used for enabling NumPy behavior on slicing for TF NumPy.
7147  """
7148  global _numpy_style_slicing
7149  _numpy_style_slicing = True
7150
7151
7152class _TensorIterator(object):
7153  """Iterates over the leading dim of a Tensor. Performs no error checks."""
7154
7155  __slots__ = ["_tensor", "_index", "_limit"]
7156
7157  def __init__(self, tensor, dim0):
7158    self._tensor = tensor
7159    self._index = 0
7160    self._limit = dim0
7161
7162  def __iter__(self):
7163    return self
7164
7165  def __next__(self):
7166    if self._index == self._limit:
7167      raise StopIteration
7168    result = self._tensor[self._index]
7169    self._index += 1
7170    return result
7171
7172  next = __next__  # python2.x compatibility.
7173
7174
7175def set_int_list_attr(op, attr_name, ints):
7176  """TF internal method used to set a list(int) attribute in the node_def."""
7177  ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
7178  op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list))  # pylint:disable=protected-access
7179
7180
7181def _get_enclosing_context(graph):
7182  # pylint: disable=protected-access
7183  if graph is None:
7184    return None
7185
7186  if graph._control_flow_context is not None:
7187    return graph._control_flow_context
7188
7189  if graph.building_function and hasattr(graph, "outer_graph"):
7190    return _get_enclosing_context(graph.outer_graph)
7191
7192
7193def get_resource_handle_data(graph_op):
7194  assert type(graph_op) == Tensor  # pylint: disable=unidiomatic-typecheck
7195
7196  handle_data = pywrap_tf_session.GetHandleShapeAndType(
7197      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
7198
7199  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
7200      compat.as_bytes(handle_data))
7201
7202
7203def _copy_handle_data_to_arg_def(tensor, arg_def):
7204  handle_data = get_resource_handle_data(tensor)
7205  if handle_data.shape_and_type:
7206    shape_and_type = handle_data.shape_and_type[0]
7207    proto = arg_def.handle_data.add()
7208    proto.dtype = shape_and_type.dtype
7209    proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
7210