• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23import sys
24import threading
25import types
26
27import numpy as np
28import six
29from six.moves import map  # pylint: disable=redefined-builtin
30from six.moves import xrange  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.framework import function_pb2
34from tensorflow.core.framework import graph_pb2
35from tensorflow.core.framework import node_def_pb2
36from tensorflow.core.framework import op_def_pb2
37from tensorflow.core.framework import versions_pb2
38from tensorflow.core.protobuf import config_pb2
39# pywrap_tensorflow must be imported first to avoid profobuf issues.
40# (b/143110113)
41# pylint: disable=invalid-import-order,g-bad-import-order,unused-import
42from tensorflow.python import pywrap_tensorflow
43from tensorflow.python import pywrap_tfe
44# pylint: enable=invalid-import-order,g-bad-import-order,unused-import
45from tensorflow.python import tf2
46from tensorflow.python.client import pywrap_tf_session
47from tensorflow.python.eager import context
48from tensorflow.python.eager import core
49from tensorflow.python.eager import monitoring
50from tensorflow.python.eager import tape
51from tensorflow.python.framework import c_api_util
52from tensorflow.python.framework import composite_tensor
53from tensorflow.python.framework import cpp_shape_inference_pb2
54from tensorflow.python.framework import device as pydev
55from tensorflow.python.framework import dtypes
56from tensorflow.python.framework import errors
57from tensorflow.python.framework import indexed_slices
58from tensorflow.python.framework import registry
59from tensorflow.python.framework import tensor_conversion_registry
60from tensorflow.python.framework import tensor_shape
61from tensorflow.python.framework import traceable_stack
62from tensorflow.python.framework import versions
63from tensorflow.python.ops import control_flow_util
64from tensorflow.python.platform import app
65from tensorflow.python.platform import tf_logging as logging
66from tensorflow.python.profiler import trace
67from tensorflow.python.types import core as core_tf_types
68from tensorflow.python.types import internal
69from tensorflow.python.util import compat
70from tensorflow.python.util import decorator_utils
71from tensorflow.python.util import deprecation
72from tensorflow.python.util import dispatch
73from tensorflow.python.util import function_utils
74from tensorflow.python.util import lock_util
75from tensorflow.python.util import memory
76from tensorflow.python.util import object_identity
77from tensorflow.python.util import tf_contextlib
78from tensorflow.python.util import tf_stack
79from tensorflow.python.util.compat import collections_abc
80from tensorflow.python.util.deprecation import deprecated_args
81from tensorflow.python.util.lazy_loader import LazyLoader
82from tensorflow.python.util.tf_export import kwarg_only
83from tensorflow.python.util.tf_export import tf_export
84
85ag_ctx = LazyLoader(
86    "ag_ctx", globals(),
87    "tensorflow.python.autograph.core.ag_ctx")
88
89
90# Temporary global switches determining if we should enable the work-in-progress
91# calls to the C API. These will be removed once all functionality is supported.
92_USE_C_API = True
93_USE_C_SHAPES = True
94
95_api_usage_gauge = monitoring.BoolGauge(
96    "/tensorflow/api/ops_eager_execution",
97    "Whether ops.enable_eager_execution() is called.")
98
99_tensor_equality_api_usage_gauge = monitoring.BoolGauge(
100    "/tensorflow/api/enable_tensor_equality",
101    "Whether ops.enable_tensor_equality() is called.")
102
103_control_flow_api_gauge = monitoring.BoolGauge(
104    "/tensorflow/api/enable_control_flow_v2",
105    "Whether enable_control_flow_v2() is called.")
106
107_tf_function_api_guage = monitoring.BoolGauge(
108    "/tensorflow/api/tf_function",
109    "Whether tf.function() is used.")
110
111# pylint: disable=protected-access
112_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE
113# pylint: enable=protected-access
114
115
116def tensor_id(tensor):
117  """Returns a unique identifier for this Tensor."""
118  return tensor._id  # pylint: disable=protected-access
119
120
121class _UserDeviceSpec(object):
122  """Store user-specified device and provide computation of merged device."""
123
124  def __init__(self, device_name_or_function):
125    self._device_name_or_function = device_name_or_function
126    self.display_name = str(self._device_name_or_function)
127    self.function = device_name_or_function
128    self.raw_string = None
129
130    if isinstance(device_name_or_function, pydev.MergeDevice):
131      self.is_null_merge = device_name_or_function.is_null_merge
132
133    elif callable(device_name_or_function):
134      self.is_null_merge = False
135      dev_func = self._device_name_or_function
136      func_name = function_utils.get_func_name(dev_func)
137      func_code = function_utils.get_func_code(dev_func)
138      if func_code:
139        fname = func_code.co_filename
140        lineno = func_code.co_firstlineno
141      else:
142        fname = "unknown"
143        lineno = -1
144      self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
145
146    elif device_name_or_function is None:
147      # NOTE(taylorrobie): This MUST be False. None signals a break in the
148      #   device stack, so `is_null_merge` must be False for such a case to
149      #   allow callers to safely skip over null merges without missing a None.
150      self.is_null_merge = False
151
152    else:
153      self.raw_string = device_name_or_function
154      self.function = pydev.merge_device(device_name_or_function)
155      self.is_null_merge = self.function.is_null_merge
156
157    # We perform this check in __init__ because it is of non-trivial cost,
158    # and self.string_merge is typically called many times.
159    self.fast_string_merge = isinstance(self.function, pydev.MergeDevice)
160
161  def string_merge(self, node_def):
162    if self.fast_string_merge:
163      return self.function.shortcut_string_merge(node_def)
164
165    return compat.as_str(_device_string(self.function(node_def)))
166
167
168class NullContextmanager(object):
169
170  def __init__(self, *args, **kwargs):
171    pass
172
173  def __enter__(self):
174    pass
175
176  def __exit__(self, type_arg, value_arg, traceback_arg):
177    return False  # False values do not suppress exceptions
178
179
180def _override_helper(clazz_object, operator, func):
181  """Overrides (string) operator on Tensors to call func.
182
183  Args:
184    clazz_object: the class to override for; either Tensor or SparseTensor.
185    operator: the string name of the operator to override.
186    func: the function that replaces the overridden operator.
187
188  Raises:
189    ValueError: If operator is not allowed to be overwritten.
190  """
191  if operator not in Tensor.OVERLOADABLE_OPERATORS:
192    raise ValueError("Overriding %s is disallowed" % operator)
193  setattr(clazz_object, operator, func)
194
195
196def _as_graph_element(obj):
197  """Convert `obj` to a graph element if possible, otherwise return `None`.
198
199  Args:
200    obj: Object to convert.
201
202  Returns:
203    The result of `obj._as_graph_element()` if that method is available;
204        otherwise `None`.
205  """
206  conv_fn = getattr(obj, "_as_graph_element", None)
207  if conv_fn and callable(conv_fn):
208    return conv_fn()
209  return None
210
211
212# Deprecated - do not use.
213# This API to avoid breaking estimator and tensorflow-mesh which depend on this
214# internal API. The stub should be safe to use after TF 2.3 is released.
215def is_dense_tensor_like(t):
216  return isinstance(t, core_tf_types.Tensor)
217
218
219def uid():
220  """A unique (within this program execution) integer."""
221  return pywrap_tfe.TFE_Py_UID()
222
223
224def numpy_text(tensor, is_repr=False):
225  """Human readable representation of a tensor's numpy value."""
226  if tensor.dtype.is_numpy_compatible:
227    # pylint: disable=protected-access
228    text = repr(tensor._numpy()) if is_repr else str(tensor._numpy())
229    # pylint: enable=protected-access
230  else:
231    text = "<unprintable>"
232  if "\n" in text:
233    text = "\n" + text
234  return text
235
236@tf_export(v1=["enable_tensor_equality"])
237def enable_tensor_equality():
238  """Compare Tensors with element-wise comparison and thus be unhashable.
239
240  Comparing tensors with element-wise allows comparisons such as
241  tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are
242  unhashable. Thus tensors can no longer be directly used in sets or as a key in
243  a dictionary.
244  """
245  _tensor_equality_api_usage_gauge.get_cell().set(True)
246  Tensor._USE_EQUALITY = True  # pylint: disable=protected-access
247
248
249@tf_export(v1=["disable_tensor_equality"])
250def disable_tensor_equality():
251  """Compare Tensors by their id and be hashable.
252
253  This is a legacy behaviour of TensorFlow and is highly discouraged.
254  """
255  _tensor_equality_api_usage_gauge.get_cell().set(False)
256  Tensor._USE_EQUALITY = False  # pylint: disable=protected-access
257
258
259# TODO(mdan): This object should subclass Symbol, not just Tensor.
260@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
261class Tensor(internal.NativeObject, core_tf_types.Tensor):
262  """A tensor is a multidimensional array of elements represented by a
263
264  `tf.Tensor` object.  All elements are of a single known data type.
265
266  When writing a TensorFlow program, the main object that is
267  manipulated and passed around is the `tf.Tensor`.
268
269  A `tf.Tensor` has the following properties:
270
271  * a single data type (float32, int32, or string, for example)
272  * a shape
273
274  TensorFlow supports eager execution and graph execution.  In eager
275  execution, operations are evaluated immediately.  In graph
276  execution, a computational graph is constructed for later
277  evaluation.
278
279  TensorFlow defaults to eager execution.  In the example below, the
280  matrix multiplication results are calculated immediately.
281
282  >>> # Compute some values using a Tensor
283  >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
284  >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
285  >>> e = tf.matmul(c, d)
286  >>> print(e)
287  tf.Tensor(
288  [[1. 3.]
289   [3. 7.]], shape=(2, 2), dtype=float32)
290
291  Note that during eager execution, you may discover your `Tensors` are actually
292  of type `EagerTensor`.  This is an internal detail, but it does give you
293  access to a useful function, `numpy`:
294
295  >>> type(e)
296  <class '...ops.EagerTensor'>
297  >>> print(e.numpy())
298    [[1. 3.]
299     [3. 7.]]
300
301  In TensorFlow, `tf.function`s are a common way to define graph execution.
302
303  A Tensor's shape (that is, the rank of the Tensor and the size of
304  each dimension) may not always be fully known.  In `tf.function`
305  definitions, the shape may only be partially known.
306
307  Most operations produce tensors of fully-known shapes if the shapes of their
308  inputs are also fully known, but in some cases it's only possible to find the
309  shape of a tensor at execution time.
310
311  A number of specialized tensors are available: see `tf.Variable`,
312  `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and
313  `tf.RaggedTensor`.
314
315  For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor).
316
317  """
318
319  # List of Python operators that we allow to override.
320  OVERLOADABLE_OPERATORS = {
321      # Binary.
322      "__add__",
323      "__radd__",
324      "__sub__",
325      "__rsub__",
326      "__mul__",
327      "__rmul__",
328      "__div__",
329      "__rdiv__",
330      "__truediv__",
331      "__rtruediv__",
332      "__floordiv__",
333      "__rfloordiv__",
334      "__mod__",
335      "__rmod__",
336      "__lt__",
337      "__le__",
338      "__gt__",
339      "__ge__",
340      "__ne__",
341      "__eq__",
342      "__and__",
343      "__rand__",
344      "__or__",
345      "__ror__",
346      "__xor__",
347      "__rxor__",
348      "__getitem__",
349      "__pow__",
350      "__rpow__",
351      # Unary.
352      "__invert__",
353      "__neg__",
354      "__abs__",
355      "__matmul__",
356      "__rmatmul__"
357  }
358
359  # Whether to allow hashing or numpy-style equality
360  _USE_EQUALITY = tf2.enabled()
361
362  def __init__(self, op, value_index, dtype):
363    """Creates a new `Tensor`.
364
365    Args:
366      op: An `Operation`. `Operation` that computes this tensor.
367      value_index: An `int`. Index of the operation's endpoint that produces
368        this tensor.
369      dtype: A `DType`. Type of elements stored in this tensor.
370
371    Raises:
372      TypeError: If the op is not an `Operation`.
373    """
374    if not isinstance(op, Operation):
375      raise TypeError("op needs to be an Operation: %s" % (op,))
376    self._op = op
377    self._value_index = value_index
378    self._dtype = dtypes.as_dtype(dtype)
379    # This will be set by self._as_tf_output().
380    self._tf_output = None
381    # This will be set by self.shape().
382    self._shape_val = None
383    # List of operations that use this Tensor as input.  We maintain this list
384    # to easily navigate a computation graph.
385    self._consumers = []
386    self._id = uid()
387    self._name = None
388
389  def __getattr__(self, name):
390    if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
391                "tolist", "data"}:
392      # TODO(wangpeng): Export the enable_numpy_behavior knob
393      raise AttributeError("""
394        If you are looking for numpy-related methods, please run the following:
395        import tensorflow.python.ops.numpy_ops.np_config
396        np_config.enable_numpy_behavior()""")
397    self.__getattribute__(name)
398
399  @staticmethod
400  def _create_with_tf_output(op, value_index, dtype, tf_output):
401    ret = Tensor(op, value_index, dtype)
402    ret._tf_output = tf_output
403    return ret
404
405  @property
406  def op(self):
407    """The `Operation` that produces this tensor as an output."""
408    return self._op
409
410  @property
411  def dtype(self):
412    """The `DType` of elements in this tensor."""
413    return self._dtype
414
415  @property
416  def graph(self):
417    """The `Graph` that contains this tensor."""
418    return self._op.graph
419
420  @property
421  def name(self):
422    """The string name of this tensor."""
423    if self._name is None:
424      if not self._op.name:
425        raise ValueError("Operation was not named: %s" % self._op)
426      self._name = "%s:%d" % (self._op.name, self._value_index)
427    return self._name
428
429  @property
430  def device(self):
431    """The name of the device on which this tensor will be produced, or None."""
432    return self._op.device
433
434  @property
435  def shape(self):
436    """Returns a `tf.TensorShape` that represents the shape of this tensor.
437
438    >>> t = tf.constant([1,2,3,4,5])
439    >>> t.shape
440    TensorShape([5])
441
442    `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`.
443
444    In a `tf.function` or when building a model using
445    `tf.keras.Input`, they return the build-time shape of the
446    tensor, which may be partially unknown.
447
448    A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor
449    containing the shape, calculated at runtime.
450
451    See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples.
452    """
453    if self._shape_val is None:
454      self._shape_val = self._c_api_shape()
455    return self._shape_val
456
457  def _c_api_shape(self):
458    """Returns the TensorShape of this tensor according to the C API."""
459    c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
460    shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper(
461        c_graph, self._as_tf_output())
462    if unknown_shape:
463      return tensor_shape.unknown_shape()
464    else:
465      shape_vec = [None if d == -1 else d for d in shape_vec]
466      return tensor_shape.TensorShape(shape_vec)
467
468  @property
469  def _shape(self):
470    logging.warning("Tensor._shape is private, use Tensor.shape "
471                    "instead. Tensor._shape will eventually be removed.")
472    return self.shape
473
474  @_shape.setter
475  def _shape(self, value):
476    raise ValueError(
477        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
478
479  def _disallow_when_autograph_disabled(self, task):
480    raise errors.OperatorNotAllowedInGraphError(
481        "{} is not allowed: AutoGraph is disabled in this function."
482        " Try decorating it directly with @tf.function.".format(task))
483
484  def _disallow_when_autograph_enabled(self, task):
485    raise errors.OperatorNotAllowedInGraphError(
486        "{} is not allowed: AutoGraph did convert this function. This might"
487        " indicate you are trying to use an unsupported feature.".format(task))
488
489  def _disallow_in_graph_mode(self, task):
490    raise errors.OperatorNotAllowedInGraphError(
491        "{} is not allowed in Graph execution. Use Eager execution or decorate"
492        " this function with @tf.function.".format(task))
493
494  def _disallow_bool_casting(self):
495    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
496      self._disallow_when_autograph_disabled(
497          "using a `tf.Tensor` as a Python `bool`")
498    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
499      self._disallow_when_autograph_enabled(
500          "using a `tf.Tensor` as a Python `bool`")
501    else:
502      # Default: V1-style Graph execution.
503      self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
504
505  def _disallow_iteration(self):
506    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
507      self._disallow_when_autograph_disabled("iterating over `tf.Tensor`")
508    elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
509      self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
510    else:
511      # Default: V1-style Graph execution.
512      self._disallow_in_graph_mode("iterating over `tf.Tensor`")
513
514  def __iter__(self):
515    if not context.executing_eagerly():
516      self._disallow_iteration()
517
518    shape = self._shape_tuple()
519    if shape is None:
520      raise TypeError("Cannot iterate over a tensor with unknown shape.")
521    if not shape:
522      raise TypeError("Cannot iterate over a scalar tensor.")
523    if shape[0] is None:
524      raise TypeError(
525          "Cannot iterate over a tensor with unknown first dimension.")
526    return _TensorIterator(self, shape[0])
527
528  def _shape_as_list(self):
529    if self.shape.ndims is not None:
530      return [dim.value for dim in self.shape.dims]
531    else:
532      return None
533
534  def _shape_tuple(self):
535    shape = self._shape_as_list()
536    if shape is None:
537      return None
538    return tuple(shape)
539
540  def _rank(self):
541    """Integer rank of this Tensor, if known, else None.
542
543    Returns:
544      Integer rank or None
545    """
546    return self.shape.ndims
547
548  def get_shape(self):
549    """Returns a `tf.TensorShape` that represents the shape of this tensor.
550
551    In eager execution the shape is always fully-known.
552
553    >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
554    >>> print(a.shape)
555    (2, 3)
556
557    `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`.
558
559
560    When executing in a `tf.function` or building a model using
561    `tf.keras.Input`, `Tensor.shape` may return a partial shape (including
562    `None` for unknown dimensions). See `tf.TensorShape` for more details.
563
564    >>> inputs = tf.keras.Input(shape = [10])
565    >>> # Unknown batch size
566    >>> print(inputs.shape)
567    (None, 10)
568
569    The shape is computed using shape inference functions that are
570    registered for each `tf.Operation`.
571
572    The returned `tf.TensorShape` is determined at *build* time, without
573    executing the underlying kernel. It is not a `tf.Tensor`. If you need a
574    shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or
575    use the `tf.shape(tensor)` function, which returns the tensor's shape at
576    *execution* time.
577
578    This is useful for debugging and providing early errors. For
579    example, when tracing a `tf.function`, no ops are being executed, shapes
580    may be unknown (See the [Concrete Functions
581    Guide](https://www.tensorflow.org/guide/concrete_function) for details).
582
583    >>> @tf.function
584    ... def my_matmul(a, b):
585    ...   result = a@b
586    ...   # the `print` executes during tracing.
587    ...   print("Result shape: ", result.shape)
588    ...   return result
589
590    The shape inference functions propagate shapes to the extent possible:
591
592    >>> f = my_matmul.get_concrete_function(
593    ...   tf.TensorSpec([None,3]),
594    ...   tf.TensorSpec([3,5]))
595    Result shape: (None, 5)
596
597    Tracing may fail if a shape missmatch can be detected:
598
599    >>> cf = my_matmul.get_concrete_function(
600    ...   tf.TensorSpec([None,3]),
601    ...   tf.TensorSpec([4,5]))
602    Traceback (most recent call last):
603    ...
604    ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op:
605    'MatMul') with input shapes: [?,3], [4,5].
606
607    In some cases, the inferred shape may have unknown dimensions. If
608    the caller has additional information about the values of these
609    dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment
610    the inferred shape.
611
612    >>> @tf.function
613    ... def my_fun(a):
614    ...   a = tf.ensure_shape(a, [5, 5])
615    ...   # the `print` executes during tracing.
616    ...   print("Result shape: ", a.shape)
617    ...   return a
618
619    >>> cf = my_fun.get_concrete_function(
620    ...   tf.TensorSpec([None, None]))
621    Result shape: (5, 5)
622
623    Returns:
624      A `tf.TensorShape` representing the shape of this tensor.
625
626    """
627    return self.shape
628
629  def set_shape(self, shape):
630    """Updates the shape of this tensor.
631
632    Note: It is recommended to use `tf.ensure_shape` instead of
633    `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for
634    programming errors and can create guarantees for compiler
635    optimization.
636
637    With eager execution this operates as a shape assertion.
638    Here the shapes match:
639
640    >>> t = tf.constant([[1,2,3]])
641    >>> t.set_shape([1, 3])
642
643    Passing a `None` in the new shape allows any value for that axis:
644
645    >>> t.set_shape([1,None])
646
647    An error is raised if an incompatible shape is passed.
648
649    >>> t.set_shape([1,5])
650    Traceback (most recent call last):
651    ...
652    ValueError: Tensor's shape (1, 3) is not compatible with supplied
653    shape [1, 5]
654
655    When executing in a `tf.function`, or building a model using
656    `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with
657    the current shape of this tensor, and set the tensor's shape to the
658    merged value (see `tf.TensorShape.merge_with` for details):
659
660    >>> t = tf.keras.Input(shape=[None, None, 3])
661    >>> print(t.shape)
662    (None, None, None, 3)
663
664    Dimensions set to `None` are not updated:
665
666    >>> t.set_shape([None, 224, 224, None])
667    >>> print(t.shape)
668    (None, 224, 224, 3)
669
670    The main use case for this is to provide additional shape information
671    that cannot be inferred from the graph alone.
672
673    For example if you know all the images in a dataset have shape [28,28,3] you
674    can set it with `tf.set_shape`:
675
676    >>> @tf.function
677    ... def load_image(filename):
678    ...   raw = tf.io.read_file(filename)
679    ...   image = tf.image.decode_png(raw, channels=3)
680    ...   # the `print` executes during tracing.
681    ...   print("Initial shape: ", image.shape)
682    ...   image.set_shape([28, 28, 3])
683    ...   print("Final shape: ", image.shape)
684    ...   return image
685
686    Trace the function, see the [Concrete Functions
687    Guide](https://www.tensorflow.org/guide/concrete_function) for details.
688
689    >>> cf = load_image.get_concrete_function(
690    ...     tf.TensorSpec([], dtype=tf.string))
691    Initial shape:  (None, None, 3)
692    Final shape: (28, 28, 3)
693
694    Similarly the `tf.io.parse_tensor` function could return a tensor with
695    any shape, even the `tf.rank` is unknown. If you know that all your
696    serialized tensors will be 2d, set it with `set_shape`:
697
698    >>> @tf.function
699    ... def my_parse(string_tensor):
700    ...   result = tf.io.parse_tensor(string_tensor, out_type=tf.float32)
701    ...   # the `print` executes during tracing.
702    ...   print("Initial shape: ", result.shape)
703    ...   result.set_shape([None, None])
704    ...   print("Final shape: ", result.shape)
705    ...   return result
706
707    Trace the function
708
709    >>> concrete_parse = my_parse.get_concrete_function(
710    ...     tf.TensorSpec([], dtype=tf.string))
711    Initial shape:  <unknown>
712    Final shape:  (None, None)
713
714    Make sure it works:
715
716    >>> t = tf.ones([5,3], dtype=tf.float32)
717    >>> serialized = tf.io.serialize_tensor(t)
718    >>> print(serialized.dtype)
719    <dtype: 'string'>
720    >>> print(serialized.shape)
721    ()
722    >>> t2 = concrete_parse(serialized)
723    >>> print(t2.shape)
724    (5, 3)
725
726    Caution: `set_shape` ensures that the applied shape is compatible with
727    the existing shape, but it does not check at runtime. Setting
728    incorrect shapes can result in inconsistencies between the
729    statically-known graph and the runtime value of tensors. For runtime
730    validation of the shape, use `tf.ensure_shape` instead. It also modifies
731    the `shape` of the tensor.
732
733    >>> # Serialize a rank-3 tensor
734    >>> t = tf.ones([5,5,5], dtype=tf.float32)
735    >>> serialized = tf.io.serialize_tensor(t)
736    >>> # The function still runs, even though it `set_shape([None,None])`
737    >>> t2 = concrete_parse(serialized)
738    >>> print(t2.shape)
739    (5, 5, 5)
740
741    Args:
742      shape: A `TensorShape` representing the shape of this tensor, a
743        `TensorShapeProto`, a list, a tuple, or None.
744
745    Raises:
746      ValueError: If `shape` is not compatible with the current shape of
747        this tensor.
748    """
749    # Reset cached shape.
750    self._shape_val = None
751
752    # We want set_shape to be reflected in the C API graph for when we run it.
753    if not isinstance(shape, tensor_shape.TensorShape):
754      shape = tensor_shape.TensorShape(shape)
755    dim_list = []
756    if shape.dims is None:
757      unknown_shape = True
758    else:
759      unknown_shape = False
760      for dim in shape.dims:
761        if dim.value is None:
762          dim_list.append(-1)
763        else:
764          dim_list.append(dim.value)
765    try:
766      pywrap_tf_session.TF_GraphSetTensorShape_wrapper(
767          self._op._graph._c_graph,  # pylint: disable=protected-access
768          self._as_tf_output(),
769          dim_list,
770          unknown_shape)
771    except errors.InvalidArgumentError as e:
772      # Convert to ValueError for backwards compatibility.
773      raise ValueError(str(e))
774
775  @property
776  def value_index(self):
777    """The index of this tensor in the outputs of its `Operation`."""
778    return self._value_index
779
780  def consumers(self):
781    """Returns a list of `Operation`s that consume this tensor.
782
783    Returns:
784      A list of `Operation`s.
785    """
786    consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper(
787        self._as_tf_output())
788    # pylint: disable=protected-access
789    return [
790        self.graph._get_operation_by_name_unsafe(name)
791        for name in consumer_names
792    ]
793    # pylint: enable=protected-access
794
795  def _as_node_def_input(self):
796    """Return a value to use for the NodeDef "input" attribute.
797
798    The returned string can be used in a NodeDef "input" attribute
799    to indicate that the NodeDef uses this Tensor as input.
800
801    Raises:
802      ValueError: if this Tensor's Operation does not have a name.
803
804    Returns:
805      a string.
806    """
807    if not self._op.name:
808      raise ValueError("Operation was not named: %s" % self._op)
809    if self._value_index == 0:
810      return self._op.name
811    else:
812      return "%s:%d" % (self._op.name, self._value_index)
813
814  def _as_tf_output(self):
815    # pylint: disable=protected-access
816    # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
817    # also guarantees that a dictionary of tf_output objects will retain a
818    # deterministic (yet unsorted) order which prevents memory blowup in the
819    # cache of executor(s) stored for every session.
820    if self._tf_output is None:
821      self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
822    return self._tf_output
823    # pylint: enable=protected-access
824
825  def __str__(self):
826    return "Tensor(\"%s\"%s%s%s)" % (
827        self.name,
828        (", shape=%s" %
829         self.get_shape()) if self.get_shape().ndims is not None else "",
830        (", dtype=%s" % self._dtype.name) if self._dtype else "",
831        (", device=%s" % self.device) if self.device else "")
832
833  def __repr__(self):
834    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
835                                                   self._dtype.name)
836
837  def __hash__(self):
838    g = getattr(self, "graph", None)
839    if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and
840        (g is None or g.building_function)):
841      raise TypeError("Tensor is unhashable. "
842                      "Instead, use tensor.ref() as the key.")
843    else:
844      return id(self)
845
846  def __copy__(self):
847    # TODO(b/77597810): get rid of Tensor copies.
848    cls = self.__class__
849    result = cls.__new__(cls)
850    result.__dict__.update(self.__dict__)
851    return result
852
853  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
854  # operators to run when the left operand is an ndarray, because it
855  # accords the Tensor class higher priority than an ndarray, or a
856  # numpy matrix.
857  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
858  # mechanism, which allows more control over how Tensors interact
859  # with ndarrays.
860  __array_priority__ = 100
861
862  def __array__(self):
863    raise NotImplementedError(
864        "Cannot convert a symbolic Tensor ({}) to a numpy array."
865        " This error may indicate that you're trying to pass a Tensor to"
866        " a NumPy call, which is not supported".format(self.name))
867
868  def __len__(self):
869    raise TypeError("len is not well defined for symbolic Tensors. ({}) "
870                    "Please call `x.shape` rather than `len(x)` for "
871                    "shape information.".format(self.name))
872
873  # TODO(mdan): This convoluted machinery is hard to maintain. Clean up.
874  @staticmethod
875  def _override_operator(operator, func):
876    _override_helper(Tensor, operator, func)
877
878  def __bool__(self):
879    """Dummy method to prevent a tensor from being used as a Python `bool`.
880
881    This overload raises a `TypeError` when the user inadvertently
882    treats a `Tensor` as a boolean (most commonly in an `if` or `while`
883    statement), in code that was not converted by AutoGraph. For example:
884
885    ```python
886    if tf.constant(True):  # Will raise.
887      # ...
888
889    if tf.constant(5) < tf.constant(7):  # Will raise.
890      # ...
891    ```
892
893    Raises:
894      `TypeError`.
895    """
896    self._disallow_bool_casting()
897
898  def __nonzero__(self):
899    """Dummy method to prevent a tensor from being used as a Python `bool`.
900
901    This is the Python 2.x counterpart to `__bool__()` above.
902
903    Raises:
904      `TypeError`.
905    """
906    self._disallow_bool_casting()
907
908  def eval(self, feed_dict=None, session=None):
909    """Evaluates this tensor in a `Session`.
910
911    Note: If you are not using `compat.v1` libraries, you should not need this,
912    (or `feed_dict` or `Session`).  In eager execution (or within `tf.function`)
913    you do not need to call `eval`.
914
915    Calling this method will execute all preceding operations that
916    produce the inputs needed for the operation that produces this
917    tensor.
918
919    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
920    launched in a session, and either a default session must be
921    available, or `session` must be specified explicitly.
922
923    Args:
924      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
925        `tf.Session.run` for a description of the valid feed values.
926      session: (Optional.) The `Session` to be used to evaluate this tensor. If
927        none, the default session will be used.
928
929    Returns:
930      A numpy array corresponding to the value of this tensor.
931    """
932    return _eval_using_default_session(self, feed_dict, self.graph, session)
933
934  @deprecation.deprecated(None, "Use ref() instead.")
935  def experimental_ref(self):
936    return self.ref()
937
938  def ref(self):
939    # tf.Variable also has the same ref() API.  If you update the
940    # documentation here, please update tf.Variable.ref() as well.
941    """Returns a hashable reference object to this Tensor.
942
943    The primary use case for this API is to put tensors in a set/dictionary.
944    We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer
945    available starting Tensorflow 2.0.
946
947    The following will raise an exception starting 2.0
948
949    >>> x = tf.constant(5)
950    >>> y = tf.constant(10)
951    >>> z = tf.constant(10)
952    >>> tensor_set = {x, y, z}
953    Traceback (most recent call last):
954      ...
955    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
956    >>> tensor_dict = {x: 'five', y: 'ten'}
957    Traceback (most recent call last):
958      ...
959    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
960
961    Instead, we can use `tensor.ref()`.
962
963    >>> tensor_set = {x.ref(), y.ref(), z.ref()}
964    >>> x.ref() in tensor_set
965    True
966    >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'}
967    >>> tensor_dict[y.ref()]
968    'ten'
969
970    Also, the reference object provides `.deref()` function that returns the
971    original Tensor.
972
973    >>> x = tf.constant(5)
974    >>> x.ref().deref()
975    <tf.Tensor: shape=(), dtype=int32, numpy=5>
976    """
977    return object_identity.Reference(self)
978
979
980# TODO(agarwal): consider getting rid of this.
981# TODO(mdan): This object should not subclass ops.Tensor.
982class _EagerTensorBase(Tensor):
983  """Base class for EagerTensor."""
984
985  # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and
986  # only work for scalars; values are cast as per numpy.
987  def __complex__(self):
988    return complex(self._numpy())
989
990  def __int__(self):
991    return int(self._numpy())
992
993  def __long__(self):
994    return long(self._numpy())
995
996  def __float__(self):
997    return float(self._numpy())
998
999  def __index__(self):
1000    return self._numpy().__index__()
1001
1002  def __bool__(self):
1003    return bool(self._numpy())
1004
1005  __nonzero__ = __bool__
1006
1007  def __format__(self, format_spec):
1008    return self._numpy().__format__(format_spec)
1009
1010  def __reduce__(self):
1011    return convert_to_tensor, (self._numpy(),)
1012
1013  def __copy__(self):
1014    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1015    return self
1016
1017  def __deepcopy__(self, memo):
1018    # Eager Tensors are immutable so it's safe to return themselves as a copy.
1019    del memo
1020    return self
1021
1022  def __str__(self):
1023    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape,
1024                                                  self.dtype.name)
1025
1026  def __repr__(self):
1027    return "<tf.Tensor: shape=%s, dtype=%s, numpy=%s>" % (
1028        self.shape, self.dtype.name, numpy_text(self, is_repr=True))
1029
1030  def __len__(self):
1031    """Returns the length of the first dimension in the Tensor."""
1032    if not self.shape.ndims:
1033      raise TypeError("Scalar tensor has no `len()`")
1034    # pylint: disable=protected-access
1035    try:
1036      return self._shape_tuple()[0]
1037    except core._NotOkStatusException as e:
1038      six.raise_from(core._status_to_exception(e.code, e.message), None)
1039
1040  def __array__(self):
1041    return self._numpy()
1042
1043  def _numpy_internal(self):
1044    raise NotImplementedError()
1045
1046  def _numpy(self):
1047    try:
1048      return self._numpy_internal()
1049    except core._NotOkStatusException as e:  # pylint: disable=protected-access
1050      six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access
1051
1052  @property
1053  def dtype(self):
1054    # Note: using the intern table directly here as this is
1055    # performance-sensitive in some models.
1056    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
1057
1058  def numpy(self):
1059    """Copy of the contents of this Tensor into a NumPy array or scalar.
1060
1061    Unlike NumPy arrays, Tensors are immutable, so this method has to copy
1062    the contents to ensure safety. Use `memoryview` to get a readonly
1063    view of the contents without doing a copy:
1064
1065    >>> t = tf.constant([42])
1066    >>> np.array(memoryview(t))
1067    array([42], dtype=int32)
1068
1069    Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor
1070    is on GPU, it will have to be transferred to CPU first in order for
1071    `memoryview` to work.
1072
1073    Returns:
1074      A NumPy array of the same shape and dtype or a NumPy scalar, if this
1075      Tensor has rank 0.
1076
1077    Raises:
1078      ValueError: If the dtype of this Tensor does not have a compatible
1079        NumPy dtype.
1080    """
1081    # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
1082    maybe_arr = self._numpy()  # pylint: disable=protected-access
1083    return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
1084
1085  @property
1086  def backing_device(self):
1087    """Returns the name of the device holding this tensor's memory.
1088
1089    `.backing_device` is usually the same as `.device`, which returns
1090    the device on which the kernel of the operation that produced this tensor
1091    ran. However, some operations can produce tensors on a different device
1092    (e.g., an operation that executes on the GPU but produces output tensors
1093    in host memory).
1094    """
1095    raise NotImplementedError()
1096
1097  def _datatype_enum(self):
1098    raise NotImplementedError()
1099
1100  def _shape_tuple(self):
1101    """The shape of this Tensor, as a tuple.
1102
1103    This is more performant than tuple(shape().as_list()) as it avoids
1104    two list and one object creation. Marked private for now as from an API
1105    perspective, it would be better to have a single performant way of
1106    getting a shape rather than exposing shape() and shape_tuple()
1107    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
1108    but ideally one would work things out and remove the need for this method.
1109
1110    Returns:
1111      tuple with the shape.
1112    """
1113    raise NotImplementedError()
1114
1115  def _rank(self):
1116    """Integer rank of this Tensor.
1117
1118    Unlike regular Tensors, the rank is always known for EagerTensors.
1119
1120    This is more performant than len(self._shape_tuple())
1121
1122    Returns:
1123      Integer rank
1124    """
1125    raise NotImplementedError()
1126
1127  def _num_elements(self):
1128    """Number of elements of this Tensor.
1129
1130    Unlike regular Tensors, the number of elements is always known for
1131    EagerTensors.
1132
1133    This is more performant than tensor.shape.num_elements
1134
1135    Returns:
1136      Long - num elements in the tensor
1137    """
1138    raise NotImplementedError()
1139
1140  def _copy_to_device(self, device_name):  # pylint: disable=redefined-outer-name
1141    raise NotImplementedError()
1142
1143  @staticmethod
1144  def _override_operator(name, func):
1145    setattr(_EagerTensorBase, name, func)
1146
1147  def _copy_nograd(self, ctx=None, device_name=None):
1148    """Copies tensor to dest device, but doesn't record the operation."""
1149    # Creates a new tensor on the dest device.
1150    if ctx is None:
1151      ctx = context.context()
1152    if device_name is None:
1153      device_name = ctx.device_name
1154    # pylint: disable=protected-access
1155    try:
1156      ctx.ensure_initialized()
1157      new_tensor = self._copy_to_device(device_name)
1158    except core._NotOkStatusException as e:
1159      six.raise_from(core._status_to_exception(e.code, e.message), None)
1160    return new_tensor
1161
1162  def _copy(self, ctx=None, device_name=None):
1163    """Copies tensor to dest device."""
1164    new_tensor = self._copy_nograd(ctx, device_name)
1165    # Record the copy on tape and define backprop copy as well.
1166    if context.executing_eagerly():
1167      self_device = self.device
1168
1169      def grad_fun(dresult):
1170        return [
1171            dresult._copy(device_name=self_device)
1172            if hasattr(dresult, "_copy") else dresult
1173        ]
1174
1175      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
1176    return new_tensor
1177    # pylint: enable=protected-access
1178
1179  @property
1180  def shape(self):
1181    if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
1182      # pylint: disable=protected-access
1183      try:
1184        # `_tensor_shape` is declared and defined in the definition of
1185        # `EagerTensor`, in C.
1186        self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
1187      except core._NotOkStatusException as e:
1188        six.raise_from(core._status_to_exception(e.code, e.message), None)
1189
1190    return self._tensor_shape
1191
1192  def get_shape(self):
1193    """Alias of Tensor.shape."""
1194    return self.shape
1195
1196  def _shape_as_list(self):
1197    """The shape of the tensor as a list."""
1198    return list(self._shape_tuple())
1199
1200  @property
1201  def ndim(self):
1202    """Returns the number of Tensor dimensions."""
1203    return self.shape.ndims
1204
1205  @deprecation.deprecated(None, "Use tf.identity instead.")
1206  def cpu(self):
1207    """A copy of this Tensor with contents backed by host memory."""
1208    return self._copy(context.context(), "CPU:0")
1209
1210  @deprecation.deprecated(None, "Use tf.identity instead.")
1211  def gpu(self, gpu_index=0):
1212    """A copy of this Tensor with contents backed by memory on the GPU.
1213
1214    Args:
1215      gpu_index: Identifies which GPU to place the contents on the returned
1216        Tensor in.
1217
1218    Returns:
1219      A GPU-memory backed Tensor object initialized with the same contents
1220      as this Tensor.
1221    """
1222    return self._copy(context.context(), "GPU:" + str(gpu_index))
1223
1224  def set_shape(self, shape):
1225    if not self.shape.is_compatible_with(shape):
1226      raise ValueError(
1227          "Tensor's shape %s is not compatible with supplied shape %s" %
1228          (self.shape, shape))
1229
1230  # Methods not supported / implemented for Eager Tensors.
1231  @property
1232  def op(self):
1233    raise AttributeError(
1234        "Tensor.op is meaningless when eager execution is enabled.")
1235
1236  @property
1237  def graph(self):
1238    raise AttributeError(
1239        "Tensor.graph is meaningless when eager execution is enabled.")
1240
1241  @property
1242  def name(self):
1243    raise AttributeError(
1244        "Tensor.name is meaningless when eager execution is enabled.")
1245
1246  @property
1247  def value_index(self):
1248    raise AttributeError(
1249        "Tensor.value_index is meaningless when eager execution is enabled.")
1250
1251  def consumers(self):
1252    raise NotImplementedError(
1253        "Tensor.consumers is meaningless when eager execution is enabled.")
1254
1255  def _add_consumer(self, consumer):
1256    raise NotImplementedError(
1257        "_add_consumer not supported when eager execution is enabled.")
1258
1259  def _as_node_def_input(self):
1260    raise NotImplementedError(
1261        "_as_node_def_input not supported when eager execution is enabled.")
1262
1263  def _as_tf_output(self):
1264    raise NotImplementedError(
1265        "_as_tf_output not supported when eager execution is enabled.")
1266
1267  def eval(self, feed_dict=None, session=None):
1268    raise NotImplementedError(
1269        "eval is not supported when eager execution is enabled, "
1270        "is .numpy() what you're looking for?")
1271
1272
1273# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
1274# registers it with the current module.
1275# It is exposed as an __internal__ api for now (b/171081052), though we
1276# expect it to be eventually covered by tf Tensor types and typing.
1277EagerTensor = tf_export("__internal__.EagerTensor", v1=[])(
1278    pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase))
1279
1280
1281@tf_export(v1=["convert_to_tensor"])
1282@dispatch.add_dispatch_support
1283def convert_to_tensor_v1_with_dispatch(
1284    value,
1285    dtype=None,
1286    name=None,
1287    preferred_dtype=None,
1288    dtype_hint=None):
1289  """Converts the given `value` to a `Tensor`.
1290
1291  This function converts Python objects of various types to `Tensor`
1292  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1293  and Python scalars. For example:
1294
1295  ```python
1296  import numpy as np
1297
1298  def my_func(arg):
1299    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1300    return tf.matmul(arg, arg) + arg
1301
1302  # The following calls are equivalent.
1303  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1304  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1305  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1306  ```
1307
1308  This function can be useful when composing a new operation in Python
1309  (such as `my_func` in the example above). All standard Python op
1310  constructors apply this function to each of their Tensor-valued
1311  inputs, which allows those ops to accept numpy arrays, Python lists,
1312  and scalars in addition to `Tensor` objects.
1313
1314  Note: This function diverges from default Numpy behavior for `float` and
1315    `string` types when `None` is present in a Python list or scalar. Rather
1316    than silently converting `None` values, an error will be thrown.
1317
1318  Args:
1319    value: An object whose type has a registered `Tensor` conversion function.
1320    dtype: Optional element type for the returned tensor. If missing, the type
1321      is inferred from the type of `value`.
1322    name: Optional name to use if a new `Tensor` is created.
1323    preferred_dtype: Optional element type for the returned tensor, used when
1324      dtype is None. In some cases, a caller may not have a dtype in mind when
1325      converting to a tensor, so preferred_dtype can be used as a soft
1326      preference.  If the conversion to `preferred_dtype` is not possible, this
1327      argument has no effect.
1328    dtype_hint: same meaning as preferred_dtype, and overrides it.
1329
1330  Returns:
1331    A `Tensor` based on `value`.
1332
1333  Raises:
1334    TypeError: If no conversion function is registered for `value` to `dtype`.
1335    RuntimeError: If a registered conversion function returns an invalid value.
1336    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1337  """
1338  return convert_to_tensor_v1(value, dtype=dtype, name=name,
1339                              preferred_dtype=preferred_dtype,
1340                              dtype_hint=dtype_hint)
1341
1342
1343def convert_to_tensor_v1(value,
1344                         dtype=None,
1345                         name=None,
1346                         preferred_dtype=None,
1347                         dtype_hint=None):
1348  """Converts the given `value` to a `Tensor` (with the TF1 API)."""
1349  preferred_dtype = deprecation.deprecated_argument_lookup(
1350      "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
1351  return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
1352
1353
1354@tf_export("convert_to_tensor", v1=[])
1355@dispatch.add_dispatch_support
1356def convert_to_tensor_v2_with_dispatch(
1357    value, dtype=None, dtype_hint=None, name=None):
1358  """Converts the given `value` to a `Tensor`.
1359
1360  This function converts Python objects of various types to `Tensor`
1361  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1362  and Python scalars.
1363
1364  For example:
1365
1366  >>> import numpy as np
1367  >>> def my_func(arg):
1368  ...   arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1369  ...   return arg
1370
1371  >>> # The following calls are equivalent.
1372  ...
1373  >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1374  >>> print(value_1)
1375  tf.Tensor(
1376    [[1. 2.]
1377     [3. 4.]], shape=(2, 2), dtype=float32)
1378  >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1379  >>> print(value_2)
1380  tf.Tensor(
1381    [[1. 2.]
1382     [3. 4.]], shape=(2, 2), dtype=float32)
1383  >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1384  >>> print(value_3)
1385  tf.Tensor(
1386    [[1. 2.]
1387     [3. 4.]], shape=(2, 2), dtype=float32)
1388
1389  This function can be useful when composing a new operation in Python
1390  (such as `my_func` in the example above). All standard Python op
1391  constructors apply this function to each of their Tensor-valued
1392  inputs, which allows those ops to accept numpy arrays, Python lists,
1393  and scalars in addition to `Tensor` objects.
1394
1395  Note: This function diverges from default Numpy behavior for `float` and
1396    `string` types when `None` is present in a Python list or scalar. Rather
1397    than silently converting `None` values, an error will be thrown.
1398
1399  Args:
1400    value: An object whose type has a registered `Tensor` conversion function.
1401    dtype: Optional element type for the returned tensor. If missing, the type
1402      is inferred from the type of `value`.
1403    dtype_hint: Optional element type for the returned tensor, used when dtype
1404      is None. In some cases, a caller may not have a dtype in mind when
1405      converting to a tensor, so dtype_hint can be used as a soft preference.
1406      If the conversion to `dtype_hint` is not possible, this argument has no
1407      effect.
1408    name: Optional name to use if a new `Tensor` is created.
1409
1410  Returns:
1411    A `Tensor` based on `value`.
1412
1413  Raises:
1414    TypeError: If no conversion function is registered for `value` to `dtype`.
1415    RuntimeError: If a registered conversion function returns an invalid value.
1416    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1417  """
1418  return convert_to_tensor_v2(
1419      value, dtype=dtype, dtype_hint=dtype_hint, name=name)
1420
1421
1422def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
1423  """Converts the given `value` to a `Tensor`."""
1424  return convert_to_tensor(
1425      value=value,
1426      dtype=dtype,
1427      name=name,
1428      preferred_dtype=dtype_hint,
1429      as_ref=False)
1430
1431
1432def _error_prefix(name):
1433  return "" if name is None else "%s: " % name
1434
1435
1436def pack_eager_tensors(tensors, ctx=None):
1437  """Pack multiple `EagerTensor`s of the same dtype and shape.
1438
1439  Args:
1440    tensors: a list of EagerTensors to pack.
1441    ctx: context.context().
1442
1443  Returns:
1444    A packed EagerTensor.
1445  """
1446  if not isinstance(tensors, list):
1447    raise TypeError("tensors must be a list or a tuple: %s" % tensors)
1448
1449  if not tensors:
1450    raise ValueError("Empty tensors is unexpected for packing.")
1451
1452  dtype = tensors[0].dtype
1453  shape = tensors[0].shape
1454  handle_data = tensors[0]._handle_data  # pylint: disable=protected-access
1455  is_resource = dtype == dtypes.resource
1456  for i in range(len(tensors)):
1457    t = tensors[i]
1458    if not isinstance(t, EagerTensor):
1459      raise TypeError("tensors must be a list of EagerTensors: %s" % t)
1460
1461    if t.dtype != dtype:
1462      raise ValueError(
1463          "All tensors being packed should have the same dtype %s, "
1464          "but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype))
1465    if t.shape != shape:
1466      raise ValueError(
1467          "All tensors being packed should have the same shape %s, "
1468          "but the %d-th tensor is of shape %s" % (shape, i, t.shape))
1469    # pylint: disable=protected-access
1470    if is_resource and t._handle_data != handle_data:
1471      raise ValueError(
1472          "All tensors being packed should have the same handle data %s, "
1473          "but the %d-th tensor is of handle data %s" %
1474          (handle_data, i, t._handle_data))
1475    # pylint: enable=protected-access
1476
1477  if ctx is None:
1478    ctx = context.context()
1479
1480  # Propogate handle data for resource variables
1481  packed_tensor = ctx.pack_eager_tensors(tensors)
1482  if handle_data is not None:
1483    packed_tensor._handle_data = handle_data  # pylint: disable=protected-access
1484
1485  def grad_fun(_):
1486    raise ValueError(
1487        "Gradients through pack_eager_tensors are not supported yet.")
1488
1489  tape.record_operation("pack_eager_tensors", [packed_tensor], tensors,
1490                        grad_fun)
1491
1492  return packed_tensor
1493
1494
1495@trace.trace_wrapper("convert_to_tensor")
1496def convert_to_tensor(value,
1497                      dtype=None,
1498                      name=None,
1499                      as_ref=False,
1500                      preferred_dtype=None,
1501                      dtype_hint=None,
1502                      ctx=None,
1503                      accepted_result_types=(Tensor,)):
1504  """Implementation of the public convert_to_tensor."""
1505  # TODO(b/142518781): Fix all call-sites and remove redundant arg
1506  preferred_dtype = preferred_dtype or dtype_hint
1507  if isinstance(value, EagerTensor):
1508    if ctx is None:
1509      ctx = context.context()
1510    if not ctx.executing_eagerly():
1511      graph = get_default_graph()
1512      if not graph.building_function:
1513        raise RuntimeError("Attempting to capture an EagerTensor without "
1514                           "building a function.")
1515      return graph.capture(value, name=name)
1516
1517  if dtype is not None:
1518    dtype = dtypes.as_dtype(dtype)
1519  if isinstance(value, Tensor):
1520    if dtype is not None and not dtype.is_compatible_with(value.dtype):
1521      raise ValueError(
1522          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1523          (dtype.name, value.dtype.name, value))
1524    return value
1525
1526  if preferred_dtype is not None:
1527    preferred_dtype = dtypes.as_dtype(preferred_dtype)
1528
1529  # See below for the reason why it's `type(value)` and not just `value`.
1530  # https://docs.python.org/3.8/reference/datamodel.html#special-lookup
1531  overload = getattr(type(value), "__tf_tensor__", None)
1532  if overload is not None:
1533    return overload(value, dtype, name)
1534
1535  for base_type, conversion_func in tensor_conversion_registry.get(type(value)):
1536    # If dtype is None but preferred_dtype is not None, we try to
1537    # cast to preferred_dtype first.
1538    ret = None
1539    if dtype is None and preferred_dtype is not None:
1540      try:
1541        ret = conversion_func(
1542            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1543      except (TypeError, ValueError):
1544        # Could not coerce the conversion to use the preferred dtype.
1545        pass
1546      else:
1547        if (ret is not NotImplemented and
1548            ret.dtype.base_dtype != preferred_dtype.base_dtype):
1549          raise TypeError("convert_to_tensor did not convert to "
1550                          "the preferred dtype: %s vs %s " %
1551                          (ret.dtype.base_dtype, preferred_dtype.base_dtype))
1552
1553    if ret is None:
1554      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1555
1556    if ret is NotImplemented:
1557      continue
1558
1559    if not isinstance(ret, accepted_result_types):
1560      raise RuntimeError(
1561          "%sConversion function %r for type %s returned non-Tensor: %r" %
1562          (_error_prefix(name), conversion_func, base_type, ret))
1563    if dtype and not dtype.is_compatible_with(ret.dtype):
1564      raise RuntimeError(
1565          "%sConversion function %r for type %s returned incompatible "
1566          "dtype: requested = %s, actual = %s" %
1567          (_error_prefix(name), conversion_func, base_type, dtype.name,
1568           ret.dtype.name))
1569    return ret
1570  raise TypeError("%sCannot convert %r with type %s to Tensor: "
1571                  "no conversion function registered." %
1572                  (_error_prefix(name), value, type(value)))
1573
1574
1575internal_convert_to_tensor = convert_to_tensor
1576
1577
1578def internal_convert_n_to_tensor(values,
1579                                 dtype=None,
1580                                 name=None,
1581                                 as_ref=False,
1582                                 preferred_dtype=None,
1583                                 ctx=None):
1584  """Converts `values` to a list of `Tensor` objects.
1585
1586  Args:
1587    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1588    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1589    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1590      which case element `i` will be given the name `name + '_' + i`.
1591    as_ref: True if the caller wants the results as ref tensors.
1592    preferred_dtype: Optional element type for the returned tensors, used when
1593      dtype is None. In some cases, a caller may not have a dtype in mind when
1594      converting to a tensor, so preferred_dtype can be used as a soft
1595      preference.  If the conversion to `preferred_dtype` is not possible, this
1596      argument has no effect.
1597    ctx: The value of context.context().
1598
1599  Returns:
1600    A list of `Tensor` and/or `IndexedSlices` objects.
1601
1602  Raises:
1603    TypeError: If no conversion function is registered for an element in
1604      `values`.
1605    RuntimeError: If a registered conversion function returns an invalid
1606      value.
1607  """
1608  if not isinstance(values, collections_abc.Sequence):
1609    raise TypeError("values must be a sequence.")
1610  ret = []
1611  if ctx is None:
1612    ctx = context.context()
1613  for i, value in enumerate(values):
1614    n = None if name is None else "%s_%d" % (name, i)
1615    ret.append(
1616        convert_to_tensor(
1617            value,
1618            dtype=dtype,
1619            name=n,
1620            as_ref=as_ref,
1621            preferred_dtype=preferred_dtype,
1622            ctx=ctx))
1623  return ret
1624
1625
1626def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1627  """Converts `values` to a list of `Tensor` objects.
1628
1629  Args:
1630    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1631    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1632    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1633      which case element `i` will be given the name `name + '_' + i`.
1634    preferred_dtype: Optional element type for the returned tensors, used when
1635      dtype is None. In some cases, a caller may not have a dtype in mind when
1636      converting to a tensor, so preferred_dtype can be used as a soft
1637      preference.  If the conversion to `preferred_dtype` is not possible, this
1638      argument has no effect.
1639
1640  Returns:
1641    A list of `Tensor` and/or `IndexedSlices` objects.
1642
1643  Raises:
1644    TypeError: If no conversion function is registered for an element in
1645      `values`.
1646    RuntimeError: If a registered conversion function returns an invalid
1647      value.
1648  """
1649  return internal_convert_n_to_tensor(
1650      values=values,
1651      dtype=dtype,
1652      name=name,
1653      preferred_dtype=preferred_dtype,
1654      as_ref=False)
1655
1656
1657def convert_to_tensor_or_composite(value, dtype=None, name=None):
1658  """Converts the given object to a `Tensor` or `CompositeTensor`.
1659
1660  If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1661  is converted to a `Tensor` using `convert_to_tensor()`.
1662
1663  Args:
1664    value: A `CompositeTensor` or an object that can be consumed by
1665      `convert_to_tensor()`.
1666    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1667      `CompositeTensor`.
1668    name: (Optional.) A name to use if a new `Tensor` is created.
1669
1670  Returns:
1671    A `Tensor` or `CompositeTensor`, based on `value`.
1672
1673  Raises:
1674    ValueError: If `dtype` does not match the element type of `value`.
1675  """
1676  return internal_convert_to_tensor_or_composite(
1677      value=value, dtype=dtype, name=name, as_ref=False)
1678
1679
1680def internal_convert_to_tensor_or_composite(value,
1681                                            dtype=None,
1682                                            name=None,
1683                                            as_ref=False):
1684  """Converts the given object to a `Tensor` or `CompositeTensor`.
1685
1686  If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
1687  is converted to a `Tensor` using `convert_to_tensor()`.
1688
1689  Args:
1690    value: A `CompositeTensor`, or an object that can be consumed by
1691      `convert_to_tensor()`.
1692    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1693      `CompositeTensor`.
1694    name: (Optional.) A name to use if a new `Tensor` is created.
1695    as_ref: True if the caller wants the results as ref tensors.
1696
1697  Returns:
1698    A `Tensor` or `CompositeTensor`, based on `value`.
1699
1700  Raises:
1701    ValueError: If `dtype` does not match the element type of `value`.
1702  """
1703  if isinstance(value, composite_tensor.CompositeTensor):
1704    value_dtype = getattr(value, "dtype", None)
1705    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1706      raise ValueError(
1707          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1708          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1709    return value
1710  else:
1711    return convert_to_tensor(
1712        value,
1713        dtype=dtype,
1714        name=name,
1715        as_ref=as_ref,
1716        accepted_result_types=(Tensor, composite_tensor.CompositeTensor))
1717
1718
1719def internal_convert_n_to_tensor_or_composite(values,
1720                                              dtype=None,
1721                                              name=None,
1722                                              as_ref=False):
1723  """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1724
1725  Any `CompositeTensor` objects in `values` are returned unmodified.
1726
1727  Args:
1728    values: A list of `None`, `CompositeTensor`, or objects that can be consumed
1729      by `convert_to_tensor()`.
1730    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1731      `CompositeTensor`s.
1732    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1733      which case element `i` will be given the name `name + '_' + i`.
1734    as_ref: True if the caller wants the results as ref tensors.
1735
1736  Returns:
1737    A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1738
1739  Raises:
1740    TypeError: If no conversion function is registered for an element in
1741      `values`.
1742    RuntimeError: If a registered conversion function returns an invalid
1743      value.
1744  """
1745  if not isinstance(values, collections_abc.Sequence):
1746    raise TypeError("values must be a sequence.")
1747  ret = []
1748  for i, value in enumerate(values):
1749    if value is None:
1750      ret.append(value)
1751    else:
1752      n = None if name is None else "%s_%d" % (name, i)
1753      ret.append(
1754          internal_convert_to_tensor_or_composite(
1755              value, dtype=dtype, name=n, as_ref=as_ref))
1756  return ret
1757
1758
1759def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1760  """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1761
1762  Any `CompositeTensor` objects in `values` are returned unmodified.
1763
1764  Args:
1765    values: A list of `None`, `CompositeTensor``, or objects that can be
1766      consumed by `convert_to_tensor()`.
1767    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1768      `CompositeTensor`s.
1769    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
1770      which case element `i` will be given the name `name + '_' + i`.
1771
1772  Returns:
1773    A list of `Tensor` and/or `CompositeTensor` objects.
1774
1775  Raises:
1776    TypeError: If no conversion function is registered for an element in
1777      `values`.
1778    RuntimeError: If a registered conversion function returns an invalid
1779      value.
1780  """
1781  return internal_convert_n_to_tensor_or_composite(
1782      values=values, dtype=dtype, name=name, as_ref=False)
1783
1784
1785def _device_string(dev_spec):
1786  if pydev.is_device_spec(dev_spec):
1787    return dev_spec.to_string()
1788  else:
1789    return dev_spec
1790
1791
1792def _NodeDef(op_type, name, attrs=None):
1793  """Create a NodeDef proto.
1794
1795  Args:
1796    op_type: Value for the "op" attribute of the NodeDef proto.
1797    name: Value for the "name" attribute of the NodeDef proto.
1798    attrs: Dictionary where the key is the attribute name (a string)
1799      and the value is the respective "attr" attribute of the NodeDef proto (an
1800      AttrValue).
1801
1802  Returns:
1803    A node_def_pb2.NodeDef protocol buffer.
1804  """
1805  node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type),
1806                                  name=compat.as_bytes(name))
1807  if attrs:
1808    for k, v in six.iteritems(attrs):
1809      node_def.attr[k].CopyFrom(v)
1810  return node_def
1811
1812
1813# Copied from core/framework/node_def_util.cc
1814# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1815_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
1816_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$")
1817
1818
1819@tf_export("__internal__.create_c_op", v1=[])
1820def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None):
1821  """Creates a TF_Operation.
1822
1823  Args:
1824    graph: a `Graph`.
1825    node_def: `node_def_pb2.NodeDef` for the operation to create.
1826    inputs: A flattened list of `Tensor`s. This function handles grouping
1827      tensors into lists as per attributes in the `node_def`.
1828    control_inputs: A list of `Operation`s to set as control dependencies.
1829    op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not
1830      specified, is looked up from the `graph` using `node_def.op`.
1831
1832  Returns:
1833    A wrapped TF_Operation*.
1834  """
1835  if op_def is None:
1836    op_def = graph._get_op_def(node_def.op)  # pylint: disable=protected-access
1837  # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1838  # Refactor so we don't have to do this here.
1839  inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr)
1840  # pylint: disable=protected-access
1841  op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph,
1842                                              compat.as_str(node_def.op),
1843                                              compat.as_str(node_def.name))
1844  if node_def.device:
1845    pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1846  # Add inputs
1847  for op_input in inputs:
1848    if isinstance(op_input, (list, tuple)):
1849      pywrap_tf_session.TF_AddInputList(op_desc,
1850                                        [t._as_tf_output() for t in op_input])
1851    else:
1852      pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output())
1853
1854  # Add control inputs
1855  for control_input in control_inputs:
1856    pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op)
1857  # pylint: enable=protected-access
1858
1859  # Add attrs
1860  for name, attr_value in node_def.attr.items():
1861    serialized = attr_value.SerializeToString()
1862    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1863    # It might be worth creating a convenient way to re-use the same status.
1864    pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name),
1865                                           serialized)
1866
1867  try:
1868    c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1869  except errors.InvalidArgumentError as e:
1870    # Convert to ValueError for backwards compatibility.
1871    raise ValueError(str(e))
1872
1873  return c_op
1874
1875
1876@tf_export("Operation")
1877class Operation(object):
1878  """Represents a graph node that performs computation on tensors.
1879
1880  An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor`
1881  objects as input, and produces zero or more `Tensor` objects as output.
1882  Objects of type `Operation` are created by calling a Python op constructor
1883  (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default`
1884  context manager.
1885
1886  For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an
1887  `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and
1888  produces `c` as output.
1889
1890  If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be
1891  executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for
1892  calling `tf.compat.v1.get_default_session().run(op)`.
1893  """
1894
1895  def __init__(self,
1896               node_def,
1897               g,
1898               inputs=None,
1899               output_types=None,
1900               control_inputs=None,
1901               input_types=None,
1902               original_op=None,
1903               op_def=None):
1904    r"""Creates an `Operation`.
1905
1906    NOTE: This constructor validates the name of the `Operation` (passed
1907    as `node_def.name`). Valid `Operation` names match the following
1908    regular expression:
1909
1910        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1911
1912    Args:
1913      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`. Used for
1914        attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and
1915        `device`.  The `input` attribute is irrelevant here as it will be
1916        computed when generating the model.
1917      g: `Graph`. The parent graph.
1918      inputs: list of `Tensor` objects. The inputs to this `Operation`.
1919      output_types: list of `DType` objects.  List of the types of the `Tensors`
1920        computed by this operation.  The length of this list indicates the
1921        number of output endpoints of the `Operation`.
1922      control_inputs: list of operations or tensors from which to have a control
1923        dependency.
1924      input_types: List of `DType` objects representing the types of the tensors
1925        accepted by the `Operation`.  By default uses `[x.dtype.base_dtype for x
1926        in inputs]`.  Operations that expect reference-typed inputs must specify
1927        these explicitly.
1928      original_op: Optional. Used to associate the new `Operation` with an
1929        existing `Operation` (for example, a replica with the op that was
1930        replicated).
1931      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type
1932        that this `Operation` represents.
1933
1934    Raises:
1935      TypeError: if control inputs are not Operations or Tensors,
1936        or if `node_def` is not a `NodeDef`,
1937        or if `g` is not a `Graph`,
1938        or if `inputs` are not tensors,
1939        or if `inputs` and `input_types` are incompatible.
1940      ValueError: if the `node_def` name is not valid.
1941    """
1942    # For internal use only: `node_def` can be set to a TF_Operation to create
1943    # an Operation for that op. This is useful for creating Operations for ops
1944    # indirectly created by C API methods, e.g. the ops created by
1945    # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
1946    # should be None.
1947
1948    if isinstance(node_def, node_def_pb2.NodeDef):
1949      if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1950        raise ValueError(
1951            "Cannot create a tensor proto whose content is larger than 2GB.")
1952      if not _VALID_OP_NAME_REGEX.match(node_def.name):
1953        raise ValueError("'%s' is not a valid node name" % node_def.name)
1954      c_op = None
1955    elif type(node_def).__name__ == "TF_Operation":
1956      assert inputs is None
1957      assert output_types is None
1958      assert control_inputs is None
1959      assert input_types is None
1960      assert original_op is None
1961      assert op_def is None
1962      c_op = node_def
1963    else:
1964      raise TypeError("node_def needs to be a NodeDef: %s" % (node_def,))
1965
1966    if not isinstance(g, Graph):
1967      raise TypeError("g needs to be a Graph: %s" % (g,))
1968    self._graph = g
1969
1970    if inputs is None:
1971      inputs = []
1972    elif not isinstance(inputs, list):
1973      raise TypeError("inputs needs to be a list of Tensors: %s" % (inputs,))
1974    for a in inputs:
1975      if not isinstance(a, Tensor):
1976        raise TypeError("input needs to be a Tensor: %s" % (a,))
1977    if input_types is None:
1978      input_types = [i.dtype.base_dtype for i in inputs]
1979    else:
1980      if not all(
1981          x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)):
1982        raise TypeError("In op '%s', input types (%s) are not compatible "
1983                        "with expected types (%s)" %
1984                        (node_def.name, [i.dtype for i in inputs], input_types))
1985
1986    # Build the list of control inputs.
1987    control_input_ops = []
1988    if control_inputs:
1989      for c in control_inputs:
1990        control_op = None
1991        if isinstance(c, Operation):
1992          control_op = c
1993        elif isinstance(c, (Tensor, IndexedSlices)):
1994          control_op = c.op
1995        else:
1996          raise TypeError("Control input must be an Operation, "
1997                          "a Tensor, or IndexedSlices: %s" % c)
1998        control_input_ops.append(control_op)
1999
2000    # This will be set by self.inputs.
2001    self._inputs_val = None
2002
2003    # pylint: disable=protected-access
2004    self._original_op = original_op
2005
2006    # List of _UserDevSpecs holding code location of device context manager
2007    # invocations and the users original argument to them.
2008    self._device_code_locations = None
2009    # Dict mapping op name to file and line information for op colocation
2010    # context managers.
2011    self._colocation_code_locations = None
2012    self._control_flow_context = self.graph._get_control_flow_context()
2013
2014    # Gradient function for this op. There are three ways to specify gradient
2015    # function, and first available gradient gets used, in the following order.
2016    # 1. self._gradient_function
2017    # 2. Gradient name registered by "_gradient_op_type" attribute.
2018    # 3. Gradient name registered by op.type.
2019    self._gradient_function = None
2020
2021    # Initialize self._c_op.
2022    if c_op:
2023      self._c_op = c_op
2024      op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op))
2025      name = self.name
2026    else:
2027      if op_def is None:
2028        op_def = self._graph._get_op_def(node_def.op)
2029      self._c_op = _create_c_op(self._graph, node_def, inputs,
2030                                control_input_ops, op_def)
2031      name = compat.as_str(node_def.name)
2032
2033    self._traceback = tf_stack.extract_stack_for_node(self._c_op)
2034
2035    # pylint: enable=protected-access
2036
2037    self._is_stateful = op_def.is_stateful
2038
2039    # Initialize self._outputs.
2040    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2041    self._outputs = []
2042    for i in range(num_outputs):
2043      tf_output = c_api_util.tf_output(self._c_op, i)
2044      output_type = pywrap_tf_session.TF_OperationOutputType(tf_output)
2045      tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output)  # pylint: disable=protected-access
2046      self._outputs.append(tensor)
2047
2048    self._id_value = self._graph._add_op(self, name)  # pylint: disable=protected-access
2049
2050    if not c_op:
2051      self._control_flow_post_processing(input_tensors=inputs)
2052
2053  def _control_flow_post_processing(self, input_tensors=None):
2054    """Add this op to its control flow context.
2055
2056    This may add new ops and change this op's inputs. self.inputs must be
2057    available before calling this method.
2058
2059    Args:
2060      input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs
2061        of this op, which should be equivalent to `self.inputs`. Pass this
2062        argument to avoid evaluating `self.inputs` unnecessarily.
2063    """
2064    if input_tensors is None:
2065      input_tensors = self.inputs
2066    for input_tensor in input_tensors:
2067      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
2068    if self._control_flow_context is not None:
2069      self._control_flow_context.AddOp(self)
2070
2071  def colocation_groups(self):
2072    """Returns the list of colocation groups of the op."""
2073    default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)]
2074    try:
2075      class_attr = self.get_attr("_class")
2076    except ValueError:
2077      # This op has no explicit colocation group, so it is itself its
2078      # own root of a colocation group.
2079      return default_colocation_group
2080
2081    attr_groups = [
2082        class_name for class_name in class_attr
2083        if class_name.startswith(b"loc:@")
2084    ]
2085
2086    # If there are no colocation groups in the explicit _class field,
2087    # return the default colocation group.
2088    return attr_groups if attr_groups else default_colocation_group
2089
2090  def values(self):
2091    """DEPRECATED: Use outputs."""
2092    return tuple(self.outputs)
2093
2094  def _get_control_flow_context(self):
2095    """Returns the control flow context of this op.
2096
2097    Returns:
2098      A context object.
2099    """
2100    return self._control_flow_context
2101
2102  def _set_control_flow_context(self, ctx):
2103    """Sets the current control flow context of this op.
2104
2105    Args:
2106      ctx: a context object.
2107    """
2108    self._control_flow_context = ctx
2109
2110  @property
2111  def name(self):
2112    """The full name of this operation."""
2113    return pywrap_tf_session.TF_OperationName(self._c_op)
2114
2115  @property
2116  def _id(self):
2117    """The unique integer id of this operation."""
2118    return self._id_value
2119
2120  @property
2121  def device(self):
2122    """The name of the device to which this op has been assigned, if any.
2123
2124    Returns:
2125      The string name of the device to which this op has been
2126      assigned, or an empty string if it has not been assigned to a
2127      device.
2128    """
2129    return pywrap_tf_session.TF_OperationDevice(self._c_op)
2130
2131  @property
2132  def _device_assignments(self):
2133    """Code locations for device context managers active at op creation.
2134
2135    This property will return a list of traceable_stack.TraceableObject
2136    instances where .obj is a string representing the assigned device
2137    (or information about the function that would be applied to this op
2138    to compute the desired device) and the filename and lineno members
2139    record the location of the relevant device context manager.
2140
2141    For example, suppose file_a contained these lines:
2142
2143      file_a.py:
2144        15: with tf.device('/gpu:0'):
2145        16:   node_b = tf.constant(4, name='NODE_B')
2146
2147    Then a TraceableObject t_obj representing the device context manager
2148    would have these member values:
2149
2150      t_obj.obj -> '/gpu:0'
2151      t_obj.filename = 'file_a.py'
2152      t_obj.lineno = 15
2153
2154    and node_b.op._device_assignments would return the list [t_obj].
2155
2156    Returns:
2157      [str: traceable_stack.TraceableObject, ...] as per this method's
2158      description, above.
2159    """
2160    return self._device_code_locations or []
2161
2162  @property
2163  def _colocation_dict(self):
2164    """Code locations for colocation context managers active at op creation.
2165
2166    This property will return a dictionary for which the keys are nodes with
2167    which this Operation is colocated, and for which the values are
2168    traceable_stack.TraceableObject instances.  The TraceableObject instances
2169    record the location of the relevant colocation context manager but have the
2170    "obj" field set to None to prevent leaking private data.
2171
2172    For example, suppose file_a contained these lines:
2173
2174      file_a.py:
2175        14: node_a = tf.constant(3, name='NODE_A')
2176        15: with tf.compat.v1.colocate_with(node_a):
2177        16:   node_b = tf.constant(4, name='NODE_B')
2178
2179    Then a TraceableObject t_obj representing the colocation context manager
2180    would have these member values:
2181
2182      t_obj.obj -> None
2183      t_obj.filename = 'file_a.py'
2184      t_obj.lineno = 15
2185
2186    and node_b.op._colocation_dict would return the dictionary
2187
2188      { 'NODE_A': t_obj }
2189
2190    Returns:
2191      {str: traceable_stack.TraceableObject} as per this method's description,
2192      above.
2193    """
2194    locations_dict = self._colocation_code_locations or {}
2195    return locations_dict.copy()
2196
2197  @property
2198  def _output_types(self):
2199    """List this operation's output types.
2200
2201    Returns:
2202      List of the types of the Tensors computed by this operation.
2203      Each element in the list is an integer whose value is one of
2204      the TF_DataType enums defined in pywrap_tf_session.h
2205      The length of this list indicates the number of output endpoints
2206      of the operation.
2207    """
2208    num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op)
2209    output_types = [
2210        int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i)))
2211        for i in xrange(num_outputs)
2212    ]
2213
2214    return output_types
2215
2216  def _tf_output(self, output_idx):
2217    """Create and return a new TF_Output for output_idx'th output of this op."""
2218    tf_output = pywrap_tf_session.TF_Output()
2219    tf_output.oper = self._c_op
2220    tf_output.index = output_idx
2221    return tf_output
2222
2223  def _tf_input(self, input_idx):
2224    """Create and return a new TF_Input for input_idx'th input of this op."""
2225    tf_input = pywrap_tf_session.TF_Input()
2226    tf_input.oper = self._c_op
2227    tf_input.index = input_idx
2228    return tf_input
2229
2230  def _set_device(self, device):  # pylint: disable=redefined-outer-name
2231    """Set the device of this operation.
2232
2233    Args:
2234      device: string or device..  The device to set.
2235    """
2236    self._set_device_from_string(compat.as_str(_device_string(device)))
2237
2238  def _set_device_from_string(self, device_str):
2239    """Fast path to set device if the type is known to be a string.
2240
2241    This function is called frequently enough during graph construction that
2242    there are non-trivial performance gains if the caller can guarantee that
2243    the specified device is already a string.
2244
2245    Args:
2246      device_str: A string specifying where to place this op.
2247    """
2248    pywrap_tf_session.SetRequestedDevice(
2249        self._graph._c_graph,  # pylint: disable=protected-access
2250        self._c_op,  # pylint: disable=protected-access
2251        device_str)
2252
2253  def _update_input(self, index, tensor):
2254    """Update the input to this operation at the given index.
2255
2256    NOTE: This is for TF internal use only. Please don't use it.
2257
2258    Args:
2259      index: the index of the input to update.
2260      tensor: the Tensor to be used as the input at the given index.
2261
2262    Raises:
2263      TypeError: if tensor is not a Tensor,
2264        or if input tensor type is not convertible to dtype.
2265      ValueError: if the Tensor is from a different graph.
2266    """
2267    if not isinstance(tensor, Tensor):
2268      raise TypeError("tensor must be a Tensor: %s" % tensor)
2269    _assert_same_graph(self, tensor)
2270
2271    # Reset cached inputs.
2272    self._inputs_val = None
2273    pywrap_tf_session.UpdateEdge(
2274        self._graph._c_graph,  # pylint: disable=protected-access
2275        tensor._as_tf_output(),  # pylint: disable=protected-access
2276        self._tf_input(index))
2277
2278  def _add_while_inputs(self, tensors):
2279    """See AddWhileInputHack in python_api.h.
2280
2281    NOTE: This is for TF internal use only. Please don't use it.
2282
2283    Args:
2284      tensors: list of Tensors
2285
2286    Raises:
2287      TypeError: if tensor is not a Tensor,
2288        or if input tensor type is not convertible to dtype.
2289      ValueError: if the Tensor is from a different graph.
2290    """
2291    for tensor in tensors:
2292      if not isinstance(tensor, Tensor):
2293        raise TypeError("tensor must be a Tensor: %s" % tensor)
2294      _assert_same_graph(self, tensor)
2295
2296      # Reset cached inputs.
2297      self._inputs_val = None
2298      pywrap_tf_session.AddWhileInputHack(
2299          self._graph._c_graph,  # pylint: disable=protected-access
2300          tensor._as_tf_output(),  # pylint: disable=protected-access
2301          self._c_op)
2302
2303  def _add_control_inputs(self, ops):
2304    """Add a list of new control inputs to this operation.
2305
2306    Args:
2307      ops: the list of Operations to add as control input.
2308
2309    Raises:
2310      TypeError: if ops is not a list of Operations.
2311      ValueError: if any op in ops is from a different graph.
2312    """
2313    for op in ops:
2314      if not isinstance(op, Operation):
2315        raise TypeError("op must be an Operation: %s" % op)
2316      pywrap_tf_session.AddControlInput(
2317          self._graph._c_graph,  # pylint: disable=protected-access
2318          self._c_op,  # pylint: disable=protected-access
2319          op._c_op)  # pylint: disable=protected-access
2320
2321  def _add_control_input(self, op):
2322    """Add a new control input to this operation.
2323
2324    Args:
2325      op: the Operation to add as control input.
2326
2327    Raises:
2328      TypeError: if op is not an Operation.
2329      ValueError: if op is from a different graph.
2330    """
2331    if not isinstance(op, Operation):
2332      raise TypeError("op must be an Operation: %s" % op)
2333    pywrap_tf_session.AddControlInput(
2334        self._graph._c_graph,  # pylint: disable=protected-access
2335        self._c_op,  # pylint: disable=protected-access
2336        op._c_op)  # pylint: disable=protected-access
2337
2338  def _remove_all_control_inputs(self):
2339    """Removes any control inputs to this operation."""
2340    pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
2341
2342  def _add_outputs(self, types, shapes):
2343    """Adds new Tensors to self.outputs.
2344
2345    Note: this is generally unsafe to use. This is used in certain situations in
2346    conjunction with _set_type_list_attr.
2347
2348    Args:
2349      types: list of DTypes
2350      shapes: list of TensorShapes
2351    """
2352    assert len(types) == len(shapes)
2353    orig_num_outputs = len(self.outputs)
2354    for i in range(len(types)):
2355      t = Tensor(self, orig_num_outputs + i, types[i])
2356      self._outputs.append(t)
2357      t.set_shape(shapes[i])
2358
2359  def __str__(self):
2360    return str(self.node_def)
2361
2362  def __repr__(self):
2363    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2364
2365  def __tf_tensor__(self, dtype=None, name=None):
2366    """Raises a helpful error."""
2367    raise TypeError("can't convert Operation '{}' to Tensor".format(self.name))
2368
2369  @property
2370  def outputs(self):
2371    """The list of `Tensor` objects representing the outputs of this op."""
2372    return self._outputs
2373
2374  @property
2375  def inputs(self):
2376    """The sequence of `Tensor` objects representing the data inputs of this op."""
2377    if self._inputs_val is None:
2378      # pylint: disable=protected-access
2379      self._inputs_val = tuple(
2380          map(self.graph._get_tensor_by_tf_output,
2381              pywrap_tf_session.GetOperationInputs(self._c_op)))
2382      # pylint: enable=protected-access
2383    return self._inputs_val
2384
2385  @property
2386  def _input_types(self):
2387    num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op)
2388    input_types = [
2389        dtypes.as_dtype(
2390            pywrap_tf_session.TF_OperationInputType(self._tf_input(i)))
2391        for i in xrange(num_inputs)
2392    ]
2393    return input_types
2394
2395  @property
2396  def control_inputs(self):
2397    """The `Operation` objects on which this op has a control dependency.
2398
2399    Before this op is executed, TensorFlow will ensure that the
2400    operations in `self.control_inputs` have finished executing. This
2401    mechanism can be used to run ops sequentially for performance
2402    reasons, or to ensure that the side effects of an op are observed
2403    in the correct order.
2404
2405    Returns:
2406      A list of `Operation` objects.
2407
2408    """
2409    control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper(
2410        self._c_op)
2411    # pylint: disable=protected-access
2412    return [
2413        self.graph._get_operation_by_name_unsafe(
2414            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2415    ]
2416    # pylint: enable=protected-access
2417
2418  @property
2419  def _control_outputs(self):
2420    """The `Operation` objects which have a control dependency on this op.
2421
2422    Before any of the ops in self._control_outputs can execute tensorflow will
2423    ensure self has finished executing.
2424
2425    Returns:
2426      A list of `Operation` objects.
2427
2428    """
2429    control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper(
2430        self._c_op)
2431    # pylint: disable=protected-access
2432    return [
2433        self.graph._get_operation_by_name_unsafe(
2434            pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops
2435    ]
2436    # pylint: enable=protected-access
2437
2438  @property
2439  def type(self):
2440    """The type of the op (e.g. `"MatMul"`)."""
2441    return pywrap_tf_session.TF_OperationOpType(self._c_op)
2442
2443  @property
2444  def graph(self):
2445    """The `Graph` that contains this operation."""
2446    return self._graph
2447
2448  @property
2449  def node_def(self):
2450    # pylint: disable=line-too-long
2451    """Returns the `NodeDef` representation of this operation.
2452
2453    Returns:
2454      A
2455      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2456      protocol buffer.
2457    """
2458    # pylint: enable=line-too-long
2459    with c_api_util.tf_buffer() as buf:
2460      pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf)
2461      data = pywrap_tf_session.TF_GetBuffer(buf)
2462    node_def = node_def_pb2.NodeDef()
2463    node_def.ParseFromString(compat.as_bytes(data))
2464    return node_def
2465
2466  @property
2467  def op_def(self):
2468    # pylint: disable=line-too-long
2469    """Returns the `OpDef` proto that represents the type of this op.
2470
2471    Returns:
2472      An
2473      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2474      protocol buffer.
2475    """
2476    # pylint: enable=line-too-long
2477    return self._graph._get_op_def(self.type)
2478
2479  @property
2480  def traceback(self):
2481    """Returns the call stack from when this operation was constructed."""
2482    return self._traceback
2483
2484  def _set_attr(self, attr_name, attr_value):
2485    """Private method used to set an attribute in the node_def."""
2486    buf = pywrap_tf_session.TF_NewBufferFromString(
2487        compat.as_bytes(attr_value.SerializeToString()))
2488    try:
2489      self._set_attr_with_buf(attr_name, buf)
2490    finally:
2491      pywrap_tf_session.TF_DeleteBuffer(buf)
2492
2493  def _set_attr_with_buf(self, attr_name, attr_buf):
2494    """Set an attr in the node_def with a pre-allocated buffer."""
2495    # pylint: disable=protected-access
2496    pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name,
2497                              attr_buf)
2498    # pylint: enable=protected-access
2499
2500  def _set_func_attr(self, attr_name, func_name):
2501    """Private method used to set a function attribute in the node_def."""
2502    func = attr_value_pb2.NameAttrList(name=func_name)
2503    self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2504
2505  def _set_func_list_attr(self, attr_name, func_names):
2506    """Private method used to set a list(function) attribute in the node_def."""
2507    funcs = [attr_value_pb2.NameAttrList(name=func_name)
2508             for func_name in func_names]
2509    funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
2510    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
2511
2512  def _set_type_list_attr(self, attr_name, types):
2513    """Private method used to set a list(type) attribute in the node_def."""
2514    if not types:
2515      return
2516    if isinstance(types[0], dtypes.DType):
2517      types = [dt.as_datatype_enum for dt in types]
2518    types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2519    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2520
2521  def _set_shape_list_attr(self, attr_name, shapes):
2522    """Private method used to set a list(shape) attribute in the node_def."""
2523    shapes = [s.as_proto() for s in shapes]
2524    shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2525    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2526
2527  def _clear_attr(self, attr_name):
2528    """Private method used to clear an attribute in the node_def."""
2529    # pylint: disable=protected-access
2530    pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
2531    # pylint: enable=protected-access
2532
2533  def get_attr(self, name):
2534    """Returns the value of the attr of this op with the given `name`.
2535
2536    Args:
2537      name: The name of the attr to fetch.
2538
2539    Returns:
2540      The value of the attr, as a Python object.
2541
2542    Raises:
2543      ValueError: If this op does not have an attr with the given `name`.
2544    """
2545    fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2546    try:
2547      with c_api_util.tf_buffer() as buf:
2548        pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2549        data = pywrap_tf_session.TF_GetBuffer(buf)
2550    except errors.InvalidArgumentError as e:
2551      # Convert to ValueError for backwards compatibility.
2552      raise ValueError(str(e))
2553    x = attr_value_pb2.AttrValue()
2554    x.ParseFromString(data)
2555
2556    oneof_value = x.WhichOneof("value")
2557    if oneof_value is None:
2558      return []
2559    if oneof_value == "list":
2560      for f in fields:
2561        if getattr(x.list, f):
2562          if f == "type":
2563            return [dtypes.as_dtype(t) for t in x.list.type]
2564          else:
2565            return list(getattr(x.list, f))
2566      return []
2567    if oneof_value == "type":
2568      return dtypes.as_dtype(x.type)
2569    assert oneof_value in fields, "Unsupported field type in " + str(x)
2570    return getattr(x, oneof_value)
2571
2572  def _get_attr_type(self, name):
2573    """Returns the `DType` value of the attr of this op with the given `name`."""
2574    try:
2575      dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name)
2576      return _DTYPES_INTERN_TABLE[dtype_enum]
2577    except errors.InvalidArgumentError as e:
2578      # Convert to ValueError for backwards compatibility.
2579      raise ValueError(str(e))
2580
2581  def _get_attr_bool(self, name):
2582    """Returns the `bool` value of the attr of this op with the given `name`."""
2583    try:
2584      return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name)
2585    except errors.InvalidArgumentError as e:
2586      # Convert to ValueError for backwards compatibility.
2587      raise ValueError(str(e))
2588
2589  def _get_attr_int(self, name):
2590    """Returns the `int` value of the attr of this op with the given `name`."""
2591    try:
2592      return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name)
2593    except errors.InvalidArgumentError as e:
2594      # Convert to ValueError for backwards compatibility.
2595      raise ValueError(str(e))
2596
2597  def run(self, feed_dict=None, session=None):
2598    """Runs this operation in a `Session`.
2599
2600    Calling this method will execute all preceding operations that
2601    produce the inputs needed for this operation.
2602
2603    *N.B.* Before invoking `Operation.run()`, its graph must have been
2604    launched in a session, and either a default session must be
2605    available, or `session` must be specified explicitly.
2606
2607    Args:
2608      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
2609        `tf.Session.run` for a description of the valid feed values.
2610      session: (Optional.) The `Session` to be used to run to this operation. If
2611        none, the default session will be used.
2612    """
2613    _run_using_default_session(self, feed_dict, self.graph, session)
2614
2615_gradient_registry = registry.Registry("gradient")
2616
2617
2618@tf_export("RegisterGradient")
2619class RegisterGradient(object):
2620  """A decorator for registering the gradient function for an op type.
2621
2622  This decorator is only used when defining a new op type. For an op
2623  with `m` inputs and `n` outputs, the gradient function is a function
2624  that takes the original `Operation` and `n` `Tensor` objects
2625  (representing the gradients with respect to each output of the op),
2626  and returns `m` `Tensor` objects (representing the partial gradients
2627  with respect to each input of the op).
2628
2629  For example, assuming that operations of type `"Sub"` take two
2630  inputs `x` and `y`, and return a single output `x - y`, the
2631  following gradient function would be registered:
2632
2633  ```python
2634  @tf.RegisterGradient("Sub")
2635  def _sub_grad(unused_op, grad):
2636    return grad, tf.negative(grad)
2637  ```
2638
2639  The decorator argument `op_type` is the string type of an
2640  operation. This corresponds to the `OpDef.name` field for the proto
2641  that defines the operation.
2642  """
2643
2644  __slots__ = ["_op_type"]
2645
2646  def __init__(self, op_type):
2647    """Creates a new decorator with `op_type` as the Operation type.
2648
2649    Args:
2650      op_type: The string type of an operation. This corresponds to the
2651        `OpDef.name` field for the proto that defines the operation.
2652
2653    Raises:
2654      TypeError: If `op_type` is not string.
2655    """
2656    if not isinstance(op_type, six.string_types):
2657      raise TypeError("op_type must be a string")
2658    self._op_type = op_type
2659
2660  def __call__(self, f):
2661    """Registers the function `f` as gradient function for `op_type`."""
2662    _gradient_registry.register(f, self._op_type)
2663    return f
2664
2665
2666@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2667@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2668def no_gradient(op_type):
2669  """Specifies that ops of type `op_type` is not differentiable.
2670
2671  This function should *not* be used for operations that have a
2672  well-defined gradient that is not yet implemented.
2673
2674  This function is only used when defining a new op type. It may be
2675  used for ops such as `tf.size()` that are not differentiable.  For
2676  example:
2677
2678  ```python
2679  tf.no_gradient("Size")
2680  ```
2681
2682  The gradient computed for 'op_type' will then propagate zeros.
2683
2684  For ops that have a well-defined gradient but are not yet implemented,
2685  no declaration should be made, and an error *must* be thrown if
2686  an attempt to request its gradient is made.
2687
2688  Args:
2689    op_type: The string type of an operation. This corresponds to the
2690      `OpDef.name` field for the proto that defines the operation.
2691
2692  Raises:
2693    TypeError: If `op_type` is not a string.
2694
2695  """
2696  if not isinstance(op_type, six.string_types):
2697    raise TypeError("op_type must be a string")
2698  _gradient_registry.register(None, op_type)
2699
2700
2701# Aliases for the old names, will be eventually removed.
2702NoGradient = no_gradient
2703NotDifferentiable = no_gradient
2704
2705
2706def get_gradient_function(op):
2707  """Returns the function that computes gradients for "op"."""
2708  if not op.inputs:
2709    return None
2710
2711  gradient_function = op._gradient_function  # pylint: disable=protected-access
2712  if gradient_function:
2713    return gradient_function
2714
2715  try:
2716    op_type = op.get_attr("_gradient_op_type")
2717  except ValueError:
2718    op_type = op.type
2719  return _gradient_registry.lookup(op_type)
2720
2721
2722def set_shape_and_handle_data_for_outputs(_):
2723  """No op. TODO(b/74620627): Remove this."""
2724  pass
2725
2726
2727class OpStats(object):
2728  """A holder for statistics about an operator.
2729
2730  This class holds information about the resource requirements for an op,
2731  including the size of its weight parameters on-disk and how many FLOPS it
2732  requires to execute forward inference.
2733
2734  If you define a new operation, you can create a function that will return a
2735  set of information about its usage of the CPU and disk space when serialized.
2736  The function itself takes a Graph object that's been set up so you can call
2737  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2738  argument.
2739
2740  """
2741
2742  __slots__ = ["_statistic_type", "_value"]
2743
2744  def __init__(self, statistic_type, value=None):
2745    """Sets up the initial placeholders for the statistics."""
2746    self.statistic_type = statistic_type
2747    self.value = value
2748
2749  @property
2750  def statistic_type(self):
2751    return self._statistic_type
2752
2753  @statistic_type.setter
2754  def statistic_type(self, statistic_type):
2755    self._statistic_type = statistic_type
2756
2757  @property
2758  def value(self):
2759    return self._value
2760
2761  @value.setter
2762  def value(self, value):
2763    self._value = value
2764
2765  def __iadd__(self, other):
2766    if other.statistic_type != self.statistic_type:
2767      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2768                       (self.statistic_type, other.statistic_type))
2769    if self.value is None:
2770      self.value = other.value
2771    elif other.value is not None:
2772      self._value += other.value
2773    return self
2774
2775
2776_stats_registry = registry.Registry("statistical functions")
2777
2778
2779class RegisterStatistics(object):
2780  """A decorator for registering the statistics function for an op type.
2781
2782  This decorator can be defined for an op type so that it gives a
2783  report on the resources used by an instance of an operator, in the
2784  form of an OpStats object.
2785
2786  Well-known types of statistics include these so far:
2787
2788  - flops: When running a graph, the bulk of the computation happens doing
2789    numerical calculations like matrix multiplications. This type allows a node
2790    to return how many floating-point operations it takes to complete. The
2791    total number of FLOPs for a graph is a good guide to its expected latency.
2792
2793  You can add your own statistics just by picking a new type string, registering
2794  functions for the ops you care about, and then calling get_stats_for_node_def.
2795
2796  If a statistic for an op is registered multiple times, a KeyError will be
2797  raised.
2798
2799  Since the statistics is counted on a per-op basis. It is not suitable for
2800  model parameters (capacity), which is expected to be counted only once, even
2801  if it is shared by multiple ops. (e.g. RNN)
2802
2803  For example, you can define a new metric called doohickey for a Foo operation
2804  by placing this in your code:
2805
2806  ```python
2807  @ops.RegisterStatistics("Foo", "doohickey")
2808  def _calc_foo_bojangles(unused_graph, unused_node_def):
2809    return ops.OpStats("doohickey", 20)
2810  ```
2811
2812  Then in client code you can retrieve the value by making this call:
2813
2814  ```python
2815  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2816  ```
2817
2818  If the NodeDef is for an op with a registered doohickey function, you'll get
2819  back the calculated amount in doohickey.value, or None if it's not defined.
2820
2821  """
2822
2823  __slots__ = ["_op_type", "_statistic_type"]
2824
2825  def __init__(self, op_type, statistic_type):
2826    """Saves the `op_type` as the `Operation` type."""
2827    if not isinstance(op_type, six.string_types):
2828      raise TypeError("op_type must be a string.")
2829    if "," in op_type:
2830      raise TypeError("op_type must not contain a comma.")
2831    self._op_type = op_type
2832    if not isinstance(statistic_type, six.string_types):
2833      raise TypeError("statistic_type must be a string.")
2834    if "," in statistic_type:
2835      raise TypeError("statistic_type must not contain a comma.")
2836    self._statistic_type = statistic_type
2837
2838  def __call__(self, f):
2839    """Registers "f" as the statistics function for "op_type"."""
2840    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2841    return f
2842
2843
2844def get_stats_for_node_def(graph, node, statistic_type):
2845  """Looks up the node's statistics function in the registry and calls it.
2846
2847  This function takes a Graph object and a NodeDef from a GraphDef, and if
2848  there's an associated statistics method, calls it and returns a result. If no
2849  function has been registered for the particular node type, it returns an empty
2850  statistics object.
2851
2852  Args:
2853    graph: A Graph object that's been set up with the node's graph.
2854    node: A NodeDef describing the operator.
2855    statistic_type: A string identifying the statistic we're interested in.
2856
2857  Returns:
2858    An OpStats object containing information about resource usage.
2859  """
2860
2861  try:
2862    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2863    result = stats_func(graph, node)
2864  except LookupError:
2865    result = OpStats(statistic_type)
2866  return result
2867
2868
2869def name_from_scope_name(name):
2870  """Returns the name of an op given the name of its scope.
2871
2872  Args:
2873    name: the name of the scope.
2874
2875  Returns:
2876    the name of the op (equal to scope name minus any trailing slash).
2877  """
2878  return name[:-1] if (name and name[-1] == "/") else name
2879
2880
2881_MUTATION_LOCK_GROUP = 0
2882_SESSION_RUN_LOCK_GROUP = 1
2883
2884
2885@tf_export("Graph")
2886class Graph(object):
2887  """A TensorFlow computation, represented as a dataflow graph.
2888
2889  Graphs are used by `tf.function`s to represent the function's computations.
2890  Each graph contains a set of `tf.Operation` objects, which represent units of
2891  computation; and `tf.Tensor` objects, which represent the units of data that
2892  flow between operations.
2893
2894  ### Using graphs directly (deprecated)
2895
2896  A `tf.Graph` can be constructed and used directly without a `tf.function`, as
2897  was required in TensorFlow 1, but this is deprecated and it is recommended to
2898  use a `tf.function` instead. If a graph is directly used, other deprecated
2899  TensorFlow 1 classes are also required to execute the graph, such as a
2900  `tf.compat.v1.Session`.
2901
2902  A default graph can be registered with the `tf.Graph.as_default` context
2903  manager. Then, operations will be added to the graph instead of being executed
2904  eagerly. For example:
2905
2906  ```python
2907  g = tf.Graph()
2908  with g.as_default():
2909    # Define operations and tensors in `g`.
2910    c = tf.constant(30.0)
2911    assert c.graph is g
2912  ```
2913
2914  `tf.compat.v1.get_default_graph()` can be used to obtain the default graph.
2915
2916  Important note: This class *is not* thread-safe for graph construction. All
2917  operations should be created from a single thread, or external
2918  synchronization must be provided. Unless otherwise specified, all methods
2919  are not thread-safe.
2920
2921  A `Graph` instance supports an arbitrary number of "collections"
2922  that are identified by name. For convenience when building a large
2923  graph, collections can store groups of related objects: for
2924  example, the `tf.Variable` uses a collection (named
2925  `tf.GraphKeys.GLOBAL_VARIABLES`) for
2926  all variables that are created during the construction of a graph. The caller
2927  may define additional collections by specifying a new name.
2928  """
2929
2930  def __init__(self):
2931    """Creates a new, empty Graph."""
2932    # Protects core state that can be returned via public accessors.
2933    # Thread-safety is provided on a best-effort basis to support buggy
2934    # programs, and is not guaranteed by the public `tf.Graph` API.
2935    #
2936    # NOTE(mrry): This does not protect the various stacks. A warning will
2937    # be reported if these are used from multiple threads
2938    self._lock = threading.RLock()
2939    # The group lock synchronizes Session.run calls with methods that create
2940    # and mutate ops (e.g. Graph.create_op()). This synchronization is
2941    # necessary because it's illegal to modify an operation after it's been run.
2942    # The group lock allows any number of threads to mutate ops at the same time
2943    # but if any modification is going on, all Session.run calls have to wait.
2944    # Similarly, if one or more Session.run calls are going on, all mutate ops
2945    # have to wait until all Session.run calls have finished.
2946    self._group_lock = lock_util.GroupLock(num_groups=2)
2947    self._nodes_by_id = {}  # GUARDED_BY(self._lock)
2948    self._next_id_counter = 0  # GUARDED_BY(self._lock)
2949    self._nodes_by_name = {}  # GUARDED_BY(self._lock)
2950    self._version = 0  # GUARDED_BY(self._lock)
2951    # Maps a name used in the graph to the next id to use for that name.
2952    self._names_in_use = {}
2953    self._stack_state_is_thread_local = False
2954    self._thread_local = threading.local()
2955    # Functions that will be applied to choose a device if none is specified.
2956    # In TF2.x or after switch_to_thread_local(),
2957    # self._thread_local._device_function_stack is used instead.
2958    self._graph_device_function_stack = traceable_stack.TraceableStack()
2959    # Default original_op applied to new ops.
2960    self._default_original_op = None
2961    # Current control flow context. It could be either CondContext or
2962    # WhileContext defined in ops/control_flow_ops.py
2963    self._control_flow_context = None
2964    # A new node will depend of the union of all of the nodes in the stack.
2965    # In TF2.x or after switch_to_thread_local(),
2966    # self._thread_local._control_dependencies_stack is used instead.
2967    self._graph_control_dependencies_stack = []
2968    # Arbitrary collections of objects.
2969    self._collections = {}
2970    # The graph-level random seed
2971    self._seed = None
2972    # A dictionary of attributes that should be applied to all ops.
2973    self._attr_scope_map = {}
2974    # A map from op type to the kernel label that should be used.
2975    self._op_to_kernel_label_map = {}
2976    # A map from op type to an alternative op type that should be used when
2977    # computing gradients.
2978    self._gradient_override_map = {}
2979    # A map from op type to a gradient function that should be used instead.
2980    self._gradient_function_map = {}
2981    # True if the graph is considered "finalized".  In that case no
2982    # new operations can be added.
2983    self._finalized = False
2984    # Functions defined in the graph
2985    self._functions = collections.OrderedDict()
2986    # Default GraphDef versions
2987    self._graph_def_versions = versions_pb2.VersionDef(
2988        producer=versions.GRAPH_DEF_VERSION,
2989        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
2990    self._building_function = False
2991    # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
2992    # self._thread_local._colocation_stack is used instead.
2993    self._graph_colocation_stack = traceable_stack.TraceableStack()
2994    # Set of tensors that are dangerous to feed!
2995    self._unfeedable_tensors = object_identity.ObjectIdentitySet()
2996    # Set of operations that are dangerous to fetch!
2997    self._unfetchable_ops = set()
2998    # A map of tensor handle placeholder to tensor dtype.
2999    self._handle_feeders = {}
3000    # A map from tensor handle to its read op.
3001    self._handle_readers = {}
3002    # A map from tensor handle to its move op.
3003    self._handle_movers = {}
3004    # A map from tensor handle to its delete op.
3005    self._handle_deleters = {}
3006    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
3007    # will be shared when defining function graphs, for example, so optimizers
3008    # being called inside function definitions behave as if they were seeing the
3009    # actual outside graph).
3010    self._graph_key = "grap-key-%d/" % (uid(),)
3011    # A string with the last reduction method passed to
3012    # losses.compute_weighted_loss(), or None. This is required only for
3013    # backward compatibility with Estimator and optimizer V1 use cases.
3014    self._last_loss_reduction = None
3015    # Flag that is used to indicate whether loss has been scaled by optimizer.
3016    # If this flag has been set, then estimator uses it to scale losss back
3017    # before reporting. This is required only for backward compatibility with
3018    # Estimator and optimizer V1 use cases.
3019    self._is_loss_scaled_by_optimizer = False
3020    self._container = ""
3021    # Set to True if this graph is being built in an
3022    # AutomaticControlDependencies context.
3023    self._add_control_dependencies = False
3024    # Cache for OpDef protobufs retrieved via the C API.
3025    self._op_def_cache = {}
3026    # Cache for constant results of `broadcast_gradient_args()`. The keys are
3027    # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the
3028    # values are tuples of reduction indices: (rx, ry).
3029    self._bcast_grad_args_cache = {}
3030    # Cache for constant results of `reduced_shape()`. The keys are pairs of
3031    # tuples: (input_shape_tuple, reduction_indices_tuple), and the values
3032    # are pairs of tuples: (output_shape_kept_dims, tile_scaling).
3033    self._reduced_shape_cache = {}
3034
3035    # TODO(skyewm): fold as much of the above as possible into the C
3036    # implementation
3037    self._scoped_c_graph = c_api_util.ScopedTFGraph()
3038    # The C API requires all ops to have shape functions. Disable this
3039    # requirement (many custom ops do not have shape functions, and we don't
3040    # want to break these existing cases).
3041    pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False)
3042    if tf2.enabled():
3043      self.switch_to_thread_local()
3044
3045  # Note: this method is private because the API of tf.Graph() is public and
3046  # frozen, and this functionality is still not ready for public visibility.
3047  @tf_contextlib.contextmanager
3048  def _variable_creator_scope(self, creator, priority=100):
3049    """Scope which defines a variable creation function.
3050
3051    Args:
3052      creator: A callable taking `next_creator` and `kwargs`. See the
3053        `tf.variable_creator_scope` docstring.
3054      priority: Creators with a higher `priority` are called first. Within the
3055        same priority, creators are called inner-to-outer.
3056
3057    Yields:
3058      `_variable_creator_scope` is a context manager with a side effect, but
3059      doesn't return a value.
3060
3061    Raises:
3062      RuntimeError: If variable creator scopes are not properly nested.
3063    """
3064    # This step keeps a reference to the existing stack, and it also initializes
3065    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3066    old = self._variable_creator_stack
3067    new = list(old)
3068    new.append((priority, creator))
3069    # Sorting is stable, so we'll put higher-priority creators later in the list
3070    # but otherwise maintain registration order.
3071    new.sort(key=lambda item: item[0])
3072    self._thread_local._variable_creator_stack = new  # pylint: disable=protected-access
3073    try:
3074      yield
3075    finally:
3076      if self._thread_local._variable_creator_stack is not new:  # pylint: disable=protected-access
3077        raise RuntimeError(
3078            "Exiting variable_creator_scope without proper nesting.")
3079      self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
3080
3081  # Note: this method is private because the API of tf.Graph() is public and
3082  # frozen, and this functionality is still not ready for public visibility.
3083  @property
3084  def _variable_creator_stack(self):
3085    if not hasattr(self._thread_local, "_variable_creator_stack"):
3086      self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
3087
3088    # This previously returned a copy of the stack instead of the stack itself,
3089    # to guard against accidental mutation. Consider, however, code that wants
3090    # to save and restore the variable creator stack:
3091    #     def f():
3092    #       original_stack = graph._variable_creator_stack
3093    #       graph._variable_creator_stack = new_stack
3094    #       ...  # Some code
3095    #       graph._variable_creator_stack = original_stack
3096    #
3097    # And lets say you have some code that calls this function with some
3098    # variable_creator:
3099    #     def g():
3100    #       with variable_scope.variable_creator_scope(creator):
3101    #         f()
3102    # When exiting the variable creator scope, it would see a different stack
3103    # object than it expected leading to a "Exiting variable_creator_scope
3104    # without proper nesting" error.
3105    return self._thread_local._variable_creator_stack  # pylint: disable=protected-access
3106
3107  @_variable_creator_stack.setter
3108  def _variable_creator_stack(self, variable_creator_stack):
3109    self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
3110
3111  def _check_not_finalized(self):
3112    """Check if the graph is finalized.
3113
3114    Raises:
3115      RuntimeError: If the graph finalized.
3116    """
3117    if self._finalized:
3118      raise RuntimeError("Graph is finalized and cannot be modified.")
3119
3120  def _add_op(self, op, op_name):
3121    """Adds 'op' to the graph and returns the unique ID for the added Operation.
3122
3123    Args:
3124      op: the Operation to add.
3125      op_name: the name of the Operation.
3126
3127    Returns:
3128      An integer that is a unique ID for the added Operation.
3129    """
3130    self._check_not_finalized()
3131    with self._lock:
3132      self._next_id_counter += 1
3133      op_id = self._next_id_counter
3134      self._nodes_by_id[op_id] = op
3135      self._nodes_by_name[op_name] = op
3136      self._version = max(self._version, op_id)
3137      return op_id
3138
3139  @property
3140  def _c_graph(self):
3141    if self._scoped_c_graph:
3142      return self._scoped_c_graph.graph
3143    return None
3144
3145  @property
3146  def version(self):
3147    """Returns a version number that increases as ops are added to the graph.
3148
3149    Note that this is unrelated to the
3150    `tf.Graph.graph_def_versions`.
3151
3152    Returns:
3153       An integer version that increases as ops are added to the graph.
3154    """
3155    if self._finalized:
3156      return self._version
3157
3158    with self._lock:
3159      return self._version
3160
3161  @property
3162  def graph_def_versions(self):
3163    # pylint: disable=line-too-long
3164    """The GraphDef version information of this graph.
3165
3166    For details on the meaning of each version, see
3167    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
3168
3169    Returns:
3170      A `VersionDef`.
3171    """
3172    # pylint: enable=line-too-long
3173    with c_api_util.tf_buffer() as buf:
3174      pywrap_tf_session.TF_GraphVersions(self._c_graph, buf)
3175      data = pywrap_tf_session.TF_GetBuffer(buf)
3176    version_def = versions_pb2.VersionDef()
3177    version_def.ParseFromString(compat.as_bytes(data))
3178    return version_def
3179
3180  @property
3181  def seed(self):
3182    """The graph-level random seed of this graph."""
3183    return self._seed
3184
3185  @seed.setter
3186  def seed(self, seed):
3187    self._seed = seed
3188
3189  @property
3190  def finalized(self):
3191    """True if this graph has been finalized."""
3192    return self._finalized
3193
3194  def finalize(self):
3195    """Finalizes this graph, making it read-only.
3196
3197    After calling `g.finalize()`, no new operations can be added to
3198    `g`.  This method is used to ensure that no operations are added
3199    to a graph when it is shared between multiple threads, for example
3200    when using a `tf.compat.v1.train.QueueRunner`.
3201    """
3202    self._finalized = True
3203
3204  def _unsafe_unfinalize(self):
3205    """Opposite of `finalize`.
3206
3207    Internal interface.
3208
3209    NOTE: Unfinalizing a graph could have negative impact on performance,
3210    especially in a multi-threaded environment.  Unfinalizing a graph
3211    when it is in use by a Session may lead to undefined behavior. Ensure
3212    that all sessions using a graph are closed before calling this method.
3213    """
3214    self._finalized = False
3215
3216  def _get_control_flow_context(self):
3217    """Returns the current control flow context.
3218
3219    Returns:
3220      A context object.
3221    """
3222    return self._control_flow_context
3223
3224  def _set_control_flow_context(self, ctx):
3225    """Sets the current control flow context.
3226
3227    Args:
3228      ctx: a context object.
3229    """
3230    self._control_flow_context = ctx
3231
3232  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3233    """If this graph contains functions, copy them to `graph_def`."""
3234    bytesize = starting_bytesize
3235    for f in self._functions.values():
3236      bytesize += f.definition.ByteSize()
3237      if bytesize >= (1 << 31) or bytesize < 0:
3238        raise ValueError("GraphDef cannot be larger than 2GB.")
3239      graph_def.library.function.extend([f.definition])
3240      if f.grad_func_name:
3241        grad_def = function_pb2.GradientDef()
3242        grad_def.function_name = f.name
3243        grad_def.gradient_func = f.grad_func_name
3244        graph_def.library.gradient.extend([grad_def])
3245
3246  def _as_graph_def(self, from_version=None, add_shapes=False):
3247    # pylint: disable=line-too-long
3248    """Returns a serialized `GraphDef` representation of this graph.
3249
3250    The serialized `GraphDef` can be imported into another `Graph`
3251    (using `tf.import_graph_def`) or used with the
3252    [C++ Session API](../../../../api_docs/cc/index.md).
3253
3254    This method is thread-safe.
3255
3256    Args:
3257      from_version: Optional.  If this is set, returns a `GraphDef` containing
3258        only the nodes that were added to this graph since its `version`
3259        property had the given value.
3260      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3261        the inferred shapes of each of its outputs.
3262
3263    Returns:
3264      A tuple containing a
3265      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3266      protocol buffer, and the version of the graph to which that
3267      `GraphDef` corresponds.
3268
3269    Raises:
3270      ValueError: If the `graph_def` would be too large.
3271
3272    """
3273    # pylint: enable=line-too-long
3274    with self._lock:
3275      with c_api_util.tf_buffer() as buf:
3276        pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf)
3277        data = pywrap_tf_session.TF_GetBuffer(buf)
3278      graph = graph_pb2.GraphDef()
3279      graph.ParseFromString(compat.as_bytes(data))
3280      # Strip the experimental library field iff it's empty.
3281      if not graph.library.function:
3282        graph.ClearField("library")
3283
3284      if add_shapes:
3285        for node in graph.node:
3286          op = self._nodes_by_name[node.name]
3287          if op.outputs:
3288            node.attr["_output_shapes"].list.shape.extend(
3289                [output.get_shape().as_proto() for output in op.outputs])
3290        for function_def in graph.library.function:
3291          defined_function = self._functions[function_def.signature.name]
3292          try:
3293            func_graph = defined_function.graph
3294          except AttributeError:
3295            # _DefinedFunction doesn't have a graph, _EagerDefinedFunction
3296            # does. Both rely on ops.py, so we can't really isinstance check
3297            # them.
3298            continue
3299          input_shapes = function_def.attr["_input_shapes"]
3300          try:
3301            func_graph_inputs = func_graph.inputs
3302          except AttributeError:
3303            continue
3304          # TODO(b/141471245): Fix the inconsistency when inputs of func graph
3305          # are appended during gradient computation of while/cond.
3306          for input_tensor, arg_def in zip(func_graph_inputs,
3307                                           function_def.signature.input_arg):
3308            input_shapes.list.shape.add().CopyFrom(
3309                input_tensor.get_shape().as_proto())
3310            if input_tensor.dtype == dtypes.resource:
3311              _copy_handle_data_to_arg_def(input_tensor, arg_def)
3312
3313          for output_tensor, arg_def in zip(func_graph.outputs,
3314                                            function_def.signature.output_arg):
3315            if output_tensor.dtype == dtypes.resource:
3316              _copy_handle_data_to_arg_def(output_tensor, arg_def)
3317
3318          for node in function_def.node_def:
3319            try:
3320              op = func_graph.get_operation_by_name(node.name)
3321            except KeyError:
3322              continue
3323            outputs = op.outputs
3324
3325            if op.type == "StatefulPartitionedCall":
3326              # Filter out any extra outputs (possibly added by function
3327              # backpropagation rewriting).
3328              num_outputs = len(node.attr["Tout"].list.type)
3329              outputs = outputs[:num_outputs]
3330
3331            node.attr["_output_shapes"].list.shape.extend(
3332                [output.get_shape().as_proto() for output in outputs])
3333
3334    return graph, self._version
3335
3336  def as_graph_def(self, from_version=None, add_shapes=False):
3337    # pylint: disable=line-too-long
3338    """Returns a serialized `GraphDef` representation of this graph.
3339
3340    The serialized `GraphDef` can be imported into another `Graph`
3341    (using `tf.import_graph_def`) or used with the
3342    [C++ Session API](../../api_docs/cc/index.md).
3343
3344    This method is thread-safe.
3345
3346    Args:
3347      from_version: Optional.  If this is set, returns a `GraphDef` containing
3348        only the nodes that were added to this graph since its `version`
3349        property had the given value.
3350      add_shapes: If true, adds an "_output_shapes" list attr to each node with
3351        the inferred shapes of each of its outputs.
3352
3353    Returns:
3354      A
3355      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3356      protocol buffer.
3357
3358    Raises:
3359      ValueError: If the `graph_def` would be too large.
3360    """
3361    # pylint: enable=line-too-long
3362    result, _ = self._as_graph_def(from_version, add_shapes)
3363    return result
3364
3365  def _is_function(self, name):
3366    """Tests whether 'name' is registered in this graph's function library.
3367
3368    Args:
3369      name: string op name.
3370
3371    Returns:
3372      bool indicating whether or not 'name' is registered in function library.
3373    """
3374    return compat.as_str(name) in self._functions
3375
3376  def _get_function(self, name):
3377    """Returns the function definition for 'name'.
3378
3379    Args:
3380      name: string function name.
3381
3382    Returns:
3383      The function def proto.
3384    """
3385    return self._functions.get(compat.as_str(name), None)
3386
3387  def _add_function(self, function):
3388    """Adds a function to the graph.
3389
3390    After the function has been added, you can call to the function by
3391    passing the function name in place of an op name to
3392    `Graph.create_op()`.
3393
3394    Args:
3395      function: A `_DefinedFunction` object.
3396
3397    Raises:
3398      ValueError: if another function is defined with the same name.
3399    """
3400    self._check_not_finalized()
3401
3402    name = function.name
3403    # Sanity checks on gradient definition.
3404    if (function.grad_func_name is not None) and (function.python_grad_func is
3405                                                  not None):
3406      raise ValueError("Gradient defined twice for function %s" % name)
3407
3408    # Add function to graph
3409    # pylint: disable=protected-access
3410    gradient = (
3411        function._grad_func._c_func.func if function._grad_func else None)
3412    pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func,
3413                                           gradient)
3414    # pylint: enable=protected-access
3415
3416    self._functions[compat.as_str(name)] = function
3417
3418    # Need a new-enough consumer to support the functions we add to the graph.
3419    if self._graph_def_versions.min_consumer < 12:
3420      self._graph_def_versions.min_consumer = 12
3421
3422  @property
3423  def building_function(self):
3424    """Returns True iff this graph represents a function."""
3425    return self._building_function
3426
3427  # Helper functions to create operations.
3428  @deprecated_args(None,
3429                   "Shapes are always computed; don't use the compute_shapes "
3430                   "as it has no effect.", "compute_shapes")
3431  def create_op(
3432      self,
3433      op_type,
3434      inputs,
3435      dtypes=None,  # pylint: disable=redefined-outer-name
3436      input_types=None,
3437      name=None,
3438      attrs=None,
3439      op_def=None,
3440      compute_shapes=True,
3441      compute_device=True):
3442    """Creates an `Operation` in this graph.
3443
3444    This is a low-level interface for creating an `Operation`. Most
3445    programs will not call this method directly, and instead use the
3446    Python op constructors, such as `tf.constant()`, which add ops to
3447    the default graph.
3448
3449    Args:
3450      op_type: The `Operation` type to create. This corresponds to the
3451        `OpDef.name` field for the proto that defines the operation.
3452      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3453      dtypes: (Optional) A list of `DType` objects that will be the types of the
3454        tensors that the operation produces.
3455      input_types: (Optional.) A list of `DType`s that will be the types of the
3456        tensors that the operation consumes. By default, uses the base `DType`
3457        of each input in `inputs`. Operations that expect reference-typed inputs
3458        must specify `input_types` explicitly.
3459      name: (Optional.) A string name for the operation. If not specified, a
3460        name is generated based on `op_type`.
3461      attrs: (Optional.) A dictionary where the key is the attribute name (a
3462        string) and the value is the respective `attr` attribute of the
3463        `NodeDef` proto that will represent the operation (an `AttrValue`
3464        proto).
3465      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3466        the operation will have.
3467      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3468        computed).
3469      compute_device: (Optional.) If True, device functions will be executed to
3470        compute the device property of the Operation.
3471
3472    Raises:
3473      TypeError: if any of the inputs is not a `Tensor`.
3474      ValueError: if colocation conflicts with existing device assignment.
3475
3476    Returns:
3477      An `Operation` object.
3478    """
3479    del compute_shapes
3480    for idx, a in enumerate(inputs):
3481      if not isinstance(a, Tensor):
3482        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3483    return self._create_op_internal(op_type, inputs, dtypes, input_types, name,
3484                                    attrs, op_def, compute_device)
3485
3486  def _create_op_internal(
3487      self,
3488      op_type,
3489      inputs,
3490      dtypes=None,  # pylint: disable=redefined-outer-name
3491      input_types=None,
3492      name=None,
3493      attrs=None,
3494      op_def=None,
3495      compute_device=True):
3496    """Creates an `Operation` in this graph.
3497
3498    Implements `Graph.create_op()` without the overhead of the deprecation
3499    wrapper.
3500
3501    Args:
3502      op_type: The `Operation` type to create. This corresponds to the
3503        `OpDef.name` field for the proto that defines the operation.
3504      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3505      dtypes: (Optional) A list of `DType` objects that will be the types of the
3506        tensors that the operation produces.
3507      input_types: (Optional.) A list of `DType`s that will be the types of the
3508        tensors that the operation consumes. By default, uses the base `DType`
3509        of each input in `inputs`. Operations that expect reference-typed inputs
3510        must specify `input_types` explicitly.
3511      name: (Optional.) A string name for the operation. If not specified, a
3512        name is generated based on `op_type`.
3513      attrs: (Optional.) A dictionary where the key is the attribute name (a
3514        string) and the value is the respective `attr` attribute of the
3515        `NodeDef` proto that will represent the operation (an `AttrValue`
3516        proto).
3517      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3518        the operation will have.
3519      compute_device: (Optional.) If True, device functions will be executed to
3520        compute the device property of the Operation.
3521
3522    Raises:
3523      ValueError: if colocation conflicts with existing device assignment.
3524
3525    Returns:
3526      An `Operation` object.
3527    """
3528    self._check_not_finalized()
3529    if name is None:
3530      name = op_type
3531    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3532    # after removing the trailing '/'.
3533    if name and name[-1] == "/":
3534      name = name_from_scope_name(name)
3535    else:
3536      name = self.unique_name(name)
3537
3538    node_def = _NodeDef(op_type, name, attrs)
3539
3540    input_ops = set(t.op for t in inputs)
3541    control_inputs = self._control_dependencies_for_inputs(input_ops)
3542    # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3543    # Session.run call cannot occur between creating and mutating the op.
3544    with self._mutation_lock():
3545      ret = Operation(
3546          node_def,
3547          self,
3548          inputs=inputs,
3549          output_types=dtypes,
3550          control_inputs=control_inputs,
3551          input_types=input_types,
3552          original_op=self._default_original_op,
3553          op_def=op_def)
3554      self._create_op_helper(ret, compute_device=compute_device)
3555    return ret
3556
3557  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3558    """Creates an `Operation` in this graph from the supplied TF_Operation.
3559
3560    This method is like create_op() except the new Operation is constructed
3561    using `c_op`. The returned Operation will have `c_op` as its _c_op
3562    field. This is used to create Operation objects around TF_Operations created
3563    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3564
3565    This function does not call Operation._control_flow_post_processing or
3566    Graph._control_dependencies_for_inputs (since the inputs may not be
3567    available yet). The caller is responsible for calling these methods.
3568
3569    Args:
3570      c_op: a wrapped TF_Operation
3571      compute_device: (Optional.) If True, device functions will be executed to
3572        compute the device property of the Operation.
3573
3574    Returns:
3575      An `Operation` object.
3576    """
3577    self._check_not_finalized()
3578    ret = Operation(c_op, self)
3579    # If a name_scope was created with ret.name but no nodes were created in it,
3580    # the name will still appear in _names_in_use even though the name hasn't
3581    # been used. This is ok, just leave _names_in_use as-is in this case.
3582    # TODO(skyewm): make the C API guarantee no name conflicts.
3583    name_key = ret.name.lower()
3584    if name_key not in self._names_in_use:
3585      self._names_in_use[name_key] = 1
3586    self._create_op_helper(ret, compute_device=compute_device)
3587    return ret
3588
3589  def _create_op_helper(self, op, compute_device=True):
3590    """Common logic for creating an op in this graph."""
3591    # Apply any additional attributes requested. Do not overwrite any existing
3592    # attributes.
3593    for key, value in self._attr_scope_map.items():
3594      try:
3595        op.get_attr(key)
3596      except ValueError:
3597        if callable(value):
3598          value = value(op.node_def)
3599          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3600            raise TypeError(
3601                "Callable for scope map key '%s' must return either None or "
3602                "an AttrValue protocol buffer; but it returned: %s" %
3603                (key, value))
3604        if value:
3605          op._set_attr(key, value)  # pylint: disable=protected-access
3606
3607    # Apply a kernel label if one has been specified for this op type.
3608    try:
3609      kernel_label = self._op_to_kernel_label_map[op.type]
3610      op._set_attr("_kernel",  # pylint: disable=protected-access
3611                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3612    except KeyError:
3613      pass
3614
3615    op._gradient_function = self._gradient_function_map.get(op.type)  # pylint: disable=protected-access
3616
3617    # Apply the overriding op type for gradients if one has been specified for
3618    # this op type.
3619    try:
3620      mapped_op_type = self._gradient_override_map[op.type]
3621      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3622                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3623    except KeyError:
3624      pass
3625
3626    self._record_op_seen_by_control_dependencies(op)
3627
3628    if compute_device:
3629      self._apply_device_functions(op)
3630
3631    # Snapshot the colocation stack metadata before we might generate error
3632    # messages using it.  Note that this snapshot depends on the actual stack
3633    # and is independent of the op's _class attribute.
3634    # pylint: disable=protected-access
3635    op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3636    # pylint: enable=protected-access
3637
3638    if self._colocation_stack:
3639      all_colocation_groups = []
3640      is_device_set = False
3641      for colocation_op in self._colocation_stack.peek_objs():
3642        try:
3643          all_colocation_groups.extend(colocation_op.colocation_groups())
3644        except AttributeError:
3645          pass
3646        if colocation_op.device and not is_device_set:
3647          # pylint: disable=protected-access
3648          op._set_device(colocation_op.device)
3649          # pylint: enable=protected-access
3650          is_device_set = True
3651
3652      all_colocation_groups = sorted(set(all_colocation_groups))
3653      # pylint: disable=protected-access
3654      op._set_attr(
3655          "_class",
3656          attr_value_pb2.AttrValue(
3657              list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3658      # pylint: enable=protected-access
3659
3660    # Sets "container" attribute if
3661    # (1) self._container is not None
3662    # (2) "is_stateful" is set in OpDef
3663    # (3) "container" attribute is in OpDef
3664    # (4) "container" attribute is None
3665    if self._container and op._is_stateful:  # pylint: disable=protected-access
3666      try:
3667        container_attr = op.get_attr("container")
3668      except ValueError:
3669        # "container" attribute is not in OpDef
3670        pass
3671      else:
3672        if not container_attr:
3673          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3674              s=compat.as_bytes(self._container)))
3675
3676  def _add_new_tf_operations(self, compute_devices=True):
3677    """Creates `Operations` in this graph for any new TF_Operations.
3678
3679    This is useful for when TF_Operations are indirectly created by the C API
3680    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3681    TF_FinishWhile). This ensures there are corresponding Operations for all
3682    TF_Operations in the underlying TF_Graph.
3683
3684    Args:
3685      compute_devices: (Optional.) If True, device functions will be executed to
3686        compute the device properties of each new Operation.
3687
3688    Returns:
3689      A list of the new `Operation` objects.
3690    """
3691    self._check_not_finalized()
3692
3693    # Create all Operation objects before accessing their inputs since an op may
3694    # be created before its inputs.
3695    new_ops = [
3696        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3697        for c_op in c_api_util.new_tf_operations(self)
3698    ]
3699
3700    # pylint: disable=protected-access
3701    for op in new_ops:
3702      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3703      op._add_control_inputs(new_control_inputs)
3704      op._control_flow_post_processing()
3705    # pylint: enable=protected-access
3706
3707    return new_ops
3708
3709  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3710    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3711
3712    This function validates that `obj` represents an element of this
3713    graph, and gives an informative error message if it is not.
3714
3715    This function is the canonical way to get/validate an object of
3716    one of the allowed types from an external argument reference in the
3717    Session API.
3718
3719    This method may be called concurrently from multiple threads.
3720
3721    Args:
3722      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can
3723        also be any object with an `_as_graph_element()` method that returns a
3724        value of one of these types. Note: `_as_graph_element` will be called
3725        inside the graph's lock and so may not modify the graph.
3726      allow_tensor: If true, `obj` may refer to a `Tensor`.
3727      allow_operation: If true, `obj` may refer to an `Operation`.
3728
3729    Returns:
3730      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3731
3732    Raises:
3733      TypeError: If `obj` is not a type we support attempting to convert
3734        to types.
3735      ValueError: If `obj` is of an appropriate type but invalid. For
3736        example, an invalid string.
3737      KeyError: If `obj` is not an object in the graph.
3738    """
3739    if self._finalized:
3740      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3741
3742    with self._lock:
3743      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3744
3745  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3746    """See `Graph.as_graph_element()` for details."""
3747    # The vast majority of this function is figuring
3748    # out what an API user might be doing wrong, so
3749    # that we can give helpful error messages.
3750    #
3751    # Ideally, it would be nice to split it up, but we
3752    # need context to generate nice error messages.
3753
3754    if allow_tensor and allow_operation:
3755      types_str = "Tensor or Operation"
3756    elif allow_tensor:
3757      types_str = "Tensor"
3758    elif allow_operation:
3759      types_str = "Operation"
3760    else:
3761      raise ValueError("allow_tensor and allow_operation can't both be False.")
3762
3763    temp_obj = _as_graph_element(obj)
3764    if temp_obj is not None:
3765      obj = temp_obj
3766
3767    # If obj appears to be a name...
3768    if isinstance(obj, compat.bytes_or_text_types):
3769      name = compat.as_str(obj)
3770
3771      if ":" in name and allow_tensor:
3772        # Looks like a Tensor name and can be a Tensor.
3773        try:
3774          op_name, out_n = name.split(":")
3775          out_n = int(out_n)
3776        except:
3777          raise ValueError("The name %s looks a like a Tensor name, but is "
3778                           "not a valid one. Tensor names must be of the "
3779                           "form \"<op_name>:<output_index>\"." % repr(name))
3780        if op_name in self._nodes_by_name:
3781          op = self._nodes_by_name[op_name]
3782        else:
3783          raise KeyError("The name %s refers to a Tensor which does not "
3784                         "exist. The operation, %s, does not exist in the "
3785                         "graph." % (repr(name), repr(op_name)))
3786        try:
3787          return op.outputs[out_n]
3788        except:
3789          raise KeyError("The name %s refers to a Tensor which does not "
3790                         "exist. The operation, %s, exists but only has "
3791                         "%s outputs." %
3792                         (repr(name), repr(op_name), len(op.outputs)))
3793
3794      elif ":" in name and not allow_tensor:
3795        # Looks like a Tensor name but can't be a Tensor.
3796        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3797                         (repr(name), types_str))
3798
3799      elif ":" not in name and allow_operation:
3800        # Looks like an Operation name and can be an Operation.
3801        if name not in self._nodes_by_name:
3802          raise KeyError("The name %s refers to an Operation not in the "
3803                         "graph." % repr(name))
3804        return self._nodes_by_name[name]
3805
3806      elif ":" not in name and not allow_operation:
3807        # Looks like an Operation name but can't be an Operation.
3808        if name in self._nodes_by_name:
3809          # Yep, it's an Operation name
3810          err_msg = ("The name %s refers to an Operation, not a %s." %
3811                     (repr(name), types_str))
3812        else:
3813          err_msg = ("The name %s looks like an (invalid) Operation name, "
3814                     "not a %s." % (repr(name), types_str))
3815        err_msg += (" Tensor names must be of the form "
3816                    "\"<op_name>:<output_index>\".")
3817        raise ValueError(err_msg)
3818
3819    elif isinstance(obj, Tensor) and allow_tensor:
3820      # Actually obj is just the object it's referring to.
3821      if obj.graph is not self:
3822        raise ValueError("Tensor %s is not an element of this graph." % obj)
3823      return obj
3824    elif isinstance(obj, Operation) and allow_operation:
3825      # Actually obj is just the object it's referring to.
3826      if obj.graph is not self:
3827        raise ValueError("Operation %s is not an element of this graph." % obj)
3828      return obj
3829    else:
3830      # We give up!
3831      raise TypeError("Can not convert a %s into a %s." %
3832                      (type(obj).__name__, types_str))
3833
3834  def get_operations(self):
3835    """Return the list of operations in the graph.
3836
3837    You can modify the operations in place, but modifications
3838    to the list such as inserts/delete have no effect on the
3839    list of operations known to the graph.
3840
3841    This method may be called concurrently from multiple threads.
3842
3843    Returns:
3844      A list of Operations.
3845    """
3846    if self._finalized:
3847      return list(self._nodes_by_id.values())
3848
3849    with self._lock:
3850      return list(self._nodes_by_id.values())
3851
3852  def get_operation_by_name(self, name):
3853    """Returns the `Operation` with the given `name`.
3854
3855    This method may be called concurrently from multiple threads.
3856
3857    Args:
3858      name: The name of the `Operation` to return.
3859
3860    Returns:
3861      The `Operation` with the given `name`.
3862
3863    Raises:
3864      TypeError: If `name` is not a string.
3865      KeyError: If `name` does not correspond to an operation in this graph.
3866    """
3867
3868    if not isinstance(name, six.string_types):
3869      raise TypeError("Operation names are strings (or similar), not %s." %
3870                      type(name).__name__)
3871    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3872
3873  def _get_operation_by_name_unsafe(self, name):
3874    """Returns the `Operation` with the given `name`.
3875
3876    This is a internal unsafe version of get_operation_by_name. It skips many
3877    checks and does not have user friendly error messages but runs considerably
3878    faster. This method may be called concurrently from multiple threads.
3879
3880    Args:
3881      name: The name of the `Operation` to return.
3882
3883    Returns:
3884      The `Operation` with the given `name`.
3885
3886    Raises:
3887      KeyError: If `name` does not correspond to an operation in this graph.
3888    """
3889
3890    if self._finalized:
3891      return self._nodes_by_name[name]
3892
3893    with self._lock:
3894      return self._nodes_by_name[name]
3895
3896  def _get_operation_by_tf_operation(self, tf_oper):
3897    op_name = pywrap_tf_session.TF_OperationName(tf_oper)
3898    return self._get_operation_by_name_unsafe(op_name)
3899
3900  def get_tensor_by_name(self, name):
3901    """Returns the `Tensor` with the given `name`.
3902
3903    This method may be called concurrently from multiple threads.
3904
3905    Args:
3906      name: The name of the `Tensor` to return.
3907
3908    Returns:
3909      The `Tensor` with the given `name`.
3910
3911    Raises:
3912      TypeError: If `name` is not a string.
3913      KeyError: If `name` does not correspond to a tensor in this graph.
3914    """
3915    # Names should be strings.
3916    if not isinstance(name, six.string_types):
3917      raise TypeError("Tensor names are strings (or similar), not %s." %
3918                      type(name).__name__)
3919    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
3920
3921  def _get_tensor_by_tf_output(self, tf_output):
3922    """Returns the `Tensor` representing `tf_output`.
3923
3924    Note that there is only one such `Tensor`, i.e. multiple calls to this
3925    function with the same TF_Output value will always return the same `Tensor`
3926    object.
3927
3928    Args:
3929      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
3930
3931    Returns:
3932      The `Tensor` that represents `tf_output`.
3933    """
3934    op = self._get_operation_by_tf_operation(tf_output.oper)
3935    return op.outputs[tf_output.index]
3936
3937  @property
3938  def _last_id(self):
3939    return self._next_id_counter
3940
3941  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
3942    """Returns the `OpDef` proto for `type`. `type` is a string."""
3943    # NOTE: No locking is required because the lookup and insertion operations
3944    # on Python dictionaries are atomic.
3945    try:
3946      return self._op_def_cache[type]
3947    except KeyError:
3948      with c_api_util.tf_buffer() as buf:
3949        # pylint: disable=protected-access
3950        pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type),
3951                                           buf)
3952        # pylint: enable=protected-access
3953        data = pywrap_tf_session.TF_GetBuffer(buf)
3954      op_def = op_def_pb2.OpDef()
3955      op_def.ParseFromString(compat.as_bytes(data))
3956      self._op_def_cache[type] = op_def
3957      return op_def
3958
3959  def as_default(self):
3960    """Returns a context manager that makes this `Graph` the default graph.
3961
3962    This method should be used if you want to create multiple graphs
3963    in the same process. For convenience, a global default graph is
3964    provided, and all ops will be added to this graph if you do not
3965    create a new graph explicitly.
3966
3967    Use this method with the `with` keyword to specify that ops created within
3968    the scope of a block should be added to this graph. In this case, once
3969    the scope of the `with` is exited, the previous default graph is set again
3970    as default. There is a stack, so it's ok to have multiple nested levels
3971    of `as_default` calls.
3972
3973    The default graph is a property of the current thread. If you
3974    create a new thread, and wish to use the default graph in that
3975    thread, you must explicitly add a `with g.as_default():` in that
3976    thread's function.
3977
3978    The following code examples are equivalent:
3979
3980    ```python
3981    # 1. Using Graph.as_default():
3982    g = tf.Graph()
3983    with g.as_default():
3984      c = tf.constant(5.0)
3985      assert c.graph is g
3986
3987    # 2. Constructing and making default:
3988    with tf.Graph().as_default() as g:
3989      c = tf.constant(5.0)
3990      assert c.graph is g
3991    ```
3992
3993    If eager execution is enabled ops created under this context manager will be
3994    added to the graph instead of executed eagerly.
3995
3996    Returns:
3997      A context manager for using this graph as the default graph.
3998    """
3999    return _default_graph_stack.get_controller(self)
4000
4001  @property
4002  def collections(self):
4003    """Returns the names of the collections known to this graph."""
4004    return list(self._collections)
4005
4006  def add_to_collection(self, name, value):
4007    """Stores `value` in the collection with the given `name`.
4008
4009    Note that collections are not sets, so it is possible to add a value to
4010    a collection several times.
4011
4012    Args:
4013      name: The key for the collection. The `GraphKeys` class contains many
4014        standard names for collections.
4015      value: The value to add to the collection.
4016    """  # pylint: disable=g-doc-exception
4017    self._check_not_finalized()
4018    with self._lock:
4019      if name not in self._collections:
4020        self._collections[name] = [value]
4021      else:
4022        self._collections[name].append(value)
4023
4024  def add_to_collections(self, names, value):
4025    """Stores `value` in the collections given by `names`.
4026
4027    Note that collections are not sets, so it is possible to add a value to
4028    a collection several times. This function makes sure that duplicates in
4029    `names` are ignored, but it will not check for pre-existing membership of
4030    `value` in any of the collections in `names`.
4031
4032    `names` can be any iterable, but if `names` is a string, it is treated as a
4033    single collection name.
4034
4035    Args:
4036      names: The keys for the collections to add to. The `GraphKeys` class
4037        contains many standard names for collections.
4038      value: The value to add to the collections.
4039    """
4040    # Make sure names are unique, but treat strings as a single collection name
4041    names = (names,) if isinstance(names, six.string_types) else set(names)
4042    for name in names:
4043      self.add_to_collection(name, value)
4044
4045  def get_collection_ref(self, name):
4046    """Returns a list of values in the collection with the given `name`.
4047
4048    If the collection exists, this returns the list itself, which can
4049    be modified in place to change the collection.  If the collection does
4050    not exist, it is created as an empty list and the list is returned.
4051
4052    This is different from `get_collection()` which always returns a copy of
4053    the collection list if it exists and never creates an empty collection.
4054
4055    Args:
4056      name: The key for the collection. For example, the `GraphKeys` class
4057        contains many standard names for collections.
4058
4059    Returns:
4060      The list of values in the collection with the given `name`, or an empty
4061      list if no value has been added to that collection.
4062    """  # pylint: disable=g-doc-exception
4063    with self._lock:
4064      coll_list = self._collections.get(name, None)
4065      if coll_list is None:
4066        coll_list = []
4067        self._collections[name] = coll_list
4068      return coll_list
4069
4070  def get_collection(self, name, scope=None):
4071    """Returns a list of values in the collection with the given `name`.
4072
4073    This is different from `get_collection_ref()` which always returns the
4074    actual collection list if it exists in that it returns a new list each time
4075    it is called.
4076
4077    Args:
4078      name: The key for the collection. For example, the `GraphKeys` class
4079        contains many standard names for collections.
4080      scope: (Optional.) A string. If supplied, the resulting list is filtered
4081        to include only items whose `name` attribute matches `scope` using
4082        `re.match`. Items without a `name` attribute are never returned if a
4083        scope is supplied. The choice of `re.match` means that a `scope` without
4084        special tokens filters by prefix.
4085
4086    Returns:
4087      The list of values in the collection with the given `name`, or
4088      an empty list if no value has been added to that collection. The
4089      list contains the values in the order under which they were
4090      collected.
4091    """  # pylint: disable=g-doc-exception
4092    with self._lock:
4093      collection = self._collections.get(name, None)
4094      if collection is None:
4095        return []
4096      if scope is None:
4097        return list(collection)
4098      else:
4099        c = []
4100        regex = re.compile(scope)
4101        for item in collection:
4102          try:
4103            if regex.match(item.name):
4104              c.append(item)
4105          except AttributeError:
4106            # Collection items with no name are ignored.
4107            pass
4108        return c
4109
4110  def get_all_collection_keys(self):
4111    """Returns a list of collections used in this graph."""
4112    with self._lock:
4113      return [x for x in self._collections if isinstance(x, six.string_types)]
4114
4115  def clear_collection(self, name):
4116    """Clears all values in a collection.
4117
4118    Args:
4119      name: The key for the collection. The `GraphKeys` class contains many
4120        standard names for collections.
4121    """
4122    self._check_not_finalized()
4123    with self._lock:
4124      if name in self._collections:
4125        del self._collections[name]
4126
4127  @tf_contextlib.contextmanager
4128  def _original_op(self, op):
4129    """Python 'with' handler to help annotate ops with their originator.
4130
4131    An op may have an 'original_op' property that indicates the op on which
4132    it was based. For example a replica op is based on the op that was
4133    replicated and a gradient op is based on the op that was differentiated.
4134
4135    All ops created in the scope of this 'with' handler will have
4136    the given 'op' as their original op.
4137
4138    Args:
4139      op: The Operation that all ops created in this scope will have as their
4140        original op.
4141
4142    Yields:
4143      Nothing.
4144    """
4145    old_original_op = self._default_original_op
4146    self._default_original_op = op
4147    try:
4148      yield
4149    finally:
4150      self._default_original_op = old_original_op
4151
4152  @property
4153  def _name_stack(self):
4154    # This may be called from a thread where name_stack doesn't yet exist.
4155    if not hasattr(self._thread_local, "_name_stack"):
4156      self._thread_local._name_stack = ""
4157    return self._thread_local._name_stack
4158
4159  @_name_stack.setter
4160  def _name_stack(self, name_stack):
4161    self._thread_local._name_stack = name_stack
4162
4163  # pylint: disable=g-doc-return-or-yield,line-too-long
4164  @tf_contextlib.contextmanager
4165  def name_scope(self, name):
4166    """Returns a context manager that creates hierarchical names for operations.
4167
4168    A graph maintains a stack of name scopes. A `with name_scope(...):`
4169    statement pushes a new name onto the stack for the lifetime of the context.
4170
4171    The `name` argument will be interpreted as follows:
4172
4173    * A string (not ending with '/') will create a new name scope, in which
4174      `name` is appended to the prefix of all operations created in the
4175      context. If `name` has been used before, it will be made unique by
4176      calling `self.unique_name(name)`.
4177    * A scope previously captured from a `with g.name_scope(...) as
4178      scope:` statement will be treated as an "absolute" name scope, which
4179      makes it possible to re-enter existing scopes.
4180    * A value of `None` or the empty string will reset the current name scope
4181      to the top-level (empty) name scope.
4182
4183    For example:
4184
4185    ```python
4186    with tf.Graph().as_default() as g:
4187      c = tf.constant(5.0, name="c")
4188      assert c.op.name == "c"
4189      c_1 = tf.constant(6.0, name="c")
4190      assert c_1.op.name == "c_1"
4191
4192      # Creates a scope called "nested"
4193      with g.name_scope("nested") as scope:
4194        nested_c = tf.constant(10.0, name="c")
4195        assert nested_c.op.name == "nested/c"
4196
4197        # Creates a nested scope called "inner".
4198        with g.name_scope("inner"):
4199          nested_inner_c = tf.constant(20.0, name="c")
4200          assert nested_inner_c.op.name == "nested/inner/c"
4201
4202        # Create a nested scope called "inner_1".
4203        with g.name_scope("inner"):
4204          nested_inner_1_c = tf.constant(30.0, name="c")
4205          assert nested_inner_1_c.op.name == "nested/inner_1/c"
4206
4207          # Treats `scope` as an absolute name scope, and
4208          # switches to the "nested/" scope.
4209          with g.name_scope(scope):
4210            nested_d = tf.constant(40.0, name="d")
4211            assert nested_d.op.name == "nested/d"
4212
4213            with g.name_scope(""):
4214              e = tf.constant(50.0, name="e")
4215              assert e.op.name == "e"
4216    ```
4217
4218    The name of the scope itself can be captured by `with
4219    g.name_scope(...) as scope:`, which stores the name of the scope
4220    in the variable `scope`. This value can be used to name an
4221    operation that represents the overall result of executing the ops
4222    in a scope. For example:
4223
4224    ```python
4225    inputs = tf.constant(...)
4226    with g.name_scope('my_layer') as scope:
4227      weights = tf.Variable(..., name="weights")
4228      biases = tf.Variable(..., name="biases")
4229      affine = tf.matmul(inputs, weights) + biases
4230      output = tf.nn.relu(affine, name=scope)
4231    ```
4232
4233    NOTE: This constructor validates the given `name`. Valid scope
4234    names match one of the following regular expressions:
4235
4236        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4237        [A-Za-z0-9_.\\-/]* (for other scopes)
4238
4239    Args:
4240      name: A name for the scope.
4241
4242    Returns:
4243      A context manager that installs `name` as a new name scope.
4244
4245    Raises:
4246      ValueError: If `name` is not a valid scope name, according to the rules
4247        above.
4248    """
4249    if name:
4250      if isinstance(name, compat.bytes_or_text_types):
4251        name = compat.as_str(name)
4252
4253      if self._name_stack:
4254        # Scopes created in a nested scope may have initial characters
4255        # that are illegal as the initial character of an op name
4256        # (viz. '-', '\', '/', and '_').
4257        if not _VALID_SCOPE_NAME_REGEX.match(name):
4258          raise ValueError("'%s' is not a valid scope name" % name)
4259      else:
4260        # Scopes created in the root must match the more restrictive
4261        # op name regex, which constrains the initial character.
4262        if not _VALID_OP_NAME_REGEX.match(name):
4263          raise ValueError("'%s' is not a valid scope name" % name)
4264    old_stack = self._name_stack
4265    if not name:  # Both for name=None and name="" we re-set to empty scope.
4266      new_stack = None
4267    elif name[-1] == "/":
4268      new_stack = name_from_scope_name(name)
4269    else:
4270      new_stack = self.unique_name(name)
4271    self._name_stack = new_stack
4272    try:
4273      yield "" if new_stack is None else new_stack + "/"
4274    finally:
4275      self._name_stack = old_stack
4276
4277  # pylint: enable=g-doc-return-or-yield,line-too-long
4278
4279  def unique_name(self, name, mark_as_used=True):
4280    """Return a unique operation name for `name`.
4281
4282    Note: You rarely need to call `unique_name()` directly.  Most of
4283    the time you just need to create `with g.name_scope()` blocks to
4284    generate structured names.
4285
4286    `unique_name` is used to generate structured names, separated by
4287    `"/"`, to help identify operations when debugging a graph.
4288    Operation names are displayed in error messages reported by the
4289    TensorFlow runtime, and in various visualization tools such as
4290    TensorBoard.
4291
4292    If `mark_as_used` is set to `True`, which is the default, a new
4293    unique name is created and marked as in use. If it's set to `False`,
4294    the unique name is returned without actually being marked as used.
4295    This is useful when the caller simply wants to know what the name
4296    to be created will be.
4297
4298    Args:
4299      name: The name for an operation.
4300      mark_as_used: Whether to mark this name as being used.
4301
4302    Returns:
4303      A string to be passed to `create_op()` that will be used
4304      to name the operation being created.
4305    """
4306    if self._name_stack:
4307      name = self._name_stack + "/" + name
4308
4309    # For the sake of checking for names in use, we treat names as case
4310    # insensitive (e.g. foo = Foo).
4311    name_key = name.lower()
4312    i = self._names_in_use.get(name_key, 0)
4313    # Increment the number for "name_key".
4314    if mark_as_used:
4315      self._names_in_use[name_key] = i + 1
4316    if i > 0:
4317      base_name_key = name_key
4318      # Make sure the composed name key is not already used.
4319      while name_key in self._names_in_use:
4320        name_key = "%s_%d" % (base_name_key, i)
4321        i += 1
4322      # Mark the composed name_key as used in case someone wants
4323      # to call unique_name("name_1").
4324      if mark_as_used:
4325        self._names_in_use[name_key] = 1
4326
4327      # Return the new name with the original capitalization of the given name.
4328      name = "%s_%d" % (name, i - 1)
4329    return name
4330
4331  def get_name_scope(self):
4332    """Returns the current name scope.
4333
4334    For example:
4335
4336    ```python
4337    with tf.name_scope('scope1'):
4338      with tf.name_scope('scope2'):
4339        print(tf.compat.v1.get_default_graph().get_name_scope())
4340    ```
4341    would print the string `scope1/scope2`.
4342
4343    Returns:
4344      A string representing the current name scope.
4345    """
4346    return self._name_stack
4347
4348  @tf_contextlib.contextmanager
4349  def _colocate_with_for_gradient(self, op, gradient_uid,
4350                                  ignore_existing=False):
4351    with self.colocate_with(op, ignore_existing):
4352      if gradient_uid is not None:
4353        ctx = _get_enclosing_context(self)
4354        if ctx is not None:
4355          ctx.EnterGradientColocation(op, gradient_uid)
4356          try:
4357            yield
4358          finally:
4359            ctx.ExitGradientColocation(op, gradient_uid)
4360        else:
4361          yield
4362      else:
4363        yield
4364
4365  @tf_contextlib.contextmanager
4366  def colocate_with(self, op, ignore_existing=False):
4367    """Returns a context manager that specifies an op to colocate with.
4368
4369    Note: this function is not for public use, only for internal libraries.
4370
4371    For example:
4372
4373    ```python
4374    a = tf.Variable([1.0])
4375    with g.colocate_with(a):
4376      b = tf.constant(1.0)
4377      c = tf.add(a, b)
4378    ```
4379
4380    `b` and `c` will always be colocated with `a`, no matter where `a`
4381    is eventually placed.
4382
4383    **NOTE** Using a colocation scope resets any existing device constraints.
4384
4385    If `op` is `None` then `ignore_existing` must be `True` and the new
4386    scope resets all colocation and device constraints.
4387
4388    Args:
4389      op: The op to colocate all created ops with, or `None`.
4390      ignore_existing: If true, only applies colocation of this op within the
4391        context, rather than applying all colocation properties on the stack.
4392        If `op` is `None`, this value must be `True`.
4393
4394    Raises:
4395      ValueError: if op is None but ignore_existing is False.
4396
4397    Yields:
4398      A context manager that specifies the op with which to colocate
4399      newly created ops.
4400    """
4401    if op is None and not ignore_existing:
4402      raise ValueError("Trying to reset colocation (op is None) but "
4403                       "ignore_existing is not True")
4404    op, device_only_candidate = _op_to_colocate_with(op, self)
4405
4406    # By default, colocate_with resets the device function stack,
4407    # since colocate_with is typically used in specific internal
4408    # library functions where colocation is intended to be "stronger"
4409    # than device functions.
4410    #
4411    # In the future, a caller may specify that device_functions win
4412    # over colocation, in which case we can add support.
4413    device_fn_tmp = self._device_function_stack
4414    self._device_function_stack = traceable_stack.TraceableStack()
4415
4416    if ignore_existing:
4417      current_stack = self._colocation_stack
4418      self._colocation_stack = traceable_stack.TraceableStack()
4419
4420    if op is not None:
4421      # offset refers to the stack frame used for storing code location.
4422      # We use 4, the sum of 1 to use our caller's stack frame and 3
4423      # to jump over layers of context managers above us.
4424      if device_only_candidate is not None:
4425        self._colocation_stack.push_obj(device_only_candidate, offset=4)
4426      self._colocation_stack.push_obj(op, offset=4)
4427    elif not ignore_existing:
4428      raise ValueError("Trying to reset colocation (op is None) but "
4429                       "ignore_existing is not True")
4430    try:
4431      yield
4432    finally:
4433      # Restore device function stack
4434      self._device_function_stack = device_fn_tmp
4435      if op is not None:
4436        self._colocation_stack.pop_obj()
4437        if device_only_candidate is not None:
4438          self._colocation_stack.pop_obj()
4439
4440      # Reset the colocation stack if requested.
4441      if ignore_existing:
4442        self._colocation_stack = current_stack
4443
4444  def _add_device_to_stack(self, device_name_or_function, offset=0):
4445    """Add device to stack manually, separate from a context manager."""
4446    total_offset = 1 + offset
4447    spec = _UserDeviceSpec(device_name_or_function)
4448    self._device_function_stack.push_obj(spec, offset=total_offset)
4449    return spec
4450
4451  @tf_contextlib.contextmanager
4452  def device(self, device_name_or_function):
4453    # pylint: disable=line-too-long
4454    """Returns a context manager that specifies the default device to use.
4455
4456    The `device_name_or_function` argument may either be a device name
4457    string, a device function, or None:
4458
4459    * If it is a device name string, all operations constructed in
4460      this context will be assigned to the device with that name, unless
4461      overridden by a nested `device()` context.
4462    * If it is a function, it will be treated as a function from
4463      Operation objects to device name strings, and invoked each time
4464      a new Operation is created. The Operation will be assigned to
4465      the device with the returned name.
4466    * If it is None, all `device()` invocations from the enclosing context
4467      will be ignored.
4468
4469    For information about the valid syntax of device name strings, see
4470    the documentation in
4471    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4472
4473    For example:
4474
4475    ```python
4476    with g.device('/device:GPU:0'):
4477      # All operations constructed in this context will be placed
4478      # on GPU 0.
4479      with g.device(None):
4480        # All operations constructed in this context will have no
4481        # assigned device.
4482
4483    # Defines a function from `Operation` to device string.
4484    def matmul_on_gpu(n):
4485      if n.type == "MatMul":
4486        return "/device:GPU:0"
4487      else:
4488        return "/cpu:0"
4489
4490    with g.device(matmul_on_gpu):
4491      # All operations of type "MatMul" constructed in this context
4492      # will be placed on GPU 0; all other operations will be placed
4493      # on CPU 0.
4494    ```
4495
4496    **N.B.** The device scope may be overridden by op wrappers or
4497    other library code. For example, a variable assignment op
4498    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4499    incompatible device scopes will be ignored.
4500
4501    Args:
4502      device_name_or_function: The device name or function to use in the
4503        context.
4504
4505    Yields:
4506      A context manager that specifies the default device to use for newly
4507      created ops.
4508
4509    Raises:
4510      RuntimeError: If device scopes are not properly nested.
4511    """
4512    self._add_device_to_stack(device_name_or_function, offset=2)
4513    old_top_of_stack = self._device_function_stack.peek_top_obj()
4514    try:
4515      yield
4516    finally:
4517      new_top_of_stack = self._device_function_stack.peek_top_obj()
4518      if old_top_of_stack is not new_top_of_stack:
4519        raise RuntimeError("Exiting device scope without proper scope nesting.")
4520      self._device_function_stack.pop_obj()
4521
4522  def _apply_device_functions(self, op):
4523    """Applies the current device function stack to the given operation."""
4524    # Apply any device functions in LIFO order, so that the most recently
4525    # pushed function has the first chance to apply a device to the op.
4526    # We apply here because the result can depend on the Operation's
4527    # signature, which is computed in the Operation constructor.
4528    # pylint: disable=protected-access
4529    prior_device_string = None
4530    for device_spec in self._device_function_stack.peek_objs():
4531      if device_spec.is_null_merge:
4532        continue
4533
4534      if device_spec.function is None:
4535        break
4536
4537      device_string = device_spec.string_merge(op)
4538
4539      # Take advantage of the fact that None is a singleton and Python interns
4540      # strings, since identity checks are faster than equality checks.
4541      if device_string is not prior_device_string:
4542        op._set_device_from_string(device_string)
4543        prior_device_string = device_string
4544    op._device_code_locations = self._snapshot_device_function_stack_metadata()
4545    # pylint: enable=protected-access
4546
4547  # pylint: disable=g-doc-return-or-yield
4548  @tf_contextlib.contextmanager
4549  def container(self, container_name):
4550    """Returns a context manager that specifies the resource container to use.
4551
4552    Stateful operations, such as variables and queues, can maintain their
4553    states on devices so that they can be shared by multiple processes.
4554    A resource container is a string name under which these stateful
4555    operations are tracked. These resources can be released or cleared
4556    with `tf.Session.reset()`.
4557
4558    For example:
4559
4560    ```python
4561    with g.container('experiment0'):
4562      # All stateful Operations constructed in this context will be placed
4563      # in resource container "experiment0".
4564      v1 = tf.Variable([1.0])
4565      v2 = tf.Variable([2.0])
4566      with g.container("experiment1"):
4567        # All stateful Operations constructed in this context will be
4568        # placed in resource container "experiment1".
4569        v3 = tf.Variable([3.0])
4570        q1 = tf.queue.FIFOQueue(10, tf.float32)
4571      # All stateful Operations constructed in this context will be
4572      # be created in the "experiment0".
4573      v4 = tf.Variable([4.0])
4574      q1 = tf.queue.FIFOQueue(20, tf.float32)
4575      with g.container(""):
4576        # All stateful Operations constructed in this context will be
4577        # be placed in the default resource container.
4578        v5 = tf.Variable([5.0])
4579        q3 = tf.queue.FIFOQueue(30, tf.float32)
4580
4581    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4582    # will become undefined (such as uninitialized).
4583    tf.Session.reset(target, ["experiment0"])
4584    ```
4585
4586    Args:
4587      container_name: container name string.
4588
4589    Returns:
4590      A context manager for defining resource containers for stateful ops,
4591        yields the container name.
4592    """
4593    original_container = self._container
4594    self._container = container_name
4595    try:
4596      yield self._container
4597    finally:
4598      self._container = original_container
4599
4600  # pylint: enable=g-doc-return-or-yield
4601
4602  class _ControlDependenciesController(object):
4603    """Context manager for `control_dependencies()`."""
4604
4605    def __init__(self, graph, control_inputs):
4606      """Create a new `_ControlDependenciesController`.
4607
4608      A `_ControlDependenciesController` is the context manager for
4609      `with tf.control_dependencies()` blocks.  These normally nest,
4610      as described in the documentation for `control_dependencies()`.
4611
4612      The `control_inputs` argument list control dependencies that must be
4613      added to the current set of control dependencies.  Because of
4614      uniquification the set can be empty even if the caller passed a list of
4615      ops.  The special value `None` indicates that we want to start a new
4616      empty set of control dependencies instead of extending the current set.
4617
4618      In that case we also clear the current control flow context, which is an
4619      additional mechanism to add control dependencies.
4620
4621      Args:
4622        graph: The graph that this controller is managing.
4623        control_inputs: List of ops to use as control inputs in addition to the
4624          current control dependencies.  None to indicate that the dependencies
4625          should be cleared.
4626      """
4627      self._graph = graph
4628      if control_inputs is None:
4629        self._control_inputs_val = []
4630        self._new_stack = True
4631      else:
4632        self._control_inputs_val = control_inputs
4633        self._new_stack = False
4634      self._seen_nodes = set()
4635      self._old_stack = None
4636      self._old_control_flow_context = None
4637
4638# pylint: disable=protected-access
4639
4640    def __enter__(self):
4641      if self._new_stack:
4642        # Clear the control_dependencies graph.
4643        self._old_stack = self._graph._control_dependencies_stack
4644        self._graph._control_dependencies_stack = []
4645        # Clear the control_flow_context too.
4646        self._old_control_flow_context = self._graph._get_control_flow_context()
4647        self._graph._set_control_flow_context(None)
4648      self._graph._push_control_dependencies_controller(self)
4649
4650    def __exit__(self, unused_type, unused_value, unused_traceback):
4651      self._graph._pop_control_dependencies_controller(self)
4652      if self._new_stack:
4653        self._graph._control_dependencies_stack = self._old_stack
4654        self._graph._set_control_flow_context(self._old_control_flow_context)
4655
4656# pylint: enable=protected-access
4657
4658    @property
4659    def control_inputs(self):
4660      return self._control_inputs_val
4661
4662    def add_op(self, op):
4663      if isinstance(op, Tensor):
4664        op = op.ref()
4665      self._seen_nodes.add(op)
4666
4667    def op_in_group(self, op):
4668      if isinstance(op, Tensor):
4669        op = op.ref()
4670      return op in self._seen_nodes
4671
4672  def _push_control_dependencies_controller(self, controller):
4673    self._control_dependencies_stack.append(controller)
4674
4675  def _pop_control_dependencies_controller(self, controller):
4676    assert self._control_dependencies_stack[-1] is controller
4677    self._control_dependencies_stack.pop()
4678
4679  def _current_control_dependencies(self):
4680    ret = set()
4681    for controller in self._control_dependencies_stack:
4682      for op in controller.control_inputs:
4683        ret.add(op)
4684    return ret
4685
4686  def _control_dependencies_for_inputs(self, input_ops):
4687    """For an op that takes `input_ops` as inputs, compute control inputs.
4688
4689    The returned control dependencies should yield an execution that
4690    is equivalent to adding all control inputs in
4691    self._control_dependencies_stack to a newly created op. However,
4692    this function attempts to prune the returned control dependencies
4693    by observing that nodes created within the same `with
4694    control_dependencies(...):` block may have data dependencies that make
4695    the explicit approach redundant.
4696
4697    Args:
4698      input_ops: The data input ops for an op to be created.
4699
4700    Returns:
4701      A list of control inputs for the op to be created.
4702    """
4703    ret = []
4704    for controller in self._control_dependencies_stack:
4705      # If any of the input_ops already depends on the inputs from controller,
4706      # we say that the new op is dominated (by that input), and we therefore
4707      # do not need to add control dependencies for this controller's inputs.
4708      dominated = False
4709      for op in input_ops:
4710        if controller.op_in_group(op):
4711          dominated = True
4712          break
4713      if not dominated:
4714        # Don't add a control input if we already have a data dependency on i.
4715        # NOTE(mrry): We do not currently track transitive data dependencies,
4716        #   so we may add redundant control inputs.
4717        ret.extend(c for c in controller.control_inputs if c not in input_ops)
4718    return ret
4719
4720  def _record_op_seen_by_control_dependencies(self, op):
4721    """Record that the given op depends on all registered control dependencies.
4722
4723    Args:
4724      op: An Operation.
4725    """
4726    for controller in self._control_dependencies_stack:
4727      controller.add_op(op)
4728
4729  def control_dependencies(self, control_inputs):
4730    """Returns a context manager that specifies control dependencies.
4731
4732    Use with the `with` keyword to specify that all operations constructed
4733    within the context should have control dependencies on
4734    `control_inputs`. For example:
4735
4736    ```python
4737    with g.control_dependencies([a, b, c]):
4738      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4739      d = ...
4740      e = ...
4741    ```
4742
4743    Multiple calls to `control_dependencies()` can be nested, and in
4744    that case a new `Operation` will have control dependencies on the union
4745    of `control_inputs` from all active contexts.
4746
4747    ```python
4748    with g.control_dependencies([a, b]):
4749      # Ops constructed here run after `a` and `b`.
4750      with g.control_dependencies([c, d]):
4751        # Ops constructed here run after `a`, `b`, `c`, and `d`.
4752    ```
4753
4754    You can pass None to clear the control dependencies:
4755
4756    ```python
4757    with g.control_dependencies([a, b]):
4758      # Ops constructed here run after `a` and `b`.
4759      with g.control_dependencies(None):
4760        # Ops constructed here run normally, not waiting for either `a` or `b`.
4761        with g.control_dependencies([c, d]):
4762          # Ops constructed here run after `c` and `d`, also not waiting
4763          # for either `a` or `b`.
4764    ```
4765
4766    *N.B.* The control dependencies context applies *only* to ops that
4767    are constructed within the context. Merely using an op or tensor
4768    in the context does not add a control dependency. The following
4769    example illustrates this point:
4770
4771    ```python
4772    # WRONG
4773    def my_func(pred, tensor):
4774      t = tf.matmul(tensor, tensor)
4775      with tf.control_dependencies([pred]):
4776        # The matmul op is created outside the context, so no control
4777        # dependency will be added.
4778        return t
4779
4780    # RIGHT
4781    def my_func(pred, tensor):
4782      with tf.control_dependencies([pred]):
4783        # The matmul op is created in the context, so a control dependency
4784        # will be added.
4785        return tf.matmul(tensor, tensor)
4786    ```
4787
4788    Also note that though execution of ops created under this scope will trigger
4789    execution of the dependencies, the ops created under this scope might still
4790    be pruned from a normal tensorflow graph. For example, in the following
4791    snippet of code the dependencies are never executed:
4792
4793    ```python
4794      loss = model.loss()
4795      with tf.control_dependencies(dependencies):
4796        loss = loss + tf.constant(1)  # note: dependencies ignored in the
4797                                      # backward pass
4798      return tf.gradients(loss, model.variables)
4799    ```
4800
4801    This is because evaluating the gradient graph does not require evaluating
4802    the constant(1) op created in the forward pass.
4803
4804    Args:
4805      control_inputs: A list of `Operation` or `Tensor` objects which must be
4806        executed or computed before running the operations defined in the
4807        context.  Can also be `None` to clear the control dependencies.
4808
4809    Returns:
4810     A context manager that specifies control dependencies for all
4811     operations constructed within the context.
4812
4813    Raises:
4814      TypeError: If `control_inputs` is not a list of `Operation` or
4815        `Tensor` objects.
4816    """
4817    if control_inputs is None:
4818      return self._ControlDependenciesController(self, None)
4819    # First convert the inputs to ops, and deduplicate them.
4820    # NOTE(mrry): Other than deduplication, we do not currently track direct
4821    #   or indirect dependencies between control_inputs, which may result in
4822    #   redundant control inputs.
4823    control_ops = []
4824    current = self._current_control_dependencies()
4825    for c in control_inputs:
4826      # The hasattr(handle) is designed to match ResourceVariables. This is so
4827      # control dependencies on a variable or on an unread variable don't
4828      # trigger reads.
4829      if (isinstance(c, IndexedSlices) or
4830          (hasattr(c, "_handle") and hasattr(c, "op"))):
4831        c = c.op
4832      c = self.as_graph_element(c)
4833      if isinstance(c, Tensor):
4834        c = c.op
4835      elif not isinstance(c, Operation):
4836        raise TypeError("Control input must be Operation or Tensor: %s" % c)
4837      if c not in current:
4838        control_ops.append(c)
4839        current.add(c)
4840    return self._ControlDependenciesController(self, control_ops)
4841
4842  # pylint: disable=g-doc-return-or-yield
4843  @tf_contextlib.contextmanager
4844  def _attr_scope(self, attr_map):
4845    """EXPERIMENTAL: A context manager for setting attributes on operators.
4846
4847    This context manager can be used to add additional
4848    attributes to operators within the scope of the context.
4849
4850    For example:
4851
4852       with ops.Graph().as_default() as g:
4853         f_1 = Foo()  # No extra attributes
4854         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4855           f_2 = Foo()  # Additional attribute _a=False
4856           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4857             f_3 = Foo()  # Additional attribute _a=False
4858             with g._attr_scope({"_a": None}):
4859               f_4 = Foo()  # No additional attributes.
4860
4861    Args:
4862      attr_map: A dictionary mapping attr name strings to AttrValue protocol
4863        buffers or None.
4864
4865    Returns:
4866      A context manager that sets the kernel label to be used for one or more
4867      ops created in that context.
4868
4869    Raises:
4870      TypeError: If attr_map is not a dictionary mapping
4871        strings to AttrValue protobufs.
4872    """
4873    if not isinstance(attr_map, dict):
4874      raise TypeError("attr_map must be a dictionary mapping "
4875                      "strings to AttrValue protocol buffers")
4876    # The saved_attrs dictionary stores any currently-set labels that
4877    # will be overridden by this context manager.
4878    saved_attrs = {}
4879    # Install the given attribute
4880    for name, attr in attr_map.items():
4881      if not (isinstance(name, six.string_types) and
4882              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4883               callable(attr))):
4884        raise TypeError("attr_map must be a dictionary mapping "
4885                        "strings to AttrValue protocol buffers or "
4886                        "callables that emit AttrValue protocol buffers")
4887      try:
4888        saved_attrs[name] = self._attr_scope_map[name]
4889      except KeyError:
4890        pass
4891      if attr is None:
4892        del self._attr_scope_map[name]
4893      else:
4894        self._attr_scope_map[name] = attr
4895    try:
4896      yield  # The code within the context runs here.
4897    finally:
4898      # Remove the attributes set for this context, and restore any saved
4899      # attributes.
4900      for name, attr in attr_map.items():
4901        try:
4902          self._attr_scope_map[name] = saved_attrs[name]
4903        except KeyError:
4904          del self._attr_scope_map[name]
4905
4906  # pylint: enable=g-doc-return-or-yield
4907
4908  # pylint: disable=g-doc-return-or-yield
4909  @tf_contextlib.contextmanager
4910  def _kernel_label_map(self, op_to_kernel_label_map):
4911    """EXPERIMENTAL: A context manager for setting kernel labels.
4912
4913    This context manager can be used to select particular
4914    implementations of kernels within the scope of the context.
4915
4916    For example:
4917
4918        with ops.Graph().as_default() as g:
4919          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
4920          with g.kernel_label_map({"Foo": "v_2"}):
4921            f_2 = Foo()  # Uses the registered kernel with label "v_2"
4922                         # for the Foo op.
4923            with g.kernel_label_map({"Foo": "v_3"}):
4924              f_3 = Foo()  # Uses the registered kernel with label "v_3"
4925                           # for the Foo op.
4926              with g.kernel_label_map({"Foo": ""}):
4927                f_4 = Foo()  # Uses the default registered kernel
4928                             # for the Foo op.
4929
4930    Args:
4931      op_to_kernel_label_map: A dictionary mapping op type strings to kernel
4932        label strings.
4933
4934    Returns:
4935      A context manager that sets the kernel label to be used for one or more
4936      ops created in that context.
4937
4938    Raises:
4939      TypeError: If op_to_kernel_label_map is not a dictionary mapping
4940        strings to strings.
4941    """
4942    if not isinstance(op_to_kernel_label_map, dict):
4943      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4944                      "strings to strings")
4945    # The saved_labels dictionary stores any currently-set labels that
4946    # will be overridden by this context manager.
4947    saved_labels = {}
4948    # Install the given label
4949    for op_type, label in op_to_kernel_label_map.items():
4950      if not (isinstance(op_type, six.string_types) and
4951              isinstance(label, six.string_types)):
4952        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4953                        "strings to strings")
4954      try:
4955        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
4956      except KeyError:
4957        pass
4958      self._op_to_kernel_label_map[op_type] = label
4959    try:
4960      yield  # The code within the context runs here.
4961    finally:
4962      # Remove the labels set for this context, and restore any saved labels.
4963      for op_type, label in op_to_kernel_label_map.items():
4964        try:
4965          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
4966        except KeyError:
4967          del self._op_to_kernel_label_map[op_type]
4968
4969  # pylint: enable=g-doc-return-or-yield
4970
4971  @tf_contextlib.contextmanager
4972  def _override_gradient_function(self, gradient_function_map):
4973    """Specify gradient function for the given op type."""
4974
4975    # This is an internal API and we don't need nested context for this.
4976    # TODO(mdan): make it a proper context manager.
4977    assert not self._gradient_function_map
4978    self._gradient_function_map = gradient_function_map
4979    try:
4980      yield
4981    finally:
4982      self._gradient_function_map = {}
4983
4984  # pylint: disable=g-doc-return-or-yield
4985  @tf_contextlib.contextmanager
4986  def gradient_override_map(self, op_type_map):
4987    """EXPERIMENTAL: A context manager for overriding gradient functions.
4988
4989    This context manager can be used to override the gradient function
4990    that will be used for ops within the scope of the context.
4991
4992    For example:
4993
4994    ```python
4995    @tf.RegisterGradient("CustomSquare")
4996    def _custom_square_grad(op, grad):
4997      # ...
4998
4999    with tf.Graph().as_default() as g:
5000      c = tf.constant(5.0)
5001      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
5002      with g.gradient_override_map({"Square": "CustomSquare"}):
5003        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
5004                              # gradient of s_2.
5005    ```
5006
5007    Args:
5008      op_type_map: A dictionary mapping op type strings to alternative op type
5009        strings.
5010
5011    Returns:
5012      A context manager that sets the alternative op type to be used for one
5013      or more ops created in that context.
5014
5015    Raises:
5016      TypeError: If `op_type_map` is not a dictionary mapping strings to
5017        strings.
5018    """
5019    if not isinstance(op_type_map, dict):
5020      raise TypeError("op_type_map must be a dictionary mapping "
5021                      "strings to strings")
5022    # The saved_mappings dictionary stores any currently-set mappings that
5023    # will be overridden by this context manager.
5024    saved_mappings = {}
5025    # Install the given label
5026    for op_type, mapped_op_type in op_type_map.items():
5027      if not (isinstance(op_type, six.string_types) and
5028              isinstance(mapped_op_type, six.string_types)):
5029        raise TypeError("op_type_map must be a dictionary mapping "
5030                        "strings to strings")
5031      try:
5032        saved_mappings[op_type] = self._gradient_override_map[op_type]
5033      except KeyError:
5034        pass
5035      self._gradient_override_map[op_type] = mapped_op_type
5036    try:
5037      yield  # The code within the context runs here.
5038    finally:
5039      # Remove the labels set for this context, and restore any saved labels.
5040      for op_type, mapped_op_type in op_type_map.items():
5041        try:
5042          self._gradient_override_map[op_type] = saved_mappings[op_type]
5043        except KeyError:
5044          del self._gradient_override_map[op_type]
5045
5046  # pylint: enable=g-doc-return-or-yield
5047
5048  def prevent_feeding(self, tensor):
5049    """Marks the given `tensor` as unfeedable in this graph."""
5050    self._unfeedable_tensors.add(tensor)
5051
5052  def is_feedable(self, tensor):
5053    """Returns `True` if and only if `tensor` is feedable."""
5054    return tensor not in self._unfeedable_tensors
5055
5056  def prevent_fetching(self, op):
5057    """Marks the given `op` as unfetchable in this graph."""
5058    self._unfetchable_ops.add(op)
5059
5060  def is_fetchable(self, tensor_or_op):
5061    """Returns `True` if and only if `tensor_or_op` is fetchable."""
5062    if isinstance(tensor_or_op, Tensor):
5063      return tensor_or_op.op not in self._unfetchable_ops
5064    else:
5065      return tensor_or_op not in self._unfetchable_ops
5066
5067  def switch_to_thread_local(self):
5068    """Make device, colocation and dependencies stacks thread-local.
5069
5070    Device, colocation and dependencies stacks are not thread-local be default.
5071    If multiple threads access them, then the state is shared.  This means that
5072    one thread may affect the behavior of another thread.
5073
5074    After this method is called, the stacks become thread-local.  If multiple
5075    threads access them, then the state is not shared.  Each thread uses its own
5076    value; a thread doesn't affect other threads by mutating such a stack.
5077
5078    The initial value for every thread's stack is set to the current value
5079    of the stack when `switch_to_thread_local()` was first called.
5080    """
5081    if not self._stack_state_is_thread_local:
5082      self._stack_state_is_thread_local = True
5083
5084  @property
5085  def _device_function_stack(self):
5086    if self._stack_state_is_thread_local:
5087      # This may be called from a thread where device_function_stack doesn't yet
5088      # exist.
5089      # pylint: disable=protected-access
5090      if not hasattr(self._thread_local, "_device_function_stack"):
5091        stack_copy_for_this_thread = self._graph_device_function_stack.copy()
5092        self._thread_local._device_function_stack = stack_copy_for_this_thread
5093      return self._thread_local._device_function_stack
5094      # pylint: enable=protected-access
5095    else:
5096      return self._graph_device_function_stack
5097
5098  @property
5099  def _device_functions_outer_to_inner(self):
5100    user_device_specs = self._device_function_stack.peek_objs()
5101    device_functions = [spec.function for spec in user_device_specs]
5102    device_functions_outer_to_inner = list(reversed(device_functions))
5103    return device_functions_outer_to_inner
5104
5105  def _snapshot_device_function_stack_metadata(self):
5106    """Return device function stack as a list of TraceableObjects.
5107
5108    Returns:
5109      [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
5110      member is a displayable name for the user's argument to Graph.device, and
5111      the filename and lineno members point to the code location where
5112      Graph.device was called directly or indirectly by the user.
5113    """
5114    snapshot = []
5115    for obj in self._device_function_stack.peek_traceable_objs():
5116      obj_copy = obj.copy_metadata()
5117      obj_copy.obj = obj.obj.display_name
5118      snapshot.append(obj_copy)
5119    return snapshot
5120
5121  @_device_function_stack.setter
5122  def _device_function_stack(self, device_function_stack):
5123    if self._stack_state_is_thread_local:
5124      # pylint: disable=protected-access
5125      self._thread_local._device_function_stack = device_function_stack
5126      # pylint: enable=protected-access
5127    else:
5128      self._graph_device_function_stack = device_function_stack
5129
5130  @property
5131  def _colocation_stack(self):
5132    """Return thread-local copy of colocation stack."""
5133    if self._stack_state_is_thread_local:
5134      # This may be called from a thread where colocation_stack doesn't yet
5135      # exist.
5136      # pylint: disable=protected-access
5137      if not hasattr(self._thread_local, "_colocation_stack"):
5138        stack_copy_for_this_thread = self._graph_colocation_stack.copy()
5139        self._thread_local._colocation_stack = stack_copy_for_this_thread
5140      return self._thread_local._colocation_stack
5141      # pylint: enable=protected-access
5142    else:
5143      return self._graph_colocation_stack
5144
5145  def _snapshot_colocation_stack_metadata(self):
5146    """Return colocation stack metadata as a dictionary."""
5147    return {
5148        traceable_obj.obj.name: traceable_obj.copy_metadata()
5149        for traceable_obj in self._colocation_stack.peek_traceable_objs()
5150    }
5151
5152  @_colocation_stack.setter
5153  def _colocation_stack(self, colocation_stack):
5154    if self._stack_state_is_thread_local:
5155      # pylint: disable=protected-access
5156      self._thread_local._colocation_stack = colocation_stack
5157      # pylint: enable=protected-access
5158    else:
5159      self._graph_colocation_stack = colocation_stack
5160
5161  @property
5162  def _control_dependencies_stack(self):
5163    if self._stack_state_is_thread_local:
5164      # This may be called from a thread where control_dependencies_stack
5165      # doesn't yet exist.
5166      if not hasattr(self._thread_local, "_control_dependencies_stack"):
5167        self._thread_local._control_dependencies_stack = (
5168            self._graph_control_dependencies_stack[:])
5169      return self._thread_local._control_dependencies_stack
5170    else:
5171      return self._graph_control_dependencies_stack
5172
5173  @_control_dependencies_stack.setter
5174  def _control_dependencies_stack(self, control_dependencies):
5175    if self._stack_state_is_thread_local:
5176      self._thread_local._control_dependencies_stack = control_dependencies
5177    else:
5178      self._graph_control_dependencies_stack = control_dependencies
5179
5180  @property
5181  def _distribution_strategy_stack(self):
5182    """A stack to maintain distribution strategy context for each thread."""
5183    if not hasattr(self._thread_local, "_distribution_strategy_stack"):
5184      self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
5185    return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
5186
5187  @_distribution_strategy_stack.setter
5188  def _distribution_strategy_stack(self, _distribution_strategy_stack):
5189    self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
5190        _distribution_strategy_stack)
5191
5192  @property
5193  def _global_distribute_strategy_scope(self):
5194    """For implementing `tf.distribute.set_strategy()`."""
5195    if not hasattr(self._thread_local, "distribute_strategy_scope"):
5196      self._thread_local.distribute_strategy_scope = None
5197    return self._thread_local.distribute_strategy_scope
5198
5199  @_global_distribute_strategy_scope.setter
5200  def _global_distribute_strategy_scope(self, distribute_strategy_scope):
5201    self._thread_local.distribute_strategy_scope = (distribute_strategy_scope)
5202
5203  def _mutation_lock(self):
5204    """Returns a lock to guard code that creates & mutates ops.
5205
5206    See the comment for self._group_lock for more info.
5207    """
5208    return self._group_lock.group(_MUTATION_LOCK_GROUP)
5209
5210  def _session_run_lock(self):
5211    """Returns a lock to guard code for Session.run.
5212
5213    See the comment for self._group_lock for more info.
5214    """
5215    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5216
5217
5218# TODO(agarwal): currently device directives in an outer eager scope will not
5219# apply to inner graph mode code. Fix that.
5220
5221
5222@tf_export(v1=["device"])
5223def device(device_name_or_function):
5224  """Wrapper for `Graph.device()` using the default graph.
5225
5226  See `tf.Graph.device` for more details.
5227
5228  Args:
5229    device_name_or_function: The device name or function to use in the context.
5230
5231  Returns:
5232    A context manager that specifies the default device to use for newly
5233    created ops.
5234
5235  Raises:
5236    RuntimeError: If eager execution is enabled and a function is passed in.
5237  """
5238  if context.executing_eagerly():
5239    if callable(device_name_or_function):
5240      raise RuntimeError(
5241          "tf.device does not support functions when eager execution "
5242          "is enabled.")
5243    return context.device(device_name_or_function)
5244  elif executing_eagerly_outside_functions():
5245    @tf_contextlib.contextmanager
5246    def combined(device_name_or_function):
5247      with get_default_graph().device(device_name_or_function):
5248        if not callable(device_name_or_function):
5249          with context.device(device_name_or_function):
5250            yield
5251        else:
5252          yield
5253    return combined(device_name_or_function)
5254  else:
5255    return get_default_graph().device(device_name_or_function)
5256
5257
5258@tf_export("device", v1=[])
5259def device_v2(device_name):
5260  """Specifies the device for ops created/executed in this context.
5261
5262  This function specifies the device to be used for ops created/executed in a
5263  particular context. Nested contexts will inherit and also create/execute
5264  their ops on the specified device. If a specific device is not required,
5265  consider not using this function so that a device can be automatically
5266  assigned.  In general the use of this function is optional. `device_name` can
5267  be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially
5268  specified, containing only a subset of the "/"-separated fields. Any fields
5269  which are specified will override device annotations from outer scopes.
5270
5271  For example:
5272
5273  ```python
5274  with tf.device('/job:foo'):
5275    # ops created here have devices with /job:foo
5276    with tf.device('/job:bar/task:0/device:gpu:2'):
5277      # ops created here have the fully specified device above
5278    with tf.device('/device:gpu:1'):
5279      # ops created here have the device '/job:foo/device:gpu:1'
5280  ```
5281
5282  Args:
5283    device_name: The device name to use in the context.
5284
5285  Returns:
5286    A context manager that specifies the default device to use for newly
5287    created ops.
5288
5289  Raises:
5290    RuntimeError: If a function is passed in.
5291  """
5292  if callable(device_name):
5293    raise RuntimeError("tf.device does not support functions.")
5294  return device(device_name)
5295
5296
5297@tf_export(v1=["container"])
5298def container(container_name):
5299  """Wrapper for `Graph.container()` using the default graph.
5300
5301  Args:
5302    container_name: The container string to use in the context.
5303
5304  Returns:
5305    A context manager that specifies the default container to use for newly
5306    created stateful ops.
5307  """
5308  return get_default_graph().container(container_name)
5309
5310
5311def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5312  if context.executing_eagerly():
5313    if op is not None:
5314      if not hasattr(op, "device"):
5315        op = internal_convert_to_tensor_or_indexed_slices(op)
5316      return device(op.device)
5317    else:
5318      return NullContextmanager()
5319  else:
5320    default_graph = get_default_graph()
5321    if isinstance(op, EagerTensor):
5322      if default_graph.building_function:
5323        return default_graph.device(op.device)
5324      else:
5325        raise ValueError("Encountered an Eager-defined Tensor during graph "
5326                         "construction, but a function was not being built.")
5327    return default_graph._colocate_with_for_gradient(
5328        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5329
5330
5331# Internal interface to colocate_with. colocate_with has been deprecated from
5332# public API. There are still a few internal uses of colocate_with. Add internal
5333# only API for those uses to avoid deprecation warning.
5334def colocate_with(op, ignore_existing=False):
5335  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5336
5337
5338@deprecation.deprecated(
5339    date=None, instructions="Colocations handled automatically by placer.")
5340@tf_export(v1=["colocate_with"])
5341def _colocate_with(op, ignore_existing=False):
5342  return colocate_with(op, ignore_existing)
5343
5344
5345@tf_export("control_dependencies")
5346def control_dependencies(control_inputs):
5347  """Wrapper for `Graph.control_dependencies()` using the default graph.
5348
5349  See `tf.Graph.control_dependencies` for more details.
5350
5351  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
5352  this method, as ops execute in the expected order thanks to automatic control
5353  dependencies.* Only use `tf.control_dependencies` when working with v1
5354  `tf.Graph` code.
5355
5356  When eager execution is enabled, any callable object in the `control_inputs`
5357  list will be called.
5358
5359  Args:
5360    control_inputs: A list of `Operation` or `Tensor` objects which must be
5361      executed or computed before running the operations defined in the context.
5362      Can also be `None` to clear the control dependencies. If eager execution
5363      is enabled, any callable object in the `control_inputs` list will be
5364      called.
5365
5366  Returns:
5367   A context manager that specifies control dependencies for all
5368   operations constructed within the context.
5369  """
5370  if context.executing_eagerly():
5371    if control_inputs:
5372      # Execute any pending callables.
5373      for control in control_inputs:
5374        if callable(control):
5375          control()
5376    return NullContextmanager()
5377  else:
5378    return get_default_graph().control_dependencies(control_inputs)
5379
5380
5381class _DefaultStack(threading.local):
5382  """A thread-local stack of objects for providing implicit defaults."""
5383
5384  def __init__(self):
5385    super(_DefaultStack, self).__init__()
5386    self._enforce_nesting = True
5387    self.stack = []
5388
5389  def get_default(self):
5390    return self.stack[-1] if self.stack else None
5391
5392  def reset(self):
5393    self.stack = []
5394
5395  def is_cleared(self):
5396    return not self.stack
5397
5398  @property
5399  def enforce_nesting(self):
5400    return self._enforce_nesting
5401
5402  @enforce_nesting.setter
5403  def enforce_nesting(self, value):
5404    self._enforce_nesting = value
5405
5406  @tf_contextlib.contextmanager
5407  def get_controller(self, default):
5408    """A context manager for manipulating a default stack."""
5409    self.stack.append(default)
5410    try:
5411      yield default
5412    finally:
5413      # stack may be empty if reset() was called
5414      if self.stack:
5415        if self._enforce_nesting:
5416          if self.stack[-1] is not default:
5417            raise AssertionError(
5418                "Nesting violated for default stack of %s objects" %
5419                type(default))
5420          self.stack.pop()
5421        else:
5422          self.stack.remove(default)
5423
5424
5425_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
5426
5427
5428def default_session(session):
5429  """Python "with" handler for defining a default session.
5430
5431  This function provides a means of registering a session for handling
5432  Tensor.eval() and Operation.run() calls. It is primarily intended for use
5433  by session.Session, but can be used with any object that implements
5434  the Session.run() interface.
5435
5436  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
5437  invocations within the scope of a block should be executed by a particular
5438  session.
5439
5440  The default session applies to the current thread only, so it is always
5441  possible to inspect the call stack and determine the scope of a default
5442  session. If you create a new thread, and wish to use the default session
5443  in that thread, you must explicitly add a "with ops.default_session(sess):"
5444  block in that thread's function.
5445
5446  Example:
5447    The following code examples are equivalent:
5448
5449    # 1. Using the Session object directly:
5450    sess = ...
5451    c = tf.constant(5.0)
5452    sess.run(c)
5453
5454    # 2. Using default_session():
5455    sess = ...
5456    with ops.default_session(sess):
5457      c = tf.constant(5.0)
5458      result = c.eval()
5459
5460    # 3. Overriding default_session():
5461    sess = ...
5462    with ops.default_session(sess):
5463      c = tf.constant(5.0)
5464      with ops.default_session(...):
5465        c.eval(session=sess)
5466
5467  Args:
5468    session: The session to be installed as the default session.
5469
5470  Returns:
5471    A context manager for the default session.
5472  """
5473  return _default_session_stack.get_controller(session)
5474
5475
5476@tf_export(v1=["get_default_session"])
5477def get_default_session():
5478  """Returns the default session for the current thread.
5479
5480  The returned `Session` will be the innermost session on which a
5481  `Session` or `Session.as_default()` context has been entered.
5482
5483  NOTE: The default session is a property of the current thread. If you
5484  create a new thread, and wish to use the default session in that
5485  thread, you must explicitly add a `with sess.as_default():` in that
5486  thread's function.
5487
5488  Returns:
5489    The default `Session` being used in the current thread.
5490  """
5491  return _default_session_stack.get_default()
5492
5493
5494def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5495  """Uses the default session to evaluate one or more tensors.
5496
5497  Args:
5498    tensors: A single Tensor, or a list of Tensor objects.
5499    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5500      numpy ndarrays, TensorProtos, or strings.
5501    graph: The graph in which the tensors are defined.
5502    session: (Optional) A different session to use to evaluate "tensors".
5503
5504  Returns:
5505    Either a single numpy ndarray if "tensors" is a single tensor; or a list
5506    of numpy ndarrays that each correspond to the respective element in
5507    "tensors".
5508
5509  Raises:
5510    ValueError: If no default session is available; the default session
5511      does not have "graph" as its graph; or if "session" is specified,
5512      and it does not have "graph" as its graph.
5513  """
5514  if session is None:
5515    session = get_default_session()
5516    if session is None:
5517      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5518                       "session is registered. Use `with "
5519                       "sess.as_default()` or pass an explicit session to "
5520                       "`eval(session=sess)`")
5521    if session.graph is not graph:
5522      raise ValueError("Cannot use the default session to evaluate tensor: "
5523                       "the tensor's graph is different from the session's "
5524                       "graph. Pass an explicit session to "
5525                       "`eval(session=sess)`.")
5526  else:
5527    if session.graph is not graph:
5528      raise ValueError("Cannot use the given session to evaluate tensor: "
5529                       "the tensor's graph is different from the session's "
5530                       "graph.")
5531  return session.run(tensors, feed_dict)
5532
5533
5534def _run_using_default_session(operation, feed_dict, graph, session=None):
5535  """Uses the default session to run "operation".
5536
5537  Args:
5538    operation: The Operation to be run.
5539    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5540      numpy ndarrays, TensorProtos, or strings.
5541    graph: The graph in which "operation" is defined.
5542    session: (Optional) A different session to use to run "operation".
5543
5544  Raises:
5545    ValueError: If no default session is available; the default session
5546      does not have "graph" as its graph; or if "session" is specified,
5547      and it does not have "graph" as its graph.
5548  """
5549  if session is None:
5550    session = get_default_session()
5551    if session is None:
5552      raise ValueError("Cannot execute operation using `run()`: No default "
5553                       "session is registered. Use `with "
5554                       "sess.as_default():` or pass an explicit session to "
5555                       "`run(session=sess)`")
5556    if session.graph is not graph:
5557      raise ValueError("Cannot use the default session to execute operation: "
5558                       "the operation's graph is different from the "
5559                       "session's graph. Pass an explicit session to "
5560                       "run(session=sess).")
5561  else:
5562    if session.graph is not graph:
5563      raise ValueError("Cannot use the given session to execute operation: "
5564                       "the operation's graph is different from the session's "
5565                       "graph.")
5566  session.run(operation, feed_dict)
5567
5568
5569class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
5570  """A thread-local stack of objects for providing an implicit default graph."""
5571
5572  def __init__(self):
5573    super(_DefaultGraphStack, self).__init__()
5574    self._global_default_graph = None
5575
5576  def get_default(self):
5577    """Override that returns a global default if the stack is empty."""
5578    if self.stack:
5579      return self.stack[-1]
5580    elif self._global_default_graph:
5581      return self._global_default_graph
5582    else:
5583      self._global_default_graph = Graph()
5584      return self._global_default_graph
5585
5586  def _GetGlobalDefaultGraph(self):
5587    if self._global_default_graph is None:
5588      # TODO(mrry): Perhaps log that the default graph is being used, or set
5589      #   provide some other feedback to prevent confusion when a mixture of
5590      #   the global default graph and an explicit graph are combined in the
5591      #   same process.
5592      self._global_default_graph = Graph()
5593    return self._global_default_graph
5594
5595  def reset(self):
5596    super(_DefaultGraphStack, self).reset()
5597    self._global_default_graph = None
5598
5599  @tf_contextlib.contextmanager
5600  def get_controller(self, default):
5601    context.context().context_switches.push(default.building_function,
5602                                            default.as_default,
5603                                            default._device_function_stack)
5604    try:
5605      with super(_DefaultGraphStack,
5606                 self).get_controller(default) as g, context.graph_mode():
5607        yield g
5608    finally:
5609      # If an exception is raised here it may be hiding a related exception in
5610      # the try-block (just above).
5611      context.context().context_switches.pop()
5612
5613
5614_default_graph_stack = _DefaultGraphStack()
5615
5616
5617# Shared helper used in init_scope and executing_eagerly_outside_functions
5618# to obtain the outermost context that is not building a function, and the
5619# innermost non empty device stack.
5620def _get_outer_context_and_inner_device_stack():
5621  """Get the outermost context not building a function."""
5622  default_graph = get_default_graph()
5623  outer_context = None
5624  innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
5625
5626  if not _default_graph_stack.stack:
5627    # If the default graph stack is empty, then we cannot be building a
5628    # function. Install the global graph (which, in this case, is also the
5629    # default graph) as the outer context.
5630    if default_graph.building_function:
5631      raise RuntimeError("The global graph is building a function.")
5632    outer_context = default_graph.as_default
5633  else:
5634    # Find a context that is not building a function.
5635    for stack_entry in reversed(context.context().context_switches.stack):
5636      if not innermost_nonempty_device_stack:
5637        innermost_nonempty_device_stack = stack_entry.device_stack
5638      if not stack_entry.is_building_function:
5639        outer_context = stack_entry.enter_context_fn
5640        break
5641
5642    if outer_context is None:
5643      # As a last resort, obtain the global default graph; this graph doesn't
5644      # necessarily live on the graph stack (and hence it doesn't necessarily
5645      # live on the context stack), but it is stored in the graph stack's
5646      # encapsulating object.
5647      outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
5648
5649  if outer_context is None:
5650    # Sanity check; this shouldn't be triggered.
5651    raise RuntimeError("All graphs are building functions, and no "
5652                       "eager context was previously active.")
5653
5654  return outer_context, innermost_nonempty_device_stack
5655
5656
5657# pylint: disable=g-doc-return-or-yield,line-too-long
5658@tf_export("init_scope")
5659@tf_contextlib.contextmanager
5660def init_scope():
5661  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5662
5663  There is often a need to lift variable initialization ops out of control-flow
5664  scopes, function-building graphs, and gradient tapes. Entering an
5665  `init_scope` is a mechanism for satisfying these desiderata. In particular,
5666  entering an `init_scope` has three effects:
5667
5668    (1) All control dependencies are cleared the moment the scope is entered;
5669        this is equivalent to entering the context manager returned from
5670        `control_dependencies(None)`, which has the side-effect of exiting
5671        control-flow scopes like `tf.cond` and `tf.while_loop`.
5672
5673    (2) All operations that are created while the scope is active are lifted
5674        into the lowest context on the `context_stack` that is not building a
5675        graph function. Here, a context is defined as either a graph or an eager
5676        context. Every context switch, i.e., every installation of a graph as
5677        the default graph and every switch into eager mode, is logged in a
5678        thread-local stack called `context_switches`; the log entry for a
5679        context switch is popped from the stack when the context is exited.
5680        Entering an `init_scope` is equivalent to crawling up
5681        `context_switches`, finding the first context that is not building a
5682        graph function, and entering it. A caveat is that if graph mode is
5683        enabled but the default graph stack is empty, then entering an
5684        `init_scope` will simply install a fresh graph as the default one.
5685
5686    (3) The gradient tape is paused while the scope is active.
5687
5688  When eager execution is enabled, code inside an init_scope block runs with
5689  eager execution enabled even when tracing a `tf.function`. For example:
5690
5691  ```python
5692  tf.compat.v1.enable_eager_execution()
5693
5694  @tf.function
5695  def func():
5696    # A function constructs TensorFlow graphs,
5697    # it does not execute eagerly.
5698    assert not tf.executing_eagerly()
5699    with tf.init_scope():
5700      # Initialization runs with eager execution enabled
5701      assert tf.executing_eagerly()
5702  ```
5703
5704  Raises:
5705    RuntimeError: if graph state is incompatible with this initialization.
5706  """
5707  # pylint: enable=g-doc-return-or-yield,line-too-long
5708
5709  if context.executing_eagerly():
5710    # Fastpath.
5711    with tape.stop_recording():
5712      yield
5713  else:
5714    # Retrieve the active name scope: entering an `init_scope` preserves
5715    # the name scope of the current context.
5716    scope = get_default_graph().get_name_scope()
5717    if scope and scope[-1] != "/":
5718      # Names that end with trailing slashes are treated by `name_scope` as
5719      # absolute.
5720      scope = scope + "/"
5721
5722    outer_context, innermost_nonempty_device_stack = (
5723        _get_outer_context_and_inner_device_stack())
5724
5725    outer_graph = None
5726    outer_device_stack = None
5727    try:
5728      with outer_context(), name_scope(
5729          scope, skip_on_eager=False), control_dependencies(
5730              None), tape.stop_recording():
5731        context_manager = NullContextmanager
5732        context_manager_input = None
5733        if not context.executing_eagerly():
5734          # The device stack is preserved when lifting into a graph. Eager
5735          # execution doesn't implement device stacks and in particular it
5736          # doesn't support device functions, so in general it's not possible
5737          # to do the same when lifting into the eager context.
5738          outer_graph = get_default_graph()
5739          outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
5740          outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
5741        elif innermost_nonempty_device_stack is not None:
5742          for device_spec in innermost_nonempty_device_stack.peek_objs():
5743            if device_spec.function is None:
5744              break
5745            if device_spec.raw_string:
5746              context_manager = context.device
5747              context_manager_input = device_spec.raw_string
5748              break
5749            # It is currently not possible to have a device function in V2,
5750            # but in V1 we are unable to apply device functions in eager mode.
5751            # This means that we will silently skip some of the entries on the
5752            # device stack in V1 + eager mode.
5753
5754        with context_manager(context_manager_input):
5755          yield
5756    finally:
5757      # If an exception is raised here it may be hiding a related exception in
5758      # try-block (just above).
5759      if outer_graph is not None:
5760        outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
5761
5762
5763@tf_export(v1=["executing_eagerly_outside_functions"])
5764def executing_eagerly_outside_functions():
5765  """Returns True if executing eagerly, even if inside a graph function.
5766
5767  This function will check the outermost context for the program and see if
5768  it is in eager mode. It is useful comparing to `tf.executing_eagerly()`,
5769  which checks the current context and will return `False` within a
5770  `tf.function` body. It can be used to build library that behave differently
5771  in eager runtime and v1 session runtime (deprecated).
5772
5773  Example:
5774
5775  >>> tf.compat.v1.enable_eager_execution()
5776  >>> @tf.function
5777  ... def func():
5778  ...   # A function constructs TensorFlow graphs, it does not execute eagerly,
5779  ...   # but the outer most context is still eager.
5780  ...   assert not tf.executing_eagerly()
5781  ...   return tf.compat.v1.executing_eagerly_outside_functions()
5782  >>> func()
5783  <tf.Tensor: shape=(), dtype=bool, numpy=True>
5784
5785  Returns:
5786    boolean, whether the outermost context is in eager mode.
5787  """
5788  if context.executing_eagerly():
5789    return True
5790  else:
5791    outer_context, _ = _get_outer_context_and_inner_device_stack()
5792    with outer_context():
5793      return context.executing_eagerly()
5794
5795
5796@tf_export("inside_function", v1=[])
5797def inside_function():
5798  """Indicates whether the caller code is executing inside a `tf.function`.
5799
5800  Returns:
5801    Boolean, True if the caller code is executing inside a `tf.function`
5802    rather than eagerly.
5803
5804  Example:
5805
5806  >>> tf.inside_function()
5807  False
5808  >>> @tf.function
5809  ... def f():
5810  ...   print(tf.inside_function())
5811  >>> f()
5812  True
5813  """
5814  return get_default_graph().building_function
5815
5816
5817@tf_export(v1=["enable_eager_execution"])
5818def enable_eager_execution(config=None, device_policy=None,
5819                           execution_mode=None):
5820  """Enables eager execution for the lifetime of this program.
5821
5822  Eager execution provides an imperative interface to TensorFlow. With eager
5823  execution enabled, TensorFlow functions execute operations immediately (as
5824  opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`)
5825  and
5826  return concrete values (as opposed to symbolic references to a node in a
5827  computational graph).
5828
5829  For example:
5830
5831  ```python
5832  tf.compat.v1.enable_eager_execution()
5833
5834  # After eager execution is enabled, operations are executed as they are
5835  # defined and Tensor objects hold concrete values, which can be accessed as
5836  # numpy.ndarray`s through the numpy() method.
5837  assert tf.multiply(6, 7).numpy() == 42
5838  ```
5839
5840  Eager execution cannot be enabled after TensorFlow APIs have been used to
5841  create or execute graphs. It is typically recommended to invoke this function
5842  at program startup and not in a library (as most libraries should be usable
5843  both with and without eager execution).
5844
5845  Args:
5846    config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the
5847      environment in which operations are executed. Note that
5848      `tf.compat.v1.ConfigProto` is also used to configure graph execution (via
5849      `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto`
5850      are not implemented (or are irrelevant) when eager execution is enabled.
5851    device_policy: (Optional.) Policy controlling how operations requiring
5852      inputs on a specific device (e.g., a GPU 0) handle inputs on a different
5853      device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will
5854      be picked automatically. The value picked may change between TensorFlow
5855      releases.
5856      Valid values:
5857      - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
5858        placement is not correct.
5859      - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
5860        on the right device but logs a warning.
5861      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
5862        Note that this may hide performance problems as there is no notification
5863        provided when operations are blocked on the tensor being copied between
5864        devices.
5865      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
5866        int32 tensors, raising errors on the other ones.
5867    execution_mode: (Optional.) Policy controlling how operations dispatched are
5868      actually executed. When set to None, an appropriate value will be picked
5869      automatically. The value picked may change between TensorFlow releases.
5870      Valid values:
5871      - tf.contrib.eager.SYNC: executes each operation synchronously.
5872      - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
5873        operations may return "non-ready" handles.
5874
5875  Raises:
5876    ValueError: If eager execution is enabled after creating/executing a
5877     TensorFlow graph, or if options provided conflict with a previous call
5878     to this function.
5879  """
5880  _api_usage_gauge.get_cell().set(True)
5881  if context.default_execution_mode != context.EAGER_MODE:
5882    return enable_eager_execution_internal(
5883        config=config,
5884        device_policy=device_policy,
5885        execution_mode=execution_mode,
5886        server_def=None)
5887
5888
5889@tf_export(v1=["disable_eager_execution"])
5890def disable_eager_execution():
5891  """Disables eager execution.
5892
5893  This function can only be called before any Graphs, Ops, or Tensors have been
5894  created. It can be used at the beginning of the program for complex migration
5895  projects from TensorFlow 1.x to 2.x.
5896  """
5897  _api_usage_gauge.get_cell().set(False)
5898  context.default_execution_mode = context.GRAPH_MODE
5899  c = context.context_safe()
5900  if c is not None:
5901    c._thread_local_data.is_eager = False  # pylint: disable=protected-access
5902
5903
5904def enable_eager_execution_internal(config=None,
5905                                    device_policy=None,
5906                                    execution_mode=None,
5907                                    server_def=None):
5908  """Enables eager execution for the lifetime of this program.
5909
5910  Most of the doc string for enable_eager_execution is relevant here as well.
5911
5912  Args:
5913    config: See enable_eager_execution doc string
5914    device_policy: See enable_eager_execution doc string
5915    execution_mode: See enable_eager_execution doc string
5916    server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on
5917      remote devices. GrpcServers need to be started by creating an identical
5918      server_def to this, and setting the appropriate task_indexes, so that the
5919      servers can communicate. It will then be possible to execute operations on
5920      remote devices.
5921
5922  Raises:
5923    ValueError
5924
5925  """
5926  if config is not None and not isinstance(config, config_pb2.ConfigProto):
5927    raise TypeError("config must be a tf.ConfigProto, but got %s" %
5928                    type(config))
5929  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
5930                           context.DEVICE_PLACEMENT_WARN,
5931                           context.DEVICE_PLACEMENT_SILENT,
5932                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
5933    raise ValueError(
5934        "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
5935    )
5936  if execution_mode not in (None, context.SYNC, context.ASYNC):
5937    raise ValueError(
5938        "execution_mode must be one of None, tf.contrib.eager.SYNC, "
5939        "tf.contrib.eager.ASYNC")
5940  if context.default_execution_mode == context.GRAPH_MODE:
5941    graph_mode_has_been_used = (
5942        _default_graph_stack._global_default_graph is not None)  # pylint: disable=protected-access
5943    if graph_mode_has_been_used:
5944      raise ValueError(
5945          "tf.enable_eager_execution must be called at program startup.")
5946  context.default_execution_mode = context.EAGER_MODE
5947  # pylint: disable=protected-access
5948  with context._context_lock:
5949    if context._context is None:
5950      context._set_context_locked(context.Context(
5951          config=config,
5952          device_policy=device_policy,
5953          execution_mode=execution_mode,
5954          server_def=server_def))
5955    elif ((config is not None and config is not context._context._config) or
5956          (device_policy is not None and
5957           device_policy is not context._context._device_policy) or
5958          (execution_mode is not None and
5959           execution_mode is not context._context._execution_mode)):
5960      raise ValueError(
5961          "Trying to change the options of an active eager"
5962          " execution. Context config: %s, specified config:"
5963          " %s. Context device policy: %s, specified device"
5964          " policy: %s. Context execution mode: %s, "
5965          " specified execution mode %s." %
5966          (context._context._config, config, context._context._device_policy,
5967           device_policy, context._context._execution_mode, execution_mode))
5968    else:
5969      # We already created everything, so update the thread local data.
5970      context._context._thread_local_data.is_eager = True
5971
5972  # Monkey patch to get rid of an unnecessary conditional since the context is
5973  # now initialized.
5974  context.context = context.context_safe
5975
5976
5977def eager_run(main=None, argv=None):
5978  """Runs the program with an optional main function and argv list.
5979
5980  The program will run with eager execution enabled.
5981
5982  Example:
5983  ```python
5984  import tensorflow as tf
5985  # Import subject to future changes:
5986  from tensorflow.contrib.eager.python import tfe
5987
5988  def main(_):
5989    u = tf.constant(6.0)
5990    v = tf.constant(7.0)
5991    print(u * v)
5992
5993  if __name__ == "__main__":
5994    tfe.run()
5995  ```
5996
5997  Args:
5998    main: the main function to run.
5999    argv: the arguments to pass to it.
6000  """
6001  enable_eager_execution()
6002  app.run(main, argv)
6003
6004
6005@tf_export(v1=["reset_default_graph"])
6006def reset_default_graph():
6007  """Clears the default graph stack and resets the global default graph.
6008
6009  NOTE: The default graph is a property of the current thread. This
6010  function applies only to the current thread.  Calling this function while
6011  a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will
6012  result in undefined
6013  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
6014  after calling this function will result in undefined behavior.
6015  Raises:
6016    AssertionError: If this function is called within a nested graph.
6017  """
6018  if not _default_graph_stack.is_cleared():
6019    raise AssertionError("Do not use tf.reset_default_graph() to clear "
6020                         "nested graphs. If you need a cleared graph, "
6021                         "exit the nesting and create a new graph.")
6022  _default_graph_stack.reset()
6023
6024
6025@tf_export(v1=["get_default_graph"])
6026def get_default_graph():
6027  """Returns the default graph for the current thread.
6028
6029  The returned graph will be the innermost graph on which a
6030  `Graph.as_default()` context has been entered, or a global default
6031  graph if none has been explicitly created.
6032
6033  NOTE: The default graph is a property of the current thread. If you
6034  create a new thread, and wish to use the default graph in that
6035  thread, you must explicitly add a `with g.as_default():` in that
6036  thread's function.
6037
6038  Returns:
6039    The default `Graph` being used in the current thread.
6040  """
6041  return _default_graph_stack.get_default()
6042
6043
6044def has_default_graph():
6045  """Returns True if there is a default graph."""
6046  return len(_default_graph_stack.stack) >= 1
6047
6048
6049# Exported due to b/171079555
6050@tf_export("__internal__.get_name_scope", v1=[])
6051def get_name_scope():
6052  """Returns the current name scope in the default_graph.
6053
6054  For example:
6055
6056  ```python
6057  with tf.name_scope('scope1'):
6058    with tf.name_scope('scope2'):
6059      print(tf.get_name_scope())
6060  ```
6061  would print the string `scope1/scope2`.
6062
6063  Returns:
6064    A string representing the current name scope.
6065  """
6066  if context.executing_eagerly():
6067    return context.context().scope_name.rstrip("/")
6068  return get_default_graph().get_name_scope()
6069
6070
6071def _assert_same_graph(original_item, item):
6072  """Fail if the 2 items are from different graphs.
6073
6074  Args:
6075    original_item: Original item to check against.
6076    item: Item to check.
6077
6078  Raises:
6079    ValueError: if graphs do not match.
6080  """
6081  original_graph = getattr(original_item, "graph", None)
6082  graph = getattr(item, "graph", None)
6083  if original_graph and graph and original_graph is not graph:
6084    raise ValueError(
6085        "%s must be from the same graph as %s (graphs are %s and %s)." %
6086        (item, original_item, graph, original_graph))
6087
6088
6089def _get_graph_from_inputs(op_input_list, graph=None):
6090  """Returns the appropriate graph to use for the given inputs.
6091
6092  This library method provides a consistent algorithm for choosing the graph
6093  in which an Operation should be constructed:
6094
6095  1. If the default graph is being used to construct a function, we
6096     use the default graph.
6097  2. If the "graph" is specified explicitly, we validate that all of the inputs
6098     in "op_input_list" are compatible with that graph.
6099  3. Otherwise, we attempt to select a graph from the first Operation-
6100     or Tensor-valued input in "op_input_list", and validate that all other
6101     such inputs are in the same graph.
6102  4. If the graph was not specified and it could not be inferred from
6103     "op_input_list", we attempt to use the default graph.
6104
6105  Args:
6106    op_input_list: A list of inputs to an operation, which may include `Tensor`,
6107      `Operation`, and other objects that may be converted to a graph element.
6108    graph: (Optional) The explicit graph to use.
6109
6110  Raises:
6111    TypeError: If op_input_list is not a list or tuple, or if graph is not a
6112      Graph.
6113    ValueError: If a graph is explicitly passed and not all inputs are from it,
6114      or if the inputs are from multiple graphs, or we could not find a graph
6115      and there was no default graph.
6116
6117  Returns:
6118    The appropriate graph to use for the given inputs.
6119
6120  """
6121  current_default_graph = get_default_graph()
6122  if current_default_graph.building_function:
6123    return current_default_graph
6124
6125  op_input_list = tuple(op_input_list)  # Handle generators correctly
6126  if graph and not isinstance(graph, Graph):
6127    raise TypeError("Input graph needs to be a Graph: %s" % (graph,))
6128
6129  # 1. We validate that all of the inputs are from the same graph. This is
6130  #    either the supplied graph parameter, or the first one selected from one
6131  #    the graph-element-valued inputs. In the latter case, we hold onto
6132  #    that input in original_graph_element so we can provide a more
6133  #    informative error if a mismatch is found.
6134  original_graph_element = None
6135  for op_input in op_input_list:
6136    # Determine if this is a valid graph_element.
6137    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
6138    # up.
6139    graph_element = None
6140    if (isinstance(op_input, (Operation, internal.NativeObject)) and
6141        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
6142      graph_element = op_input
6143    else:
6144      graph_element = _as_graph_element(op_input)
6145
6146    if graph_element is not None:
6147      if not graph:
6148        original_graph_element = graph_element
6149        graph = getattr(graph_element, "graph", None)
6150      elif original_graph_element is not None:
6151        _assert_same_graph(original_graph_element, graph_element)
6152      elif graph_element.graph is not graph:
6153        raise ValueError("%s is not from the passed-in graph." % graph_element)
6154
6155  # 2. If all else fails, we use the default graph, which is always there.
6156  return graph or current_default_graph
6157
6158
6159@tf_export(v1=["GraphKeys"])
6160class GraphKeys(object):
6161  """Standard names to use for graph collections.
6162
6163  The standard library uses various well-known names to collect and
6164  retrieve values associated with a graph. For example, the
6165  `tf.Optimizer` subclasses default to optimizing the variables
6166  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
6167  specified, but it is also possible to pass an explicit list of
6168  variables.
6169
6170  The following standard keys are defined:
6171
6172  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
6173    across distributed environment (model variables are subset of these). See
6174    `tf.compat.v1.global_variables`
6175    for more details.
6176    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
6177    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
6178  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
6179    machine. Usually used for temporarily variables, like counters.
6180    Note: use `tf.contrib.framework.local_variable` to add to this collection.
6181  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
6182    model for inference (feed forward). Note: use
6183    `tf.contrib.framework.model_variable` to add to this collection.
6184  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
6185    be trained by an optimizer. See
6186    `tf.compat.v1.trainable_variables`
6187    for more details.
6188  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
6189    graph. See
6190    `tf.compat.v1.summary.merge_all`
6191    for more details.
6192  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
6193    produce input for a computation. See
6194    `tf.compat.v1.train.start_queue_runners`
6195    for more details.
6196  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
6197    keep moving averages.  See
6198    `tf.compat.v1.moving_average_variables`
6199    for more details.
6200  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
6201    construction.
6202
6203  The following standard keys are _defined_, but their collections are **not**
6204  automatically populated as many of the others are:
6205
6206  * `WEIGHTS`
6207  * `BIASES`
6208  * `ACTIVATIONS`
6209  """
6210
6211  # Key to collect Variable objects that are global (shared across machines).
6212  # Default collection for all variables, except local ones.
6213  GLOBAL_VARIABLES = "variables"
6214  # Key to collect local variables that are local to the machine and are not
6215  # saved/restored.
6216  LOCAL_VARIABLES = "local_variables"
6217  # Key to collect local variables which are used to accumulate interal state
6218  # to be used in tf.metrics.*.
6219  METRIC_VARIABLES = "metric_variables"
6220  # Key to collect model variables defined by layers.
6221  MODEL_VARIABLES = "model_variables"
6222  # Key to collect Variable objects that will be trained by the
6223  # optimizers.
6224  TRAINABLE_VARIABLES = "trainable_variables"
6225  # Key to collect summaries.
6226  SUMMARIES = "summaries"
6227  # Key to collect QueueRunners.
6228  QUEUE_RUNNERS = "queue_runners"
6229  # Key to collect table initializers.
6230  TABLE_INITIALIZERS = "table_initializer"
6231  # Key to collect asset filepaths. An asset represents an external resource
6232  # like a vocabulary file.
6233  ASSET_FILEPATHS = "asset_filepaths"
6234  # Key to collect Variable objects that keep moving averages.
6235  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
6236  # Key to collect regularization losses at graph construction.
6237  REGULARIZATION_LOSSES = "regularization_losses"
6238  # Key to collect concatenated sharded variables.
6239  CONCATENATED_VARIABLES = "concatenated_variables"
6240  # Key to collect savers.
6241  SAVERS = "savers"
6242  # Key to collect weights
6243  WEIGHTS = "weights"
6244  # Key to collect biases
6245  BIASES = "biases"
6246  # Key to collect activations
6247  ACTIVATIONS = "activations"
6248  # Key to collect update_ops
6249  UPDATE_OPS = "update_ops"
6250  # Key to collect losses
6251  LOSSES = "losses"
6252  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6253  SAVEABLE_OBJECTS = "saveable_objects"
6254  # Key to collect all shared resources used by the graph which need to be
6255  # initialized once per cluster.
6256  RESOURCES = "resources"
6257  # Key to collect all shared resources used in this graph which need to be
6258  # initialized once per session.
6259  LOCAL_RESOURCES = "local_resources"
6260  # Trainable resource-style variables.
6261  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6262
6263  # Key to indicate various ops.
6264  INIT_OP = "init_op"
6265  LOCAL_INIT_OP = "local_init_op"
6266  READY_OP = "ready_op"
6267  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6268  SUMMARY_OP = "summary_op"
6269  GLOBAL_STEP = "global_step"
6270
6271  # Used to count the number of evaluations performed during a single evaluation
6272  # run.
6273  EVAL_STEP = "eval_step"
6274  TRAIN_OP = "train_op"
6275
6276  # Key for control flow context.
6277  COND_CONTEXT = "cond_context"
6278  WHILE_CONTEXT = "while_context"
6279
6280  # Used to store v2 summary names.
6281  _SUMMARY_COLLECTION = "_SUMMARY_V2"
6282
6283  # List of all collections that keep track of variables.
6284  _VARIABLE_COLLECTIONS = [
6285      GLOBAL_VARIABLES,
6286      LOCAL_VARIABLES,
6287      METRIC_VARIABLES,
6288      MODEL_VARIABLES,
6289      TRAINABLE_VARIABLES,
6290      MOVING_AVERAGE_VARIABLES,
6291      CONCATENATED_VARIABLES,
6292      TRAINABLE_RESOURCE_VARIABLES,
6293  ]
6294
6295  # Key for streaming model ports.
6296  # NOTE(yuanbyu): internal and experimental.
6297  _STREAMING_MODEL_PORTS = "streaming_model_ports"
6298
6299  @decorator_utils.classproperty
6300  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6301  def VARIABLES(cls):  # pylint: disable=no-self-argument
6302    return cls.GLOBAL_VARIABLES
6303
6304
6305def dismantle_graph(graph):
6306  """Cleans up reference cycles from a `Graph`.
6307
6308  Helpful for making sure the garbage collector doesn't need to run after a
6309  temporary `Graph` is no longer needed.
6310
6311  Args:
6312    graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6313      after this function runs.
6314  """
6315  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
6316
6317  # Now clean up Operation<->Graph reference cycles by clearing all of the
6318  # attributes for the Graph and its ops.
6319  graph_operations = graph.get_operations()
6320  for op in graph_operations:
6321    op.__dict__ = {}
6322  graph.__dict__ = {}
6323
6324
6325@tf_export(v1=["add_to_collection"])
6326def add_to_collection(name, value):
6327  """Wrapper for `Graph.add_to_collection()` using the default graph.
6328
6329  See `tf.Graph.add_to_collection`
6330  for more details.
6331
6332  Args:
6333    name: The key for the collection. For example, the `GraphKeys` class
6334      contains many standard names for collections.
6335    value: The value to add to the collection.
6336
6337  @compatibility(eager)
6338  Collections are only supported in eager when variables are created inside
6339  an EagerVariableStore (e.g. as part of a layer or template).
6340  @end_compatibility
6341  """
6342  get_default_graph().add_to_collection(name, value)
6343
6344
6345@tf_export(v1=["add_to_collections"])
6346def add_to_collections(names, value):
6347  """Wrapper for `Graph.add_to_collections()` using the default graph.
6348
6349  See `tf.Graph.add_to_collections`
6350  for more details.
6351
6352  Args:
6353    names: The key for the collections. The `GraphKeys` class contains many
6354      standard names for collections.
6355    value: The value to add to the collections.
6356
6357  @compatibility(eager)
6358  Collections are only supported in eager when variables are created inside
6359  an EagerVariableStore (e.g. as part of a layer or template).
6360  @end_compatibility
6361  """
6362  get_default_graph().add_to_collections(names, value)
6363
6364
6365@tf_export(v1=["get_collection_ref"])
6366def get_collection_ref(key):
6367  """Wrapper for `Graph.get_collection_ref()` using the default graph.
6368
6369  See `tf.Graph.get_collection_ref`
6370  for more details.
6371
6372  Args:
6373    key: The key for the collection. For example, the `GraphKeys` class contains
6374      many standard names for collections.
6375
6376  Returns:
6377    The list of values in the collection with the given `name`, or an empty
6378    list if no value has been added to that collection.  Note that this returns
6379    the collection list itself, which can be modified in place to change the
6380    collection.
6381
6382  @compatibility(eager)
6383  Collections are not supported when eager execution is enabled.
6384  @end_compatibility
6385  """
6386  return get_default_graph().get_collection_ref(key)
6387
6388
6389@tf_export(v1=["get_collection"])
6390def get_collection(key, scope=None):
6391  """Wrapper for `Graph.get_collection()` using the default graph.
6392
6393  See `tf.Graph.get_collection`
6394  for more details.
6395
6396  Args:
6397    key: The key for the collection. For example, the `GraphKeys` class contains
6398      many standard names for collections.
6399    scope: (Optional.) If supplied, the resulting list is filtered to include
6400      only items whose `name` attribute matches using `re.match`. Items without
6401      a `name` attribute are never returned if a scope is supplied and the
6402      choice or `re.match` means that a `scope` without special tokens filters
6403      by prefix.
6404
6405  Returns:
6406    The list of values in the collection with the given `name`, or
6407    an empty list if no value has been added to that collection. The
6408    list contains the values in the order under which they were
6409    collected.
6410
6411  @compatibility(eager)
6412  Collections are not supported when eager execution is enabled.
6413  @end_compatibility
6414  """
6415  return get_default_graph().get_collection(key, scope)
6416
6417
6418def get_all_collection_keys():
6419  """Returns a list of collections used in the default graph."""
6420  return get_default_graph().get_all_collection_keys()
6421
6422
6423def name_scope(name, default_name=None, values=None, skip_on_eager=True):
6424  """Internal-only entry point for `name_scope*`.
6425
6426  Internal ops do not use the public API and instead rely on
6427  `ops.name_scope` regardless of the execution mode. This function
6428  dispatches to the correct `name_scope*` implementation based on
6429  the arguments provided and the current mode. Specifically,
6430
6431  * if `values` contains a graph tensor `Graph.name_scope` is used;
6432  * `name_scope_v1` is used in graph mode;
6433  * `name_scope_v2` -- in eager mode.
6434
6435  Args:
6436    name: The name argument that is passed to the op function.
6437    default_name: The default name to use if the `name` argument is `None`.
6438    values: The list of `Tensor` arguments that are passed to the op function.
6439    skip_on_eager: Indicates to return NullContextmanager if executing eagerly.
6440      By default this is True since naming tensors and operations in eager mode
6441      have little use and cause unnecessary performance overhead. However, it is
6442      important to preserve variable names since they are often useful for
6443      debugging and saved models.
6444
6445  Returns:
6446    `name_scope*` context manager.
6447  """
6448  if not context.executing_eagerly():
6449    return internal_name_scope_v1(name, default_name, values)
6450
6451  if skip_on_eager:
6452    return NullContextmanager()
6453
6454  name = default_name if name is None else name
6455  if values:
6456    # The presence of a graph tensor in `values` overrides the context.
6457    # TODO(slebedev): this is Keras-specific and should be removed.
6458    # pylint: disable=unidiomatic-typecheck
6459    graph_value = next((value for value in values if type(value) == Tensor),
6460                       None)
6461    # pylint: enable=unidiomatic-typecheck
6462    if graph_value is not None:
6463      return graph_value.graph.name_scope(name)
6464
6465  return name_scope_v2(name or "")
6466
6467
6468class internal_name_scope_v1(object):  # pylint: disable=invalid-name
6469  """Graph-only version of `name_scope_v1`."""
6470
6471  @property
6472  def name(self):
6473    return self._name
6474
6475  def __init__(self, name, default_name=None, values=None):
6476    """Initialize the context manager.
6477
6478    Args:
6479      name: The name argument that is passed to the op function.
6480      default_name: The default name to use if the `name` argument is `None`.
6481      values: The list of `Tensor` arguments that are passed to the op function.
6482
6483    Raises:
6484      TypeError: if `default_name` is passed in but not a string.
6485    """
6486    if not (default_name is None or isinstance(default_name, six.string_types)):
6487      raise TypeError(
6488          "`default_name` type (%s) is not a string type. You likely meant to "
6489          "pass this into the `values` kwarg." % type(default_name))
6490    self._name = default_name if name is None else name
6491    self._default_name = default_name
6492    self._values = values
6493
6494  def __enter__(self):
6495    """Start the scope block.
6496
6497    Returns:
6498      The scope name.
6499
6500    Raises:
6501      ValueError: if neither `name` nor `default_name` is provided
6502        but `values` are.
6503    """
6504    if self._name is None and self._values is not None:
6505      # We only raise an error if values is not None (provided) because
6506      # currently tf.name_scope(None) (values=None then) is sometimes used as
6507      # an idiom to reset to top scope.
6508      raise ValueError(
6509          "At least one of name (%s) and default_name (%s) must be provided."
6510          % (self._name, self._default_name))
6511
6512    g = get_default_graph()
6513    if self._values and not g.building_function:
6514      # Specialize based on the knowledge that `_get_graph_from_inputs()`
6515      # ignores `inputs` when building a function.
6516      g_from_inputs = _get_graph_from_inputs(self._values)
6517      if g_from_inputs is not g:
6518        g = g_from_inputs
6519        self._g_manager = g.as_default()
6520        self._g_manager.__enter__()
6521      else:
6522        self._g_manager = None
6523    else:
6524      self._g_manager = None
6525
6526    try:
6527      self._name_scope = g.name_scope(self._name)
6528      return self._name_scope.__enter__()
6529    except:
6530      if self._g_manager is not None:
6531        self._g_manager.__exit__(*sys.exc_info())
6532      raise
6533
6534  def __exit__(self, *exc_info):
6535    self._name_scope.__exit__(*exc_info)
6536    if self._g_manager is not None:
6537      self._g_manager.__exit__(*exc_info)
6538
6539
6540# Named like a function for backwards compatibility with the
6541# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6542# some object creation overhead.
6543@tf_export(v1=["name_scope"])
6544class name_scope_v1(object):  # pylint: disable=invalid-name
6545  """A context manager for use when defining a Python op.
6546
6547  This context manager validates that the given `values` are from the
6548  same graph, makes that graph the default graph, and pushes a
6549  name scope in that graph (see
6550  `tf.Graph.name_scope`
6551  for more details on that).
6552
6553  For example, to define a new Python op called `my_op`:
6554
6555  ```python
6556  def my_op(a, b, c, name=None):
6557    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6558      a = tf.convert_to_tensor(a, name="a")
6559      b = tf.convert_to_tensor(b, name="b")
6560      c = tf.convert_to_tensor(c, name="c")
6561      # Define some computation that uses `a`, `b`, and `c`.
6562      return foo_op(..., name=scope)
6563  ```
6564  """
6565
6566  __slots__ = ["_name", "_name_scope"]
6567
6568  @property
6569  def name(self):
6570    return self._name
6571
6572  def __init__(self, name, default_name=None, values=None):
6573    """Initialize the context manager.
6574
6575    Args:
6576      name: The name argument that is passed to the op function.
6577      default_name: The default name to use if the `name` argument is `None`.
6578      values: The list of `Tensor` arguments that are passed to the op function.
6579
6580    Raises:
6581      TypeError: if `default_name` is passed in but not a string.
6582    """
6583    self._name_scope = name_scope(
6584        name, default_name, values, skip_on_eager=False)
6585    self._name = default_name if name is None else name
6586
6587  def __enter__(self):
6588    return self._name_scope.__enter__()
6589
6590  def __exit__(self, *exc_info):
6591    return self._name_scope.__exit__(*exc_info)
6592
6593
6594@tf_export("name_scope", v1=[])
6595class name_scope_v2(object):
6596  """A context manager for use when defining a Python op.
6597
6598  This context manager pushes a name scope, which will make the name of all
6599  operations added within it have a prefix.
6600
6601  For example, to define a new Python op called `my_op`:
6602
6603  ```python
6604  def my_op(a, b, c, name=None):
6605    with tf.name_scope("MyOp") as scope:
6606      a = tf.convert_to_tensor(a, name="a")
6607      b = tf.convert_to_tensor(b, name="b")
6608      c = tf.convert_to_tensor(c, name="c")
6609      # Define some computation that uses `a`, `b`, and `c`.
6610      return foo_op(..., name=scope)
6611  ```
6612
6613  When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6614  and `MyOp/c`.
6615
6616  Inside a `tf.function`, if the scope name already exists, the name will be
6617  made unique by appending `_n`. For example, calling `my_op` the second time
6618  will generate `MyOp_1/a`, etc.
6619  """
6620
6621  __slots__ = ["_name", "_exit_fns"]
6622
6623  def __init__(self, name):
6624    """Initialize the context manager.
6625
6626    Args:
6627      name: The prefix to use on all names created within the name scope.
6628
6629    Raises:
6630      ValueError: If name is not a string.
6631    """
6632    if not isinstance(name, six.string_types):
6633      raise ValueError("name for name_scope must be a string.")
6634    self._name = name
6635    self._exit_fns = []
6636
6637  @property
6638  def name(self):
6639    return self._name
6640
6641  def __enter__(self):
6642    """Start the scope block.
6643
6644    Returns:
6645      The scope name.
6646    """
6647    ctx = context.context()
6648    if ctx.executing_eagerly():
6649      # Names are not auto-incremented in eager mode.
6650      # A trailing slash breaks out of nested name scopes, indicating a
6651      # fully specified scope name, for compatibility with Graph.name_scope.
6652      # This also prevents auto-incrementing.
6653      old_name = ctx.scope_name
6654      name = self._name
6655      if not name:
6656        scope_name = ""
6657      elif name[-1] == "/":
6658        scope_name = name
6659      elif old_name:
6660        scope_name = old_name + name + "/"
6661      else:
6662        scope_name = name + "/"
6663      ctx.scope_name = scope_name
6664
6665      def _restore_name_scope(*_):
6666        ctx.scope_name = old_name
6667
6668      self._exit_fns.append(_restore_name_scope)
6669    else:
6670      scope = get_default_graph().name_scope(self._name)
6671      scope_name = scope.__enter__()
6672      self._exit_fns.append(scope.__exit__)
6673    return scope_name
6674
6675  def __exit__(self, type_arg, value_arg, traceback_arg):
6676    self._exit_fns.pop()(type_arg, value_arg, traceback_arg)
6677    return False  # False values do not suppress exceptions
6678
6679  def __getstate__(self):
6680    return self._name, self._exit_fns
6681
6682  def __setstate__(self, state):
6683    self._name = state[0]
6684    self._exit_fns = state[1]
6685
6686
6687def strip_name_scope(name, export_scope):
6688  """Removes name scope from a name.
6689
6690  Args:
6691    name: A `string` name.
6692    export_scope: Optional `string`. Name scope to remove.
6693
6694  Returns:
6695    Name with name scope removed, or the original name if export_scope
6696    is None.
6697  """
6698  if export_scope:
6699    if export_scope[-1] == "/":
6700      export_scope = export_scope[:-1]
6701
6702    try:
6703      # Strips export_scope/, export_scope///,
6704      # ^export_scope/, loc:@export_scope/.
6705      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
6706      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
6707    except TypeError as e:
6708      # If the name is not of a type we can process, simply return it.
6709      logging.warning(e)
6710      return name
6711  else:
6712    return name
6713
6714
6715def prepend_name_scope(name, import_scope):
6716  """Prepends name scope to a name.
6717
6718  Args:
6719    name: A `string` name.
6720    import_scope: Optional `string`. Name scope to add.
6721
6722  Returns:
6723    Name with name scope added, or the original name if import_scope
6724    is None.
6725  """
6726  if import_scope:
6727    if import_scope[-1] == "/":
6728      import_scope = import_scope[:-1]
6729
6730    try:
6731      str_to_replace = r"([\^]|loc:@|^)(.*)"
6732      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
6733                    compat.as_str(name))
6734    except TypeError as e:
6735      # If the name is not of a type we can process, simply return it.
6736      logging.warning(e)
6737      return name
6738  else:
6739    return name
6740
6741
6742# pylint: disable=g-doc-return-or-yield
6743# pylint: disable=not-context-manager
6744@tf_export(v1=["op_scope"])
6745@tf_contextlib.contextmanager
6746def op_scope(values, name, default_name=None):
6747  """DEPRECATED. Same as name_scope above, just different argument order."""
6748  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
6749               " use tf.name_scope(name, default_name, values)")
6750  with name_scope(name, default_name=default_name, values=values) as scope:
6751    yield scope
6752
6753
6754_proto_function_registry = registry.Registry("proto functions")
6755
6756
6757def register_proto_function(collection_name,
6758                            proto_type=None,
6759                            to_proto=None,
6760                            from_proto=None):
6761  """Registers `to_proto` and `from_proto` functions for collection_name.
6762
6763  `to_proto` function converts a Python object to the corresponding protocol
6764  buffer, and returns the protocol buffer.
6765
6766  `from_proto` function converts protocol buffer into a Python object, and
6767  returns the object..
6768
6769  Args:
6770    collection_name: Name of the collection.
6771    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
6772      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
6773    to_proto: Function that implements Python object to protobuf conversion.
6774    from_proto: Function that implements protobuf to Python object conversion.
6775  """
6776  if to_proto and not callable(to_proto):
6777    raise TypeError("to_proto must be callable.")
6778  if from_proto and not callable(from_proto):
6779    raise TypeError("from_proto must be callable.")
6780
6781  _proto_function_registry.register((proto_type, to_proto, from_proto),
6782                                    collection_name)
6783
6784
6785def get_collection_proto_type(collection_name):
6786  """Returns the proto_type for collection_name."""
6787  try:
6788    return _proto_function_registry.lookup(collection_name)[0]
6789  except LookupError:
6790    return None
6791
6792
6793def get_to_proto_function(collection_name):
6794  """Returns the to_proto function for collection_name."""
6795  try:
6796    return _proto_function_registry.lookup(collection_name)[1]
6797  except LookupError:
6798    return None
6799
6800
6801def get_from_proto_function(collection_name):
6802  """Returns the from_proto function for collection_name."""
6803  try:
6804    return _proto_function_registry.lookup(collection_name)[2]
6805  except LookupError:
6806    return None
6807
6808
6809def _op_to_colocate_with(v, graph):
6810  """Operation object corresponding to v to use for colocation constraints."""
6811  if v is None:
6812    return None, None
6813  if isinstance(v, Operation):
6814    return v, None
6815
6816  # We always want to colocate with the reference op.
6817  # When 'v' is a ResourceVariable, the reference op is the handle creating op.
6818  #
6819  # What this should be is:
6820  # if isinstance(v, ResourceVariable):
6821  #   return v.handle.op, v
6822  # However, that would require a circular import dependency.
6823  # As of October 2018, there were attempts underway to remove
6824  # colocation constraints altogether. Assuming that will
6825  # happen soon, perhaps this hack to work around the circular
6826  # import dependency is acceptable.
6827  if hasattr(v, "handle") and isinstance(v.handle, Tensor):
6828    device_only_candidate = lambda: None
6829    device_only_candidate.device = v.device
6830    device_only_candidate.name = v.name
6831    if graph.building_function:
6832      return graph.capture(v.handle).op, device_only_candidate
6833    else:
6834      return v.handle.op, device_only_candidate
6835  return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None
6836
6837
6838def _is_keras_symbolic_tensor(x):
6839  return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
6840
6841
6842# These symbols were originally defined in this module; import them for
6843# backwards compatibility until all references have been updated to access
6844# them from the indexed_slices.py module.
6845IndexedSlices = indexed_slices.IndexedSlices
6846IndexedSlicesValue = indexed_slices.IndexedSlicesValue
6847convert_to_tensor_or_indexed_slices = \
6848    indexed_slices.convert_to_tensor_or_indexed_slices
6849convert_n_to_tensor_or_indexed_slices = \
6850    indexed_slices.convert_n_to_tensor_or_indexed_slices
6851internal_convert_to_tensor_or_indexed_slices = \
6852    indexed_slices.internal_convert_to_tensor_or_indexed_slices
6853internal_convert_n_to_tensor_or_indexed_slices = \
6854    indexed_slices.internal_convert_n_to_tensor_or_indexed_slices
6855register_tensor_conversion_function = \
6856    tensor_conversion_registry.register_tensor_conversion_function
6857
6858
6859# Helper functions for op wrapper modules generated by `python_op_gen`.
6860
6861
6862def to_raw_op(f):
6863  """Make a given op wrapper function `f` raw.
6864
6865  Raw op wrappers can only be called with keyword arguments.
6866
6867  Args:
6868    f: An op wrapper function to make raw.
6869
6870  Returns:
6871    Raw `f`.
6872  """
6873  # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail
6874  # due to double-registration.
6875  f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__,
6876                         f.__closure__)
6877  return kwarg_only(f)
6878
6879
6880def raise_from_not_ok_status(e, name):
6881  message = e.message + (" name: " + name if name is not None else "")
6882  # pylint: disable=protected-access
6883  six.raise_from(core._status_to_exception(e.code, message), None)
6884  # pylint: enable=protected-access
6885
6886
6887def add_exit_callback_to_default_func_graph(fn):
6888  """Add a callback to run when the default function graph goes out of scope.
6889
6890  Usage:
6891
6892  ```python
6893  @tf.function
6894  def fn(x, v):
6895    expensive = expensive_object(v)
6896    add_exit_callback_to_default_func_graph(lambda: expensive.release())
6897    return g(x, expensive)
6898
6899  fn(x=tf.constant(...), v=...)
6900  # `expensive` has been released.
6901  ```
6902
6903  Args:
6904    fn: A callable that takes no arguments and whose output is ignored.
6905      To be executed when exiting func graph scope.
6906
6907  Raises:
6908    RuntimeError: If executed when the current default graph is not a FuncGraph,
6909      or not currently executing in function creation mode (e.g., if inside
6910      an init_scope).
6911  """
6912  default_graph = get_default_graph()
6913  if not default_graph._building_function:  # pylint: disable=protected-access
6914    raise RuntimeError(
6915        "Cannot add scope exit callbacks when not building a function.  "
6916        "Default graph: {}".format(default_graph))
6917  default_graph._add_scope_exit_callback(fn)  # pylint: disable=protected-access
6918
6919
6920def _reconstruct_sequence_inputs(op_def, inputs, attrs):
6921  """Regroups a flat list of input tensors into scalar and sequence inputs.
6922
6923  Args:
6924    op_def: The `op_def_pb2.OpDef` (for knowing the input types)
6925    inputs: a list of input `Tensor`s to the op.
6926    attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
6927      how long each sequence is)
6928
6929  Returns:
6930    A list of `Tensor`s (corresponding to scalar inputs) and lists of
6931    `Tensor`s (corresponding to sequence inputs).
6932  """
6933  grouped_inputs = []
6934  i = 0
6935  for input_arg in op_def.input_arg:
6936    if input_arg.number_attr:
6937      input_len = attrs[input_arg.number_attr].i
6938      is_sequence = True
6939    elif input_arg.type_list_attr:
6940      input_len = len(attrs[input_arg.type_list_attr].list.type)
6941      is_sequence = True
6942    else:
6943      input_len = 1
6944      is_sequence = False
6945
6946    if is_sequence:
6947      grouped_inputs.append(inputs[i:i + input_len])
6948    else:
6949      grouped_inputs.append(inputs[i])
6950    i += input_len
6951
6952  assert i == len(inputs)
6953  return grouped_inputs
6954
6955
6956_numpy_style_type_promotion = False
6957
6958
6959def enable_numpy_style_type_promotion():
6960  """If called, follows NumPy's rules for type promotion.
6961
6962  Used for enabling NumPy behavior on methods for TF NumPy.
6963  """
6964  global _numpy_style_type_promotion
6965  _numpy_style_type_promotion = True
6966
6967
6968_numpy_style_slicing = False
6969
6970
6971def enable_numpy_style_slicing():
6972  """If called, follows NumPy's rules for slicing Tensors.
6973
6974  Used for enabling NumPy behavior on slicing for TF NumPy.
6975  """
6976  global _numpy_style_slicing
6977  _numpy_style_slicing = True
6978
6979
6980class _TensorIterator(object):
6981  """Iterates over the leading dim of a Tensor. Performs no error checks."""
6982
6983  __slots__ = ["_tensor", "_index", "_limit"]
6984
6985  def __init__(self, tensor, dim0):
6986    self._tensor = tensor
6987    self._index = 0
6988    self._limit = dim0
6989
6990  def __iter__(self):
6991    return self
6992
6993  def __next__(self):
6994    if self._index == self._limit:
6995      raise StopIteration
6996    result = self._tensor[self._index]
6997    self._index += 1
6998    return result
6999
7000  next = __next__  # python2.x compatibility.
7001
7002
7003def set_int_list_attr(op, attr_name, ints):
7004  """TF internal method used to set a list(int) attribute in the node_def."""
7005  ints_list = attr_value_pb2.AttrValue.ListValue(i=ints)
7006  op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list))  # pylint:disable=protected-access
7007
7008
7009def _get_enclosing_context(graph):
7010  # pylint: disable=protected-access
7011  if graph is None:
7012    return None
7013
7014  if graph._control_flow_context is not None:
7015    return graph._control_flow_context
7016
7017  if graph.building_function and hasattr(graph, "outer_graph"):
7018    return _get_enclosing_context(graph.outer_graph)
7019
7020
7021def get_resource_handle_data(graph_op):
7022  assert type(graph_op) == Tensor  # pylint: disable=unidiomatic-typecheck
7023
7024  handle_data = pywrap_tf_session.GetHandleShapeAndType(
7025      graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
7026
7027  return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
7028      compat.as_bytes(handle_data))
7029
7030
7031def _copy_handle_data_to_arg_def(tensor, arg_def):
7032  handle_data = get_resource_handle_data(tensor)
7033  if handle_data.shape_and_type:
7034    shape_and_type = handle_data.shape_and_type[0]
7035    proto = arg_def.handle_data.add()
7036    proto.dtype = shape_and_type.dtype
7037    proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
7038