• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions used to construct graphs."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import re
24import sys
25import threading
26
27import numpy as np
28import six
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from tensorflow.core.framework import attr_value_pb2
32from tensorflow.core.framework import function_pb2
33from tensorflow.core.framework import graph_pb2
34from tensorflow.core.framework import node_def_pb2
35from tensorflow.core.framework import op_def_pb2
36from tensorflow.core.framework import versions_pb2
37from tensorflow.core.protobuf import config_pb2
38from tensorflow.python import pywrap_tensorflow as c_api
39from tensorflow.python import tf2
40from tensorflow.python.eager import context
41from tensorflow.python.eager import core
42from tensorflow.python.eager import tape
43from tensorflow.python.framework import c_api_util
44from tensorflow.python.framework import composite_tensor
45from tensorflow.python.framework import device as pydev
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import op_def_registry
49from tensorflow.python.framework import registry
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import traceable_stack
52from tensorflow.python.framework import versions
53from tensorflow.python.ops import control_flow_util
54from tensorflow.python.platform import app
55from tensorflow.python.platform import tf_logging as logging
56from tensorflow.python.util import compat
57from tensorflow.python.util import decorator_utils
58from tensorflow.python.util import deprecation
59from tensorflow.python.util import function_utils
60from tensorflow.python.util import lock_util
61from tensorflow.python.util import memory
62from tensorflow.python.util import tf_contextlib
63from tensorflow.python.util import tf_stack
64from tensorflow.python.util.deprecation import deprecated_args
65from tensorflow.python.util.tf_export import tf_export
66
67
68# Temporary global switches determining if we should enable the work-in-progress
69# calls to the C API. These will be removed once all functionality is supported.
70_USE_C_API = True
71_USE_C_SHAPES = True
72
73
74def tensor_id(tensor):
75  """Returns a unique identifier for this Tensor."""
76  return tensor._id  # pylint: disable=protected-access
77
78
79class _UserDeviceSpec(object):
80  """Store user-specified device and provide computation of merged device."""
81
82  def __init__(self, device_name_or_function):
83    self._device_name_or_function = device_name_or_function
84
85    self.display_name = str(self._device_name_or_function)
86    if callable(self._device_name_or_function):
87      dev_func = self._device_name_or_function
88      func_name = function_utils.get_func_name(dev_func)
89      func_code = function_utils.get_func_code(dev_func)
90      if func_code:
91        fname = func_code.co_filename
92        lineno = func_code.co_firstlineno
93      else:
94        fname = "unknown"
95        lineno = -1
96      self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
97
98    self.raw_string = None
99
100    self.function = self._device_name_or_function
101    if not (self._device_name_or_function is None or
102            callable(self._device_name_or_function)):
103      self.raw_string = self._device_name_or_function
104      self.function = pydev.merge_device(self._device_name_or_function)
105
106
107class NullContextmanager(object):
108
109  def __init__(self, *args, **kwargs):
110    pass
111
112  def __enter__(self):
113    pass
114
115  def __exit__(self, type_arg, value_arg, traceback_arg):
116    return False  # False values do not suppress exceptions
117
118
119def _override_helper(clazz_object, operator, func):
120  """Overrides (string) operator on Tensors to call func.
121
122  Args:
123    clazz_object: the class to override for; either Tensor or SparseTensor.
124    operator: the string name of the operator to override.
125    func: the function that replaces the overridden operator.
126
127  Raises:
128    ValueError: If operator has already been overwritten,
129      or if operator is not allowed to be overwritten.
130  """
131  existing = getattr(clazz_object, operator, None)
132  if existing is not None:
133    # Check to see if this is a default method-wrapper or slot wrapper which
134    # will be true for the comparison operators.
135    if not isinstance(existing, type(object.__lt__)):
136      raise ValueError("operator %s cannot be overwritten again on class %s." %
137                       (operator, clazz_object))
138  if operator not in Tensor.OVERLOADABLE_OPERATORS:
139    raise ValueError("Overriding %s is disallowed" % operator)
140  setattr(clazz_object, operator, func)
141
142
143def _as_graph_element(obj):
144  """Convert `obj` to a graph element if possible, otherwise return `None`.
145
146  Args:
147    obj: Object to convert.
148
149  Returns:
150    The result of `obj._as_graph_element()` if that method is available;
151        otherwise `None`.
152  """
153  conv_fn = getattr(obj, "_as_graph_element", None)
154  if conv_fn and callable(conv_fn):
155    return conv_fn()
156  return None
157
158
159_TENSOR_LIKE_TYPES = tuple()
160
161
162def is_dense_tensor_like(t):
163  """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
164
165  See `register_dense_tensor_like_type()` for the current definition of a
166  "tensor-like type".
167
168  Args:
169    t: An object.
170
171  Returns:
172    True iff `t` is an instance of one of the registered "tensor-like" types.
173  """
174  return isinstance(t, _TENSOR_LIKE_TYPES)
175
176
177def register_dense_tensor_like_type(tensor_type):
178  """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
179
180  A "tensor-like type" can represent a single dense tensor, and implements
181  the `name` and `dtype` properties.
182
183  Args:
184    tensor_type: A type implementing the tensor interface.
185
186  Raises:
187    TypeError: If `tensor_type` does not implement the tensor interface.
188  """
189  try:
190    if not isinstance(tensor_type.name, property):
191      raise TypeError("Type %s does not define a `name` property" %
192                      tensor_type.__name__)
193  except AttributeError:
194    raise TypeError("Type %s does not define a `name` property" %
195                    tensor_type.__name__)
196  try:
197    if not isinstance(tensor_type.dtype, property):
198      raise TypeError("Type %s does not define a `dtype` property" %
199                      tensor_type.__name__)
200  except AttributeError:
201    raise TypeError("Type %s does not define a `dtype` property" %
202                    tensor_type.__name__)
203  # We expect this list to be small, so choose quadratic complexity
204  # for registration, so that we have a tuple that can be used for
205  # more efficient `isinstance` checks later.
206  global _TENSOR_LIKE_TYPES
207  _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
208
209
210def uid():
211  """A unique (within this program execution) integer."""
212  return c_api.TFE_Py_UID()
213
214
215def numpy_text(tensor, is_repr=False):
216  """Human readable representation of a tensor's numpy value."""
217  if tensor.dtype.is_numpy_compatible:
218    text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
219  else:
220    text = "<unprintable>"
221  if "\n" in text:
222    text = "\n" + text
223  return text
224
225
226# NOTE(ebrevdo): Do not subclass this.  If you do, I will break you on purpose.
227class _TensorLike(object):
228  """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
229  pass
230
231
232@tf_export("Tensor")
233class Tensor(_TensorLike):
234  """Represents one of the outputs of an `Operation`.
235
236  A `Tensor` is a symbolic handle to one of the outputs of an
237  `Operation`. It does not hold the values of that operation's output,
238  but instead provides a means of computing those values in a
239  TensorFlow `tf.Session`.
240
241  This class has two primary purposes:
242
243  1. A `Tensor` can be passed as an input to another `Operation`.
244     This builds a dataflow connection between operations, which
245     enables TensorFlow to execute an entire `Graph` that represents a
246     large, multi-step computation.
247
248  2. After the graph has been launched in a session, the value of the
249     `Tensor` can be computed by passing it to
250     `tf.Session.run`.
251     `t.eval()` is a shortcut for calling
252     `tf.get_default_session().run(t)`.
253
254  In the following example, `c`, `d`, and `e` are symbolic `Tensor`
255  objects, whereas `result` is a numpy array that stores a concrete
256  value:
257
258  ```python
259  # Build a dataflow graph.
260  c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
261  d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
262  e = tf.matmul(c, d)
263
264  # Construct a `Session` to execute the graph.
265  sess = tf.Session()
266
267  # Execute the graph and store the value that `e` represents in `result`.
268  result = sess.run(e)
269  ```
270  """
271
272  # List of Python operators that we allow to override.
273  OVERLOADABLE_OPERATORS = {
274      # Binary.
275      "__add__",
276      "__radd__",
277      "__sub__",
278      "__rsub__",
279      "__mul__",
280      "__rmul__",
281      "__div__",
282      "__rdiv__",
283      "__truediv__",
284      "__rtruediv__",
285      "__floordiv__",
286      "__rfloordiv__",
287      "__mod__",
288      "__rmod__",
289      "__lt__",
290      "__le__",
291      "__gt__",
292      "__ge__",
293      "__and__",
294      "__rand__",
295      "__or__",
296      "__ror__",
297      "__xor__",
298      "__rxor__",
299      "__getitem__",
300      "__pow__",
301      "__rpow__",
302      # Unary.
303      "__invert__",
304      "__neg__",
305      "__abs__",
306      "__matmul__",
307      "__rmatmul__"
308  }
309
310  def __init__(self, op, value_index, dtype):
311    """Creates a new `Tensor`.
312
313    Args:
314      op: An `Operation`. `Operation` that computes this tensor.
315      value_index: An `int`. Index of the operation's endpoint that produces
316        this tensor.
317      dtype: A `DType`. Type of elements stored in this tensor.
318
319    Raises:
320      TypeError: If the op is not an `Operation`.
321    """
322    if not isinstance(op, Operation):
323      raise TypeError("op needs to be an Operation: %s" % op)
324    self._op = op
325    self._value_index = value_index
326    self._dtype = dtypes.as_dtype(dtype)
327    # This will be set by self._as_tf_output().
328    self._tf_output = None
329    # This will be set by self.shape().
330    self._shape_val = None
331    # List of operations that use this Tensor as input.  We maintain this list
332    # to easily navigate a computation graph.
333    self._consumers = []
334    self._id = uid()
335
336  @property
337  def op(self):
338    """The `Operation` that produces this tensor as an output."""
339    return self._op
340
341  @property
342  def dtype(self):
343    """The `DType` of elements in this tensor."""
344    return self._dtype
345
346  @property
347  def graph(self):
348    """The `Graph` that contains this tensor."""
349    return self._op.graph
350
351  @property
352  def name(self):
353    """The string name of this tensor."""
354    if not self._op.name:
355      raise ValueError("Operation was not named: %s" % self._op)
356    return "%s:%d" % (self._op.name, self._value_index)
357
358  @property
359  def device(self):
360    """The name of the device on which this tensor will be produced, or None."""
361    return self._op.device
362
363  @property
364  def shape(self):
365    """Returns the `TensorShape` that represents the shape of this tensor.
366
367    The shape is computed using shape inference functions that are
368    registered in the Op for each `Operation`.  See
369    `tf.TensorShape`
370    for more details of what a shape represents.
371
372    The inferred shape of a tensor is used to provide shape
373    information without having to launch the graph in a session. This
374    can be used for debugging, and providing early error messages. For
375    example:
376
377    ```python
378    c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
379
380    print(c.shape)
381    ==> TensorShape([Dimension(2), Dimension(3)])
382
383    d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
384
385    print(d.shape)
386    ==> TensorShape([Dimension(4), Dimension(2)])
387
388    # Raises a ValueError, because `c` and `d` do not have compatible
389    # inner dimensions.
390    e = tf.matmul(c, d)
391
392    f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
393
394    print(f.shape)
395    ==> TensorShape([Dimension(3), Dimension(4)])
396    ```
397
398    In some cases, the inferred shape may have unknown dimensions. If
399    the caller has additional information about the values of these
400    dimensions, `Tensor.set_shape()` can be used to augment the
401    inferred shape.
402
403    Returns:
404      A `TensorShape` representing the shape of this tensor.
405
406    """
407    if self._shape_val is None:
408      self._shape_val = self._c_api_shape()
409    return self._shape_val
410
411  def _get_input_ops_without_shapes(self, target_op):
412    """Returns ops needing shape inference to compute target_op's shape."""
413    result = []
414    stack = [self._op]
415    visited = set()
416    while stack:
417      op = stack.pop()
418      if op in visited: continue
419      result.append(op)
420      stack.extend(t.op for t in op.inputs if t._shape_val is None)
421      visited.add(op)
422    return result
423
424  def _c_api_shape(self):
425    """Returns the TensorShape of this tensor according to the C API."""
426    c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
427    shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
428        c_graph, self._as_tf_output())
429    if unknown_shape:
430      return tensor_shape.unknown_shape()
431    else:
432      shape_vector = [None if d == -1 else d for d in shape_vector]
433      return tensor_shape.TensorShape(shape_vector)
434
435  @property
436  def _shape(self):
437    logging.warning("Tensor._shape is private, use Tensor.shape "
438                    "instead. Tensor._shape will eventually be removed.")
439    return self.shape
440
441  @_shape.setter
442  def _shape(self, value):
443    raise ValueError(
444        "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
445
446  def __iter__(self):
447    if not context.executing_eagerly():
448      raise TypeError(
449          "Tensor objects are only iterable when eager execution is "
450          "enabled. To iterate over this tensor use tf.map_fn.")
451    shape = self._shape_tuple()
452    if shape is None:
453      raise TypeError("Cannot iterate over a tensor with unknown shape.")
454    if not shape:
455      raise TypeError("Cannot iterate over a scalar tensor.")
456    if shape[0] is None:
457      raise TypeError(
458          "Cannot iterate over a tensor with unknown first dimension.")
459    for i in xrange(shape[0]):
460      yield self[i]
461
462  def _shape_as_list(self):
463    if self.shape.ndims is not None:
464      return [dim.value for dim in self.shape.dims]
465    else:
466      return None
467
468  def _shape_tuple(self):
469    shape = self._shape_as_list()
470    if shape is None:
471      return None
472    return tuple(shape)
473
474  def _rank(self):
475    """Integer rank of this Tensor, if known, else None.
476
477    Returns:
478      Integer rank or None
479    """
480    return self.shape.ndims
481
482  def get_shape(self):
483    """Alias of Tensor.shape."""
484    return self.shape
485
486  def set_shape(self, shape):
487    """Updates the shape of this tensor.
488
489    This method can be called multiple times, and will merge the given
490    `shape` with the current shape of this tensor. It can be used to
491    provide additional information about the shape of this tensor that
492    cannot be inferred from the graph alone. For example, this can be used
493    to provide additional information about the shapes of images:
494
495    ```python
496    _, image_data = tf.TFRecordReader(...).read(...)
497    image = tf.image.decode_png(image_data, channels=3)
498
499    # The height and width dimensions of `image` are data dependent, and
500    # cannot be computed without executing the op.
501    print(image.shape)
502    ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
503
504    # We know that each image in this dataset is 28 x 28 pixels.
505    image.set_shape([28, 28, 3])
506    print(image.shape)
507    ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
508    ```
509
510    NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
511    result in inconsistencies between the statically-known graph and the runtime
512    value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
513    instead.
514
515    Args:
516      shape: A `TensorShape` representing the shape of this tensor, a
517      `TensorShapeProto`, a list, a tuple, or None.
518
519    Raises:
520      ValueError: If `shape` is not compatible with the current shape of
521        this tensor.
522    """
523    # Reset cached shape.
524    self._shape_val = None
525
526    # We want set_shape to be reflected in the C API graph for when we run it.
527    if not isinstance(shape, tensor_shape.TensorShape):
528      shape = tensor_shape.TensorShape(shape)
529    dim_list = []
530    if shape.dims is None:
531      unknown_shape = True
532    else:
533      unknown_shape = False
534      for dim in shape.dims:
535        if dim.value is None:
536          dim_list.append(-1)
537        else:
538          dim_list.append(dim.value)
539    try:
540      c_api.TF_GraphSetTensorShape_wrapper(
541          self._op._graph._c_graph,  # pylint: disable=protected-access
542          self._as_tf_output(),
543          dim_list,
544          unknown_shape)
545    except errors.InvalidArgumentError as e:
546      # Convert to ValueError for backwards compatibility.
547      raise ValueError(str(e))
548
549  @property
550  def value_index(self):
551    """The index of this tensor in the outputs of its `Operation`."""
552    return self._value_index
553
554  def consumers(self):
555    """Returns a list of `Operation`s that consume this tensor.
556
557    Returns:
558      A list of `Operation`s.
559    """
560    consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
561        self._as_tf_output())
562    # pylint: disable=protected-access
563    return [
564        self.graph._get_operation_by_name_unsafe(name)
565        for name in consumer_names
566    ]
567    # pylint: enable=protected-access
568
569  def _as_node_def_input(self):
570    """Return a value to use for the NodeDef "input" attribute.
571
572    The returned string can be used in a NodeDef "input" attribute
573    to indicate that the NodeDef uses this Tensor as input.
574
575    Raises:
576      ValueError: if this Tensor's Operation does not have a name.
577
578    Returns:
579      a string.
580    """
581    if not self._op.name:
582      raise ValueError("Operation was not named: %s" % self._op)
583    if self._value_index == 0:
584      return self._op.name
585    else:
586      return "%s:%d" % (self._op.name, self._value_index)
587
588  def _as_tf_output(self):
589    # pylint: disable=protected-access
590    # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
591    # also guarantees that a dictionary of tf_output objects will retain a
592    # deterministic (yet unsorted) order which prevents memory blowup in the
593    # cache of executor(s) stored for every session.
594    if self._tf_output is None:
595      self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
596    return self._tf_output
597    # pylint: enable=protected-access
598
599  def __str__(self):
600    return "Tensor(\"%s\"%s%s%s)" % (
601        self.name, (", shape=%s" % self.get_shape())
602        if self.get_shape().ndims is not None else "",
603        (", dtype=%s" % self._dtype.name)
604        if self._dtype else "", (", device=%s" % self.device)
605        if self.device else "")
606
607  def __repr__(self):
608    return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
609                                                   self._dtype.name)
610
611  def __hash__(self):
612    # Necessary to support Python's collection membership operators
613    return id(self)
614
615  def __eq__(self, other):
616    # Necessary to support Python's collection membership operators
617    return id(self) == id(other)
618
619  def __copy__(self):
620    # TODO(b/77597810): get rid of Tensor copies.
621    cls = self.__class__
622    result = cls.__new__(cls)
623    result.__dict__.update(self.__dict__)
624    return result
625
626  # NOTE(mrry): This enables the Tensor's overloaded "right" binary
627  # operators to run when the left operand is an ndarray, because it
628  # accords the Tensor class higher priority than an ndarray, or a
629  # numpy matrix.
630  # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
631  # mechanism, which allows more control over how Tensors interact
632  # with ndarrays.
633  __array_priority__ = 100
634
635  @staticmethod
636  def _override_operator(operator, func):
637    _override_helper(Tensor, operator, func)
638
639  def __bool__(self):
640    """Dummy method to prevent a tensor from being used as a Python `bool`.
641
642    This overload raises a `TypeError` when the user inadvertently
643    treats a `Tensor` as a boolean (e.g. in an `if` statement). For
644    example:
645
646    ```python
647    if tf.constant(True):  # Will raise.
648      # ...
649
650    if tf.constant(5) < tf.constant(7):  # Will raise.
651      # ...
652    ```
653
654    This disallows ambiguities between testing the Python value vs testing the
655    dynamic condition of the `Tensor`.
656
657    Raises:
658      `TypeError`.
659    """
660    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
661                    "Use `if t is not None:` instead of `if t:` to test if a "
662                    "tensor is defined, and use TensorFlow ops such as "
663                    "tf.cond to execute subgraphs conditioned on the value of "
664                    "a tensor.")
665
666  def __nonzero__(self):
667    """Dummy method to prevent a tensor from being used as a Python `bool`.
668
669    This is the Python 2.x counterpart to `__bool__()` above.
670
671    Raises:
672      `TypeError`.
673    """
674    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
675                    "Use `if t is not None:` instead of `if t:` to test if a "
676                    "tensor is defined, and use TensorFlow ops such as "
677                    "tf.cond to execute subgraphs conditioned on the value of "
678                    "a tensor.")
679
680  def eval(self, feed_dict=None, session=None):
681    """Evaluates this tensor in a `Session`.
682
683    Calling this method will execute all preceding operations that
684    produce the inputs needed for the operation that produces this
685    tensor.
686
687    *N.B.* Before invoking `Tensor.eval()`, its graph must have been
688    launched in a session, and either a default session must be
689    available, or `session` must be specified explicitly.
690
691    Args:
692      feed_dict: A dictionary that maps `Tensor` objects to feed values.
693        See `tf.Session.run` for a
694        description of the valid feed values.
695      session: (Optional.) The `Session` to be used to evaluate this tensor. If
696        none, the default session will be used.
697
698    Returns:
699      A numpy array corresponding to the value of this tensor.
700
701    """
702    return _eval_using_default_session(self, feed_dict, self.graph, session)
703
704
705# TODO(agarwal): consider getting rid of this.
706class _EagerTensorBase(Tensor):
707  """Base class for EagerTensor."""
708
709  @property
710  def dtype(self):
711    # Note: using the intern table directly here as this is
712    # performance-sensitive in some models.
713    return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
714
715  def numpy(self):
716    """Returns a numpy array or a scalar with the same contents as the Tensor.
717
718    TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
719    buffer but instead always explicitly copy? Note that currently it may or may
720    not copy based on whether the numpy data is properly aligned or not.
721
722    Returns:
723      A numpy array or a scalar. Numpy array may share memory with the
724      Tensor object. Any changes to one may be reflected in the other. A scalar
725      value is returned when self has rank 0.
726
727    Raises:
728      ValueError: if the type of this Tensor is not representable in numpy.
729    """
730    if self.dtype == dtypes.resource:
731      raise ValueError("Resource handles are not convertible to numpy.")
732    return self._cpu_nograd()._numpy()  # pylint: disable=protected-access
733
734  # __int__, __float__ and __index__ may copy the tensor to CPU and
735  # only work for scalars; values are cast as per numpy.
736  def __int__(self):
737    return int(self.numpy())
738
739  def __float__(self):
740    return float(self.numpy())
741
742  def __index__(self):
743    return int(self.numpy())
744
745  def __array__(self, dtype=None):
746    return np.array(self.numpy(), dtype=dtype)
747
748  def __format__(self, format_spec):
749    return self.numpy().__format__(format_spec)
750
751  def __reduce__(self):
752    return (convert_to_tensor, (self.numpy(),))
753
754  def _numpy(self):
755    raise NotImplementedError()
756
757  @property
758  def backing_device(self):
759    """Returns the name of the device holding this tensor's memory.
760
761    `.backing_device` is usually the same as `.device`, which returns
762    the device on which the kernel of the operation that produced this tensor
763    ran. However, some operations can produce tensors on a different device
764    (e.g., an operation that executes on the GPU but produces output tensors
765    in host memory).
766    """
767    raise NotImplementedError()
768
769  def __copy__(self):
770    # Eager Tensors are immutable so it's safe to return themselves as a copy.
771    return self
772
773  def __deepcopy__(self, memo):
774    # Eager Tensors are immutable so it's safe to return themselves as a copy.
775    del memo
776    return self
777
778  def _datatype_enum(self):
779    raise NotImplementedError()
780
781  def _shape_tuple(self):
782    """The shape of this Tensor, as a tuple.
783
784    This is more performant than tuple(shape().as_list()) as it avoids
785    two list and one object creation. Marked private for now as from an API
786    perspective, it would be better to have a single performant way of
787    getting a shape rather than exposing shape() and shape_tuple()
788    (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
789    but ideally one would work things out and remove the need for this method.
790
791    Returns:
792      tuple with the shape.
793    """
794    raise NotImplementedError()
795
796  def _rank(self):
797    """Integer rank of this Tensor.
798
799    Unlike regular Tensors, the rank is always known for EagerTensors.
800
801    This is more performant than len(self._shape_tuple())
802
803    Returns:
804      Integer rank
805    """
806    raise NotImplementedError()
807
808  def _num_elements(self):
809    """Number of elements of this Tensor.
810
811    Unlike regular Tensors, the number of elements is always known for
812    EagerTensors.
813
814    This is more performant than tensor.shape.num_elements
815
816    Returns:
817      Long - num elements in the tensor
818    """
819    raise NotImplementedError()
820
821  def _copy_to_device(self, context, device):  # pylint: disable=redefined-outer-name
822    raise NotImplementedError()
823
824  def __str__(self):
825    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self),
826                                                  self.shape,
827                                                  self.dtype.name)
828
829  def __repr__(self):
830    return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
831        self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True))
832
833  @staticmethod
834  def _override_operator(name, func):
835    setattr(_EagerTensorBase, name, func)
836
837  def _copy_nograd(self, ctx=None, device_name=None):
838    """Copies tensor to dest device, but doesn't record the operation."""
839    # pylint: disable=protected-access
840    # Creates a new tensor on the dest device.
841    if ctx is None:
842      ctx = context.context()
843    if device_name is None:
844      device_name = ctx.device_name
845    # pylint: disable=protected-access
846    try:
847      new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
848    except core._NotOkStatusException as e:
849      six.raise_from(core._status_to_exception(e.code, e.message), None)
850    return new_tensor
851
852  def _copy(self, ctx=None, device_name=None):
853    """Copies tensor to dest device."""
854    new_tensor = self._copy_nograd(ctx, device_name)
855    # Record the copy on tape and define backprop copy as well.
856    if context.executing_eagerly():
857      self_device = self.device
858      def grad_fun(dresult):
859        return [dresult._copy(device_name=self_device)]
860      tape.record_operation("_copy", [new_tensor], [self], grad_fun)
861    return new_tensor
862    # pylint: enable=protected-access
863
864  @property
865  def shape(self):
866    if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
867      # `_tensor_shape` is declared and defined in the definition of
868      # `EagerTensor`, in C.
869      self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
870    return self._tensor_shape
871
872  def get_shape(self):
873    """Alias of Tensor.shape."""
874    return self.shape
875
876  def _shape_as_list(self):
877    """The shape of the tensor as a list."""
878    return list(self._shape_tuple())
879
880  @property
881  def ndim(self):
882    """Returns the number of Tensor dimensions."""
883    return self.shape.ndims
884
885  def __len__(self):
886    """Returns the length of the first dimension in the Tensor."""
887    if not self.shape.ndims:
888      raise TypeError("Scalar tensor has no `len()`")
889    return self._shape_tuple()[0]
890
891  def _cpu_nograd(self):
892    """A copy of this Tensor with contents backed by host memory.
893
894    The copy cannot be differentiated through.
895
896    Returns:
897      A CPU-memory backed Tensor object with the same contents as this Tensor.
898    """
899    return self._copy_nograd(context.context(), "CPU:0")
900
901  def cpu(self):
902    """A copy of this Tensor with contents backed by host memory."""
903    return self._copy(context.context(), "CPU:0")
904
905  def gpu(self, gpu_index=0):
906    """A copy of this Tensor with contents backed by memory on the GPU.
907
908    Arguments:
909      gpu_index: Identifies which GPU to place the contents on the returned
910        Tensor in.
911
912    Returns:
913      A GPU-memory backed Tensor object initialized with the same contents
914      as this Tensor.
915    """
916    return self._copy(context.context(), "GPU:" + str(gpu_index))
917
918  def __bool__(self):
919    return bool(self.numpy())
920
921  def __nonzero__(self):
922    return self.__bool__()
923
924  def set_shape(self, shape):
925    if not self.shape.is_compatible_with(shape):
926      raise ValueError(
927          "Tensor's shape %s is not compatible with supplied shape %s" %
928          (self.shape, shape))
929
930  # Methods not supported / implemented for Eager Tensors.
931  @property
932  def op(self):
933    raise AttributeError(
934        "Tensor.op is meaningless when eager execution is enabled.")
935
936  @property
937  def graph(self):
938    raise AttributeError(
939        "Tensor.graph is meaningless when eager execution is enabled.")
940
941  @property
942  def name(self):
943    raise AttributeError(
944        "Tensor.name is meaningless when eager execution is enabled.")
945
946  @property
947  def value_index(self):
948    raise AttributeError(
949        "Tensor.value_index is meaningless when eager execution is enabled.")
950
951  def consumers(self):
952    raise NotImplementedError(
953        "Tensor.consumers is meaningless when eager execution is enabled.")
954
955  def _add_consumer(self, consumer):
956    raise NotImplementedError(
957        "_add_consumer not supported when eager execution is enabled.")
958
959  def _as_node_def_input(self):
960    raise NotImplementedError(
961        "_as_node_def_input not supported when eager execution is enabled.")
962
963  def _as_tf_output(self):
964    raise NotImplementedError(
965        "_as_tf_output not supported when eager execution is enabled.")
966
967  def eval(self, feed_dict=None, session=None):
968    raise NotImplementedError(
969        "eval is not supported when eager execution is enabled, "
970        "is .numpy() what you're looking for?"
971    )
972
973
974# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
975# registers it with the current module.
976EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
977
978
979def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
980  _ = name, as_ref
981  if dtype and not dtype.is_compatible_with(t.dtype):
982    raise ValueError(
983        "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
984        (dtype.name, t.dtype.name, str(t)))
985  return t
986
987
988_tensor_conversion_func_registry = {
989    0: [(Tensor, _TensorTensorConversionFunction)]
990}
991_tensor_conversion_func_cache = {}
992_tensor_conversion_func_lock = threading.Lock()
993register_dense_tensor_like_type(Tensor)
994
995
996@tf_export(v1=["convert_to_tensor"])
997def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None,
998                      dtype_hint=None):
999  """Converts the given `value` to a `Tensor`.
1000
1001  This function converts Python objects of various types to `Tensor`
1002  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1003  and Python scalars. For example:
1004
1005  ```python
1006  import numpy as np
1007
1008  def my_func(arg):
1009    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1010    return tf.matmul(arg, arg) + arg
1011
1012  # The following calls are equivalent.
1013  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1014  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1015  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1016  ```
1017
1018  This function can be useful when composing a new operation in Python
1019  (such as `my_func` in the example above). All standard Python op
1020  constructors apply this function to each of their Tensor-valued
1021  inputs, which allows those ops to accept numpy arrays, Python lists,
1022  and scalars in addition to `Tensor` objects.
1023
1024  Note: This function diverges from default Numpy behavior for `float` and
1025    `string` types when `None` is present in a Python list or scalar. Rather
1026    than silently converting `None` values, an error will be thrown.
1027
1028  Args:
1029    value: An object whose type has a registered `Tensor` conversion function.
1030    dtype: Optional element type for the returned tensor. If missing, the
1031      type is inferred from the type of `value`.
1032    name: Optional name to use if a new `Tensor` is created.
1033    preferred_dtype: Optional element type for the returned tensor,
1034      used when dtype is None. In some cases, a caller may not have a
1035      dtype in mind when converting to a tensor, so preferred_dtype
1036      can be used as a soft preference.  If the conversion to
1037      `preferred_dtype` is not possible, this argument has no effect.
1038    dtype_hint: same meaning as preferred_dtype, and overrides it.
1039
1040  Returns:
1041    A `Tensor` based on `value`.
1042
1043  Raises:
1044    TypeError: If no conversion function is registered for `value` to `dtype`.
1045    RuntimeError: If a registered conversion function returns an invalid value.
1046    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1047  """
1048  preferred_dtype = deprecation.deprecated_argument_lookup(
1049      "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
1050  return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
1051
1052
1053@tf_export("convert_to_tensor", v1=[])
1054def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
1055  """Converts the given `value` to a `Tensor`.
1056
1057  This function converts Python objects of various types to `Tensor`
1058  objects. It accepts `Tensor` objects, numpy arrays, Python lists,
1059  and Python scalars. For example:
1060
1061  ```python
1062  import numpy as np
1063
1064  def my_func(arg):
1065    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
1066    return tf.matmul(arg, arg) + arg
1067
1068  # The following calls are equivalent.
1069  value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
1070  value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
1071  value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
1072  ```
1073
1074  This function can be useful when composing a new operation in Python
1075  (such as `my_func` in the example above). All standard Python op
1076  constructors apply this function to each of their Tensor-valued
1077  inputs, which allows those ops to accept numpy arrays, Python lists,
1078  and scalars in addition to `Tensor` objects.
1079
1080  Note: This function diverges from default Numpy behavior for `float` and
1081    `string` types when `None` is present in a Python list or scalar. Rather
1082    than silently converting `None` values, an error will be thrown.
1083
1084  Args:
1085    value: An object whose type has a registered `Tensor` conversion function.
1086    dtype: Optional element type for the returned tensor. If missing, the
1087      type is inferred from the type of `value`.
1088    dtype_hint: Optional element type for the returned tensor,
1089      used when dtype is None. In some cases, a caller may not have a
1090      dtype in mind when converting to a tensor, so dtype_hint
1091      can be used as a soft preference.  If the conversion to
1092      `dtype_hint` is not possible, this argument has no effect.
1093    name: Optional name to use if a new `Tensor` is created.
1094
1095  Returns:
1096    A `Tensor` based on `value`.
1097
1098  Raises:
1099    TypeError: If no conversion function is registered for `value` to `dtype`.
1100    RuntimeError: If a registered conversion function returns an invalid value.
1101    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
1102  """
1103  return internal_convert_to_tensor(
1104      value=value,
1105      dtype=dtype,
1106      name=name,
1107      preferred_dtype=dtype_hint,
1108      as_ref=False)
1109
1110
1111def _error_prefix(name):
1112  return "" if name is None else "%s: " % name
1113
1114
1115def internal_convert_to_tensor(value,
1116                               dtype=None,
1117                               name=None,
1118                               as_ref=False,
1119                               preferred_dtype=None,
1120                               ctx=None,
1121                               accept_symbolic_tensors=True):
1122  """Implementation of the public convert_to_tensor."""
1123  if ctx is None: ctx = context.context()
1124  if isinstance(value, EagerTensor):
1125    if ctx.executing_eagerly():
1126      if dtype is not None:
1127        dtype = dtypes.as_dtype(dtype)
1128        value = _TensorTensorConversionFunction(value, dtype=dtype)
1129      return value
1130    else:
1131      graph = get_default_graph()
1132      if not graph.building_function:
1133        raise RuntimeError("Attempting to capture an EagerTensor without "
1134                           "building a function.")
1135      return graph.capture(value, name=name)
1136  elif ((not accept_symbolic_tensors) and
1137        isinstance(value, Tensor) and
1138        ctx.executing_eagerly()):
1139    # Found a symbolic tensor in an eager context.
1140    # This happens when we use the Keras functional API (i.e. calling layers
1141    # on the output of `keras.Input()`, which is symbolic) while eager
1142    # execution is enabled.
1143    if _is_keras_symbolic_tensor(value):
1144      # If the graph of the tensor isn't the Keras graph, we should still
1145      # fail, for the time being. TODO(fchollet): consider allowing
1146      # all symbolic tensors to raise this exception in this case.
1147      raise core._SymbolicException(  # pylint: disable=protected-access
1148          "Using the symbolic output of a Keras layer during eager execution.")
1149
1150  if dtype is not None:
1151    dtype = dtypes.as_dtype(dtype)
1152  unwrapped_type = type(value)
1153  conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
1154  if conversion_func_list is None:
1155    with _tensor_conversion_func_lock:
1156      conversion_func_list = []
1157      for _, funcs_at_priority in sorted(
1158          _tensor_conversion_func_registry.items()):
1159        for base_type, conversion_func in funcs_at_priority:
1160          if isinstance(value, base_type):
1161            conversion_func_list.append((base_type, conversion_func))
1162      _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
1163
1164  for base_type, conversion_func in conversion_func_list:
1165    # If dtype is None but preferred_dtype is not None, we try to
1166    # cast to preferred_dtype first.
1167    ret = None
1168    if dtype is None and preferred_dtype is not None:
1169      try:
1170        ret = conversion_func(
1171            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
1172      except (TypeError, ValueError, errors.UnimplementedError,
1173              errors.InvalidArgumentError):
1174        # Could not coerce the conversion to use the preferred dtype.
1175        ret = None
1176
1177      if ret is not None and ret is not NotImplemented:
1178        if (ret.dtype.base_dtype !=
1179            dtypes.as_dtype(preferred_dtype).base_dtype):
1180          raise TypeError("convert_to_tensor did not convert to "
1181                          "the preferred dtype: %s vs %s " %
1182                          (ret.dtype.base_dtype,
1183                           dtypes.as_dtype(preferred_dtype).base_dtype))
1184
1185    if ret is None:
1186      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1187
1188    if ret is NotImplemented:
1189      continue
1190
1191    if not isinstance(ret, Tensor):
1192      raise RuntimeError(
1193          "%sConversion function %r for type %s returned non-Tensor: %r" %
1194          (_error_prefix(name), conversion_func, base_type, ret))
1195    if dtype and not dtype.is_compatible_with(ret.dtype):
1196      raise RuntimeError(
1197          "%sConversion function %r for type %s returned incompatible "
1198          "dtype: requested = %s, actual = %s" %
1199          (_error_prefix(name), conversion_func, base_type, dtype.name,
1200           ret.dtype.name))
1201    return ret
1202  raise TypeError("%sCannot convert %r with type %s to Tensor: "
1203                  "no conversion function registered." %
1204                  (_error_prefix(name), value, unwrapped_type))
1205
1206
1207def internal_convert_n_to_tensor(values,
1208                                 dtype=None,
1209                                 name=None,
1210                                 as_ref=False,
1211                                 preferred_dtype=None,
1212                                 ctx=None):
1213  """Converts `values` to a list of `Tensor` objects.
1214
1215  Args:
1216    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1217    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1218    name: (Optional.) A name prefix to used when a new `Tensor` is
1219      created, in which case element `i` will be given the name `name
1220      + '_' + i`.
1221    as_ref: True if the caller wants the results as ref tensors.
1222    preferred_dtype: Optional element type for the returned tensors,
1223      used when dtype is None. In some cases, a caller may not have a
1224      dtype in mind when converting to a tensor, so preferred_dtype
1225      can be used as a soft preference.  If the conversion to
1226      `preferred_dtype` is not possible, this argument has no effect.
1227    ctx: The value of context.context().
1228
1229  Returns:
1230    A list of `Tensor` and/or `IndexedSlices` objects.
1231
1232  Raises:
1233    TypeError: If no conversion function is registered for an element in
1234      `values`.
1235    RuntimeError: If a registered conversion function returns an invalid
1236      value.
1237  """
1238  if not isinstance(values, collections.Sequence):
1239    raise TypeError("values must be a sequence.")
1240  ret = []
1241  if ctx is None: ctx = context.context()
1242  for i, value in enumerate(values):
1243    n = None if name is None else "%s_%d" % (name, i)
1244    ret.append(
1245        internal_convert_to_tensor(
1246            value,
1247            dtype=dtype,
1248            name=n,
1249            as_ref=as_ref,
1250            preferred_dtype=preferred_dtype,
1251            ctx=ctx))
1252  return ret
1253
1254
1255def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
1256  """Converts `values` to a list of `Tensor` objects.
1257
1258  Args:
1259    values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
1260    dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
1261    name: (Optional.) A name prefix to used when a new `Tensor` is
1262      created, in which case element `i` will be given the name `name
1263      + '_' + i`.
1264    preferred_dtype: Optional element type for the returned tensors,
1265      used when dtype is None. In some cases, a caller may not have a
1266      dtype in mind when converting to a tensor, so preferred_dtype
1267      can be used as a soft preference.  If the conversion to
1268      `preferred_dtype` is not possible, this argument has no effect.
1269
1270  Returns:
1271    A list of `Tensor` and/or `IndexedSlices` objects.
1272
1273  Raises:
1274    TypeError: If no conversion function is registered for an element in
1275      `values`.
1276    RuntimeError: If a registered conversion function returns an invalid
1277      value.
1278  """
1279  return internal_convert_n_to_tensor(
1280      values=values,
1281      dtype=dtype,
1282      name=name,
1283      preferred_dtype=preferred_dtype,
1284      as_ref=False)
1285
1286
1287@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
1288def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
1289  """Converts the given object to a `Tensor` or an `IndexedSlices`.
1290
1291  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
1292  unmodified. Otherwise, it is converted to a `Tensor` using
1293  `convert_to_tensor()`.
1294
1295  Args:
1296    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
1297      by `convert_to_tensor()`.
1298    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1299      `IndexedSlices`.
1300    name: (Optional.) A name to use if a new `Tensor` is created.
1301
1302  Returns:
1303    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
1304
1305  Raises:
1306    ValueError: If `dtype` does not match the element type of `value`.
1307  """
1308  return internal_convert_to_tensor_or_indexed_slices(
1309      value=value, dtype=dtype, name=name, as_ref=False)
1310
1311
1312def internal_convert_to_tensor_or_indexed_slices(value,
1313                                                 dtype=None,
1314                                                 name=None,
1315                                                 as_ref=False):
1316  """Converts the given object to a `Tensor` or an `IndexedSlices`.
1317
1318  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
1319  unmodified. Otherwise, it is converted to a `Tensor` using
1320  `convert_to_tensor()`.
1321
1322  Args:
1323    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
1324      by `convert_to_tensor()`.
1325    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1326      `IndexedSlices`.
1327    name: (Optional.) A name to use if a new `Tensor` is created.
1328    as_ref: True if the caller wants the results as ref tensors.
1329
1330  Returns:
1331    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
1332
1333  Raises:
1334    ValueError: If `dtype` does not match the element type of `value`.
1335  """
1336  if isinstance(value, EagerTensor) and not context.executing_eagerly():
1337    return internal_convert_to_tensor(
1338        value, dtype=dtype, name=name, as_ref=as_ref)
1339  elif isinstance(value, _TensorLike):
1340    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
1341      raise ValueError(
1342          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1343          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1344    return value
1345  else:
1346    return internal_convert_to_tensor(
1347        value, dtype=dtype, name=name, as_ref=as_ref)
1348
1349
1350def internal_convert_n_to_tensor_or_indexed_slices(values,
1351                                                   dtype=None,
1352                                                   name=None,
1353                                                   as_ref=False):
1354  """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
1355
1356  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
1357  unmodified.
1358
1359  Args:
1360    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
1361      can be consumed by `convert_to_tensor()`.
1362    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1363      `IndexedSlices`.
1364    name: (Optional.) A name prefix to used when a new `Tensor` is
1365      created, in which case element `i` will be given the name `name
1366      + '_' + i`.
1367    as_ref: True if the caller wants the results as ref tensors.
1368
1369  Returns:
1370    A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects.
1371
1372  Raises:
1373    TypeError: If no conversion function is registered for an element in
1374      `values`.
1375    RuntimeError: If a registered conversion function returns an invalid
1376      value.
1377  """
1378  if not isinstance(values, collections.Sequence):
1379    raise TypeError("values must be a sequence.")
1380  ret = []
1381  for i, value in enumerate(values):
1382    if value is None:
1383      ret.append(value)
1384    else:
1385      n = None if name is None else "%s_%d" % (name, i)
1386      ret.append(
1387          internal_convert_to_tensor_or_indexed_slices(
1388              value, dtype=dtype, name=n, as_ref=as_ref))
1389  return ret
1390
1391
1392def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
1393  """Converts `values` to a list of `Output` or `IndexedSlices` objects.
1394
1395  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
1396  unmodified.
1397
1398  Args:
1399    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
1400      can be consumed by `convert_to_tensor()`.
1401    dtype: (Optional.) The required `DType` of the returned `Tensor`
1402      `IndexedSlices`.
1403    name: (Optional.) A name prefix to used when a new `Tensor` is
1404      created, in which case element `i` will be given the name `name
1405      + '_' + i`.
1406
1407  Returns:
1408    A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
1409
1410  Raises:
1411    TypeError: If no conversion function is registered for an element in
1412      `values`.
1413    RuntimeError: If a registered conversion function returns an invalid
1414      value.
1415  """
1416  return internal_convert_n_to_tensor_or_indexed_slices(
1417      values=values, dtype=dtype, name=name, as_ref=False)
1418
1419
1420def convert_to_tensor_or_composite(value, dtype=None, name=None):
1421  """Converts the given object to a `Tensor` or `CompositeTensor`.
1422
1423  If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
1424  is converted to a `Tensor` using `convert_to_tensor()`.
1425
1426  Args:
1427    value: A `CompositeTensor` or an object that can be consumed
1428      by `convert_to_tensor()`.
1429    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1430      `CompositeTensor`.
1431    name: (Optional.) A name to use if a new `Tensor` is created.
1432
1433  Returns:
1434    A `Tensor` or `CompositeTensor`, based on `value`.
1435
1436  Raises:
1437    ValueError: If `dtype` does not match the element type of `value`.
1438  """
1439  return internal_convert_to_tensor_or_composite(
1440      value=value, dtype=dtype, name=name, as_ref=False)
1441
1442
1443def internal_convert_to_tensor_or_composite(value,
1444                                            dtype=None,
1445                                            name=None,
1446                                            as_ref=False):
1447  """Converts the given object to a `Tensor` or `CompositeTensor`.
1448
1449  If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
1450  is converted to a `Tensor` using `convert_to_tensor()`.
1451
1452  Args:
1453    value: A `CompositeTensor`, or an object that can be consumed
1454      by `convert_to_tensor()`.
1455    dtype: (Optional.) The required `DType` of the returned `Tensor` or
1456      `CompositeTensor`.
1457    name: (Optional.) A name to use if a new `Tensor` is created.
1458    as_ref: True if the caller wants the results as ref tensors.
1459
1460  Returns:
1461    A `Tensor` or `CompositeTensor`, based on `value`.
1462
1463  Raises:
1464    ValueError: If `dtype` does not match the element type of `value`.
1465  """
1466  if isinstance(value, composite_tensor.CompositeTensor):
1467    value_dtype = getattr(value, "dtype", None)
1468    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
1469      raise ValueError(
1470          "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
1471          (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
1472    return value
1473  else:
1474    return internal_convert_to_tensor(
1475        value, dtype=dtype, name=name, as_ref=as_ref)
1476
1477
1478def internal_convert_n_to_tensor_or_composite(values,
1479                                              dtype=None,
1480                                              name=None,
1481                                              as_ref=False):
1482  """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
1483
1484  Any `CompositeTensor` objects in `values` are returned unmodified.
1485
1486  Args:
1487    values: A list of `None`, `CompositeTensor`, or objects that
1488      can be consumed by `convert_to_tensor()`.
1489    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1490      `CompositeTensor`s.
1491    name: (Optional.) A name prefix to used when a new `Tensor` is
1492      created, in which case element `i` will be given the name `name
1493      + '_' + i`.
1494    as_ref: True if the caller wants the results as ref tensors.
1495
1496  Returns:
1497    A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
1498
1499  Raises:
1500    TypeError: If no conversion function is registered for an element in
1501      `values`.
1502    RuntimeError: If a registered conversion function returns an invalid
1503      value.
1504  """
1505  if not isinstance(values, collections.Sequence):
1506    raise TypeError("values must be a sequence.")
1507  ret = []
1508  for i, value in enumerate(values):
1509    if value is None:
1510      ret.append(value)
1511    else:
1512      n = None if name is None else "%s_%d" % (name, i)
1513      ret.append(
1514          internal_convert_to_tensor_or_composite(
1515              value, dtype=dtype, name=n, as_ref=as_ref))
1516  return ret
1517
1518
1519def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
1520  """Converts `values` to a list of `Output` or `CompositeTensor` objects.
1521
1522  Any `CompositeTensor` objects in `values` are returned unmodified.
1523
1524  Args:
1525    values: A list of `None`, `CompositeTensor``, or objects that
1526      can be consumed by `convert_to_tensor()`.
1527    dtype: (Optional.) The required `DType` of the returned `Tensor`s or
1528      `CompositeTensor`s.
1529    name: (Optional.) A name prefix to used when a new `Tensor` is
1530      created, in which case element `i` will be given the name `name
1531      + '_' + i`.
1532
1533  Returns:
1534    A list of `Tensor` and/or `CompositeTensor` objects.
1535
1536  Raises:
1537    TypeError: If no conversion function is registered for an element in
1538      `values`.
1539    RuntimeError: If a registered conversion function returns an invalid
1540      value.
1541  """
1542  return internal_convert_n_to_tensor_or_composite(
1543      values=values, dtype=dtype, name=name, as_ref=False)
1544
1545
1546# TODO(josh11b): Add ctx argument to conversion_func() signature.
1547@tf_export("register_tensor_conversion_function")
1548def register_tensor_conversion_function(base_type,
1549                                        conversion_func,
1550                                        priority=100):
1551  """Registers a function for converting objects of `base_type` to `Tensor`.
1552
1553  The conversion function must have the following signature:
1554
1555  ```python
1556      def conversion_func(value, dtype=None, name=None, as_ref=False):
1557        # ...
1558  ```
1559
1560  It must return a `Tensor` with the given `dtype` if specified. If the
1561  conversion function creates a new `Tensor`, it should use the given
1562  `name` if specified. All exceptions will be propagated to the caller.
1563
1564  The conversion function may return `NotImplemented` for some
1565  inputs. In this case, the conversion process will continue to try
1566  subsequent conversion functions.
1567
1568  If `as_ref` is true, the function must return a `Tensor` reference,
1569  such as a `Variable`.
1570
1571  NOTE: The conversion functions will execute in order of priority,
1572  followed by order of registration. To ensure that a conversion function
1573  `F` runs before another conversion function `G`, ensure that `F` is
1574  registered with a smaller priority than `G`.
1575
1576  Args:
1577    base_type: The base type or tuple of base types for all objects that
1578      `conversion_func` accepts.
1579    conversion_func: A function that converts instances of `base_type` to
1580      `Tensor`.
1581    priority: Optional integer that indicates the priority for applying this
1582      conversion function. Conversion functions with smaller priority values
1583      run earlier than conversion functions with larger priority values.
1584      Defaults to 100.
1585
1586  Raises:
1587    TypeError: If the arguments do not have the appropriate type.
1588
1589  """
1590  global _tensor_conversion_func_cache
1591  with _tensor_conversion_func_lock:
1592    if not (isinstance(base_type, type) or
1593            (isinstance(base_type, tuple) and
1594             all(isinstance(x, type) for x in base_type))):
1595      raise TypeError("base_type must be a type or a tuple of types.")
1596    if not callable(conversion_func):
1597      raise TypeError("conversion_func must be callable.")
1598
1599    # context._context is checked so that we don't inadvertently create it.
1600    # This is because enable_eager_execution will fail when called from the main
1601    # function if the context._context is already created, and the
1602    # register_tensor_conversion_function calls happen when the module is
1603    # imported.
1604    if context._context is not None and context.executing_eagerly(
1605    ) and isinstance(base_type, six.integer_types + (
1606        float,
1607        np.ndarray,
1608    )):
1609      # TODO(nareshmodi): consider setting a context variable which disables the
1610      # fastpath instead.
1611      raise TypeError(
1612          "Cannot register conversions for numpy arrays, python number types "
1613          "when executing eagerly.")
1614
1615    try:
1616      funcs_at_priority = _tensor_conversion_func_registry[priority]
1617    except KeyError:
1618      funcs_at_priority = []
1619      _tensor_conversion_func_registry[priority] = funcs_at_priority
1620    funcs_at_priority.append((base_type, conversion_func))
1621    _tensor_conversion_func_cache = {}
1622
1623
1624@tf_export("IndexedSlices")
1625class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
1626  """A sparse representation of a set of tensor slices at given indices.
1627
1628  This class is a simple wrapper for a pair of `Tensor` objects:
1629
1630  * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
1631  * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
1632
1633  An `IndexedSlices` is typically used to represent a subset of a larger
1634  tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
1635  The values in `indices` are the indices in the first dimension of
1636  the slices that have been extracted from the larger tensor.
1637
1638  The dense tensor `dense` represented by an `IndexedSlices` `slices` has
1639
1640  ```python
1641  dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
1642  ```
1643
1644  The `IndexedSlices` class is used principally in the definition of
1645  gradients for operations that have sparse gradients
1646  (e.g. `tf.gather`).
1647
1648  Contrast this representation with
1649  `tf.SparseTensor`,
1650  which uses multi-dimensional indices and scalar values.
1651  """
1652
1653  def __init__(self, values, indices, dense_shape=None):
1654    """Creates an `IndexedSlices`."""
1655    _get_graph_from_inputs([values, indices, dense_shape])
1656    self._values = values
1657    self._indices = indices
1658    self._dense_shape = dense_shape
1659
1660  @property
1661  def values(self):
1662    """A `Tensor` containing the values of the slices."""
1663    return self._values
1664
1665  @property
1666  def indices(self):
1667    """A 1-D `Tensor` containing the indices of the slices."""
1668    return self._indices
1669
1670  @property
1671  def dense_shape(self):
1672    """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
1673    return self._dense_shape
1674
1675  @property
1676  def name(self):
1677    """The name of this `IndexedSlices`."""
1678    return self.values.name
1679
1680  @property
1681  def device(self):
1682    """The name of the device on which `values` will be produced, or `None`."""
1683    return self.values.device
1684
1685  @property
1686  def op(self):
1687    """The `Operation` that produces `values` as an output."""
1688    return self.values.op
1689
1690  @property
1691  def dtype(self):
1692    """The `DType` of elements in this tensor."""
1693    return self.values.dtype
1694
1695  @property
1696  def graph(self):
1697    """The `Graph` that contains the values, indices, and shape tensors."""
1698    return self._values.graph
1699
1700  def __str__(self):
1701    return "IndexedSlices(indices=%s, values=%s%s)" % (
1702        self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
1703        if self._dense_shape is not None else "")
1704
1705  def __neg__(self):
1706    return IndexedSlices(-self.values, self.indices, self.dense_shape)
1707
1708  def _to_components(self):
1709    if self._dense_shape is None:
1710      return (self._values, self._indices)
1711    else:
1712      return (self._values, self._indices, self._dense_shape)
1713
1714  @classmethod
1715  def _from_components(cls, components):
1716    return cls(*components)
1717
1718  def _shape_invariant_to_components(self, shape=None):
1719    if shape is None:
1720      shape = self._values.shape
1721    if self._dense_shape is None:
1722      return [shape, shape[:1]]  # values, indices
1723    else:
1724      # values, indices, dense_shape
1725      return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])]
1726
1727  @property
1728  def _is_graph_tensor(self):
1729    return hasattr(self._values, 'graph')
1730
1731
1732IndexedSlicesValue = collections.namedtuple(
1733    "IndexedSlicesValue", ["values", "indices", "dense_shape"])
1734
1735
1736def _device_string(dev_spec):
1737  if isinstance(dev_spec, pydev.DeviceSpec):
1738    return dev_spec.to_string()
1739  else:
1740    return dev_spec
1741
1742
1743def _NodeDef(op_type, name, device=None, attrs=None):  # pylint: disable=redefined-outer-name
1744  """Create a NodeDef proto.
1745
1746  Args:
1747    op_type: Value for the "op" attribute of the NodeDef proto.
1748    name: Value for the "name" attribute of the NodeDef proto.
1749    device: string, device, or function from NodeDef to string.
1750      Value for the "device" attribute of the NodeDef proto.
1751    attrs: Optional dictionary where the key is the attribute name (a string)
1752      and the value is the respective "attr" attribute of the NodeDef proto (an
1753      AttrValue).
1754
1755  Returns:
1756    A node_def_pb2.NodeDef protocol buffer.
1757  """
1758  node_def = node_def_pb2.NodeDef()
1759  node_def.op = compat.as_bytes(op_type)
1760  node_def.name = compat.as_bytes(name)
1761  if attrs is not None:
1762    for k, v in six.iteritems(attrs):
1763      node_def.attr[k].CopyFrom(v)
1764  if device is not None:
1765    if callable(device):
1766      node_def.device = device(node_def)
1767    else:
1768      node_def.device = _device_string(device)
1769  return node_def
1770
1771
1772# Copied from core/framework/node_def_util.cc
1773# TODO(mrry,josh11b): Consolidate this validation in C++ code.
1774_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
1775_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
1776
1777
1778def _create_c_op(graph, node_def, inputs, control_inputs):
1779  """Creates a TF_Operation.
1780
1781  Args:
1782    graph: a `Graph`.
1783    node_def: `node_def_pb2.NodeDef` for the operation to create.
1784    inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of
1785      `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
1786      "list(int64)"). The length of the list should be equal to the number of
1787      inputs specified by this operation's op def.
1788    control_inputs: A list of `Operation`s to set as control dependencies.
1789
1790  Returns:
1791    A wrapped TF_Operation*.
1792  """
1793  # pylint: disable=protected-access
1794  op_desc = c_api.TF_NewOperation(graph._c_graph,
1795                                  compat.as_str(node_def.op),
1796                                  compat.as_str(node_def.name))
1797  if node_def.device:
1798    c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
1799  # Add inputs
1800  for op_input in inputs:
1801    if isinstance(op_input, (list, tuple)):
1802      c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
1803    else:
1804      c_api.TF_AddInput(op_desc, op_input._as_tf_output())
1805
1806  # Add control inputs
1807  for control_input in control_inputs:
1808    c_api.TF_AddControlInput(op_desc, control_input._c_op)
1809  # pylint: enable=protected-access
1810
1811  # Add attrs
1812  for name, attr_value in node_def.attr.items():
1813    serialized = attr_value.SerializeToString()
1814    # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
1815    # It might be worth creating a convenient way to re-use the same status.
1816    c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
1817
1818  try:
1819    c_op = c_api.TF_FinishOperation(op_desc)
1820  except errors.InvalidArgumentError as e:
1821    # Convert to ValueError for backwards compatibility.
1822    raise ValueError(str(e))
1823
1824  return c_op
1825
1826
1827@tf_export("Operation")
1828class Operation(object):
1829  """Represents a graph node that performs computation on tensors.
1830
1831  An `Operation` is a node in a TensorFlow `Graph` that takes zero or
1832  more `Tensor` objects as input, and produces zero or more `Tensor`
1833  objects as output. Objects of type `Operation` are created by
1834  calling a Python op constructor (such as
1835  `tf.matmul`)
1836  or `tf.Graph.create_op`.
1837
1838  For example `c = tf.matmul(a, b)` creates an `Operation` of type
1839  "MatMul" that takes tensors `a` and `b` as input, and produces `c`
1840  as output.
1841
1842  After the graph has been launched in a session, an `Operation` can
1843  be executed by passing it to
1844  `tf.Session.run`.
1845  `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
1846  """
1847
1848  def __init__(self,
1849               node_def,
1850               g,
1851               inputs=None,
1852               output_types=None,
1853               control_inputs=None,
1854               input_types=None,
1855               original_op=None,
1856               op_def=None):
1857    r"""Creates an `Operation`.
1858
1859    NOTE: This constructor validates the name of the `Operation` (passed
1860    as `node_def.name`). Valid `Operation` names match the following
1861    regular expression:
1862
1863        [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
1864
1865    Args:
1866      node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.
1867        Used for attributes of `node_def_pb2.NodeDef`, typically `name`,
1868        `op`, and `device`.  The `input` attribute is irrelevant here
1869        as it will be computed when generating the model.
1870      g: `Graph`. The parent graph.
1871      inputs: list of `Tensor` objects. The inputs to this `Operation`.
1872      output_types: list of `DType` objects.  List of the types of the
1873        `Tensors` computed by this operation.  The length of this list indicates
1874        the number of output endpoints of the `Operation`.
1875      control_inputs: list of operations or tensors from which to have a
1876        control dependency.
1877      input_types: List of `DType` objects representing the
1878        types of the tensors accepted by the `Operation`.  By default
1879        uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
1880        reference-typed inputs must specify these explicitly.
1881      original_op: Optional. Used to associate the new `Operation` with an
1882        existing `Operation` (for example, a replica with the op that was
1883        replicated).
1884      op_def: Optional. The `op_def_pb2.OpDef` proto that describes the
1885        op type that this `Operation` represents.
1886
1887    Raises:
1888      TypeError: if control inputs are not Operations or Tensors,
1889        or if `node_def` is not a `NodeDef`,
1890        or if `g` is not a `Graph`,
1891        or if `inputs` are not tensors,
1892        or if `inputs` and `input_types` are incompatible.
1893      ValueError: if the `node_def` name is not valid.
1894    """
1895    # For internal use only: `node_def` can be set to a TF_Operation to create
1896    # an Operation for that op. This is useful for creating Operations for ops
1897    # indirectly created by C API methods, e.g. the ops created by
1898    # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
1899    # should be None.
1900
1901    if isinstance(node_def, node_def_pb2.NodeDef):
1902      if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
1903        raise ValueError(
1904            "Cannot create a tensor proto whose content is larger than 2GB.")
1905      if not _VALID_OP_NAME_REGEX.match(node_def.name):
1906        raise ValueError("'%s' is not a valid node name" % node_def.name)
1907      c_op = None
1908    elif type(node_def).__name__ == "SwigPyObject":
1909      assert inputs is None
1910      assert output_types is None
1911      assert control_inputs is None
1912      assert input_types is None
1913      assert original_op is None
1914      assert op_def is None
1915      c_op = node_def
1916    else:
1917      raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
1918
1919    if not isinstance(g, Graph):
1920      raise TypeError("g needs to be a Graph: %s" % g)
1921    self._graph = g
1922
1923    if inputs is None:
1924      inputs = []
1925    elif not isinstance(inputs, list):
1926      raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
1927    for a in inputs:
1928      if not isinstance(a, Tensor):
1929        raise TypeError("input needs to be a Tensor: %s" % a)
1930    if input_types is None:
1931      input_types = [i.dtype.base_dtype for i in inputs]
1932    else:
1933      if not all(
1934          x.is_compatible_with(i.dtype)
1935          for i, x in zip(inputs, input_types)):
1936        raise TypeError("In op '%s', input types (%s) are not compatible "
1937                        "with expected types (%s)" %
1938                        (node_def.name, [i.dtype for i in inputs],
1939                         input_types))
1940
1941    # Build the list of control inputs.
1942    control_input_ops = []
1943    if control_inputs:
1944      for c in control_inputs:
1945        control_op = None
1946        if isinstance(c, Operation):
1947          control_op = c
1948        elif isinstance(c, (Tensor, IndexedSlices)):
1949          control_op = c.op
1950        else:
1951          raise TypeError("Control input must be an Operation, "
1952                          "a Tensor, or IndexedSlices: %s" % c)
1953        control_input_ops.append(control_op)
1954
1955    # This will be set by self.inputs.
1956    self._inputs_val = None
1957
1958    # pylint: disable=protected-access
1959    self._id_value = self._graph._next_id()
1960    self._original_op = original_op
1961    self._traceback = tf_stack.extract_stack()
1962
1963    # List of _UserDevSpecs holding code location of device context manager
1964    # invocations and the users original argument to them.
1965    self._device_code_locations = None
1966    # Dict mapping op name to file and line information for op colocation
1967    # context managers.
1968    self._colocation_code_locations = None
1969    self._control_flow_context = self.graph._get_control_flow_context()
1970    # pylint: enable=protected-access
1971
1972    # Initialize self._c_op.
1973    if c_op:
1974      self._c_op = c_op
1975    else:
1976      if op_def is None:
1977        op_def = self._graph._get_op_def(node_def.op)
1978      # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
1979      # Refactor so we don't have to do this here.
1980      grouped_inputs = self._reconstruct_sequence_inputs(
1981          op_def, inputs, node_def.attr)
1982      self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
1983                                control_input_ops)
1984
1985    # Initialize self._outputs.
1986    num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
1987    output_types = [
1988        c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
1989        for i in range(num_outputs)]
1990    self._outputs = [
1991        Tensor(self, i, output_type)
1992        for i, output_type in enumerate(output_types)
1993    ]
1994
1995    self._graph._add_op(self)  # pylint: disable=protected-access
1996
1997    if not c_op:
1998      self._control_flow_post_processing()
1999
2000  def _control_flow_post_processing(self):
2001    """Add this op to its control flow context.
2002
2003    This may add new ops and change this op's inputs. self.inputs must be
2004    available before calling this method.
2005    """
2006    for input_tensor in self.inputs:
2007      control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
2008    if self._control_flow_context is not None:
2009      self._control_flow_context.AddOp(self)
2010
2011  def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
2012    """Regroups a flat list of input tensors into scalar and sequence inputs.
2013
2014    Args:
2015      op_def: The `op_def_pb2.OpDef` (for knowing the input types)
2016      inputs: a list of input `Tensor`s to the op.
2017      attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
2018        how long each sequence is)
2019
2020    Returns:
2021      A list of `Tensor`s (corresponding to scalar inputs) and lists of
2022      `Tensor`s (corresponding to sequence inputs).
2023    """
2024    grouped_inputs = []
2025    i = 0
2026    for input_arg in op_def.input_arg:
2027      if input_arg.number_attr:
2028        input_len = attrs[input_arg.number_attr].i
2029        is_sequence = True
2030      elif input_arg.type_list_attr:
2031        input_len = len(attrs[input_arg.type_list_attr].list.type)
2032        is_sequence = True
2033      else:
2034        input_len = 1
2035        is_sequence = False
2036
2037      if is_sequence:
2038        grouped_inputs.append(inputs[i:i + input_len])
2039      else:
2040        grouped_inputs.append(inputs[i])
2041      i += input_len
2042
2043    assert i == len(inputs)
2044    return grouped_inputs
2045
2046  def colocation_groups(self):
2047    """Returns the list of colocation groups of the op."""
2048    default_colocation_group = [
2049        compat.as_bytes("loc:@%s" % self.name)
2050    ]
2051    try:
2052      class_attr = self.get_attr("_class")
2053    except ValueError:
2054      # This op has no explicit colocation group, so it is itself its
2055      # own root of a colocation group.
2056      return default_colocation_group
2057
2058    attr_groups = [
2059        class_name for class_name in class_attr
2060        if class_name.startswith(b"loc:@")
2061    ]
2062
2063    # If there are no colocation groups in the explicit _class field,
2064    # return the default colocation group.
2065    return attr_groups if attr_groups else default_colocation_group
2066
2067  def values(self):
2068    """DEPRECATED: Use outputs."""
2069    return tuple(self.outputs)
2070
2071  def _get_control_flow_context(self):
2072    """Returns the control flow context of this op.
2073
2074    Returns:
2075      A context object.
2076    """
2077    return self._control_flow_context
2078
2079  def _set_control_flow_context(self, ctx):
2080    """Sets the current control flow context of this op.
2081
2082    Args:
2083      ctx: a context object.
2084    """
2085    self._control_flow_context = ctx
2086
2087  @property
2088  def name(self):
2089    """The full name of this operation."""
2090    return c_api.TF_OperationName(self._c_op)
2091
2092  @property
2093  def _id(self):
2094    """The unique integer id of this operation."""
2095    return self._id_value
2096
2097  @property
2098  def device(self):
2099    """The name of the device to which this op has been assigned, if any.
2100
2101    Returns:
2102      The string name of the device to which this op has been
2103      assigned, or an empty string if it has not been assigned to a
2104      device.
2105    """
2106    return c_api.TF_OperationDevice(self._c_op)
2107
2108  @property
2109  def _device_assignments(self):
2110    """Code locations for device context managers active at op creation.
2111
2112    This property will return a list of traceable_stack.TraceableObject
2113    instances where .obj is a string representing the assigned device
2114    (or information about the function that would be applied to this op
2115    to compute the desired device) and the filename and lineno members
2116    record the location of the relevant device context manager.
2117
2118    For example, suppose file_a contained these lines:
2119
2120      file_a.py:
2121        15: with tf.device('/gpu:0'):
2122        16:   node_b = tf.constant(4, name='NODE_B')
2123
2124    Then a TraceableObject t_obj representing the device context manager
2125    would have these member values:
2126
2127      t_obj.obj -> '/gpu:0'
2128      t_obj.filename = 'file_a.py'
2129      t_obj.lineno = 15
2130
2131    and node_b.op._device_assignments would return the list [t_obj].
2132
2133    Returns:
2134      [str: traceable_stack.TraceableObject, ...] as per this method's
2135      description, above.
2136    """
2137    return self._device_code_locations or []
2138
2139  @property
2140  def _colocation_dict(self):
2141    """Code locations for colocation context managers active at op creation.
2142
2143    This property will return a dictionary for which the keys are nodes with
2144    which this Operation is colocated, and for which the values are
2145    traceable_stack.TraceableObject instances.  The TraceableObject instances
2146    record the location of the relevant colocation context manager but have the
2147    "obj" field set to None to prevent leaking private data.
2148
2149    For example, suppose file_a contained these lines:
2150
2151      file_a.py:
2152        14: node_a = tf.constant(3, name='NODE_A')
2153        15: with tf.colocate_with(node_a):
2154        16:   node_b = tf.constant(4, name='NODE_B')
2155
2156    Then a TraceableObject t_obj representing the colocation context manager
2157    would have these member values:
2158
2159      t_obj.obj -> None
2160      t_obj.filename = 'file_a.py'
2161      t_obj.lineno = 15
2162
2163    and node_b.op._colocation_dict would return the dictionary
2164
2165      { 'NODE_A': t_obj }
2166
2167    Returns:
2168      {str: traceable_stack.TraceableObject} as per this method's description,
2169      above.
2170    """
2171    locations_dict = self._colocation_code_locations or {}
2172    return locations_dict.copy()
2173
2174  @property
2175  def _output_types(self):
2176    """List this operation's output types.
2177
2178    Returns:
2179      List of the types of the Tensors computed by this operation.
2180      Each element in the list is an integer whose value is one of
2181      the TF_DataType enums defined in c_api.h
2182      The length of this list indicates the number of output endpoints
2183      of the operation.
2184    """
2185    num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
2186    output_types = [
2187        c_api.TF_OperationOutputType(self._tf_output(i))
2188        for i in xrange(num_outputs)
2189    ]
2190    # In all the tests we have output_types that are passed into
2191    # Operation.__init__ are a list of ints (which is illegal according
2192    # to the docstring), but input_types are instances of DType.
2193    # This extra assert is to catch if we ever use DType for output_types.
2194    if output_types:
2195      assert isinstance(output_types[0], int)
2196    return output_types
2197
2198  def _tf_output(self, output_idx):
2199    """Create and return a new TF_Output for output_idx'th output of this op."""
2200    tf_output = c_api.TF_Output()
2201    tf_output.oper = self._c_op
2202    tf_output.index = output_idx
2203    return tf_output
2204
2205  def _tf_input(self, input_idx):
2206    """Create and return a new TF_Input for input_idx'th input of this op."""
2207    tf_input = c_api.TF_Input()
2208    tf_input.oper = self._c_op
2209    tf_input.index = input_idx
2210    return tf_input
2211
2212  def _set_device(self, device):  # pylint: disable=redefined-outer-name
2213    """Set the device of this operation.
2214
2215    Args:
2216      device: string or device..  The device to set.
2217    """
2218    c_api.SetRequestedDevice(
2219        self._graph._c_graph,  # pylint: disable=protected-access
2220        self._c_op,  # pylint: disable=protected-access
2221        compat.as_str(_device_string(device)))
2222
2223  def _update_input(self, index, tensor):
2224    """Update the input to this operation at the given index.
2225
2226    NOTE: This is for TF internal use only. Please don't use it.
2227
2228    Args:
2229      index: the index of the input to update.
2230      tensor: the Tensor to be used as the input at the given index.
2231
2232    Raises:
2233      TypeError: if tensor is not a Tensor,
2234        or if input tensor type is not convertible to dtype.
2235      ValueError: if the Tensor is from a different graph.
2236    """
2237    if not isinstance(tensor, Tensor):
2238      raise TypeError("tensor must be a Tensor: %s" % tensor)
2239    _assert_same_graph(self, tensor)
2240
2241    # Reset cached inputs.
2242    self._inputs_val = None
2243    c_api.UpdateEdge(
2244        self._graph._c_graph,  # pylint: disable=protected-access
2245        tensor._as_tf_output(),  # pylint: disable=protected-access
2246        self._tf_input(index))
2247
2248  def _add_while_inputs(self, tensors):
2249    """See AddWhileInputHack in python_api.h.
2250
2251    NOTE: This is for TF internal use only. Please don't use it.
2252
2253    Args:
2254      tensors: list of Tensors
2255
2256    Raises:
2257      TypeError: if tensor is not a Tensor,
2258        or if input tensor type is not convertible to dtype.
2259      ValueError: if the Tensor is from a different graph.
2260    """
2261    for tensor in tensors:
2262      if not isinstance(tensor, Tensor):
2263        raise TypeError("tensor must be a Tensor: %s" % tensor)
2264      _assert_same_graph(self, tensor)
2265
2266      # Reset cached inputs.
2267      self._inputs_val = None
2268      c_api.AddWhileInputHack(
2269          self._graph._c_graph,  # pylint: disable=protected-access
2270          tensor._as_tf_output(),  # pylint: disable=protected-access
2271          self._c_op)
2272
2273  def _add_control_inputs(self, ops):
2274    """Add a list of new control inputs to this operation.
2275
2276    Args:
2277      ops: the list of Operations to add as control input.
2278
2279    Raises:
2280      TypeError: if ops is not a list of Operations.
2281      ValueError: if any op in ops is from a different graph.
2282    """
2283    for op in ops:
2284      if not isinstance(op, Operation):
2285        raise TypeError("op must be an Operation: %s" % op)
2286      c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
2287
2288  def _add_control_input(self, op):
2289    """Add a new control input to this operation.
2290
2291    Args:
2292      op: the Operation to add as control input.
2293
2294    Raises:
2295      TypeError: if op is not an Operation.
2296      ValueError: if op is from a different graph.
2297    """
2298    if not isinstance(op, Operation):
2299      raise TypeError("op must be an Operation: %s" % op)
2300    c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
2301
2302  def _remove_all_control_inputs(self):
2303    """Removes any control inputs to this operation."""
2304    c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
2305
2306  def _add_outputs(self, types, shapes):
2307    """Adds new Tensors to self.outputs.
2308
2309    Note: this is generally unsafe to use. This is used in certain situations in
2310    conjunction with _set_type_list_attr.
2311
2312    Arguments:
2313      types: list of DTypes
2314      shapes: list of TensorShapes
2315    """
2316    assert len(types) == len(shapes)
2317    orig_num_outputs = len(self.outputs)
2318    for i in range(len(types)):
2319      t = Tensor(self, orig_num_outputs + i, types[i])
2320      self._outputs.append(t)
2321      t.set_shape(shapes[i])
2322
2323  def __str__(self):
2324    return str(self.node_def)
2325
2326  def __repr__(self):
2327    return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
2328
2329  @property
2330  def outputs(self):
2331    """The list of `Tensor` objects representing the outputs of this op."""
2332    return self._outputs
2333
2334# pylint: disable=protected-access
2335
2336  class _InputList(object):
2337    """Immutable input list wrapper."""
2338
2339    def __init__(self, inputs):
2340      self._inputs = inputs
2341
2342    def __iter__(self):
2343      return iter(self._inputs)
2344
2345    def __len__(self):
2346      return len(self._inputs)
2347
2348    def __bool__(self):
2349      return bool(self._inputs)
2350
2351    # Python 3 wants __bool__, Python 2.7 wants __nonzero__
2352    __nonzero__ = __bool__
2353
2354    def __getitem__(self, i):
2355      return self._inputs[i]
2356
2357# pylint: enable=protected-access
2358
2359  @property
2360  def inputs(self):
2361    """The list of `Tensor` objects representing the data inputs of this op."""
2362    if self._inputs_val is None:
2363      tf_outputs = c_api.GetOperationInputs(self._c_op)
2364      # pylint: disable=protected-access
2365      retval = [
2366          self.graph._get_tensor_by_tf_output(tf_output)
2367          for tf_output in tf_outputs
2368      ]
2369      # pylint: enable=protected-access
2370      self._inputs_val = Operation._InputList(retval)
2371    return self._inputs_val
2372
2373  @property
2374  def _inputs(self):
2375    logging.warning("Operation._inputs is private, use Operation.inputs "
2376                    "instead. Operation._inputs will eventually be removed.")
2377    return self.inputs
2378
2379  @_inputs.setter
2380  def _inputs(self, value):
2381    raise ValueError("Cannot assign _inputs")
2382
2383  @property
2384  def _input_types(self):
2385    num_inputs = c_api.TF_OperationNumInputs(self._c_op)
2386    input_types = [
2387        dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
2388        for i in xrange(num_inputs)
2389    ]
2390    return input_types
2391
2392  @_input_types.setter
2393  def _input_types(self, value):
2394    raise ValueError("Cannot assign _input_types")
2395
2396  @property
2397  def control_inputs(self):
2398    """The `Operation` objects on which this op has a control dependency.
2399
2400    Before this op is executed, TensorFlow will ensure that the
2401    operations in `self.control_inputs` have finished executing. This
2402    mechanism can be used to run ops sequentially for performance
2403    reasons, or to ensure that the side effects of an op are observed
2404    in the correct order.
2405
2406    Returns:
2407      A list of `Operation` objects.
2408
2409    """
2410    control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
2411    # pylint: disable=protected-access
2412    return [
2413        self.graph._get_operation_by_name_unsafe(
2414            c_api.TF_OperationName(c_op)) for c_op in control_c_ops
2415    ]
2416    # pylint: enable=protected-access
2417
2418  @property
2419  def _control_outputs(self):
2420    """The `Operation` objects which have a control dependency on this op.
2421
2422    Before any of the ops in self._control_outputs can execute tensorflow will
2423    ensure self has finished executing.
2424
2425    Returns:
2426      A list of `Operation` objects.
2427
2428    """
2429    control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
2430    # pylint: disable=protected-access
2431    return [
2432        self.graph._get_operation_by_name_unsafe(
2433            c_api.TF_OperationName(c_op)) for c_op in control_c_ops
2434    ]
2435    # pylint: enable=protected-access
2436
2437  @property
2438  def _control_inputs(self):
2439    logging.warning("Operation._control_inputs is private, use "
2440                    "Operation.control_inputs instead. "
2441                    "Operation._control_inputs will eventually be removed.")
2442    return self.control_inputs
2443
2444  @_control_inputs.setter
2445  def _control_inputs(self, value):
2446    logging.warning("Operation._control_inputs is private, use "
2447                    "Operation.control_inputs instead. "
2448                    "Operation._control_inputs will eventually be removed.")
2449    # Copy value because it may be self._control_inputs_val (in particular if
2450    # this is called from self._control_inputs += ...), and we don't want to
2451    # clear value below.
2452    value = copy.copy(value)
2453    self._remove_all_control_inputs()
2454    self._add_control_inputs(value)
2455
2456  @property
2457  def type(self):
2458    """The type of the op (e.g. `"MatMul"`)."""
2459    return c_api.TF_OperationOpType(self._c_op)
2460
2461  @property
2462  def graph(self):
2463    """The `Graph` that contains this operation."""
2464    return self._graph
2465
2466  @property
2467  def node_def(self):
2468    # pylint: disable=line-too-long
2469    """Returns the `NodeDef` representation of this operation.
2470
2471    Returns:
2472      A
2473      [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
2474      protocol buffer.
2475    """
2476    # pylint: enable=line-too-long
2477    with c_api_util.tf_buffer() as buf:
2478      c_api.TF_OperationToNodeDef(self._c_op, buf)
2479      data = c_api.TF_GetBuffer(buf)
2480    node_def = node_def_pb2.NodeDef()
2481    node_def.ParseFromString(compat.as_bytes(data))
2482    return node_def
2483
2484  @property
2485  def _node_def(self):
2486    logging.warning("Operation._node_def is private, use Operation.node_def "
2487                    "instead. Operation._node_def will eventually be removed.")
2488    return self.node_def
2489
2490  @property
2491  def op_def(self):
2492    # pylint: disable=line-too-long
2493    """Returns the `OpDef` proto that represents the type of this op.
2494
2495    Returns:
2496      An
2497      [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
2498      protocol buffer.
2499    """
2500    # pylint: enable=line-too-long
2501    return self._graph._get_op_def(self.type)
2502
2503  @property
2504  def _op_def(self):
2505    logging.warning("Operation._op_def is private, use Operation.op_def "
2506                    "instead. Operation._op_def will eventually be removed.")
2507    return self.op_def
2508
2509  @property
2510  def traceback(self):
2511    """Returns the call stack from when this operation was constructed."""
2512    return tf_stack.convert_stack(self._traceback)
2513
2514  @property
2515  def traceback_with_start_lines(self):
2516    """Same as traceback but includes start line of function definition.
2517
2518    Returns:
2519      A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
2520    """
2521    return tf_stack.convert_stack(self._traceback,
2522                                  include_func_start_lineno=True)
2523
2524  def _set_attr(self, attr_name, attr_value):
2525    """Private method used to set an attribute in the node_def."""
2526    buf = c_api.TF_NewBufferFromString(
2527        compat.as_bytes(attr_value.SerializeToString()))
2528    try:
2529      # pylint: disable=protected-access
2530      c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
2531      # pylint: enable=protected-access
2532    finally:
2533      c_api.TF_DeleteBuffer(buf)
2534
2535  def _set_func_attr(self, attr_name, func_name):
2536    """Private method used to set a function attribute in the node_def."""
2537    func = attr_value_pb2.NameAttrList(name=func_name)
2538    self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
2539
2540  def _set_type_list_attr(self, attr_name, types):
2541    """Private method used to set a function attribute in the node_def."""
2542    if not types: return
2543    if isinstance(types[0], dtypes.DType):
2544      types = [dt.as_datatype_enum for dt in types]
2545    types_list = attr_value_pb2.AttrValue.ListValue(type=types)
2546    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
2547
2548  def _set_shape_list_attr(self, attr_name, shapes):
2549    """Private method used to set a function attribute in the node_def."""
2550    shapes = [s.as_proto() for s in shapes]
2551    shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
2552    self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
2553
2554  def _clear_attr(self, attr_name):
2555    """Private method used to clear an attribute in the node_def."""
2556    # pylint: disable=protected-access
2557    c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
2558    # pylint: enable=protected-access
2559
2560  def get_attr(self, name):
2561    """Returns the value of the attr of this op with the given `name`.
2562
2563    Args:
2564      name: The name of the attr to fetch.
2565
2566    Returns:
2567      The value of the attr, as a Python object.
2568
2569    Raises:
2570      ValueError: If this op does not have an attr with the given `name`.
2571    """
2572    fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
2573    try:
2574      with c_api_util.tf_buffer() as buf:
2575        c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2576        data = c_api.TF_GetBuffer(buf)
2577    except errors.InvalidArgumentError as e:
2578      # Convert to ValueError for backwards compatibility.
2579      raise ValueError(str(e))
2580    x = attr_value_pb2.AttrValue()
2581    x.ParseFromString(data)
2582
2583    oneof_value = x.WhichOneof("value")
2584    if oneof_value is None:
2585      return []
2586    if oneof_value == "list":
2587      for f in fields:
2588        if getattr(x.list, f):
2589          if f == "type":
2590            return [dtypes.as_dtype(t) for t in x.list.type]
2591          else:
2592            return list(getattr(x.list, f))
2593      return []
2594    if oneof_value == "type":
2595      return dtypes.as_dtype(x.type)
2596    assert oneof_value in fields, "Unsupported field type in " + str(x)
2597    return getattr(x, oneof_value)
2598
2599  def run(self, feed_dict=None, session=None):
2600    """Runs this operation in a `Session`.
2601
2602    Calling this method will execute all preceding operations that
2603    produce the inputs needed for this operation.
2604
2605    *N.B.* Before invoking `Operation.run()`, its graph must have been
2606    launched in a session, and either a default session must be
2607    available, or `session` must be specified explicitly.
2608
2609    Args:
2610      feed_dict: A dictionary that maps `Tensor` objects to feed values.
2611        See `tf.Session.run`
2612        for a description of the valid feed values.
2613      session: (Optional.) The `Session` to be used to run to this operation. If
2614        none, the default session will be used.
2615    """
2616    _run_using_default_session(self, feed_dict, self.graph, session)
2617
2618_gradient_registry = registry.Registry("gradient")
2619
2620
2621@tf_export("RegisterGradient")
2622class RegisterGradient(object):
2623  """A decorator for registering the gradient function for an op type.
2624
2625  This decorator is only used when defining a new op type. For an op
2626  with `m` inputs and `n` outputs, the gradient function is a function
2627  that takes the original `Operation` and `n` `Tensor` objects
2628  (representing the gradients with respect to each output of the op),
2629  and returns `m` `Tensor` objects (representing the partial gradients
2630  with respect to each input of the op).
2631
2632  For example, assuming that operations of type `"Sub"` take two
2633  inputs `x` and `y`, and return a single output `x - y`, the
2634  following gradient function would be registered:
2635
2636  ```python
2637  @tf.RegisterGradient("Sub")
2638  def _sub_grad(unused_op, grad):
2639    return grad, tf.negative(grad)
2640  ```
2641
2642  The decorator argument `op_type` is the string type of an
2643  operation. This corresponds to the `OpDef.name` field for the proto
2644  that defines the operation.
2645  """
2646
2647  def __init__(self, op_type):
2648    """Creates a new decorator with `op_type` as the Operation type.
2649
2650    Args:
2651      op_type: The string type of an operation. This corresponds to the
2652        `OpDef.name` field for the proto that defines the operation.
2653    """
2654    if not isinstance(op_type, six.string_types):
2655      raise TypeError("op_type must be a string")
2656    self._op_type = op_type
2657
2658  def __call__(self, f):
2659    """Registers the function `f` as gradient function for `op_type`."""
2660    _gradient_registry.register(f, self._op_type)
2661    return f
2662
2663
2664@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
2665@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
2666def no_gradient(op_type):
2667  """Specifies that ops of type `op_type` is not differentiable.
2668
2669  This function should *not* be used for operations that have a
2670  well-defined gradient that is not yet implemented.
2671
2672  This function is only used when defining a new op type. It may be
2673  used for ops such as `tf.size()` that are not differentiable.  For
2674  example:
2675
2676  ```python
2677  tf.NotDifferentiable("Size")
2678  ```
2679
2680  The gradient computed for 'op_type' will then propagate zeros.
2681
2682  For ops that have a well-defined gradient but are not yet implemented,
2683  no declaration should be made, and an error *must* be thrown if
2684  an attempt to request its gradient is made.
2685
2686  Args:
2687    op_type: The string type of an operation. This corresponds to the
2688      `OpDef.name` field for the proto that defines the operation.
2689
2690  Raises:
2691    TypeError: If `op_type` is not a string.
2692
2693  """
2694  if not isinstance(op_type, six.string_types):
2695    raise TypeError("op_type must be a string")
2696  _gradient_registry.register(None, op_type)
2697
2698
2699# Aliases for the old names, will be eventually removed.
2700NoGradient = no_gradient
2701NotDifferentiable = no_gradient
2702
2703
2704def get_gradient_function(op):
2705  """Returns the function that computes gradients for "op"."""
2706  if not op.inputs:
2707    return None
2708  try:
2709    op_type = op.get_attr("_gradient_op_type")
2710  except ValueError:
2711    op_type = op.type
2712  return _gradient_registry.lookup(op_type)
2713
2714
2715_shape_registry = registry.Registry("shape functions")
2716_default_shape_function_registry = registry.Registry("default shape functions")
2717
2718# These are set to common_shapes.call_cpp_shape_fn by op generated code
2719# (generated by python_op_gen.cc).
2720# It is set outside ops.py to avoid a circular dependency.
2721_call_cpp_shape_fn = None
2722_call_cpp_shape_fn_and_require_op = None
2723
2724
2725def _set_call_cpp_shape_fn(call_cpp_shape_fn):
2726  """Sets default shape fns from passed common_shapes.call_cpp_shape_fn."""
2727  global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op
2728  if _call_cpp_shape_fn:
2729    return  # already registered
2730
2731  def call_without_requiring(op):
2732    return call_cpp_shape_fn(op, require_shape_fn=False)
2733
2734  _call_cpp_shape_fn = call_without_requiring
2735
2736  def call_with_requiring(op):
2737    return call_cpp_shape_fn(op, require_shape_fn=True)
2738
2739  _call_cpp_shape_fn_and_require_op = call_with_requiring
2740
2741
2742class RegisterShape(object):
2743  """No longer used.  Was: A decorator for registering a shape function.
2744
2745  Shape functions must now be registered via the SetShapeFn on the
2746  original Op specification in C++.
2747
2748  """
2749
2750  def __init__(self, op_type):
2751    """Saves the `op_type` as the `Operation` type."""
2752    if not isinstance(op_type, six.string_types):
2753      raise TypeError("op_type must be a string")
2754    self._op_type = op_type
2755
2756  def __call__(self, f):
2757    """Registers "f" as the shape function for "op_type"."""
2758    if f is None:
2759      assert _call_cpp_shape_fn
2760
2761      # None is a special "weak" value that provides a default shape function,
2762      # and can be overridden by a non-None registration.
2763      try:
2764        _default_shape_function_registry.register(_call_cpp_shape_fn,
2765                                                  self._op_type)
2766      except KeyError:
2767        # Ignore duplicate registrations of the weak value. This can
2768        # occur if the op library input to wrapper generation
2769        # inadvertently links in one or more of the standard op
2770        # libraries.
2771        pass
2772    else:
2773      _shape_registry.register(f, self._op_type)
2774    return f
2775
2776
2777def set_shape_and_handle_data_for_outputs(_):
2778  """No op. TODO(b/74620627): Remove this."""
2779  pass
2780
2781
2782class OpStats(object):
2783  """A holder for statistics about an operator.
2784
2785  This class holds information about the resource requirements for an op,
2786  including the size of its weight parameters on-disk and how many FLOPS it
2787  requires to execute forward inference.
2788
2789  If you define a new operation, you can create a function that will return a
2790  set of information about its usage of the CPU and disk space when serialized.
2791  The function itself takes a Graph object that's been set up so you can call
2792  methods like get_tensor_by_name to help calculate the results, and a NodeDef
2793  argument.
2794
2795  """
2796
2797  def __init__(self, statistic_type, value=None):
2798    """Sets up the initial placeholders for the statistics."""
2799    self.statistic_type = statistic_type
2800    self.value = value
2801
2802  @property
2803  def statistic_type(self):
2804    return self._statistic_type
2805
2806  @statistic_type.setter
2807  def statistic_type(self, statistic_type):
2808    self._statistic_type = statistic_type
2809
2810  @property
2811  def value(self):
2812    return self._value
2813
2814  @value.setter
2815  def value(self, value):
2816    self._value = value
2817
2818  def __iadd__(self, other):
2819    if other.statistic_type != self.statistic_type:
2820      raise ValueError("Can't add an OpStat of type %s to one of %s." %
2821                       (self.statistic_type, other.statistic_type))
2822    if self.value is None:
2823      self.value = other.value
2824    elif other.value is not None:
2825      self._value += other.value
2826    return self
2827
2828
2829_stats_registry = registry.Registry("statistical functions")
2830
2831
2832class RegisterStatistics(object):
2833  """A decorator for registering the statistics function for an op type.
2834
2835  This decorator can be defined for an op type so that it gives a
2836  report on the resources used by an instance of an operator, in the
2837  form of an OpStats object.
2838
2839  Well-known types of statistics include these so far:
2840
2841  - flops: When running a graph, the bulk of the computation happens doing
2842    numerical calculations like matrix multiplications. This type allows a node
2843    to return how many floating-point operations it takes to complete. The
2844    total number of FLOPs for a graph is a good guide to its expected latency.
2845
2846  You can add your own statistics just by picking a new type string, registering
2847  functions for the ops you care about, and then calling get_stats_for_node_def.
2848
2849  If a statistic for an op is registered multiple times, a KeyError will be
2850  raised.
2851
2852  Since the statistics is counted on a per-op basis. It is not suitable for
2853  model parameters (capacity), which is expected to be counted only once, even
2854  if it is shared by multiple ops. (e.g. RNN)
2855
2856  For example, you can define a new metric called doohickey for a Foo operation
2857  by placing this in your code:
2858
2859  ```python
2860  @ops.RegisterStatistics("Foo", "doohickey")
2861  def _calc_foo_bojangles(unused_graph, unused_node_def):
2862    return ops.OpStats("doohickey", 20)
2863  ```
2864
2865  Then in client code you can retrieve the value by making this call:
2866
2867  ```python
2868  doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
2869  ```
2870
2871  If the NodeDef is for an op with a registered doohickey function, you'll get
2872  back the calculated amount in doohickey.value, or None if it's not defined.
2873
2874  """
2875
2876  def __init__(self, op_type, statistic_type):
2877    """Saves the `op_type` as the `Operation` type."""
2878    if not isinstance(op_type, six.string_types):
2879      raise TypeError("op_type must be a string.")
2880    if "," in op_type:
2881      raise TypeError("op_type must not contain a comma.")
2882    self._op_type = op_type
2883    if not isinstance(statistic_type, six.string_types):
2884      raise TypeError("statistic_type must be a string.")
2885    if "," in statistic_type:
2886      raise TypeError("statistic_type must not contain a comma.")
2887    self._statistic_type = statistic_type
2888
2889  def __call__(self, f):
2890    """Registers "f" as the statistics function for "op_type"."""
2891    _stats_registry.register(f, self._op_type + "," + self._statistic_type)
2892    return f
2893
2894
2895def get_stats_for_node_def(graph, node, statistic_type):
2896  """Looks up the node's statistics function in the registry and calls it.
2897
2898  This function takes a Graph object and a NodeDef from a GraphDef, and if
2899  there's an associated statistics method, calls it and returns a result. If no
2900  function has been registered for the particular node type, it returns an empty
2901  statistics object.
2902
2903  Args:
2904    graph: A Graph object that's been set up with the node's graph.
2905    node: A NodeDef describing the operator.
2906    statistic_type: A string identifying the statistic we're interested in.
2907  Returns:
2908    An OpStats object containing information about resource usage.
2909  """
2910
2911  try:
2912    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
2913    result = stats_func(graph, node)
2914  except LookupError:
2915    result = OpStats(statistic_type)
2916  return result
2917
2918
2919def _name_from_scope_name(name):
2920  """Returns the name of an op given the name of its scope.
2921
2922  Args:
2923    name: the name of the scope.
2924
2925  Returns:
2926    the name of the op (equal to scope name minus any trailing slash).
2927  """
2928  return name[:-1] if (name and name[-1] == "/") else name
2929
2930
2931_MUTATION_LOCK_GROUP = 0
2932_SESSION_RUN_LOCK_GROUP = 1
2933
2934@tf_export("Graph")
2935class Graph(object):
2936  """A TensorFlow computation, represented as a dataflow graph.
2937
2938  A `Graph` contains a set of
2939  `tf.Operation` objects,
2940  which represent units of computation; and
2941  `tf.Tensor` objects, which represent
2942  the units of data that flow between operations.
2943
2944  A default `Graph` is always registered, and accessible by calling
2945  `tf.get_default_graph`.
2946  To add an operation to the default graph, simply call one of the functions
2947  that defines a new `Operation`:
2948
2949  ```python
2950  c = tf.constant(4.0)
2951  assert c.graph is tf.get_default_graph()
2952  ```
2953
2954  Another typical usage involves the
2955  `tf.Graph.as_default`
2956  context manager, which overrides the current default graph for the
2957  lifetime of the context:
2958
2959  ```python
2960  g = tf.Graph()
2961  with g.as_default():
2962    # Define operations and tensors in `g`.
2963    c = tf.constant(30.0)
2964    assert c.graph is g
2965  ```
2966
2967  Important note: This class *is not* thread-safe for graph construction. All
2968  operations should be created from a single thread, or external
2969  synchronization must be provided. Unless otherwise specified, all methods
2970  are not thread-safe.
2971
2972  A `Graph` instance supports an arbitrary number of "collections"
2973  that are identified by name. For convenience when building a large
2974  graph, collections can store groups of related objects: for
2975  example, the `tf.Variable` uses a collection (named
2976  `tf.GraphKeys.GLOBAL_VARIABLES`) for
2977  all variables that are created during the construction of a graph. The caller
2978  may define additional collections by specifying a new name.
2979  """
2980
2981  def __init__(self):
2982    """Creates a new, empty Graph."""
2983    # Protects core state that can be returned via public accessors.
2984    # Thread-safety is provided on a best-effort basis to support buggy
2985    # programs, and is not guaranteed by the public `tf.Graph` API.
2986    #
2987    # NOTE(mrry): This does not protect the various stacks. A warning will
2988    # be reported if these are used from multiple threads
2989    self._lock = threading.RLock()
2990    # The group lock synchronizes Session.run calls with methods that create
2991    # and mutate ops (e.g. Graph.create_op()). This synchronization is
2992    # necessary because it's illegal to modify an operation after it's been run.
2993    # The group lock allows any number of threads to mutate ops at the same time
2994    # but if any modification is going on, all Session.run calls have to wait.
2995    # Similarly, if one or more Session.run calls are going on, all mutate ops
2996    # have to wait until all Session.run calls have finished.
2997    self._group_lock = lock_util.GroupLock(num_groups=2)
2998    self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
2999    self._next_id_counter = 0  # GUARDED_BY(self._lock)
3000    self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
3001    self._version = 0  # GUARDED_BY(self._lock)
3002    # Maps a name used in the graph to the next id to use for that name.
3003    self._names_in_use = {}
3004    self._stack_state_is_thread_local = False
3005    self._thread_local = threading.local()
3006    # Functions that will be applied to choose a device if none is specified.
3007    # In TF2.x or after switch_to_thread_local(),
3008    # self._thread_local._device_function_stack is used instead.
3009    self._graph_device_function_stack = traceable_stack.TraceableStack()
3010    # Default original_op applied to new ops.
3011    self._default_original_op = None
3012    # Current control flow context. It could be either CondContext or
3013    # WhileContext defined in ops/control_flow_ops.py
3014    self._control_flow_context = None
3015    # A new node will depend of the union of all of the nodes in the stack.
3016    # In TF2.x or after switch_to_thread_local(),
3017    # self._thread_local._control_dependencies_stack is used instead.
3018    self._graph_control_dependencies_stack = []
3019    # Arbitrary collections of objects.
3020    self._collections = {}
3021    # The graph-level random seed
3022    self._seed = None
3023    # A dictionary of attributes that should be applied to all ops.
3024    self._attr_scope_map = {}
3025    # A map from op type to the kernel label that should be used.
3026    self._op_to_kernel_label_map = {}
3027    # A map from op type to an alternative op type that should be used when
3028    # computing gradients.
3029    self._gradient_override_map = {}
3030    # True if the graph is considered "finalized".  In that case no
3031    # new operations can be added.
3032    self._finalized = False
3033    # Functions defined in the graph
3034    self._functions = collections.OrderedDict()
3035    # Default GraphDef versions
3036    self._graph_def_versions = versions_pb2.VersionDef(
3037        producer=versions.GRAPH_DEF_VERSION,
3038        min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
3039    self._building_function = False
3040    # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
3041    # self._thread_local._colocation_stack is used instead.
3042    self._graph_colocation_stack = traceable_stack.TraceableStack()
3043    # Set of tensors that are dangerous to feed!
3044    self._unfeedable_tensors = set()
3045    # Set of operations that are dangerous to fetch!
3046    self._unfetchable_ops = set()
3047    # A map of tensor handle placeholder to tensor dtype.
3048    self._handle_feeders = {}
3049    # A map from tensor handle to its read op.
3050    self._handle_readers = {}
3051    # A map from tensor handle to its move op.
3052    self._handle_movers = {}
3053    # A map from tensor handle to its delete op.
3054    self._handle_deleters = {}
3055    # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
3056    # will be shared when defining function graphs, for example, so optimizers
3057    # being called inside function definitions behave as if they were seeing the
3058    # actual outside graph).
3059    self._graph_key = "grap-key-%d/" % (uid(),)
3060    self._container = ""
3061    self._registered_ops = op_def_registry.get_registered_ops()
3062    # Set to True if this graph is being built in an
3063    # AutomaticControlDependencies context.
3064    self._add_control_dependencies = False
3065
3066    # TODO(skyewm): fold as much of the above as possible into the C
3067    # implementation
3068    self._scoped_c_graph = c_api_util.ScopedTFGraph()
3069    # The C API requires all ops to have shape functions. Disable this
3070    # requirement (many custom ops do not have shape functions, and we don't
3071    # want to break these existing cases).
3072    c_api.SetRequireShapeInferenceFns(self._c_graph, False)
3073    if tf2.enabled():
3074      self.switch_to_thread_local()
3075
3076  # Note: this method is private because the API of tf.Graph() is public and
3077  # frozen, and this functionality is still not ready for public visibility.
3078  @tf_contextlib.contextmanager
3079  def _variable_creator_scope(self, creator, priority=100):
3080    """Scope which defines a variable creation function.
3081
3082    Args:
3083      creator: A callable taking `next_creator` and `kwargs`. See the
3084        `tf.variable_creator_scope` docstring.
3085      priority: Creators with a higher `priority` are called first. Within the
3086        same priority, creators are called inner-to-outer.
3087
3088    Yields:
3089      `_variable_creator_scope` is a context manager with a side effect, but
3090      doesn't return a value.
3091    """
3092    # This step makes a copy of the existing stack, and it also initializes
3093    # self._thread_local._variable_creator_stack if it doesn't exist yet.
3094    old = list(self._variable_creator_stack)
3095    stack = self._thread_local._variable_creator_stack  # pylint: disable=protected-access
3096    stack.append((priority, creator))
3097    # Sorting is stable, so we'll put higher-priority creators later in the list
3098    # but otherwise maintain registration order.
3099    stack.sort(key=lambda item: item[0])
3100    try:
3101      yield
3102    finally:
3103      self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
3104
3105  # Note: this method is private because the API of tf.Graph() is public and
3106  # frozen, and this functionality is still not ready for public visibility.
3107  @property
3108  def _variable_creator_stack(self):
3109    if not hasattr(self._thread_local, "_variable_creator_stack"):
3110      self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
3111    return list(self._thread_local._variable_creator_stack)  # pylint: disable=protected-access
3112
3113  @_variable_creator_stack.setter
3114  def _variable_creator_stack(self, variable_creator_stack):
3115    self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
3116
3117  def _check_not_finalized(self):
3118    """Check if the graph is finalized.
3119
3120    Raises:
3121      RuntimeError: If the graph finalized.
3122    """
3123    if self._finalized:
3124      raise RuntimeError("Graph is finalized and cannot be modified.")
3125
3126  def _add_op(self, op):
3127    """Adds 'op' to the graph.
3128
3129    Args:
3130      op: the Operator or Tensor to add.
3131
3132    Raises:
3133      TypeError: if op is not an Operation or Tensor.
3134      ValueError: if the op.name or op._id are already used.
3135    """
3136    self._check_not_finalized()
3137    if not isinstance(op, (Tensor, Operation)):
3138      raise TypeError("op must be a Tensor or Operation: %s" % op)
3139    with self._lock:
3140      # pylint: disable=protected-access
3141      if op._id in self._nodes_by_id:
3142        raise ValueError("cannot add an op with id %d as it already "
3143                         "exists in the graph" % op._id)
3144      if op.name in self._nodes_by_name:
3145        raise ValueError("cannot add op with name %s as that name "
3146                         "is already used" % op.name)
3147      self._nodes_by_id[op._id] = op
3148      self._nodes_by_name[op.name] = op
3149      self._version = max(self._version, op._id)
3150      # pylint: enable=protected-access
3151
3152  @property
3153  def _c_graph(self):
3154    if self._scoped_c_graph:
3155      return self._scoped_c_graph.graph
3156    return None
3157
3158  @property
3159  def version(self):
3160    """Returns a version number that increases as ops are added to the graph.
3161
3162    Note that this is unrelated to the
3163    `tf.Graph.graph_def_versions`.
3164
3165    Returns:
3166       An integer version that increases as ops are added to the graph.
3167    """
3168    if self._finalized:
3169      return self._version
3170
3171    with self._lock:
3172      return self._version
3173
3174  @property
3175  def graph_def_versions(self):
3176    # pylint: disable=line-too-long
3177    """The GraphDef version information of this graph.
3178
3179    For details on the meaning of each version, see
3180    [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
3181
3182    Returns:
3183      A `VersionDef`.
3184    """
3185    # pylint: enable=line-too-long
3186    with c_api_util.tf_buffer() as buf:
3187      c_api.TF_GraphVersions(self._c_graph, buf)
3188      data = c_api.TF_GetBuffer(buf)
3189    version_def = versions_pb2.VersionDef()
3190    version_def.ParseFromString(compat.as_bytes(data))
3191    return version_def
3192
3193  @property
3194  def seed(self):
3195    """The graph-level random seed of this graph."""
3196    return self._seed
3197
3198  @seed.setter
3199  def seed(self, seed):
3200    self._seed = seed
3201
3202  @property
3203  def finalized(self):
3204    """True if this graph has been finalized."""
3205    return self._finalized
3206
3207  def finalize(self):
3208    """Finalizes this graph, making it read-only.
3209
3210    After calling `g.finalize()`, no new operations can be added to
3211    `g`.  This method is used to ensure that no operations are added
3212    to a graph when it is shared between multiple threads, for example
3213    when using a `tf.train.QueueRunner`.
3214    """
3215    self._finalized = True
3216
3217  def _unsafe_unfinalize(self):
3218    """Opposite of `finalize`. Internal interface.
3219
3220    NOTE: Unfinalizing a graph could have negative impact on performance,
3221    especially in a multi-threaded environment.  Unfinalizing a graph
3222    when it is in use by a Session may lead to undefined behavior. Ensure
3223    that all sessions using a graph are closed before calling this method.
3224    """
3225    self._finalized = False
3226
3227  def _get_control_flow_context(self):
3228    """Returns the current control flow context.
3229
3230    Returns:
3231      A context object.
3232    """
3233    return self._control_flow_context
3234
3235  def _set_control_flow_context(self, ctx):
3236    """Sets the current control flow context.
3237
3238    Args:
3239      ctx: a context object.
3240    """
3241    self._control_flow_context = ctx
3242
3243  def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
3244    """If this graph contains functions, copy them to `graph_def`."""
3245    bytesize = starting_bytesize
3246    for f in self._functions.values():
3247      bytesize += f.definition.ByteSize()
3248      if bytesize >= (1 << 31) or bytesize < 0:
3249        raise ValueError("GraphDef cannot be larger than 2GB.")
3250      graph_def.library.function.extend([f.definition])
3251      if f.grad_func_name:
3252        grad_def = function_pb2.GradientDef()
3253        grad_def.function_name = f.name
3254        grad_def.gradient_func = f.grad_func_name
3255        graph_def.library.gradient.extend([grad_def])
3256
3257  def _as_graph_def(self, from_version=None, add_shapes=False):
3258    # pylint: disable=line-too-long
3259    """Returns a serialized `GraphDef` representation of this graph.
3260
3261    The serialized `GraphDef` can be imported into another `Graph`
3262    (using `tf.import_graph_def`) or used with the
3263    [C++ Session API](../../../../api_docs/cc/index.md).
3264
3265    This method is thread-safe.
3266
3267    Args:
3268      from_version: Optional.  If this is set, returns a `GraphDef`
3269        containing only the nodes that were added to this graph since
3270        its `version` property had the given value.
3271      add_shapes: If true, adds an "_output_shapes" list attr to each
3272        node with the inferred shapes of each of its outputs.
3273
3274    Returns:
3275      A tuple containing a
3276      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3277      protocol buffer, and the version of the graph to which that
3278      `GraphDef` corresponds.
3279
3280    Raises:
3281      ValueError: If the `graph_def` would be too large.
3282
3283    """
3284    # pylint: enable=line-too-long
3285    with self._lock:
3286      with c_api_util.tf_buffer() as buf:
3287        c_api.TF_GraphToGraphDef(self._c_graph, buf)
3288        data = c_api.TF_GetBuffer(buf)
3289      graph = graph_pb2.GraphDef()
3290      graph.ParseFromString(compat.as_bytes(data))
3291      # Strip the experimental library field iff it's empty.
3292      if not graph.library.function:
3293        graph.ClearField("library")
3294
3295      if add_shapes:
3296        for node in graph.node:
3297          op = self._nodes_by_name[node.name]
3298          if op.outputs:
3299            node.attr["_output_shapes"].list.shape.extend(
3300                [output.get_shape().as_proto() for output in op.outputs])
3301    return graph, self._version
3302
3303  def as_graph_def(self, from_version=None, add_shapes=False):
3304    # pylint: disable=line-too-long
3305    """Returns a serialized `GraphDef` representation of this graph.
3306
3307    The serialized `GraphDef` can be imported into another `Graph`
3308    (using `tf.import_graph_def`) or used with the
3309    [C++ Session API](../../api_docs/cc/index.md).
3310
3311    This method is thread-safe.
3312
3313    Args:
3314      from_version: Optional.  If this is set, returns a `GraphDef`
3315        containing only the nodes that were added to this graph since
3316        its `version` property had the given value.
3317      add_shapes: If true, adds an "_output_shapes" list attr to each
3318        node with the inferred shapes of each of its outputs.
3319
3320    Returns:
3321      A
3322      [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
3323      protocol buffer.
3324
3325    Raises:
3326      ValueError: If the `graph_def` would be too large.
3327    """
3328    # pylint: enable=line-too-long
3329    result, _ = self._as_graph_def(from_version, add_shapes)
3330    return result
3331
3332  def _is_function(self, name):
3333    """Tests whether 'name' is registered in this graph's function library.
3334
3335    Args:
3336      name: string op name.
3337    Returns:
3338      bool indicating whether or not 'name' is registered in function library.
3339    """
3340    return compat.as_str(name) in self._functions
3341
3342  def _get_function(self, name):
3343    """Returns the function definition for 'name'.
3344
3345    Args:
3346      name: string function name.
3347    Returns:
3348      The function def proto.
3349    """
3350    return self._functions.get(compat.as_str(name), None)
3351
3352  def _add_function(self, function):
3353    """Adds a function to the graph.
3354
3355    After the function has been added, you can call to the function by
3356    passing the function name in place of an op name to
3357    `Graph.create_op()`.
3358
3359    Args:
3360      function: A `_DefinedFunction` object.
3361
3362
3363    Raises:
3364      ValueError: if another function is defined with the same name.
3365    """
3366    name = function.name
3367    # Sanity checks on gradient definition.
3368    if (function.grad_func_name is not None) and (function.python_grad_func is
3369                                                  not None):
3370      raise ValueError("Gradient defined twice for function %s" % name)
3371
3372    # Add function to graph
3373    # pylint: disable=protected-access
3374    # Handle functions created without using the C API. TODO(apassos,skyewm)
3375    # remove this when all functions are generated using the C API by default
3376    # as this will be unnecessary.
3377    if not function._c_func:
3378      serialized = function.definition.SerializeToString()
3379      c_func = c_api.TF_FunctionImportFunctionDef(serialized)
3380      function._c_func = c_api_util.ScopedTFFunction(c_func)
3381    gradient = (function._grad_func._c_func.func if function._grad_func
3382                else None)
3383    c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
3384    # pylint: enable=protected-access
3385
3386    self._functions[compat.as_str(name)] = function
3387
3388    # Need a new-enough consumer to support the functions we add to the graph.
3389    if self._graph_def_versions.min_consumer < 12:
3390      self._graph_def_versions.min_consumer = 12
3391
3392  @property
3393  def building_function(self):
3394    """Returns True iff this graph represents a function."""
3395    return self._building_function
3396
3397  # Helper functions to create operations.
3398  @deprecated_args(None,
3399                   "Shapes are always computed; don't use the compute_shapes "
3400                   "as it has no effect.", "compute_shapes")
3401  def create_op(
3402      self,
3403      op_type,
3404      inputs,
3405      dtypes=None,  # pylint: disable=redefined-outer-name
3406      input_types=None,
3407      name=None,
3408      attrs=None,
3409      op_def=None,
3410      compute_shapes=True,
3411      compute_device=True):
3412    """Creates an `Operation` in this graph.
3413
3414    This is a low-level interface for creating an `Operation`. Most
3415    programs will not call this method directly, and instead use the
3416    Python op constructors, such as `tf.constant()`, which add ops to
3417    the default graph.
3418
3419    Args:
3420      op_type: The `Operation` type to create. This corresponds to the
3421        `OpDef.name` field for the proto that defines the operation.
3422      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
3423      dtypes: (Optional) A list of `DType` objects that will be the types of the
3424        tensors that the operation produces.
3425      input_types: (Optional.) A list of `DType`s that will be the types of
3426        the tensors that the operation consumes. By default, uses the base
3427        `DType` of each input in `inputs`. Operations that expect
3428        reference-typed inputs must specify `input_types` explicitly.
3429      name: (Optional.) A string name for the operation. If not specified, a
3430        name is generated based on `op_type`.
3431      attrs: (Optional.) A dictionary where the key is the attribute name (a
3432        string) and the value is the respective `attr` attribute of the
3433        `NodeDef` proto that will represent the operation (an `AttrValue`
3434        proto).
3435      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
3436        the operation will have.
3437      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
3438        computed).
3439      compute_device: (Optional.) If True, device functions will be executed
3440        to compute the device property of the Operation.
3441
3442    Raises:
3443      TypeError: if any of the inputs is not a `Tensor`.
3444      ValueError: if colocation conflicts with existing device assignment.
3445
3446    Returns:
3447      An `Operation` object.
3448    """
3449    del compute_shapes
3450
3451    self._check_not_finalized()
3452    for idx, a in enumerate(inputs):
3453      if not isinstance(a, Tensor):
3454        raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
3455    if name is None:
3456      name = op_type
3457    # If a names ends with a '/' it is a "name scope" and we use it as-is,
3458    # after removing the trailing '/'.
3459    if name and name[-1] == "/":
3460      name = _name_from_scope_name(name)
3461    else:
3462      name = self.unique_name(name)
3463
3464    node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
3465
3466    input_ops = set([t.op for t in inputs])
3467    control_inputs = self._control_dependencies_for_inputs(input_ops)
3468    # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
3469    # Session.run call cannot occur between creating and mutating the op.
3470    with self._mutation_lock():
3471      ret = Operation(
3472          node_def,
3473          self,
3474          inputs=inputs,
3475          output_types=dtypes,
3476          control_inputs=control_inputs,
3477          input_types=input_types,
3478          original_op=self._default_original_op,
3479          op_def=op_def)
3480      self._create_op_helper(ret, compute_device=compute_device)
3481    return ret
3482
3483  def _create_op_from_tf_operation(self, c_op, compute_device=True):
3484    """Creates an `Operation` in this graph from the supplied TF_Operation.
3485
3486    This method is like create_op() except the new Operation is constructed
3487    using `c_op`. The returned Operation will have `c_op` as its _c_op
3488    field. This is used to create Operation objects around TF_Operations created
3489    indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
3490
3491    This function does not call Operation._control_flow_post_processing or
3492    Graph._control_dependencies_for_inputs (since the inputs may not be
3493    available yet). The caller is responsible for calling these methods.
3494
3495    Args:
3496      c_op: a wrapped TF_Operation
3497      compute_device: (Optional.) If True, device functions will be executed
3498        to compute the device property of the Operation.
3499
3500    Returns:
3501      An `Operation` object.
3502    """
3503    self._check_not_finalized()
3504    ret = Operation(c_op, self)
3505    # If a name_scope was created with ret.name but no nodes were created in it,
3506    # the name will still appear in _names_in_use even though the name hasn't
3507    # been used. This is ok, just leave _names_in_use as-is in this case.
3508    # TODO(skyewm): make the C API guarantee no name conflicts.
3509    name_key = ret.name.lower()
3510    if name_key not in self._names_in_use:
3511      self._names_in_use[name_key] = 1
3512    self._create_op_helper(ret, compute_device=compute_device)
3513    return ret
3514
3515  def _create_op_helper(self, op, compute_device=True):
3516    """Common logic for creating an op in this graph."""
3517    # Apply any additional attributes requested. Do not overwrite any existing
3518    # attributes.
3519    for key, value in self._attr_scope_map.items():
3520      try:
3521        op.get_attr(key)
3522      except ValueError:
3523        if callable(value):
3524          value = value(op.node_def)
3525          if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
3526            raise TypeError(
3527                "Callable for scope map key '%s' must return either None or "
3528                "an AttrValue protocol buffer; but it returned: %s" % (key,
3529                                                                       value))
3530        if value:
3531          op._set_attr(key, value)  # pylint: disable=protected-access
3532
3533    # Apply a kernel label if one has been specified for this op type.
3534    try:
3535      kernel_label = self._op_to_kernel_label_map[op.type]
3536      op._set_attr("_kernel",  # pylint: disable=protected-access
3537                   attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
3538    except KeyError:
3539      pass
3540
3541    # Apply the overriding op type for gradients if one has been specified for
3542    # this op type.
3543    try:
3544      mapped_op_type = self._gradient_override_map[op.type]
3545      op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
3546                   attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
3547    except KeyError:
3548      pass
3549
3550    self._record_op_seen_by_control_dependencies(op)
3551
3552    if compute_device:
3553      self._apply_device_functions(op)
3554
3555    # Snapshot the colocation stack metadata before we might generate error
3556    # messages using it.  Note that this snapshot depends on the actual stack
3557    # and is independent of the op's _class attribute.
3558    # pylint: disable=protected-access
3559    op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
3560    # pylint: enable=protected-access
3561
3562    if self._colocation_stack:
3563      all_colocation_groups = []
3564      for colocation_op in self._colocation_stack.peek_objs():
3565        all_colocation_groups.extend(colocation_op.colocation_groups())
3566        if colocation_op.device:
3567          # pylint: disable=protected-access
3568          op._set_device(colocation_op.device)
3569          # pylint: enable=protected-access
3570
3571      all_colocation_groups = sorted(set(all_colocation_groups))
3572      # pylint: disable=protected-access
3573      op._set_attr("_class", attr_value_pb2.AttrValue(
3574          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
3575      # pylint: enable=protected-access
3576
3577    # Sets "container" attribute if
3578    # (1) self._container is not None
3579    # (2) "is_stateful" is set in OpDef
3580    # (3) "container" attribute is in OpDef
3581    # (4) "container" attribute is None
3582    if self._container and op.op_def.is_stateful:
3583      try:
3584        container_attr = op.get_attr("container")
3585      except ValueError:
3586        # "container" attribute is not in OpDef
3587        pass
3588      else:
3589        if not container_attr:
3590          op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
3591              s=compat.as_bytes(self._container)))
3592
3593  def _add_new_tf_operations(self, compute_devices=True):
3594    """Creates `Operations` in this graph for any new TF_Operations.
3595
3596    This is useful for when TF_Operations are indirectly created by the C API
3597    outside of the Operation constructor (e.g. by TF_ImportGraphDef,
3598    TF_FinishWhile). This ensures there are corresponding Operations for all
3599    TF_Operations in the underlying TF_Graph.
3600
3601    Args:
3602      compute_devices: (Optional.) If True, device functions will be executed
3603        to compute the device properties of each new Operation.
3604
3605    Returns:
3606      A list of the new `Operation` objects.
3607    """
3608    # Create all Operation objects before accessing their inputs since an op may
3609    # be created before its inputs.
3610    new_ops = [
3611        self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
3612        for c_op in c_api_util.new_tf_operations(self)
3613    ]
3614
3615    # pylint: disable=protected-access
3616    for op in new_ops:
3617      new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
3618      op._add_control_inputs(new_control_inputs)
3619      op._control_flow_post_processing()
3620    # pylint: enable=protected-access
3621
3622    return new_ops
3623
3624  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
3625    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
3626
3627    This function validates that `obj` represents an element of this
3628    graph, and gives an informative error message if it is not.
3629
3630    This function is the canonical way to get/validate an object of
3631    one of the allowed types from an external argument reference in the
3632    Session API.
3633
3634    This method may be called concurrently from multiple threads.
3635
3636    Args:
3637      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
3638        Can also be any object with an `_as_graph_element()` method that returns
3639        a value of one of these types.
3640      allow_tensor: If true, `obj` may refer to a `Tensor`.
3641      allow_operation: If true, `obj` may refer to an `Operation`.
3642
3643    Returns:
3644      The `Tensor` or `Operation` in the Graph corresponding to `obj`.
3645
3646    Raises:
3647      TypeError: If `obj` is not a type we support attempting to convert
3648        to types.
3649      ValueError: If `obj` is of an appropriate type but invalid. For
3650        example, an invalid string.
3651      KeyError: If `obj` is not an object in the graph.
3652    """
3653    if self._finalized:
3654      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3655
3656    with self._lock:
3657      return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3658
3659  def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
3660    """See `Graph.as_graph_element()` for details."""
3661    # The vast majority of this function is figuring
3662    # out what an API user might be doing wrong, so
3663    # that we can give helpful error messages.
3664    #
3665    # Ideally, it would be nice to split it up, but we
3666    # need context to generate nice error messages.
3667
3668    if allow_tensor and allow_operation:
3669      types_str = "Tensor or Operation"
3670    elif allow_tensor:
3671      types_str = "Tensor"
3672    elif allow_operation:
3673      types_str = "Operation"
3674    else:
3675      raise ValueError("allow_tensor and allow_operation can't both be False.")
3676
3677    temp_obj = _as_graph_element(obj)
3678    if temp_obj is not None:
3679      obj = temp_obj
3680
3681    # If obj appears to be a name...
3682    if isinstance(obj, compat.bytes_or_text_types):
3683      name = compat.as_str(obj)
3684
3685      if ":" in name and allow_tensor:
3686        # Looks like a Tensor name and can be a Tensor.
3687        try:
3688          op_name, out_n = name.split(":")
3689          out_n = int(out_n)
3690        except:
3691          raise ValueError("The name %s looks a like a Tensor name, but is "
3692                           "not a valid one. Tensor names must be of the "
3693                           "form \"<op_name>:<output_index>\"." % repr(name))
3694        if op_name in self._nodes_by_name:
3695          op = self._nodes_by_name[op_name]
3696        else:
3697          raise KeyError("The name %s refers to a Tensor which does not "
3698                         "exist. The operation, %s, does not exist in the "
3699                         "graph." % (repr(name), repr(op_name)))
3700        try:
3701          return op.outputs[out_n]
3702        except:
3703          raise KeyError("The name %s refers to a Tensor which does not "
3704                         "exist. The operation, %s, exists but only has "
3705                         "%s outputs." % (repr(name), repr(op_name),
3706                                          len(op.outputs)))
3707
3708      elif ":" in name and not allow_tensor:
3709        # Looks like a Tensor name but can't be a Tensor.
3710        raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
3711                         (repr(name), types_str))
3712
3713      elif ":" not in name and allow_operation:
3714        # Looks like an Operation name and can be an Operation.
3715        if name not in self._nodes_by_name:
3716          raise KeyError("The name %s refers to an Operation not in the "
3717                         "graph." % repr(name))
3718        return self._nodes_by_name[name]
3719
3720      elif ":" not in name and not allow_operation:
3721        # Looks like an Operation name but can't be an Operation.
3722        if name in self._nodes_by_name:
3723          # Yep, it's an Operation name
3724          err_msg = ("The name %s refers to an Operation, not a %s." %
3725                     (repr(name), types_str))
3726        else:
3727          err_msg = ("The name %s looks like an (invalid) Operation name, "
3728                     "not a %s." % (repr(name), types_str))
3729        err_msg += (" Tensor names must be of the form "
3730                    "\"<op_name>:<output_index>\".")
3731        raise ValueError(err_msg)
3732
3733    elif isinstance(obj, Tensor) and allow_tensor:
3734      # Actually obj is just the object it's referring to.
3735      if obj.graph is not self:
3736        raise ValueError("Tensor %s is not an element of this graph." % obj)
3737      return obj
3738    elif isinstance(obj, Operation) and allow_operation:
3739      # Actually obj is just the object it's referring to.
3740      if obj.graph is not self:
3741        raise ValueError("Operation %s is not an element of this graph." % obj)
3742      return obj
3743    else:
3744      # We give up!
3745      raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
3746                                                           types_str))
3747
3748  def get_operations(self):
3749    """Return the list of operations in the graph.
3750
3751    You can modify the operations in place, but modifications
3752    to the list such as inserts/delete have no effect on the
3753    list of operations known to the graph.
3754
3755    This method may be called concurrently from multiple threads.
3756
3757    Returns:
3758      A list of Operations.
3759    """
3760    if self._finalized:
3761      return list(self._nodes_by_id.values())
3762
3763    with self._lock:
3764      return list(self._nodes_by_id.values())
3765
3766  def get_operation_by_name(self, name):
3767    """Returns the `Operation` with the given `name`.
3768
3769    This method may be called concurrently from multiple threads.
3770
3771    Args:
3772      name: The name of the `Operation` to return.
3773
3774    Returns:
3775      The `Operation` with the given `name`.
3776
3777    Raises:
3778      TypeError: If `name` is not a string.
3779      KeyError: If `name` does not correspond to an operation in this graph.
3780    """
3781
3782    if not isinstance(name, six.string_types):
3783      raise TypeError("Operation names are strings (or similar), not %s." %
3784                      type(name).__name__)
3785    return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
3786
3787  def _get_operation_by_name_unsafe(self, name):
3788    """Returns the `Operation` with the given `name`.
3789
3790    This is a internal unsafe version of get_operation_by_name. It skips many
3791    checks and does not have user friedly error messages but runs considerably
3792    faster. This method may be called concurrently from multiple threads.
3793
3794    Args:
3795      name: The name of the `Operation` to return.
3796
3797    Returns:
3798      The `Operation` with the given `name`.
3799
3800    Raises:
3801      KeyError: If `name` does not correspond to an operation in this graph.
3802    """
3803
3804    if self._finalized:
3805      return self._nodes_by_name[name]
3806
3807    with self._lock:
3808      return self._nodes_by_name[name]
3809
3810  def _get_operation_by_tf_operation(self, tf_oper):
3811    op_name = c_api.TF_OperationName(tf_oper)
3812    return self._get_operation_by_name_unsafe(op_name)
3813
3814  def get_tensor_by_name(self, name):
3815    """Returns the `Tensor` with the given `name`.
3816
3817    This method may be called concurrently from multiple threads.
3818
3819    Args:
3820      name: The name of the `Tensor` to return.
3821
3822    Returns:
3823      The `Tensor` with the given `name`.
3824
3825    Raises:
3826      TypeError: If `name` is not a string.
3827      KeyError: If `name` does not correspond to a tensor in this graph.
3828    """
3829    # Names should be strings.
3830    if not isinstance(name, six.string_types):
3831      raise TypeError("Tensor names are strings (or similar), not %s." %
3832                      type(name).__name__)
3833    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
3834
3835  def _get_tensor_by_tf_output(self, tf_output):
3836    """Returns the `Tensor` representing `tf_output`.
3837
3838    Note that there is only one such `Tensor`, i.e. multiple calls to this
3839    function with the same TF_Output value will always return the same `Tensor`
3840    object.
3841
3842    Args:
3843      tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
3844
3845    Returns:
3846      The `Tensor` that represents `tf_output`.
3847    """
3848    op = self._get_operation_by_tf_operation(tf_output.oper)
3849    return op.outputs[tf_output.index]
3850
3851  def _next_id(self):
3852    """Id for next Operation instance. Also increments the internal id."""
3853    self._check_not_finalized()
3854    with self._lock:
3855      self._next_id_counter += 1
3856      return self._next_id_counter
3857
3858  @property
3859  def _last_id(self):
3860    return self._next_id_counter
3861
3862  def _get_op_def(self, type):  # pylint: disable=redefined-builtin
3863    """Returns the `OpDef` proto for `type`. `type` is a string."""
3864    with c_api_util.tf_buffer() as buf:
3865      # pylint: disable=protected-access
3866      c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
3867      # pylint: enable=protected-access
3868      data = c_api.TF_GetBuffer(buf)
3869    op_def = op_def_pb2.OpDef()
3870    op_def.ParseFromString(compat.as_bytes(data))
3871    return op_def
3872
3873  def as_default(self):
3874    """Returns a context manager that makes this `Graph` the default graph.
3875
3876    This method should be used if you want to create multiple graphs
3877    in the same process. For convenience, a global default graph is
3878    provided, and all ops will be added to this graph if you do not
3879    create a new graph explicitly.
3880
3881    Use this method with the `with` keyword to specify that ops created within
3882    the scope of a block should be added to this graph. In this case, once
3883    the scope of the `with` is exited, the previous default graph is set again
3884    as default. There is a stack, so it's ok to have multiple nested levels
3885    of `as_default` calls.
3886
3887    The default graph is a property of the current thread. If you
3888    create a new thread, and wish to use the default graph in that
3889    thread, you must explicitly add a `with g.as_default():` in that
3890    thread's function.
3891
3892    The following code examples are equivalent:
3893
3894    ```python
3895    # 1. Using Graph.as_default():
3896    g = tf.Graph()
3897    with g.as_default():
3898      c = tf.constant(5.0)
3899      assert c.graph is g
3900
3901    # 2. Constructing and making default:
3902    with tf.Graph().as_default() as g:
3903      c = tf.constant(5.0)
3904      assert c.graph is g
3905    ```
3906
3907    If eager execution is enabled ops created under this context manager will be
3908    added to the graph instead of executed eagerly.
3909
3910    Returns:
3911      A context manager for using this graph as the default graph.
3912    """
3913    return _default_graph_stack.get_controller(self)
3914
3915  @property
3916  def collections(self):
3917    """Returns the names of the collections known to this graph."""
3918    return list(self._collections)
3919
3920  def add_to_collection(self, name, value):
3921    """Stores `value` in the collection with the given `name`.
3922
3923    Note that collections are not sets, so it is possible to add a value to
3924    a collection several times.
3925
3926    Args:
3927      name: The key for the collection. The `GraphKeys` class
3928        contains many standard names for collections.
3929      value: The value to add to the collection.
3930    """  # pylint: disable=g-doc-exception
3931    self._check_not_finalized()
3932    with self._lock:
3933      if name not in self._collections:
3934        self._collections[name] = [value]
3935      else:
3936        self._collections[name].append(value)
3937
3938  def add_to_collections(self, names, value):
3939    """Stores `value` in the collections given by `names`.
3940
3941    Note that collections are not sets, so it is possible to add a value to
3942    a collection several times. This function makes sure that duplicates in
3943    `names` are ignored, but it will not check for pre-existing membership of
3944    `value` in any of the collections in `names`.
3945
3946    `names` can be any iterable, but if `names` is a string, it is treated as a
3947    single collection name.
3948
3949    Args:
3950      names: The keys for the collections to add to. The `GraphKeys` class
3951        contains many standard names for collections.
3952      value: The value to add to the collections.
3953    """
3954    # Make sure names are unique, but treat strings as a single collection name
3955    names = (names,) if isinstance(names, six.string_types) else set(names)
3956    for name in names:
3957      self.add_to_collection(name, value)
3958
3959  def get_collection_ref(self, name):
3960    """Returns a list of values in the collection with the given `name`.
3961
3962    If the collection exists, this returns the list itself, which can
3963    be modified in place to change the collection.  If the collection does
3964    not exist, it is created as an empty list and the list is returned.
3965
3966    This is different from `get_collection()` which always returns a copy of
3967    the collection list if it exists and never creates an empty collection.
3968
3969    Args:
3970      name: The key for the collection. For example, the `GraphKeys` class
3971        contains many standard names for collections.
3972
3973    Returns:
3974      The list of values in the collection with the given `name`, or an empty
3975      list if no value has been added to that collection.
3976    """  # pylint: disable=g-doc-exception
3977    with self._lock:
3978      coll_list = self._collections.get(name, None)
3979      if coll_list is None:
3980        coll_list = []
3981        self._collections[name] = coll_list
3982      return coll_list
3983
3984  def get_collection(self, name, scope=None):
3985    """Returns a list of values in the collection with the given `name`.
3986
3987    This is different from `get_collection_ref()` which always returns the
3988    actual collection list if it exists in that it returns a new list each time
3989    it is called.
3990
3991    Args:
3992      name: The key for the collection. For example, the `GraphKeys` class
3993        contains many standard names for collections.
3994      scope: (Optional.) A string. If supplied, the resulting list is filtered
3995        to include only items whose `name` attribute matches `scope` using
3996        `re.match`. Items without a `name` attribute are never returned if a
3997        scope is supplied. The choice of `re.match` means that a `scope` without
3998        special tokens filters by prefix.
3999
4000    Returns:
4001      The list of values in the collection with the given `name`, or
4002      an empty list if no value has been added to that collection. The
4003      list contains the values in the order under which they were
4004      collected.
4005    """  # pylint: disable=g-doc-exception
4006    with self._lock:
4007      collection = self._collections.get(name, None)
4008      if collection is None:
4009        return []
4010      if scope is None:
4011        return list(collection)
4012      else:
4013        c = []
4014        regex = re.compile(scope)
4015        for item in collection:
4016          if hasattr(item, "name") and regex.match(item.name):
4017            c.append(item)
4018        return c
4019
4020  def get_all_collection_keys(self):
4021    """Returns a list of collections used in this graph."""
4022    with self._lock:
4023      return [x for x in self._collections if isinstance(x, six.string_types)]
4024
4025  def clear_collection(self, name):
4026    """Clears all values in a collection.
4027
4028    Args:
4029      name: The key for the collection. The `GraphKeys` class contains many
4030        standard names for collections.
4031    """
4032    self._check_not_finalized()
4033    with self._lock:
4034      if name in self._collections:
4035        del self._collections[name]
4036
4037  @tf_contextlib.contextmanager
4038  def _original_op(self, op):
4039    """Python 'with' handler to help annotate ops with their originator.
4040
4041    An op may have an 'original_op' property that indicates the op on which
4042    it was based. For example a replica op is based on the op that was
4043    replicated and a gradient op is based on the op that was differentiated.
4044
4045    All ops created in the scope of this 'with' handler will have
4046    the given 'op' as their original op.
4047
4048    Args:
4049      op: The Operation that all ops created in this scope will have as their
4050        original op.
4051
4052    Yields:
4053      Nothing.
4054    """
4055    old_original_op = self._default_original_op
4056    self._default_original_op = op
4057    try:
4058      yield
4059    finally:
4060      self._default_original_op = old_original_op
4061
4062  @property
4063  def _name_stack(self):
4064    # This may be called from a thread where name_stack doesn't yet exist.
4065    if not hasattr(self._thread_local, "_name_stack"):
4066      self._thread_local._name_stack = ""
4067    return self._thread_local._name_stack
4068
4069  @_name_stack.setter
4070  def _name_stack(self, name_stack):
4071    self._thread_local._name_stack = name_stack
4072
4073  # pylint: disable=g-doc-return-or-yield,line-too-long
4074  @tf_contextlib.contextmanager
4075  def name_scope(self, name):
4076    r"""Returns a context manager that creates hierarchical names for operations.
4077
4078    A graph maintains a stack of name scopes. A `with name_scope(...):`
4079    statement pushes a new name onto the stack for the lifetime of the context.
4080
4081    The `name` argument will be interpreted as follows:
4082
4083    * A string (not ending with '/') will create a new name scope, in which
4084      `name` is appended to the prefix of all operations created in the
4085      context. If `name` has been used before, it will be made unique by
4086      calling `self.unique_name(name)`.
4087    * A scope previously captured from a `with g.name_scope(...) as
4088      scope:` statement will be treated as an "absolute" name scope, which
4089      makes it possible to re-enter existing scopes.
4090    * A value of `None` or the empty string will reset the current name scope
4091      to the top-level (empty) name scope.
4092
4093    For example:
4094
4095    ```python
4096    with tf.Graph().as_default() as g:
4097      c = tf.constant(5.0, name="c")
4098      assert c.op.name == "c"
4099      c_1 = tf.constant(6.0, name="c")
4100      assert c_1.op.name == "c_1"
4101
4102      # Creates a scope called "nested"
4103      with g.name_scope("nested") as scope:
4104        nested_c = tf.constant(10.0, name="c")
4105        assert nested_c.op.name == "nested/c"
4106
4107        # Creates a nested scope called "inner".
4108        with g.name_scope("inner"):
4109          nested_inner_c = tf.constant(20.0, name="c")
4110          assert nested_inner_c.op.name == "nested/inner/c"
4111
4112        # Create a nested scope called "inner_1".
4113        with g.name_scope("inner"):
4114          nested_inner_1_c = tf.constant(30.0, name="c")
4115          assert nested_inner_1_c.op.name == "nested/inner_1/c"
4116
4117          # Treats `scope` as an absolute name scope, and
4118          # switches to the "nested/" scope.
4119          with g.name_scope(scope):
4120            nested_d = tf.constant(40.0, name="d")
4121            assert nested_d.op.name == "nested/d"
4122
4123            with g.name_scope(""):
4124              e = tf.constant(50.0, name="e")
4125              assert e.op.name == "e"
4126    ```
4127
4128    The name of the scope itself can be captured by `with
4129    g.name_scope(...) as scope:`, which stores the name of the scope
4130    in the variable `scope`. This value can be used to name an
4131    operation that represents the overall result of executing the ops
4132    in a scope. For example:
4133
4134    ```python
4135    inputs = tf.constant(...)
4136    with g.name_scope('my_layer') as scope:
4137      weights = tf.Variable(..., name="weights")
4138      biases = tf.Variable(..., name="biases")
4139      affine = tf.matmul(inputs, weights) + biases
4140      output = tf.nn.relu(affine, name=scope)
4141    ```
4142
4143    NOTE: This constructor validates the given `name`. Valid scope
4144    names match one of the following regular expressions:
4145
4146        [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
4147        [A-Za-z0-9_.\\-/]* (for other scopes)
4148
4149    Args:
4150      name: A name for the scope.
4151
4152    Returns:
4153      A context manager that installs `name` as a new name scope.
4154
4155    Raises:
4156      ValueError: If `name` is not a valid scope name, according to the rules
4157        above.
4158    """
4159    if name:
4160      if isinstance(name, compat.bytes_or_text_types):
4161        name = compat.as_str(name)
4162
4163      if self._name_stack:
4164        # Scopes created in a nested scope may have initial characters
4165        # that are illegal as the initial character of an op name
4166        # (viz. '-', '\', '/', and '_').
4167        if not _VALID_SCOPE_NAME_REGEX.match(name):
4168          raise ValueError("'%s' is not a valid scope name" % name)
4169      else:
4170        # Scopes created in the root must match the more restrictive
4171        # op name regex, which constrains the initial character.
4172        if not _VALID_OP_NAME_REGEX.match(name):
4173          raise ValueError("'%s' is not a valid scope name" % name)
4174    old_stack = self._name_stack
4175    if not name:  # Both for name=None and name="" we re-set to empty scope.
4176      new_stack = None
4177    elif name[-1] == "/":
4178      new_stack = _name_from_scope_name(name)
4179    else:
4180      new_stack = self.unique_name(name)
4181    self._name_stack = new_stack
4182    try:
4183      yield "" if new_stack is None else new_stack + "/"
4184    finally:
4185      self._name_stack = old_stack
4186
4187  # pylint: enable=g-doc-return-or-yield,line-too-long
4188
4189  def unique_name(self, name, mark_as_used=True):
4190    """Return a unique operation name for `name`.
4191
4192    Note: You rarely need to call `unique_name()` directly.  Most of
4193    the time you just need to create `with g.name_scope()` blocks to
4194    generate structured names.
4195
4196    `unique_name` is used to generate structured names, separated by
4197    `"/"`, to help identify operations when debugging a graph.
4198    Operation names are displayed in error messages reported by the
4199    TensorFlow runtime, and in various visualization tools such as
4200    TensorBoard.
4201
4202    If `mark_as_used` is set to `True`, which is the default, a new
4203    unique name is created and marked as in use. If it's set to `False`,
4204    the unique name is returned without actually being marked as used.
4205    This is useful when the caller simply wants to know what the name
4206    to be created will be.
4207
4208    Args:
4209      name: The name for an operation.
4210      mark_as_used: Whether to mark this name as being used.
4211
4212    Returns:
4213      A string to be passed to `create_op()` that will be used
4214      to name the operation being created.
4215    """
4216    if self._name_stack:
4217      name = self._name_stack + "/" + name
4218
4219    # For the sake of checking for names in use, we treat names as case
4220    # insensitive (e.g. foo = Foo).
4221    name_key = name.lower()
4222    i = self._names_in_use.get(name_key, 0)
4223    # Increment the number for "name_key".
4224    if mark_as_used:
4225      self._names_in_use[name_key] = i + 1
4226    if i > 0:
4227      base_name_key = name_key
4228      # Make sure the composed name key is not already used.
4229      while name_key in self._names_in_use:
4230        name_key = "%s_%d" % (base_name_key, i)
4231        i += 1
4232      # Mark the composed name_key as used in case someone wants
4233      # to call unique_name("name_1").
4234      if mark_as_used:
4235        self._names_in_use[name_key] = 1
4236
4237      # Return the new name with the original capitalization of the given name.
4238      name = "%s_%d" % (name, i-1)
4239    return name
4240
4241  def get_name_scope(self):
4242    """Returns the current name scope.
4243
4244    For example:
4245
4246    ```python
4247    with tf.name_scope('scope1'):
4248      with tf.name_scope('scope2'):
4249        print(tf.get_default_graph().get_name_scope())
4250    ```
4251    would print the string `scope1/scope2`.
4252
4253    Returns:
4254      A string representing the current name scope.
4255    """
4256    return self._name_stack
4257
4258  @tf_contextlib.contextmanager
4259  def _colocate_with_for_gradient(self, op, gradient_uid,
4260                                  ignore_existing=False):
4261    with self.colocate_with(op, ignore_existing):
4262      if gradient_uid is not None and self._control_flow_context is not None:
4263        self._control_flow_context.EnterGradientColocation(op, gradient_uid)
4264        try:
4265          yield
4266        finally:
4267          self._control_flow_context.ExitGradientColocation(op, gradient_uid)
4268      else:
4269        yield
4270
4271  @tf_contextlib.contextmanager
4272  def colocate_with(self, op, ignore_existing=False):
4273    """Returns a context manager that specifies an op to colocate with.
4274
4275    Note: this function is not for public use, only for internal libraries.
4276
4277    For example:
4278
4279    ```python
4280    a = tf.Variable([1.0])
4281    with g.colocate_with(a):
4282      b = tf.constant(1.0)
4283      c = tf.add(a, b)
4284    ```
4285
4286    `b` and `c` will always be colocated with `a`, no matter where `a`
4287    is eventually placed.
4288
4289    **NOTE** Using a colocation scope resets any existing device constraints.
4290
4291    If `op` is `None` then `ignore_existing` must be `True` and the new
4292    scope resets all colocation and device constraints.
4293
4294    Args:
4295      op: The op to colocate all created ops with, or `None`.
4296      ignore_existing: If true, only applies colocation of this op within
4297        the context, rather than applying all colocation properties
4298        on the stack.  If `op` is `None`, this value must be `True`.
4299
4300    Raises:
4301      ValueError: if op is None but ignore_existing is False.
4302
4303    Yields:
4304      A context manager that specifies the op with which to colocate
4305      newly created ops.
4306    """
4307    if op is None and not ignore_existing:
4308      raise ValueError("Trying to reset colocation (op is None) but "
4309                       "ignore_existing is not True")
4310    op = _op_to_colocate_with(op)
4311
4312    # By default, colocate_with resets the device function stack,
4313    # since colocate_with is typically used in specific internal
4314    # library functions where colocation is intended to be "stronger"
4315    # than device functions.
4316    #
4317    # In the future, a caller may specify that device_functions win
4318    # over colocation, in which case we can add support.
4319    device_fn_tmp = self._device_function_stack
4320    self._device_function_stack = traceable_stack.TraceableStack()
4321
4322    if ignore_existing:
4323      current_stack = self._colocation_stack
4324      self._colocation_stack = traceable_stack.TraceableStack()
4325
4326    if op is not None:
4327      # offset refers to the stack frame used for storing code location.
4328      # We use 4, the sum of 1 to use our caller's stack frame and 3
4329      # to jump over layers of context managers above us.
4330      self._colocation_stack.push_obj(op, offset=4)
4331
4332    try:
4333      yield
4334    finally:
4335      # Restore device function stack
4336      self._device_function_stack = device_fn_tmp
4337      if op is not None:
4338        self._colocation_stack.pop_obj()
4339
4340      # Reset the colocation stack if requested.
4341      if ignore_existing:
4342        self._colocation_stack = current_stack
4343
4344  def _add_device_to_stack(self, device_name_or_function, offset=0):
4345    """Add device to stack manually, separate from a context manager."""
4346    total_offset = 1 + offset
4347    spec = _UserDeviceSpec(device_name_or_function)
4348    self._device_function_stack.push_obj(spec, offset=total_offset)
4349    return spec
4350
4351  @tf_contextlib.contextmanager
4352  def device(self, device_name_or_function):
4353    # pylint: disable=line-too-long
4354    """Returns a context manager that specifies the default device to use.
4355
4356    The `device_name_or_function` argument may either be a device name
4357    string, a device function, or None:
4358
4359    * If it is a device name string, all operations constructed in
4360      this context will be assigned to the device with that name, unless
4361      overridden by a nested `device()` context.
4362    * If it is a function, it will be treated as a function from
4363      Operation objects to device name strings, and invoked each time
4364      a new Operation is created. The Operation will be assigned to
4365      the device with the returned name.
4366    * If it is None, all `device()` invocations from the enclosing context
4367      will be ignored.
4368
4369    For information about the valid syntax of device name strings, see
4370    the documentation in
4371    [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
4372
4373    For example:
4374
4375    ```python
4376    with g.device('/device:GPU:0'):
4377      # All operations constructed in this context will be placed
4378      # on GPU 0.
4379      with g.device(None):
4380        # All operations constructed in this context will have no
4381        # assigned device.
4382
4383    # Defines a function from `Operation` to device string.
4384    def matmul_on_gpu(n):
4385      if n.type == "MatMul":
4386        return "/device:GPU:0"
4387      else:
4388        return "/cpu:0"
4389
4390    with g.device(matmul_on_gpu):
4391      # All operations of type "MatMul" constructed in this context
4392      # will be placed on GPU 0; all other operations will be placed
4393      # on CPU 0.
4394    ```
4395
4396    **N.B.** The device scope may be overridden by op wrappers or
4397    other library code. For example, a variable assignment op
4398    `v.assign()` must be colocated with the `tf.Variable` `v`, and
4399    incompatible device scopes will be ignored.
4400
4401    Args:
4402      device_name_or_function: The device name or function to use in
4403        the context.
4404
4405    Yields:
4406      A context manager that specifies the default device to use for newly
4407      created ops.
4408    """
4409    self._add_device_to_stack(device_name_or_function, offset=2)
4410    try:
4411      yield
4412    finally:
4413      self._device_function_stack.pop_obj()
4414
4415  def _apply_device_functions(self, op):
4416    """Applies the current device function stack to the given operation."""
4417    # Apply any device functions in LIFO order, so that the most recently
4418    # pushed function has the first chance to apply a device to the op.
4419    # We apply here because the result can depend on the Operation's
4420    # signature, which is computed in the Operation constructor.
4421    # pylint: disable=protected-access
4422    for device_spec in self._device_function_stack.peek_objs():
4423      if device_spec.function is None:
4424        break
4425      op._set_device(device_spec.function(op))
4426    op._device_code_locations = self._snapshot_device_function_stack_metadata()
4427    # pylint: enable=protected-access
4428
4429  # pylint: disable=g-doc-return-or-yield
4430  @tf_contextlib.contextmanager
4431  def container(self, container_name):
4432    """Returns a context manager that specifies the resource container to use.
4433
4434    Stateful operations, such as variables and queues, can maintain their
4435    states on devices so that they can be shared by multiple processes.
4436    A resource container is a string name under which these stateful
4437    operations are tracked. These resources can be released or cleared
4438    with `tf.Session.reset()`.
4439
4440    For example:
4441
4442    ```python
4443    with g.container('experiment0'):
4444      # All stateful Operations constructed in this context will be placed
4445      # in resource container "experiment0".
4446      v1 = tf.Variable([1.0])
4447      v2 = tf.Variable([2.0])
4448      with g.container("experiment1"):
4449        # All stateful Operations constructed in this context will be
4450        # placed in resource container "experiment1".
4451        v3 = tf.Variable([3.0])
4452        q1 = tf.FIFOQueue(10, tf.float32)
4453      # All stateful Operations constructed in this context will be
4454      # be created in the "experiment0".
4455      v4 = tf.Variable([4.0])
4456      q1 = tf.FIFOQueue(20, tf.float32)
4457      with g.container(""):
4458        # All stateful Operations constructed in this context will be
4459        # be placed in the default resource container.
4460        v5 = tf.Variable([5.0])
4461        q3 = tf.FIFOQueue(30, tf.float32)
4462
4463    # Resets container "experiment0", after which the state of v1, v2, v4, q1
4464    # will become undefined (such as uninitialized).
4465    tf.Session.reset(target, ["experiment0"])
4466    ```
4467
4468    Args:
4469      container_name: container name string.
4470
4471    Returns:
4472      A context manager for defining resource containers for stateful ops,
4473        yields the container name.
4474    """
4475    original_container = self._container
4476    self._container = container_name
4477    try:
4478      yield self._container
4479    finally:
4480      self._container = original_container
4481
4482  # pylint: enable=g-doc-return-or-yield
4483
4484  class _ControlDependenciesController(object):
4485    """Context manager for `control_dependencies()`."""
4486
4487    def __init__(self, graph, control_inputs):
4488      """Create a new `_ControlDependenciesController`.
4489
4490      A `_ControlDependenciesController` is the context manager for
4491      `with tf.control_dependencies()` blocks.  These normally nest,
4492      as described in the documentation for `control_dependencies()`.
4493
4494      The `control_inputs` argument list control dependencies that must be
4495      added to the current set of control dependencies.  Because of
4496      uniquification the set can be empty even if the caller passed a list of
4497      ops.  The special value `None` indicates that we want to start a new
4498      empty set of control dependencies instead of extending the current set.
4499
4500      In that case we also clear the current control flow context, which is an
4501      additional mechanism to add control dependencies.
4502
4503      Args:
4504        graph: The graph that this controller is managing.
4505        control_inputs: List of ops to use as control inputs in addition
4506          to the current control dependencies.  None to indicate that
4507          the dependencies should be cleared.
4508      """
4509      self._graph = graph
4510      if control_inputs is None:
4511        self._control_inputs_val = []
4512        self._new_stack = True
4513      else:
4514        self._control_inputs_val = control_inputs
4515        self._new_stack = False
4516      self._seen_nodes = set()
4517      self._old_stack = None
4518      self._old_control_flow_context = None
4519
4520# pylint: disable=protected-access
4521
4522    def __enter__(self):
4523      if self._new_stack:
4524        # Clear the control_dependencies graph.
4525        self._old_stack = self._graph._control_dependencies_stack
4526        self._graph._control_dependencies_stack = []
4527        # Clear the control_flow_context too.
4528        self._old_control_flow_context = self._graph._get_control_flow_context()
4529        self._graph._set_control_flow_context(None)
4530      self._graph._push_control_dependencies_controller(self)
4531
4532    def __exit__(self, unused_type, unused_value, unused_traceback):
4533      self._graph._pop_control_dependencies_controller(self)
4534      if self._new_stack:
4535        self._graph._control_dependencies_stack = self._old_stack
4536        self._graph._set_control_flow_context(self._old_control_flow_context)
4537
4538# pylint: enable=protected-access
4539
4540    @property
4541    def control_inputs(self):
4542      return self._control_inputs_val
4543
4544    def add_op(self, op):
4545      self._seen_nodes.add(op)
4546
4547    def op_in_group(self, op):
4548      return op in self._seen_nodes
4549
4550  def _push_control_dependencies_controller(self, controller):
4551    self._control_dependencies_stack.append(controller)
4552
4553  def _pop_control_dependencies_controller(self, controller):
4554    assert self._control_dependencies_stack[-1] is controller
4555    self._control_dependencies_stack.pop()
4556
4557  def _current_control_dependencies(self):
4558    ret = set()
4559    for controller in self._control_dependencies_stack:
4560      for op in controller.control_inputs:
4561        ret.add(op)
4562    return ret
4563
4564  def _control_dependencies_for_inputs(self, input_ops):
4565    """For an op that takes `input_ops` as inputs, compute control inputs.
4566
4567    The returned control dependencies should yield an execution that
4568    is equivalent to adding all control inputs in
4569    self._control_dependencies_stack to a newly created op. However,
4570    this function attempts to prune the returned control dependencies
4571    by observing that nodes created within the same `with
4572    control_dependencies(...):` block may have data dependencies that make
4573    the explicit approach redundant.
4574
4575    Args:
4576      input_ops: The data input ops for an op to be created.
4577
4578    Returns:
4579      A list of control inputs for the op to be created.
4580    """
4581    ret = []
4582    for controller in self._control_dependencies_stack:
4583      # If any of the input_ops already depends on the inputs from controller,
4584      # we say that the new op is dominated (by that input), and we therefore
4585      # do not need to add control dependencies for this controller's inputs.
4586      dominated = False
4587      for op in input_ops:
4588        if controller.op_in_group(op):
4589          dominated = True
4590          break
4591      if not dominated:
4592        # Don't add a control input if we already have a data dependency on i.
4593        # NOTE(mrry): We do not currently track transitive data dependencies,
4594        #   so we may add redundant control inputs.
4595        ret.extend([c for c in controller.control_inputs if c not in input_ops])
4596    return ret
4597
4598  def _record_op_seen_by_control_dependencies(self, op):
4599    """Record that the given op depends on all registered control dependencies.
4600
4601    Args:
4602      op: An Operation.
4603    """
4604    for controller in self._control_dependencies_stack:
4605      controller.add_op(op)
4606
4607  def control_dependencies(self, control_inputs):
4608    """Returns a context manager that specifies control dependencies.
4609
4610    Use with the `with` keyword to specify that all operations constructed
4611    within the context should have control dependencies on
4612    `control_inputs`. For example:
4613
4614    ```python
4615    with g.control_dependencies([a, b, c]):
4616      # `d` and `e` will only run after `a`, `b`, and `c` have executed.
4617      d = ...
4618      e = ...
4619    ```
4620
4621    Multiple calls to `control_dependencies()` can be nested, and in
4622    that case a new `Operation` will have control dependencies on the union
4623    of `control_inputs` from all active contexts.
4624
4625    ```python
4626    with g.control_dependencies([a, b]):
4627      # Ops constructed here run after `a` and `b`.
4628      with g.control_dependencies([c, d]):
4629        # Ops constructed here run after `a`, `b`, `c`, and `d`.
4630    ```
4631
4632    You can pass None to clear the control dependencies:
4633
4634    ```python
4635    with g.control_dependencies([a, b]):
4636      # Ops constructed here run after `a` and `b`.
4637      with g.control_dependencies(None):
4638        # Ops constructed here run normally, not waiting for either `a` or `b`.
4639        with g.control_dependencies([c, d]):
4640          # Ops constructed here run after `c` and `d`, also not waiting
4641          # for either `a` or `b`.
4642    ```
4643
4644    *N.B.* The control dependencies context applies *only* to ops that
4645    are constructed within the context. Merely using an op or tensor
4646    in the context does not add a control dependency. The following
4647    example illustrates this point:
4648
4649    ```python
4650    # WRONG
4651    def my_func(pred, tensor):
4652      t = tf.matmul(tensor, tensor)
4653      with tf.control_dependencies([pred]):
4654        # The matmul op is created outside the context, so no control
4655        # dependency will be added.
4656        return t
4657
4658    # RIGHT
4659    def my_func(pred, tensor):
4660      with tf.control_dependencies([pred]):
4661        # The matmul op is created in the context, so a control dependency
4662        # will be added.
4663        return tf.matmul(tensor, tensor)
4664    ```
4665
4666    Also note that though execution of ops created under this scope will trigger
4667    execution of the dependencies, the ops created under this scope might still
4668    be pruned from a normal tensorflow graph. For example, in the following
4669    snippet of code the dependencies are never executed:
4670
4671    ```python
4672      loss = model.loss()
4673      with tf.control_dependencies(dependencies):
4674        loss = loss + tf.constant(1)  # note: dependencies ignored in the
4675                                      # backward pass
4676      return tf.gradients(loss, model.variables)
4677    ```
4678
4679    This is because evaluating the gradient graph does not require evaluating
4680    the constant(1) op created in the forward pass.
4681
4682    Args:
4683      control_inputs: A list of `Operation` or `Tensor` objects which
4684        must be executed or computed before running the operations
4685        defined in the context.  Can also be `None` to clear the control
4686        dependencies.
4687
4688    Returns:
4689     A context manager that specifies control dependencies for all
4690     operations constructed within the context.
4691
4692    Raises:
4693      TypeError: If `control_inputs` is not a list of `Operation` or
4694        `Tensor` objects.
4695    """
4696    if control_inputs is None:
4697      return self._ControlDependenciesController(self, None)
4698    # First convert the inputs to ops, and deduplicate them.
4699    # NOTE(mrry): Other than deduplication, we do not currently track direct
4700    #   or indirect dependencies between control_inputs, which may result in
4701    #   redundant control inputs.
4702    control_ops = []
4703    current = self._current_control_dependencies()
4704    for c in control_inputs:
4705      # The hasattr(handle) is designed to match ResourceVariables. This is so
4706      # control dependencies on a variable or on an unread variable don't
4707      # trigger reads.
4708      if (isinstance(c, IndexedSlices) or
4709          (hasattr(c, "_handle") and hasattr(c, "op"))):
4710        c = c.op
4711      c = self.as_graph_element(c)
4712      if isinstance(c, Tensor):
4713        c = c.op
4714      elif not isinstance(c, Operation):
4715        raise TypeError("Control input must be Operation or Tensor: %s" % c)
4716      if c not in current:
4717        control_ops.append(c)
4718        current.add(c)
4719    return self._ControlDependenciesController(self, control_ops)
4720
4721  # pylint: disable=g-doc-return-or-yield
4722  @tf_contextlib.contextmanager
4723  def _attr_scope(self, attr_map):
4724    """EXPERIMENTAL: A context manager for setting attributes on operators.
4725
4726    This context manager can be used to add additional
4727    attributes to operators within the scope of the context.
4728
4729    For example:
4730
4731       with ops.Graph().as_default() as g:
4732         f_1 = Foo()  # No extra attributes
4733         with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
4734           f_2 = Foo()  # Additional attribute _a=False
4735           with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
4736             f_3 = Foo()  # Additional attribute _a=False
4737             with g._attr_scope({"_a": None}):
4738               f_4 = Foo()  # No additional attributes.
4739
4740    Args:
4741      attr_map: A dictionary mapping attr name strings to
4742        AttrValue protocol buffers or None.
4743
4744    Returns:
4745      A context manager that sets the kernel label to be used for one or more
4746      ops created in that context.
4747
4748    Raises:
4749      TypeError: If attr_map is not a dictionary mapping
4750        strings to AttrValue protobufs.
4751    """
4752    if not isinstance(attr_map, dict):
4753      raise TypeError("attr_map must be a dictionary mapping "
4754                      "strings to AttrValue protocol buffers")
4755    # The saved_attrs dictionary stores any currently-set labels that
4756    # will be overridden by this context manager.
4757    saved_attrs = {}
4758    # Install the given attribute
4759    for name, attr in attr_map.items():
4760      if not (isinstance(name, six.string_types) and
4761              (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
4762               callable(attr))):
4763        raise TypeError("attr_map must be a dictionary mapping "
4764                        "strings to AttrValue protocol buffers or "
4765                        "callables that emit AttrValue protocol buffers")
4766      try:
4767        saved_attrs[name] = self._attr_scope_map[name]
4768      except KeyError:
4769        pass
4770      if attr is None:
4771        del self._attr_scope_map[name]
4772      else:
4773        self._attr_scope_map[name] = attr
4774    try:
4775      yield  # The code within the context runs here.
4776    finally:
4777      # Remove the attributes set for this context, and restore any saved
4778      # attributes.
4779      for name, attr in attr_map.items():
4780        try:
4781          self._attr_scope_map[name] = saved_attrs[name]
4782        except KeyError:
4783          del self._attr_scope_map[name]
4784
4785  # pylint: enable=g-doc-return-or-yield
4786
4787  # pylint: disable=g-doc-return-or-yield
4788  @tf_contextlib.contextmanager
4789  def _kernel_label_map(self, op_to_kernel_label_map):
4790    """EXPERIMENTAL: A context manager for setting kernel labels.
4791
4792    This context manager can be used to select particular
4793    implementations of kernels within the scope of the context.
4794
4795    For example:
4796
4797        with ops.Graph().as_default() as g:
4798          f_1 = Foo()  # Uses the default registered kernel for the Foo op.
4799          with g.kernel_label_map({"Foo": "v_2"}):
4800            f_2 = Foo()  # Uses the registered kernel with label "v_2"
4801                         # for the Foo op.
4802            with g.kernel_label_map({"Foo": "v_3"}):
4803              f_3 = Foo()  # Uses the registered kernel with label "v_3"
4804                           # for the Foo op.
4805              with g.kernel_label_map({"Foo": ""}):
4806                f_4 = Foo()  # Uses the default registered kernel
4807                             # for the Foo op.
4808
4809    Args:
4810      op_to_kernel_label_map: A dictionary mapping op type strings to
4811        kernel label strings.
4812
4813    Returns:
4814      A context manager that sets the kernel label to be used for one or more
4815      ops created in that context.
4816
4817    Raises:
4818      TypeError: If op_to_kernel_label_map is not a dictionary mapping
4819        strings to strings.
4820    """
4821    if not isinstance(op_to_kernel_label_map, dict):
4822      raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4823                      "strings to strings")
4824    # The saved_labels dictionary stores any currently-set labels that
4825    # will be overridden by this context manager.
4826    saved_labels = {}
4827    # Install the given label
4828    for op_type, label in op_to_kernel_label_map.items():
4829      if not (isinstance(op_type, six.string_types) and
4830              isinstance(label, six.string_types)):
4831        raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
4832                        "strings to strings")
4833      try:
4834        saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
4835      except KeyError:
4836        pass
4837      self._op_to_kernel_label_map[op_type] = label
4838    try:
4839      yield  # The code within the context runs here.
4840    finally:
4841      # Remove the labels set for this context, and restore any saved labels.
4842      for op_type, label in op_to_kernel_label_map.items():
4843        try:
4844          self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
4845        except KeyError:
4846          del self._op_to_kernel_label_map[op_type]
4847
4848  # pylint: enable=g-doc-return-or-yield
4849
4850  # pylint: disable=g-doc-return-or-yield
4851  @tf_contextlib.contextmanager
4852  def gradient_override_map(self, op_type_map):
4853    """EXPERIMENTAL: A context manager for overriding gradient functions.
4854
4855    This context manager can be used to override the gradient function
4856    that will be used for ops within the scope of the context.
4857
4858    For example:
4859
4860    ```python
4861    @tf.RegisterGradient("CustomSquare")
4862    def _custom_square_grad(op, grad):
4863      # ...
4864
4865    with tf.Graph().as_default() as g:
4866      c = tf.constant(5.0)
4867      s_1 = tf.square(c)  # Uses the default gradient for tf.square.
4868      with g.gradient_override_map({"Square": "CustomSquare"}):
4869        s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
4870                              # gradient of s_2.
4871    ```
4872
4873    Args:
4874      op_type_map: A dictionary mapping op type strings to alternative op
4875        type strings.
4876
4877    Returns:
4878      A context manager that sets the alternative op type to be used for one
4879      or more ops created in that context.
4880
4881    Raises:
4882      TypeError: If `op_type_map` is not a dictionary mapping strings to
4883        strings.
4884    """
4885    if not isinstance(op_type_map, dict):
4886      raise TypeError("op_type_map must be a dictionary mapping "
4887                      "strings to strings")
4888    # The saved_mappings dictionary stores any currently-set mappings that
4889    # will be overridden by this context manager.
4890    saved_mappings = {}
4891    # Install the given label
4892    for op_type, mapped_op_type in op_type_map.items():
4893      if not (isinstance(op_type, six.string_types) and
4894              isinstance(mapped_op_type, six.string_types)):
4895        raise TypeError("op_type_map must be a dictionary mapping "
4896                        "strings to strings")
4897      try:
4898        saved_mappings[op_type] = self._gradient_override_map[op_type]
4899      except KeyError:
4900        pass
4901      self._gradient_override_map[op_type] = mapped_op_type
4902    try:
4903      yield  # The code within the context runs here.
4904    finally:
4905      # Remove the labels set for this context, and restore any saved labels.
4906      for op_type, mapped_op_type in op_type_map.items():
4907        try:
4908          self._gradient_override_map[op_type] = saved_mappings[op_type]
4909        except KeyError:
4910          del self._gradient_override_map[op_type]
4911
4912  # pylint: enable=g-doc-return-or-yield
4913
4914  def prevent_feeding(self, tensor):
4915    """Marks the given `tensor` as unfeedable in this graph."""
4916    self._unfeedable_tensors.add(tensor)
4917
4918  def is_feedable(self, tensor):
4919    """Returns `True` if and only if `tensor` is feedable."""
4920    return tensor not in self._unfeedable_tensors
4921
4922  def prevent_fetching(self, op):
4923    """Marks the given `op` as unfetchable in this graph."""
4924    self._unfetchable_ops.add(op)
4925
4926  def is_fetchable(self, tensor_or_op):
4927    """Returns `True` if and only if `tensor_or_op` is fetchable."""
4928    if isinstance(tensor_or_op, Tensor):
4929      return tensor_or_op.op not in self._unfetchable_ops
4930    else:
4931      return tensor_or_op not in self._unfetchable_ops
4932
4933  def switch_to_thread_local(self):
4934    """Make device, colocation and dependencies stacks thread-local.
4935
4936    Device, colocation and dependencies stacks are not thread-local be default.
4937    If multiple threads access them, then the state is shared.  This means that
4938    one thread may affect the behavior of another thread.
4939
4940    After this method is called, the stacks become thread-local.  If multiple
4941    threads access them, then the state is not shared.  Each thread uses its own
4942    value; a thread doesn't affect other threads by mutating such a stack.
4943
4944    The initial value for every thread's stack is set to the current value
4945    of the stack when `switch_to_thread_local()` was first called.
4946    """
4947    if not self._stack_state_is_thread_local:
4948      self._stack_state_is_thread_local = True
4949
4950  @property
4951  def _device_function_stack(self):
4952    if self._stack_state_is_thread_local:
4953      # This may be called from a thread where device_function_stack doesn't yet
4954      # exist.
4955      # pylint: disable=protected-access
4956      if not hasattr(self._thread_local, "_device_function_stack"):
4957        stack_copy_for_this_thread = self._graph_device_function_stack.copy()
4958        self._thread_local._device_function_stack = stack_copy_for_this_thread
4959      return self._thread_local._device_function_stack
4960      # pylint: enable=protected-access
4961    else:
4962      return self._graph_device_function_stack
4963
4964  @property
4965  def _device_functions_outer_to_inner(self):
4966    user_device_specs = self._device_function_stack.peek_objs()
4967    device_functions = [spec.function for spec in user_device_specs]
4968    device_functions_outer_to_inner = list(reversed(device_functions))
4969    return device_functions_outer_to_inner
4970
4971  def _snapshot_device_function_stack_metadata(self):
4972    """Return device function stack as a list of TraceableObjects.
4973
4974    Returns:
4975      [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
4976      member is a displayable name for the user's argument to Graph.device, and
4977      the filename and lineno members point to the code location where
4978      Graph.device was called directly or indirectly by the user.
4979    """
4980    traceable_objects = self._device_function_stack.peek_traceable_objs()
4981    snapshot = []
4982    for obj in traceable_objects:
4983      obj_copy = obj.copy_metadata()
4984      obj_copy.obj = obj.obj.display_name
4985      snapshot.append(obj_copy)
4986    return snapshot
4987
4988  @_device_function_stack.setter
4989  def _device_function_stack(self, device_function_stack):
4990    if self._stack_state_is_thread_local:
4991      # pylint: disable=protected-access
4992      self._thread_local._device_function_stack = device_function_stack
4993      # pylint: enable=protected-access
4994    else:
4995      self._graph_device_function_stack = device_function_stack
4996
4997  @property
4998  def _colocation_stack(self):
4999    """Return thread-local copy of colocation stack."""
5000    if self._stack_state_is_thread_local:
5001      # This may be called from a thread where colocation_stack doesn't yet
5002      # exist.
5003      # pylint: disable=protected-access
5004      if not hasattr(self._thread_local, "_colocation_stack"):
5005        stack_copy_for_this_thread = self._graph_colocation_stack.copy()
5006        self._thread_local._colocation_stack = stack_copy_for_this_thread
5007      return self._thread_local._colocation_stack
5008      # pylint: enable=protected-access
5009    else:
5010      return self._graph_colocation_stack
5011
5012  def _snapshot_colocation_stack_metadata(self):
5013    """Return colocation stack metadata as a dictionary."""
5014    traceable_objects = self._colocation_stack.peek_traceable_objs()
5015    return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
5016
5017  @_colocation_stack.setter
5018  def _colocation_stack(self, colocation_stack):
5019    if self._stack_state_is_thread_local:
5020      # pylint: disable=protected-access
5021      self._thread_local._colocation_stack = colocation_stack
5022      # pylint: enable=protected-access
5023    else:
5024      self._graph_colocation_stack = colocation_stack
5025
5026  @property
5027  def _control_dependencies_stack(self):
5028    if self._stack_state_is_thread_local:
5029      # This may be called from a thread where control_dependencies_stack
5030      # doesn't yet exist.
5031      if not hasattr(self._thread_local, "_control_dependencies_stack"):
5032        self._thread_local._control_dependencies_stack = (
5033            self._graph_control_dependencies_stack[:])
5034      return self._thread_local._control_dependencies_stack
5035    else:
5036      return self._graph_control_dependencies_stack
5037
5038  @_control_dependencies_stack.setter
5039  def _control_dependencies_stack(self, control_dependencies):
5040    if self._stack_state_is_thread_local:
5041      self._thread_local._control_dependencies_stack = control_dependencies
5042    else:
5043      self._graph_control_dependencies_stack = control_dependencies
5044
5045  @property
5046  def _distribution_strategy_stack(self):
5047    """A stack to maintain distribution strategy context for each thread."""
5048    if not hasattr(self._thread_local, "_distribution_strategy_stack"):
5049      self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
5050    return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
5051
5052  @_distribution_strategy_stack.setter
5053  def _distribution_strategy_stack(self, _distribution_strategy_stack):
5054    self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
5055        _distribution_strategy_stack)
5056
5057  @property
5058  def _auto_cast_variable_read_dtype(self):
5059    """The dtype that instances of `AutoCastVariable` will be casted to.
5060
5061    This is None if `AutoCastVariables` should not be casted.
5062
5063    See `AutoCastVariable` for more information.
5064
5065    Returns:
5066      The dtype that instances of `AutoCastVariable` will be casted to.
5067    """
5068    if not hasattr(self._thread_local, "_auto_cast_variable_read_dtype"):
5069      self._thread_local._auto_cast_variable_read_dtype = None  # pylint: disable=protected-access
5070    return self._thread_local._auto_cast_variable_read_dtype  # pylint: disable=protected-access
5071
5072  @_auto_cast_variable_read_dtype.setter
5073  def _auto_cast_variable_read_dtype(self, _auto_cast_variable_read_dtype):
5074    self._thread_local._auto_cast_variable_read_dtype = (  # pylint: disable=protected-access
5075        _auto_cast_variable_read_dtype)
5076
5077  @tf_contextlib.contextmanager
5078  def _enable_auto_casting_variables(self, dtype):
5079    """Context manager to automatically cast AutoCastVariables.
5080
5081    If an AutoCastVariable `var` is used under this context manager, it will be
5082    casted to `dtype` before being used.
5083
5084    See `AutoCastVariable` for more information.
5085
5086    Args:
5087      dtype: The dtype that AutoCastVariables should be casted to.
5088
5089    Yields:
5090      Nothing.
5091    """
5092    prev_read_dtype = self._auto_cast_variable_read_dtype
5093    try:
5094      self._auto_cast_variable_read_dtype = dtype
5095      yield
5096    finally:
5097      self._auto_cast_variable_read_dtype = prev_read_dtype
5098
5099  def _mutation_lock(self):
5100    """Returns a lock to guard code that creates & mutates ops.
5101
5102    See the comment for self._group_lock for more info.
5103    """
5104    return self._group_lock.group(_MUTATION_LOCK_GROUP)
5105
5106  def _session_run_lock(self):
5107    """Returns a lock to guard code for Session.run.
5108
5109    See the comment for self._group_lock for more info.
5110    """
5111    return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
5112
5113
5114# TODO(agarwal): currently device directives in an outer eager scope will not
5115# apply to inner graph mode code. Fix that.
5116
5117
5118@tf_export(v1=["device"])
5119def device(device_name_or_function):
5120  """Wrapper for `Graph.device()` using the default graph.
5121
5122  See
5123  `tf.Graph.device`
5124  for more details.
5125
5126  Args:
5127    device_name_or_function: The device name or function to use in
5128      the context.
5129
5130  Returns:
5131    A context manager that specifies the default device to use for newly
5132    created ops.
5133
5134  Raises:
5135    RuntimeError: If eager execution is enabled and a function is passed in.
5136  """
5137  if context.executing_eagerly():
5138    # TODO(agarwal): support device functions in EAGER mode.
5139    if callable(device_name_or_function):
5140      raise RuntimeError(
5141          "tf.device does not support functions when eager execution "
5142          "is enabled.")
5143    return context.device(device_name_or_function)
5144  else:
5145    return get_default_graph().device(device_name_or_function)
5146
5147
5148@tf_export("device", v1=[])
5149def device_v2(device_name):
5150  """Specifies the device for ops created/executed in this context.
5151
5152  `device_name` can be fully specified, as in "/job:worker/task:1/device:cpu:0",
5153  or partially specified, containing only a subset of the "/"-separated
5154  fields. Any fields which are specified override device annotations from outer
5155  scopes. For example:
5156
5157  with tf.device('/job:foo'):
5158    # ops created here have devices with /job:foo
5159    with tf.device('/job:bar/task:0/device:gpu:2'):
5160      # ops created here have the fully specified device above
5161    with tf.device('/device:gpu:1'):
5162      # ops created here have the device '/job:foo/device:gpu:1'
5163
5164  Args:
5165    device_name: The device name to use in the context.
5166
5167  Returns:
5168    A context manager that specifies the default device to use for newly
5169    created ops.
5170
5171  Raises:
5172    RuntimeError: If a function is passed in.
5173  """
5174  if callable(device_name):
5175    raise RuntimeError("tf.device does not support functions.")
5176  if context.executing_eagerly():
5177    return context.device(device_name)
5178  else:
5179    return get_default_graph().device(device_name)
5180
5181
5182@tf_export(v1=["container"])
5183def container(container_name):
5184  """Wrapper for `Graph.container()` using the default graph.
5185
5186  Args:
5187    container_name: The container string to use in the context.
5188
5189  Returns:
5190    A context manager that specifies the default container to use for newly
5191    created stateful ops.
5192  """
5193  return get_default_graph().container(container_name)
5194
5195
5196def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
5197  if context.executing_eagerly():
5198    if op is not None:
5199      if not hasattr(op, "device"):
5200        op = internal_convert_to_tensor_or_indexed_slices(op)
5201      return device(op.device)
5202    else:
5203      return NullContextmanager()
5204  else:
5205    default_graph = get_default_graph()
5206    if isinstance(op, EagerTensor):
5207      if default_graph.building_function:
5208        return default_graph.device(op.device)
5209      else:
5210        raise ValueError("Encountered an Eager-defined Tensor during graph "
5211                         "construction, but a function was not being built.")
5212    return default_graph._colocate_with_for_gradient(
5213        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
5214
5215
5216# Internal interface to colocate_with. colocate_with has been deprecated from
5217# public API. There are still a few internal uses of colocate_with. Add internal
5218# only API for those uses to avoid deprecation warning.
5219def colocate_with(op, ignore_existing=False):
5220  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
5221
5222
5223@deprecation.deprecated(
5224    date=None,
5225    instructions="Colocations handled automatically by placer.")
5226@tf_export(v1=["colocate_with"])
5227def _colocate_with(op, ignore_existing=False):
5228  return colocate_with(op, ignore_existing)
5229
5230
5231@tf_export("control_dependencies")
5232def control_dependencies(control_inputs):
5233  """Wrapper for `Graph.control_dependencies()` using the default graph.
5234
5235  See `tf.Graph.control_dependencies`
5236  for more details.
5237
5238  When eager execution is enabled, any callable object in the `control_inputs`
5239  list will be called.
5240
5241  Args:
5242    control_inputs: A list of `Operation` or `Tensor` objects which
5243      must be executed or computed before running the operations
5244      defined in the context.  Can also be `None` to clear the control
5245      dependencies. If eager execution is enabled, any callable object in the
5246      `control_inputs` list will be called.
5247
5248  Returns:
5249   A context manager that specifies control dependencies for all
5250   operations constructed within the context.
5251  """
5252  if context.executing_eagerly():
5253    if control_inputs:
5254      # Excute any pending callables.
5255      for control in control_inputs:
5256        if callable(control):
5257          control()
5258    return NullContextmanager()
5259  else:
5260    return get_default_graph().control_dependencies(control_inputs)
5261
5262
5263class _DefaultStack(threading.local):
5264  """A thread-local stack of objects for providing implicit defaults."""
5265
5266  def __init__(self):
5267    super(_DefaultStack, self).__init__()
5268    self._enforce_nesting = True
5269    self.stack = []
5270
5271  def get_default(self):
5272    return self.stack[-1] if len(self.stack) >= 1 else None
5273
5274  def reset(self):
5275    self.stack = []
5276
5277  def is_cleared(self):
5278    return not self.stack
5279
5280  @property
5281  def enforce_nesting(self):
5282    return self._enforce_nesting
5283
5284  @enforce_nesting.setter
5285  def enforce_nesting(self, value):
5286    self._enforce_nesting = value
5287
5288  @tf_contextlib.contextmanager
5289  def get_controller(self, default):
5290    """A context manager for manipulating a default stack."""
5291    self.stack.append(default)
5292    try:
5293      yield default
5294    finally:
5295      # stack may be empty if reset() was called
5296      if self.stack:
5297        if self._enforce_nesting:
5298          if self.stack[-1] is not default:
5299            raise AssertionError(
5300                "Nesting violated for default stack of %s objects" %
5301                type(default))
5302          self.stack.pop()
5303        else:
5304          self.stack.remove(default)
5305
5306
5307_default_session_stack = _DefaultStack()  # pylint: disable=protected-access
5308
5309
5310def default_session(session):
5311  """Python "with" handler for defining a default session.
5312
5313  This function provides a means of registering a session for handling
5314  Tensor.eval() and Operation.run() calls. It is primarily intended for use
5315  by session.Session, but can be used with any object that implements
5316  the Session.run() interface.
5317
5318  Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
5319  invocations within the scope of a block should be executed by a particular
5320  session.
5321
5322  The default session applies to the current thread only, so it is always
5323  possible to inspect the call stack and determine the scope of a default
5324  session. If you create a new thread, and wish to use the default session
5325  in that thread, you must explicitly add a "with ops.default_session(sess):"
5326  block in that thread's function.
5327
5328  Example:
5329    The following code examples are equivalent:
5330
5331    # 1. Using the Session object directly:
5332    sess = ...
5333    c = tf.constant(5.0)
5334    sess.run(c)
5335
5336    # 2. Using default_session():
5337    sess = ...
5338    with ops.default_session(sess):
5339      c = tf.constant(5.0)
5340      result = c.eval()
5341
5342    # 3. Overriding default_session():
5343    sess = ...
5344    with ops.default_session(sess):
5345      c = tf.constant(5.0)
5346      with ops.default_session(...):
5347        c.eval(session=sess)
5348
5349  Args:
5350    session: The session to be installed as the default session.
5351
5352  Returns:
5353    A context manager for the default session.
5354  """
5355  return _default_session_stack.get_controller(session)
5356
5357
5358@tf_export(v1=["get_default_session"])
5359def get_default_session():
5360  """Returns the default session for the current thread.
5361
5362  The returned `Session` will be the innermost session on which a
5363  `Session` or `Session.as_default()` context has been entered.
5364
5365  NOTE: The default session is a property of the current thread. If you
5366  create a new thread, and wish to use the default session in that
5367  thread, you must explicitly add a `with sess.as_default():` in that
5368  thread's function.
5369
5370  Returns:
5371    The default `Session` being used in the current thread.
5372  """
5373  return _default_session_stack.get_default()
5374
5375
5376def _eval_using_default_session(tensors, feed_dict, graph, session=None):
5377  """Uses the default session to evaluate one or more tensors.
5378
5379  Args:
5380    tensors: A single Tensor, or a list of Tensor objects.
5381    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5382      numpy ndarrays, TensorProtos, or strings.
5383    graph: The graph in which the tensors are defined.
5384    session: (Optional) A different session to use to evaluate "tensors".
5385
5386  Returns:
5387    Either a single numpy ndarray if "tensors" is a single tensor; or a list
5388    of numpy ndarrays that each correspond to the respective element in
5389    "tensors".
5390
5391  Raises:
5392    ValueError: If no default session is available; the default session
5393      does not have "graph" as its graph; or if "session" is specified,
5394      and it does not have "graph" as its graph.
5395  """
5396  if session is None:
5397    session = get_default_session()
5398    if session is None:
5399      raise ValueError("Cannot evaluate tensor using `eval()`: No default "
5400                       "session is registered. Use `with "
5401                       "sess.as_default()` or pass an explicit session to "
5402                       "`eval(session=sess)`")
5403    if session.graph is not graph:
5404      raise ValueError("Cannot use the default session to evaluate tensor: "
5405                       "the tensor's graph is different from the session's "
5406                       "graph. Pass an explicit session to "
5407                       "`eval(session=sess)`.")
5408  else:
5409    if session.graph is not graph:
5410      raise ValueError("Cannot use the given session to evaluate tensor: "
5411                       "the tensor's graph is different from the session's "
5412                       "graph.")
5413  return session.run(tensors, feed_dict)
5414
5415
5416def _run_using_default_session(operation, feed_dict, graph, session=None):
5417  """Uses the default session to run "operation".
5418
5419  Args:
5420    operation: The Operation to be run.
5421    feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
5422      numpy ndarrays, TensorProtos, or strings.
5423    graph: The graph in which "operation" is defined.
5424    session: (Optional) A different session to use to run "operation".
5425
5426  Raises:
5427    ValueError: If no default session is available; the default session
5428      does not have "graph" as its graph; or if "session" is specified,
5429      and it does not have "graph" as its graph.
5430  """
5431  if session is None:
5432    session = get_default_session()
5433    if session is None:
5434      raise ValueError("Cannot execute operation using `run()`: No default "
5435                       "session is registered. Use `with "
5436                       "sess.as_default():` or pass an explicit session to "
5437                       "`run(session=sess)`")
5438    if session.graph is not graph:
5439      raise ValueError("Cannot use the default session to execute operation: "
5440                       "the operation's graph is different from the "
5441                       "session's graph. Pass an explicit session to "
5442                       "run(session=sess).")
5443  else:
5444    if session.graph is not graph:
5445      raise ValueError("Cannot use the given session to execute operation: "
5446                       "the operation's graph is different from the session's "
5447                       "graph.")
5448  session.run(operation, feed_dict)
5449
5450
5451class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
5452  """A thread-local stack of objects for providing an implicit default graph."""
5453
5454  def __init__(self):
5455    super(_DefaultGraphStack, self).__init__()
5456    self._global_default_graph = None
5457
5458  def get_default(self):
5459    """Override that returns a global default if the stack is empty."""
5460    ret = super(_DefaultGraphStack, self).get_default()
5461    if ret is None:
5462      ret = self._GetGlobalDefaultGraph()
5463    return ret
5464
5465  def _GetGlobalDefaultGraph(self):
5466    if self._global_default_graph is None:
5467      # TODO(mrry): Perhaps log that the default graph is being used, or set
5468      #   provide some other feedback to prevent confusion when a mixture of
5469      #   the global default graph and an explicit graph are combined in the
5470      #   same process.
5471      self._global_default_graph = Graph()
5472    return self._global_default_graph
5473
5474  def reset(self):
5475    super(_DefaultGraphStack, self).reset()
5476    self._global_default_graph = None
5477
5478  @tf_contextlib.contextmanager
5479  def get_controller(self, default):
5480    context.context().context_switches.push(
5481        default.building_function, default.as_default,
5482        default._device_function_stack)
5483    try:
5484      with super(_DefaultGraphStack, self).get_controller(
5485          default) as g, context.graph_mode():
5486        yield g
5487    finally:
5488      # If an exception is raised here it may be hiding a related exception in
5489      # the try-block (just above).
5490      context.context().context_switches.pop()
5491
5492
5493_default_graph_stack = _DefaultGraphStack()
5494
5495
5496# pylint: disable=g-doc-return-or-yield,line-too-long
5497@tf_export("init_scope")
5498@tf_contextlib.contextmanager
5499def init_scope():
5500  """A context manager that lifts ops out of control-flow scopes and function-building graphs.
5501
5502  There is often a need to lift variable initialization ops out of control-flow
5503  scopes, function-building graphs, and gradient tapes. Entering an
5504  `init_scope` is a mechanism for satisfying these desiderata. In particular,
5505  entering an `init_scope` has three effects:
5506
5507    (1) All control dependencies are cleared the moment the scope is entered;
5508        this is equivalent to entering the context manager returned from
5509        `control_dependencies(None)`, which has the side-effect of exiting
5510        control-flow scopes like `tf.cond` and `tf.while_loop`.
5511
5512    (2) All operations that are created while the scope is active are lifted
5513        into the lowest context on the `context_stack` that is not building a
5514        graph function. Here, a context is defined as either a graph or an eager
5515        context. Every context switch, i.e., every installation of a graph as
5516        the default graph and every switch into eager mode, is logged in a
5517        thread-local stack called `context_switches`; the log entry for a
5518        context switch is popped from the stack when the context is exited.
5519        Entering an `init_scope` is equivalent to crawling up
5520        `context_switches`, finding the first context that is not building a
5521        graph function, and entering it. A caveat is that if graph mode is
5522        enabled but the default graph stack is empty, then entering an
5523        `init_scope` will simply install a fresh graph as the default one.
5524
5525    (3) The gradient tape is paused while the scope is active.
5526
5527  When eager execution is enabled, code inside an init_scope block runs with
5528  eager execution enabled even when defining graph functions via
5529  tf.contrib.eager.defun. For example:
5530
5531  ```python
5532  tf.enable_eager_execution()
5533
5534  @tf.contrib.eager.defun
5535  def func():
5536    # A defun-decorated function constructs TensorFlow graphs,
5537    # it does not execute eagerly.
5538    assert not tf.executing_eagerly()
5539    with tf.init_scope():
5540      # Initialization runs with eager execution enabled
5541      assert tf.executing_eagerly()
5542  ```
5543
5544  Raises:
5545    RuntimeError: if graph state is incompatible with this initialization.
5546  """
5547  # pylint: enable=g-doc-return-or-yield,line-too-long
5548
5549  if context.executing_eagerly():
5550    # Fastpath.
5551    with tape.stop_recording():
5552      yield
5553  else:
5554    # Retrieve the active name scope: entering an `init_scope` preserves
5555    # the name scope of the current context.
5556    default_graph = get_default_graph()
5557    scope = default_graph.get_name_scope()
5558    if scope and scope[-1] != "/":
5559      # Names that end with trailing slashes are treated by `name_scope` as
5560      # absolute.
5561      scope = scope + "/"
5562    innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
5563
5564    outer_context = None
5565    if not _default_graph_stack.stack:
5566      # If the default graph stack is empty, then we cannot be building a
5567      # function. Install the global graph (which, in this case, is also the
5568      # default graph) as the outer context.
5569      if default_graph.building_function:
5570        raise RuntimeError("The global graph is building a function.")
5571      outer_context = default_graph.as_default
5572    else:
5573      # Find a context that is not building a function.
5574      for stack_entry in reversed(context.context().context_switches.stack):
5575        if not innermost_nonempty_device_stack:
5576          innermost_nonempty_device_stack = stack_entry.device_stack
5577        if not stack_entry.is_building_function:
5578          outer_context = stack_entry.enter_context_fn
5579          break
5580
5581      if outer_context is None:
5582        # As a last resort, obtain the global default graph; this graph doesn't
5583        # necessarily live on the graph stack (and hence it doesn't necessarily
5584        # live on the context stack), but it is stored in the graph stack's
5585        # encapsulating object.
5586        outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
5587
5588    if outer_context is None:
5589      # Sanity check; this shouldn't be triggered.
5590      raise RuntimeError("All graphs are building functions, and no "
5591                         "eager context was previously active.")
5592
5593    outer_graph = None
5594    outer_device_stack = None
5595    try:
5596      with outer_context(), name_scope(scope), control_dependencies(
5597          None), tape.stop_recording():
5598        context_manager = NullContextmanager
5599        context_manager_input = None
5600        if not context.executing_eagerly():
5601          # The device stack is preserved when lifting into a graph. Eager
5602          # execution doesn't implement device stacks and in particular it
5603          # doesn't support device functions, so in general it's not possible
5604          # to do the same when lifting into the eager context.
5605          outer_graph = get_default_graph()
5606          outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
5607          outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
5608        elif innermost_nonempty_device_stack is not None:
5609          for device_spec in innermost_nonempty_device_stack.peek_objs():
5610            if device_spec.function is None:
5611              break
5612            if device_spec.raw_string:
5613              context_manager = context.device
5614              context_manager_input = device_spec.raw_string
5615              break
5616            # It is currently not possible to have a device function in V2,
5617            # but in V1 we are unable to apply device functions in eager mode.
5618            # This means that we will silently skip some of the entries on the
5619            # device stack in V1 + eager mode.
5620
5621        with context_manager(context_manager_input):
5622          yield
5623    finally:
5624      # If an exception is raised here it may be hiding a related exception in
5625      # try-block (just above).
5626      if outer_graph is not None:
5627        outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
5628
5629
5630def executing_eagerly_outside_functions():
5631  """Returns True if executing eagerly, even if inside a graph function."""
5632  # Fastpath for when this is called eagerly (its not necessary to init_scope).
5633  if context.executing_eagerly():
5634    return True
5635
5636  with init_scope():
5637    return context.executing_eagerly()
5638
5639
5640def inside_function():
5641  return get_default_graph().building_function
5642
5643
5644@tf_export(v1=["enable_eager_execution"])
5645def enable_eager_execution(config=None,
5646                           device_policy=None,
5647                           execution_mode=None):
5648  """Enables eager execution for the lifetime of this program.
5649
5650  Eager execution provides an imperative interface to TensorFlow. With eager
5651  execution enabled, TensorFlow functions execute operations immediately (as
5652  opposed to adding to a graph to be executed later in a `tf.Session`) and
5653  return concrete values (as opposed to symbolic references to a node in a
5654  computational graph).
5655
5656  For example:
5657
5658  ```python
5659  tf.enable_eager_execution()
5660
5661  # After eager execution is enabled, operations are executed as they are
5662  # defined and Tensor objects hold concrete values, which can be accessed as
5663  # numpy.ndarray`s through the numpy() method.
5664  assert tf.multiply(6, 7).numpy() == 42
5665  ```
5666
5667  Eager execution cannot be enabled after TensorFlow APIs have been used to
5668  create or execute graphs. It is typically recommended to invoke this function
5669  at program startup and not in a library (as most libraries should be usable
5670  both with and without eager execution).
5671
5672  Args:
5673    config: (Optional.) A `tf.ConfigProto` to use to configure the environment
5674      in which operations are executed. Note that `tf.ConfigProto` is also
5675      used to configure graph execution (via `tf.Session`) and many options
5676      within `tf.ConfigProto` are not implemented (or are irrelevant) when
5677      eager execution is enabled.
5678    device_policy: (Optional.) Policy controlling how operations requiring
5679      inputs on a specific device (e.g., a GPU 0) handle inputs on a different
5680      device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will be
5681      picked automatically. The value picked may change between TensorFlow
5682      releases.
5683      Valid values:
5684      - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
5685        placement is not correct.
5686      - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
5687        on the right device but logs a warning.
5688      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
5689        Note that this may hide performance problems as there is no notification
5690        provided when operations are blocked on the tensor being copied between
5691        devices.
5692      - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
5693        int32 tensors, raising errors on the other ones.
5694    execution_mode: (Optional.) Policy controlling how operations dispatched are
5695      actually executed. When set to None, an appropriate value will be picked
5696      automatically. The value picked may change between TensorFlow releases.
5697      Valid values:
5698      - tf.contrib.eager.SYNC: executes each operation synchronously.
5699      - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
5700        operations may return "non-ready" handles.
5701
5702  Raises:
5703    ValueError: If eager execution is enabled after creating/executing a
5704     TensorFlow graph, or if options provided conflict with a previous call
5705     to this function.
5706  """
5707  if context.default_execution_mode != context.EAGER_MODE:
5708    return enable_eager_execution_internal(
5709        config=config,
5710        device_policy=device_policy,
5711        execution_mode=execution_mode,
5712        server_def=None)
5713
5714
5715@tf_export(v1=["disable_eager_execution"])
5716def disable_eager_execution():
5717  """Disables eager execution.
5718
5719  This function can only be called before any Graphs, Ops, or Tensors have been
5720  created. It can be used at the beginning of the program for complex migration
5721  projects from TensorFlow 1.x to 2.x.
5722  """
5723  context.default_execution_mode = context.GRAPH_MODE
5724  c = context.context_safe()
5725  if c is not None:
5726    c._thread_local_data.is_eager = False  # pylint: disable=protected-access
5727
5728
5729def enable_eager_execution_internal(config=None,
5730                                    device_policy=None,
5731                                    execution_mode=None,
5732                                    server_def=None):
5733  """Enables eager execution for the lifetime of this program.
5734
5735  Most of the doc string for enable_eager_execution is relevant here as well.
5736
5737  Args:
5738    config: See enable_eager_execution doc string
5739    device_policy: See enable_eager_execution doc string
5740    execution_mode: See enable_eager_execution doc string
5741    server_def: (Optional.) A tensorflow::ServerDef proto.
5742      Enables execution on remote devices. GrpcServers need to be started by
5743      creating an identical server_def to this, and setting the appropriate
5744      task_indexes, so that the servers can communicate. It will then be
5745      possible to execute operations on remote devices.
5746
5747  Raises:
5748    ValueError
5749
5750  """
5751  if config is not None and not isinstance(config, config_pb2.ConfigProto):
5752    raise TypeError(
5753        "config must be a tf.ConfigProto, but got %s" % type(config))
5754  if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
5755                           context.DEVICE_PLACEMENT_WARN,
5756                           context.DEVICE_PLACEMENT_SILENT,
5757                           context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
5758    raise ValueError(
5759        "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
5760    )
5761  if execution_mode not in (None, context.SYNC, context.ASYNC):
5762    raise ValueError(
5763        "execution_mode must be one of None, tf.contrib.eager.SYNC, "
5764        "tf.contrib.eager.ASYNC")
5765  if context.default_execution_mode == context.GRAPH_MODE:
5766    graph_mode_has_been_used = (
5767        _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access
5768    if graph_mode_has_been_used:
5769      raise ValueError(
5770          "tf.enable_eager_execution must be called at program startup.")
5771  context.default_execution_mode = context.EAGER_MODE
5772  # pylint: disable=protected-access
5773  if context._context is None:
5774    context._context = context.Context(
5775        config=config,
5776        device_policy=device_policy,
5777        execution_mode=execution_mode,
5778        server_def=server_def)
5779  elif ((config is not None and config is not context._context._config) or
5780        (device_policy is not None and
5781         device_policy is not context._context._device_policy) or
5782        (execution_mode is not None and
5783         execution_mode is not context._context._execution_mode)):
5784    raise ValueError("Trying to change the options of an active eager"
5785                     " execution. Context config: %s, specified config:"
5786                     " %s. Context device policy: %s, specified device"
5787                     " policy: %s. Context execution mode: %s, "
5788                     " specified execution mode %s." %
5789                     (context._context._config, config,
5790                      context._context._device_policy, device_policy,
5791                      context._context._execution_mode, execution_mode))
5792  else:
5793    raise ValueError(
5794        "tf.enable_eager_execution must be called at program startup.")
5795
5796  # Monkey patch to get rid of an unnecessary conditional since the context is
5797  # now initialized.
5798  context.context = context.context_safe
5799
5800
5801def eager_run(main=None, argv=None):
5802  """Runs the program with an optional main function and argv list.
5803
5804  The program will run with eager execution enabled.
5805
5806  Example:
5807  ```python
5808  import tensorflow as tf
5809  # Import subject to future changes:
5810  from tensorflow.contrib.eager.python import tfe
5811
5812  def main(_):
5813    u = tf.constant(6.0)
5814    v = tf.constant(7.0)
5815    print(u * v)
5816
5817  if __name__ == "__main__":
5818    tfe.run()
5819  ```
5820
5821  Args:
5822    main: the main function to run.
5823    argv: the arguments to pass to it.
5824  """
5825  enable_eager_execution()
5826  app.run(main, argv)
5827
5828
5829@tf_export(v1=["reset_default_graph"])
5830def reset_default_graph():
5831  """Clears the default graph stack and resets the global default graph.
5832
5833  NOTE: The default graph is a property of the current thread. This
5834  function applies only to the current thread.  Calling this function while
5835  a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
5836  behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
5837  after calling this function will result in undefined behavior.
5838  Raises:
5839    AssertionError: If this function is called within a nested graph.
5840  """
5841  if not _default_graph_stack.is_cleared():
5842    raise AssertionError("Do not use tf.reset_default_graph() to clear "
5843                         "nested graphs. If you need a cleared graph, "
5844                         "exit the nesting and create a new graph.")
5845  _default_graph_stack.reset()
5846
5847
5848@tf_export(v1=["get_default_graph"])
5849def get_default_graph():
5850  """Returns the default graph for the current thread.
5851
5852  The returned graph will be the innermost graph on which a
5853  `Graph.as_default()` context has been entered, or a global default
5854  graph if none has been explicitly created.
5855
5856  NOTE: The default graph is a property of the current thread. If you
5857  create a new thread, and wish to use the default graph in that
5858  thread, you must explicitly add a `with g.as_default():` in that
5859  thread's function.
5860
5861  Returns:
5862    The default `Graph` being used in the current thread.
5863  """
5864  return _default_graph_stack.get_default()
5865
5866def has_default_graph():
5867  """Returns True if there is a default graph."""
5868  return len(_default_graph_stack.stack) >= 1
5869
5870
5871def get_name_scope():
5872  """Returns the current name scope in the default_graph.
5873
5874  For example:
5875
5876  ```python
5877  with tf.name_scope('scope1'):
5878    with tf.name_scope('scope2'):
5879      print(tf.get_name_scope())
5880  ```
5881  would print the string `scope1/scope2`.
5882
5883  Returns:
5884    A string representing the current name scope.
5885  """
5886  if context.executing_eagerly():
5887    return context.context().scope_name.rstrip("/")
5888  return get_default_graph().get_name_scope()
5889
5890
5891def _assert_same_graph(original_item, item):
5892  """Fail if the 2 items are from different graphs.
5893
5894  Args:
5895    original_item: Original item to check against.
5896    item: Item to check.
5897
5898  Raises:
5899    ValueError: if graphs do not match.
5900  """
5901  if original_item.graph is not item.graph:
5902    raise ValueError("%s must be from the same graph as %s." % (item,
5903                                                                original_item))
5904
5905
5906def _get_graph_from_inputs(op_input_list, graph=None):
5907  """Returns the appropriate graph to use for the given inputs.
5908
5909  This library method provides a consistent algorithm for choosing the graph
5910  in which an Operation should be constructed:
5911
5912  1. If the default graph is being used to construct a function, we
5913     use the default graph.
5914  2. If the "graph" is specified explicitly, we validate that all of the inputs
5915     in "op_input_list" are compatible with that graph.
5916  3. Otherwise, we attempt to select a graph from the first Operation-
5917     or Tensor-valued input in "op_input_list", and validate that all other
5918     such inputs are in the same graph.
5919  4. If the graph was not specified and it could not be inferred from
5920     "op_input_list", we attempt to use the default graph.
5921
5922  Args:
5923    op_input_list: A list of inputs to an operation, which may include `Tensor`,
5924      `Operation`, and other objects that may be converted to a graph element.
5925    graph: (Optional) The explicit graph to use.
5926
5927  Raises:
5928    TypeError: If op_input_list is not a list or tuple, or if graph is not a
5929      Graph.
5930    ValueError: If a graph is explicitly passed and not all inputs are from it,
5931      or if the inputs are from multiple graphs, or we could not find a graph
5932      and there was no default graph.
5933
5934  Returns:
5935    The appropriate graph to use for the given inputs.
5936
5937  """
5938  if get_default_graph().building_function:
5939    return get_default_graph()
5940
5941  op_input_list = tuple(op_input_list)  # Handle generators correctly
5942  if graph and not isinstance(graph, Graph):
5943    raise TypeError("Input graph needs to be a Graph: %s" % graph)
5944
5945  # 1. We validate that all of the inputs are from the same graph. This is
5946  #    either the supplied graph parameter, or the first one selected from one
5947  #    the graph-element-valued inputs. In the latter case, we hold onto
5948  #    that input in original_graph_element so we can provide a more
5949  #    informative error if a mismatch is found.
5950  original_graph_element = None
5951  for op_input in op_input_list:
5952    # Determine if this is a valid graph_element.
5953    # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
5954    # up.
5955    graph_element = None
5956    if (isinstance(op_input, (Operation, _TensorLike)) and
5957        ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
5958      graph_element = op_input
5959    else:
5960      graph_element = _as_graph_element(op_input)
5961
5962    if graph_element is not None:
5963      if not graph:
5964        original_graph_element = graph_element
5965        graph = graph_element.graph
5966      elif original_graph_element is not None:
5967        _assert_same_graph(original_graph_element, graph_element)
5968      elif graph_element.graph is not graph:
5969        raise ValueError("%s is not from the passed-in graph." % graph_element)
5970
5971  # 2. If all else fails, we use the default graph, which is always there.
5972  return graph or get_default_graph()
5973
5974
5975@tf_export(v1=["GraphKeys"])
5976class GraphKeys(object):
5977  """Standard names to use for graph collections.
5978
5979  The standard library uses various well-known names to collect and
5980  retrieve values associated with a graph. For example, the
5981  `tf.Optimizer` subclasses default to optimizing the variables
5982  collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
5983  specified, but it is also possible to pass an explicit list of
5984  variables.
5985
5986  The following standard keys are defined:
5987
5988  * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
5989    across distributed environment (model variables are subset of these). See
5990    `tf.global_variables`
5991    for more details.
5992    Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
5993    and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
5994  * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
5995    machine. Usually used for temporarily variables, like counters.
5996    Note: use `tf.contrib.framework.local_variable` to add to this collection.
5997  * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
5998    model for inference (feed forward). Note: use
5999    `tf.contrib.framework.model_variable` to add to this collection.
6000  * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
6001    be trained by an optimizer. See
6002    `tf.trainable_variables`
6003    for more details.
6004  * `SUMMARIES`: the summary `Tensor` objects that have been created in the
6005    graph. See
6006    `tf.summary.merge_all`
6007    for more details.
6008  * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
6009    produce input for a computation. See
6010    `tf.train.start_queue_runners`
6011    for more details.
6012  * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
6013    keep moving averages.  See
6014    `tf.moving_average_variables`
6015    for more details.
6016  * `REGULARIZATION_LOSSES`: regularization losses collected during graph
6017    construction.
6018
6019  The following standard keys are _defined_, but their collections are **not**
6020  automatically populated as many of the others are:
6021
6022  * `WEIGHTS`
6023  * `BIASES`
6024  * `ACTIVATIONS`
6025  """
6026
6027  # Key to collect Variable objects that are global (shared across machines).
6028  # Default collection for all variables, except local ones.
6029  GLOBAL_VARIABLES = "variables"
6030  # Key to collect local variables that are local to the machine and are not
6031  # saved/restored.
6032  LOCAL_VARIABLES = "local_variables"
6033  # Key to collect local variables which are used to accumulate interal state
6034  # to be used in tf.metrics.*.
6035  METRIC_VARIABLES = "metric_variables"
6036  # Key to collect model variables defined by layers.
6037  MODEL_VARIABLES = "model_variables"
6038  # Key to collect Variable objects that will be trained by the
6039  # optimizers.
6040  TRAINABLE_VARIABLES = "trainable_variables"
6041  # Key to collect summaries.
6042  SUMMARIES = "summaries"
6043  # Key to collect QueueRunners.
6044  QUEUE_RUNNERS = "queue_runners"
6045  # Key to collect table initializers.
6046  TABLE_INITIALIZERS = "table_initializer"
6047  # Key to collect asset filepaths. An asset represents an external resource
6048  # like a vocabulary file.
6049  ASSET_FILEPATHS = "asset_filepaths"
6050  # Key to collect Variable objects that keep moving averages.
6051  MOVING_AVERAGE_VARIABLES = "moving_average_variables"
6052  # Key to collect regularization losses at graph construction.
6053  REGULARIZATION_LOSSES = "regularization_losses"
6054  # Key to collect concatenated sharded variables.
6055  CONCATENATED_VARIABLES = "concatenated_variables"
6056  # Key to collect savers.
6057  SAVERS = "savers"
6058  # Key to collect weights
6059  WEIGHTS = "weights"
6060  # Key to collect biases
6061  BIASES = "biases"
6062  # Key to collect activations
6063  ACTIVATIONS = "activations"
6064  # Key to collect update_ops
6065  UPDATE_OPS = "update_ops"
6066  # Key to collect losses
6067  LOSSES = "losses"
6068  # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
6069  SAVEABLE_OBJECTS = "saveable_objects"
6070  # Key to collect all shared resources used by the graph which need to be
6071  # initialized once per cluster.
6072  RESOURCES = "resources"
6073  # Key to collect all shared resources used in this graph which need to be
6074  # initialized once per session.
6075  LOCAL_RESOURCES = "local_resources"
6076  # Trainable resource-style variables.
6077  TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
6078
6079  # Key to indicate various ops.
6080  INIT_OP = "init_op"
6081  LOCAL_INIT_OP = "local_init_op"
6082  READY_OP = "ready_op"
6083  READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
6084  SUMMARY_OP = "summary_op"
6085  GLOBAL_STEP = "global_step"
6086
6087  # Used to count the number of evaluations performed during a single evaluation
6088  # run.
6089  EVAL_STEP = "eval_step"
6090  TRAIN_OP = "train_op"
6091
6092  # Key for control flow context.
6093  COND_CONTEXT = "cond_context"
6094  WHILE_CONTEXT = "while_context"
6095
6096  # Used to store v2 summary names.
6097  _SUMMARY_COLLECTION = "_SUMMARY_V2"
6098
6099  # List of all collections that keep track of variables.
6100  _VARIABLE_COLLECTIONS = [
6101      GLOBAL_VARIABLES,
6102      LOCAL_VARIABLES,
6103      METRIC_VARIABLES,
6104      MODEL_VARIABLES,
6105      TRAINABLE_VARIABLES,
6106      MOVING_AVERAGE_VARIABLES,
6107      CONCATENATED_VARIABLES,
6108      TRAINABLE_RESOURCE_VARIABLES,
6109  ]
6110
6111  # Key for streaming model ports.
6112  # NOTE(yuanbyu): internal and experimental.
6113  _STREAMING_MODEL_PORTS = "streaming_model_ports"
6114
6115  @decorator_utils.classproperty
6116  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
6117  def VARIABLES(cls):  # pylint: disable=no-self-argument
6118    return cls.GLOBAL_VARIABLES
6119
6120
6121def dismantle_graph(graph):
6122  """Cleans up reference cycles from a `Graph`.
6123
6124  Helpful for making sure the garbage collector doesn't need to run after a
6125  temporary `Graph` is no longer needed.
6126
6127  Args:
6128    graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
6129      after this function runs.
6130  """
6131  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
6132
6133  # Now clean up Operation<->Graph reference cycles by clearing all of the
6134  # attributes for the Graph and its ops.
6135  graph_operations = graph.get_operations()
6136  for op in graph_operations:
6137    op.__dict__ = {}
6138  graph.__dict__ = {}
6139
6140
6141@tf_export(v1=["add_to_collection"])
6142def add_to_collection(name, value):
6143  """Wrapper for `Graph.add_to_collection()` using the default graph.
6144
6145  See `tf.Graph.add_to_collection`
6146  for more details.
6147
6148  Args:
6149    name: The key for the collection. For example, the `GraphKeys` class
6150      contains many standard names for collections.
6151    value: The value to add to the collection.
6152
6153  @compatibility(eager)
6154  Collections are only supported in eager when variables are created inside an
6155  EagerVariableStore (e.g. as part of a layer or template).
6156  @end_compatibility
6157  """
6158  get_default_graph().add_to_collection(name, value)
6159
6160
6161@tf_export(v1=["add_to_collections"])
6162def add_to_collections(names, value):
6163  """Wrapper for `Graph.add_to_collections()` using the default graph.
6164
6165  See `tf.Graph.add_to_collections`
6166  for more details.
6167
6168  Args:
6169    names: The key for the collections. The `GraphKeys` class
6170      contains many standard names for collections.
6171    value: The value to add to the collections.
6172
6173  @compatibility(eager)
6174  Collections are only supported in eager when variables are created inside an
6175  EagerVariableStore (e.g. as part of a layer or template).
6176  @end_compatibility
6177  """
6178  get_default_graph().add_to_collections(names, value)
6179
6180
6181@tf_export(v1=["get_collection_ref"])
6182def get_collection_ref(key):
6183  """Wrapper for `Graph.get_collection_ref()` using the default graph.
6184
6185  See `tf.Graph.get_collection_ref`
6186  for more details.
6187
6188  Args:
6189    key: The key for the collection. For example, the `GraphKeys` class
6190      contains many standard names for collections.
6191
6192  Returns:
6193    The list of values in the collection with the given `name`, or an empty
6194    list if no value has been added to that collection.  Note that this returns
6195    the collection list itself, which can be modified in place to change the
6196    collection.
6197
6198  @compatibility(eager)
6199  Collections are not supported when eager execution is enabled.
6200  @end_compatibility
6201  """
6202  return get_default_graph().get_collection_ref(key)
6203
6204
6205@tf_export(v1=["get_collection"])
6206def get_collection(key, scope=None):
6207  """Wrapper for `Graph.get_collection()` using the default graph.
6208
6209  See `tf.Graph.get_collection`
6210  for more details.
6211
6212  Args:
6213    key: The key for the collection. For example, the `GraphKeys` class
6214      contains many standard names for collections.
6215    scope: (Optional.) If supplied, the resulting list is filtered to include
6216      only items whose `name` attribute matches using `re.match`. Items
6217      without a `name` attribute are never returned if a scope is supplied and
6218      the choice or `re.match` means that a `scope` without special tokens
6219      filters by prefix.
6220
6221  Returns:
6222    The list of values in the collection with the given `name`, or
6223    an empty list if no value has been added to that collection. The
6224    list contains the values in the order under which they were
6225    collected.
6226
6227  @compatibility(eager)
6228  Collections are not supported when eager execution is enabled.
6229  @end_compatibility
6230  """
6231  return get_default_graph().get_collection(key, scope)
6232
6233
6234def get_all_collection_keys():
6235  """Returns a list of collections used in the default graph."""
6236  return get_default_graph().get_all_collection_keys()
6237
6238
6239name_scope_cache = {}
6240
6241
6242# Named like a function for backwards compatibility with the
6243# @tf_contextlib.contextmanager version, which was switched to a class to avoid
6244# some object creation overhead.
6245@tf_export(v1=["name_scope"])
6246class name_scope(object):  # pylint: disable=invalid-name
6247  """A context manager for use when defining a Python op.
6248
6249  This context manager validates that the given `values` are from the
6250  same graph, makes that graph the default graph, and pushes a
6251  name scope in that graph (see
6252  `tf.Graph.name_scope`
6253  for more details on that).
6254
6255  For example, to define a new Python op called `my_op`:
6256
6257  ```python
6258  def my_op(a, b, c, name=None):
6259    with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
6260      a = tf.convert_to_tensor(a, name="a")
6261      b = tf.convert_to_tensor(b, name="b")
6262      c = tf.convert_to_tensor(c, name="c")
6263      # Define some computation that uses `a`, `b`, and `c`.
6264      return foo_op(..., name=scope)
6265  ```
6266  """
6267
6268  @property
6269  def name(self):
6270    return self._name
6271
6272  def __init__(self, name, default_name=None, values=None):
6273    """Initialize the context manager.
6274
6275    Args:
6276      name: The name argument that is passed to the op function.
6277      default_name: The default name to use if the `name` argument is `None`.
6278      values: The list of `Tensor` arguments that are passed to the op function.
6279
6280    Raises:
6281      TypeError: if `default_name` is passed in but not a string.
6282    """
6283    if not (default_name is None or isinstance(default_name, six.string_types)):
6284      raise TypeError(
6285          "`default_name` type (%s) is not a string type. You likely meant to "
6286          "pass this into the `values` kwarg."
6287          % type(default_name))
6288    self._name = default_name if name is None else name
6289    self._default_name = default_name
6290    self._values = values
6291    self._ctx = context.context()
6292    self._in_eager_mode = self._ctx.executing_eagerly()
6293    self._has_symbolic_input_in_eager = False
6294    if self._values and self._in_eager_mode:
6295      # The presence of a graph tensor in `self._values` overrides the context.
6296      for value in self._values:
6297        if hasattr(value, "graph"):
6298          self._has_symbolic_input_in_eager = True
6299          self._name_scope = value.graph.name_scope(self._name)
6300
6301  def __enter__(self):
6302    """Start the scope block.
6303
6304    Returns:
6305      The scope name.
6306
6307    Raises:
6308      ValueError: if neither `name` nor `default_name` is provided
6309        but `values` are.
6310    """
6311    if self._has_symbolic_input_in_eager:
6312      return self._name_scope.__enter__()
6313
6314    if self._in_eager_mode:
6315      self._old_name = self._ctx.scope_name
6316      if not self._name:
6317        scope_name = ""
6318      else:
6319        cache_key = self._name, self._old_name, self._default_name
6320        if cache_key in name_scope_cache:
6321          self._ctx.scope_name = name_scope_cache[cache_key]
6322          return self._ctx.scope_name
6323        elif self._name[-1] == "/":
6324          # A trailing slash breaks out of nested name scopes, indicating a
6325          # fully specified scope name, for compatibility with Graph.name_scope.
6326          scope_name = self._name
6327        else:
6328          name_with_trailing_slash = self._name + "/"
6329          scope_name = (
6330              self._old_name + name_with_trailing_slash
6331              if self._old_name else name_with_trailing_slash)
6332        name_scope_cache[cache_key] = scope_name
6333      self._ctx.scope_name = scope_name
6334      return scope_name
6335    else:
6336      if self._name is None and self._values is not None:
6337        # We only raise an error if values is not None (provided) because
6338        # currently tf.name_scope(None) (values=None then) is sometimes used as
6339        # an idiom to reset to top scope.
6340        raise ValueError(
6341            "At least one of name (%s) and default_name (%s) must be provided."
6342            % (self._name, self._default_name))
6343      if self._values is None:
6344        self._values = []
6345      g = _get_graph_from_inputs(self._values)
6346      self._g_manager = g.as_default()
6347      self._g_manager.__enter__()
6348      try:
6349        self._name_scope = g.name_scope(self._name)
6350        return self._name_scope.__enter__()
6351      except:
6352        self._g_manager.__exit__(*sys.exc_info())
6353        raise
6354
6355  def __exit__(self, type_arg, value_arg, traceback_arg):
6356    if self._has_symbolic_input_in_eager:
6357      self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
6358    elif self._in_eager_mode:
6359      self._ctx.scope_name = self._old_name
6360    else:
6361      self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
6362      self._g_manager.__exit__(type_arg, value_arg, traceback_arg)
6363    return False  # False values do not suppress exceptions
6364
6365
6366@tf_export("name_scope", v1=[])
6367class name_scope_v2(name_scope):
6368  """A context manager for use when defining a Python op.
6369
6370  This context manager pushes a name scope, which will make the name of all
6371  operations added within it have a prefix.
6372
6373  For example, to define a new Python op called `my_op`:
6374
6375  ```python
6376  def my_op(a, b, c, name=None):
6377    with tf.name_scope("MyOp") as scope:
6378      a = tf.convert_to_tensor(a, name="a")
6379      b = tf.convert_to_tensor(b, name="b")
6380      c = tf.convert_to_tensor(c, name="c")
6381      # Define some computation that uses `a`, `b`, and `c`.
6382      return foo_op(..., name=scope)
6383  ```
6384
6385  When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
6386  and `MyOp/c`.
6387
6388  If the scope name already exists, the name will be made unique by appending
6389  `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`,
6390  etc.
6391  """
6392
6393  def __init__(self, name):
6394    """Initialize the context manager.
6395
6396    Args:
6397      name: The prefix to use on all names created within the name scope.
6398
6399    Raises:
6400      ValueError: If name is None, or not a string.
6401    """
6402    if name is None or not isinstance(name, six.string_types):
6403      raise ValueError("name for name_scope must be a string.")
6404    super(name_scope_v2, self).__init__(name=None, default_name=name)
6405
6406
6407def strip_name_scope(name, export_scope):
6408  """Removes name scope from a name.
6409
6410  Args:
6411    name: A `string` name.
6412    export_scope: Optional `string`. Name scope to remove.
6413
6414  Returns:
6415    Name with name scope removed, or the original name if export_scope
6416    is None.
6417  """
6418  if export_scope:
6419    if export_scope[-1] == "/":
6420      export_scope = export_scope[:-1]
6421
6422    try:
6423      # Strips export_scope/, export_scope///,
6424      # ^export_scope/, loc:@export_scope/.
6425      str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
6426      return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
6427    except TypeError as e:
6428      # If the name is not of a type we can process, simply return it.
6429      logging.warning(e)
6430      return name
6431  else:
6432    return name
6433
6434
6435def prepend_name_scope(name, import_scope):
6436  """Prepends name scope to a name.
6437
6438  Args:
6439    name: A `string` name.
6440    import_scope: Optional `string`. Name scope to add.
6441
6442  Returns:
6443    Name with name scope added, or the original name if import_scope
6444    is None.
6445  """
6446  if import_scope:
6447    if import_scope[-1] == "/":
6448      import_scope = import_scope[:-1]
6449
6450    try:
6451      str_to_replace = r"([\^]|loc:@|^)(.*)"
6452      return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
6453                    compat.as_str(name))
6454    except TypeError as e:
6455      # If the name is not of a type we can process, simply return it.
6456      logging.warning(e)
6457      return name
6458  else:
6459    return name
6460
6461
6462# pylint: disable=g-doc-return-or-yield
6463# pylint: disable=not-context-manager
6464@tf_export(v1=["op_scope"])
6465@tf_contextlib.contextmanager
6466def op_scope(values, name, default_name=None):
6467  """DEPRECATED. Same as name_scope above, just different argument order."""
6468  logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
6469               " use tf.name_scope(name, default_name, values)")
6470  with name_scope(name, default_name=default_name, values=values) as scope:
6471    yield scope
6472
6473
6474_proto_function_registry = registry.Registry("proto functions")
6475
6476
6477def register_proto_function(collection_name,
6478                            proto_type=None,
6479                            to_proto=None,
6480                            from_proto=None):
6481  """Registers `to_proto` and `from_proto` functions for collection_name.
6482
6483  `to_proto` function converts a Python object to the corresponding protocol
6484  buffer, and returns the protocol buffer.
6485
6486  `from_proto` function converts protocol buffer into a Python object, and
6487  returns the object..
6488
6489  Args:
6490    collection_name: Name of the collection.
6491    proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
6492      `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
6493    to_proto: Function that implements Python object to protobuf conversion.
6494    from_proto: Function that implements protobuf to Python object conversion.
6495  """
6496  if to_proto and not callable(to_proto):
6497    raise TypeError("to_proto must be callable.")
6498  if from_proto and not callable(from_proto):
6499    raise TypeError("from_proto must be callable.")
6500
6501  _proto_function_registry.register((proto_type, to_proto, from_proto),
6502                                    collection_name)
6503
6504
6505def get_collection_proto_type(collection_name):
6506  """Returns the proto_type for collection_name."""
6507  try:
6508    return _proto_function_registry.lookup(collection_name)[0]
6509  except LookupError:
6510    return None
6511
6512
6513def get_to_proto_function(collection_name):
6514  """Returns the to_proto function for collection_name."""
6515  try:
6516    return _proto_function_registry.lookup(collection_name)[1]
6517  except LookupError:
6518    return None
6519
6520
6521def get_from_proto_function(collection_name):
6522  """Returns the from_proto function for collection_name."""
6523  try:
6524    return _proto_function_registry.lookup(collection_name)[2]
6525  except LookupError:
6526    return None
6527
6528
6529def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
6530  """Produce a nice error if someone converts an Operation to a Tensor."""
6531  raise TypeError(("Can't convert Operation '%s' to Tensor "
6532                   "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
6533                                                               name, as_ref))
6534
6535
6536def _op_to_colocate_with(v):
6537  """Operation object corresponding to v to use for colocation constraints."""
6538  if v is None:
6539    return None
6540  if isinstance(v, Operation):
6541    return v
6542  # We always want to colocate with the reference op.
6543  # When 'v' is a ResourceVariable, the reference op is the handle creating op.
6544  #
6545  # What this should be is:
6546  # if isinstance(v, ResourceVariable):
6547  #   return v.handle.op
6548  # However, that would require a circular import dependency.
6549  # As of October 2018, there were attempts underway to remove
6550  # colocation constraints altogether. Assuming that will
6551  # happen soon, perhaps this hack to work around the circular
6552  # import dependency is acceptable.
6553  if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
6554      v.handle.op, Operation):
6555    return v.handle.op
6556  return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
6557
6558
6559def _is_keras_symbolic_tensor(x):
6560  return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
6561
6562
6563register_tensor_conversion_function(Operation, _operation_conversion_error)
6564