1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions used to construct graphs.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23import sys 24import threading 25import types 26 27import numpy as np 28import six 29from six.moves import map # pylint: disable=redefined-builtin 30from six.moves import xrange # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.framework import function_pb2 34from tensorflow.core.framework import graph_pb2 35from tensorflow.core.framework import node_def_pb2 36from tensorflow.core.framework import op_def_pb2 37from tensorflow.core.framework import versions_pb2 38from tensorflow.core.protobuf import config_pb2 39# pywrap_tensorflow must be imported first to avoid profobuf issues. 40# (b/143110113) 41# pylint: disable=invalid-import-order,g-bad-import-order,unused-import 42from tensorflow.python import pywrap_tensorflow 43from tensorflow.python import pywrap_tfe 44# pylint: enable=invalid-import-order,g-bad-import-order,unused-import 45from tensorflow.python import tf2 46from tensorflow.python.client import pywrap_tf_session 47from tensorflow.python.eager import context 48from tensorflow.python.eager import core 49from tensorflow.python.eager import monitoring 50from tensorflow.python.eager import tape 51from tensorflow.python.framework import c_api_util 52from tensorflow.python.framework import composite_tensor 53from tensorflow.python.framework import device as pydev 54from tensorflow.python.framework import dtypes 55from tensorflow.python.framework import errors 56from tensorflow.python.framework import indexed_slices 57from tensorflow.python.framework import registry 58from tensorflow.python.framework import tensor_conversion_registry 59from tensorflow.python.framework import tensor_like 60from tensorflow.python.framework import tensor_shape 61from tensorflow.python.framework import traceable_stack 62from tensorflow.python.framework import versions 63from tensorflow.python.ops import control_flow_util 64from tensorflow.python.platform import app 65from tensorflow.python.platform import tf_logging as logging 66from tensorflow.python.util import compat 67from tensorflow.python.util import decorator_utils 68from tensorflow.python.util import deprecation 69from tensorflow.python.util import function_utils 70from tensorflow.python.util import lock_util 71from tensorflow.python.util import memory 72from tensorflow.python.util import object_identity 73from tensorflow.python.util import tf_contextlib 74from tensorflow.python.util import tf_stack 75from tensorflow.python.util.compat import collections_abc 76from tensorflow.python.util.deprecation import deprecated_args 77from tensorflow.python.util.lazy_loader import LazyLoader 78from tensorflow.python.util.tf_export import kwarg_only 79from tensorflow.python.util.tf_export import tf_export 80from tensorflow.tools.docs.doc_controls import do_not_generate_docs 81 82ag_ctx = LazyLoader( 83 "ag_ctx", globals(), 84 "tensorflow.python.autograph.core.ag_ctx") 85 86 87# Temporary global switches determining if we should enable the work-in-progress 88# calls to the C API. These will be removed once all functionality is supported. 89_USE_C_API = True 90_USE_C_SHAPES = True 91 92_api_usage_gauge = monitoring.BoolGauge( 93 "/tensorflow/api/ops_eager_execution", 94 "Whether ops.enable_eager_execution() is called.") 95 96 97# pylint: disable=protected-access 98_TensorLike = tensor_like._TensorLike 99_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE 100# pylint: enable=protected-access 101 102 103def tensor_id(tensor): 104 """Returns a unique identifier for this Tensor.""" 105 return tensor._id # pylint: disable=protected-access 106 107 108class _UserDeviceSpec(object): 109 """Store user-specified device and provide computation of merged device.""" 110 111 def __init__(self, device_name_or_function): 112 self._device_name_or_function = device_name_or_function 113 self.display_name = str(self._device_name_or_function) 114 self.function = device_name_or_function 115 self.raw_string = None 116 117 if isinstance(device_name_or_function, pydev.MergeDevice): 118 self.is_null_merge = device_name_or_function.is_null_merge 119 120 elif callable(device_name_or_function): 121 self.is_null_merge = False 122 dev_func = self._device_name_or_function 123 func_name = function_utils.get_func_name(dev_func) 124 func_code = function_utils.get_func_code(dev_func) 125 if func_code: 126 fname = func_code.co_filename 127 lineno = func_code.co_firstlineno 128 else: 129 fname = "unknown" 130 lineno = -1 131 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 132 133 elif device_name_or_function is None: 134 # NOTE(taylorrobie): This MUST be False. None signals a break in the 135 # device stack, so `is_null_merge` must be False for such a case to 136 # allow callers to safely skip over null merges without missing a None. 137 self.is_null_merge = False 138 139 else: 140 self.raw_string = device_name_or_function 141 self.function = pydev.merge_device(device_name_or_function) 142 self.is_null_merge = self.function.is_null_merge 143 144 # We perform this check in __init__ because it is of non-trivial cost, 145 # and self.string_merge is typically called many times. 146 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) 147 148 def string_merge(self, node_def): 149 if self.fast_string_merge: 150 return self.function.shortcut_string_merge(node_def) 151 152 return compat.as_str(_device_string(self.function(node_def))) 153 154 155class NullContextmanager(object): 156 157 def __init__(self, *args, **kwargs): 158 pass 159 160 def __enter__(self): 161 pass 162 163 def __exit__(self, type_arg, value_arg, traceback_arg): 164 return False # False values do not suppress exceptions 165 166 167def _override_helper(clazz_object, operator, func): 168 """Overrides (string) operator on Tensors to call func. 169 170 Args: 171 clazz_object: the class to override for; either Tensor or SparseTensor. 172 operator: the string name of the operator to override. 173 func: the function that replaces the overridden operator. 174 175 Raises: 176 ValueError: If operator has already been overwritten, 177 or if operator is not allowed to be overwritten. 178 """ 179 existing = getattr(clazz_object, operator, None) 180 if existing is not None: 181 # Check to see if this is a default method-wrapper or slot wrapper which 182 # will be true for the comparison operators. 183 if not isinstance(existing, type(object.__lt__)): 184 raise ValueError("operator %s cannot be overwritten again on class %s." % 185 (operator, clazz_object)) 186 if operator not in Tensor.OVERLOADABLE_OPERATORS: 187 raise ValueError("Overriding %s is disallowed" % operator) 188 setattr(clazz_object, operator, func) 189 190 191def _as_graph_element(obj): 192 """Convert `obj` to a graph element if possible, otherwise return `None`. 193 194 Args: 195 obj: Object to convert. 196 197 Returns: 198 The result of `obj._as_graph_element()` if that method is available; 199 otherwise `None`. 200 """ 201 conv_fn = getattr(obj, "_as_graph_element", None) 202 if conv_fn and callable(conv_fn): 203 return conv_fn() 204 return None 205 206 207_TENSOR_LIKE_TYPES = tuple() 208 209 210def is_dense_tensor_like(t): 211 """EXPERIMENTAL: Returns true if `t` implements the tensor interface. 212 213 See `register_dense_tensor_like_type()` for the current definition of a 214 "tensor-like type". 215 216 Args: 217 t: An object. 218 219 Returns: 220 True iff `t` is an instance of one of the registered "tensor-like" types. 221 """ 222 return isinstance(t, _TENSOR_LIKE_TYPES) 223 224 225def register_dense_tensor_like_type(tensor_type): 226 """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface. 227 228 A "tensor-like type" can represent a single dense tensor, and implements 229 the `name`, `dtype` and `shape` properties. 230 231 Args: 232 tensor_type: A type implementing the tensor interface. 233 234 Raises: 235 TypeError: If `tensor_type` does not implement the tensor interface. 236 """ 237 if not (hasattr(tensor_type, "name") and 238 isinstance(tensor_type.name, property)): 239 raise TypeError("Type %s does not define a `name` property" % 240 tensor_type.__name__) 241 if not (hasattr(tensor_type, "dtype") and 242 isinstance(tensor_type.dtype, property)): 243 raise TypeError("Type %s does not define a `dtype` property" % 244 tensor_type.__name__) 245 if not (hasattr(tensor_type, "shape") and 246 isinstance(tensor_type.shape, property)): 247 raise TypeError("Type %s does not define a `shape` property" % 248 tensor_type.__name__) 249 # We expect this list to be small, so choose quadratic complexity 250 # for registration, so that we have a tuple that can be used for 251 # more efficient `isinstance` checks later. 252 global _TENSOR_LIKE_TYPES 253 _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type]) 254 255 256def uid(): 257 """A unique (within this program execution) integer.""" 258 return pywrap_tfe.TFE_Py_UID() 259 260 261def numpy_text(tensor, is_repr=False): 262 """Human readable representation of a tensor's numpy value.""" 263 if tensor.dtype.is_numpy_compatible: 264 # pylint: disable=protected-access 265 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 266 # pylint: enable=protected-access 267 else: 268 text = "<unprintable>" 269 if "\n" in text: 270 text = "\n" + text 271 return text 272 273@tf_export(v1=["enable_tensor_equality"]) 274def enable_tensor_equality(): 275 """Compare Tensors with element-wise comparison and thus be unhashable. 276 277 Comparing tensors with element-wise allows comparisons such as 278 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are 279 unhashable. Thus tensors can no longer be directly used in sets or as a key in 280 a dictionary. 281 """ 282 Tensor._USE_EQUALITY = True # pylint: disable=protected-access 283 284@tf_export(v1=["disable_tensor_equality"]) 285def disable_tensor_equality(): 286 """Compare Tensors by their id and be hashable. 287 288 This is a legacy behaviour of TensorFlow and is highly discouraged. 289 """ 290 Tensor._USE_EQUALITY = False # pylint: disable=protected-access 291 292 293@tf_export("Tensor") 294class Tensor(_TensorLike): 295 """A tensor represents a rectangular array of data. 296 297 When writing a TensorFlow program, the main object you manipulate and pass 298 around is the `tf.Tensor`. A `tf.Tensor` object represents a rectangular array 299 of arbitrary dimension, filled with data of a specific data type. 300 301 A `tf.Tensor` has the following properties: 302 303 * a data type (float32, int32, or string, for example) 304 * a shape 305 306 Each element in the Tensor has the same data type, and the data type is always 307 known. 308 309 In eager execution, which is the default mode in TensorFlow, results are 310 calculated immediately. 311 312 >>> # Compute some values using a Tensor 313 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 314 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 315 >>> e = tf.matmul(c, d) 316 >>> print(e) 317 tf.Tensor( 318 [[1. 3.] 319 [3. 7.]], shape=(2, 2), dtype=float32) 320 321 322 Note that during eager execution, you may discover your `Tensors` are actually 323 of type `EagerTensor`. This is an internal detail, but it does give you 324 access to a useful function, `numpy`: 325 326 >>> type(e) 327 <class '...ops.EagerTensor'> 328 >>> print(e.numpy()) 329 [[1. 3.] 330 [3. 7.]] 331 332 TensorFlow can define computations without immediately executing them, most 333 commonly inside `tf.function`s, as well as in (legacy) Graph mode. In those 334 cases, the shape (that is, the rank of the Tensor and the size of 335 each dimension) might be only partially known. 336 337 Most operations produce tensors of fully-known shapes if the shapes of their 338 inputs are also fully known, but in some cases it's only possible to find the 339 shape of a tensor at execution time. 340 341 There are specialized tensors; for these, see `tf.Variable`, `tf.constant`, 342 `tf.placeholder`, `tf.SparseTensor`, and `tf.RaggedTensor`. 343 344 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor`). 345 """ 346 347 # List of Python operators that we allow to override. 348 OVERLOADABLE_OPERATORS = { 349 # Binary. 350 "__add__", 351 "__radd__", 352 "__sub__", 353 "__rsub__", 354 "__mul__", 355 "__rmul__", 356 "__div__", 357 "__rdiv__", 358 "__truediv__", 359 "__rtruediv__", 360 "__floordiv__", 361 "__rfloordiv__", 362 "__mod__", 363 "__rmod__", 364 "__lt__", 365 "__le__", 366 "__gt__", 367 "__ge__", 368 "__ne__", 369 "__eq__", 370 "__and__", 371 "__rand__", 372 "__or__", 373 "__ror__", 374 "__xor__", 375 "__rxor__", 376 "__getitem__", 377 "__pow__", 378 "__rpow__", 379 # Unary. 380 "__invert__", 381 "__neg__", 382 "__abs__", 383 "__matmul__", 384 "__rmatmul__" 385 } 386 387 # Whether to allow hashing or numpy-style equality 388 _USE_EQUALITY = tf2.enabled() 389 390 def __init__(self, op, value_index, dtype): 391 """Creates a new `Tensor`. 392 393 Args: 394 op: An `Operation`. `Operation` that computes this tensor. 395 value_index: An `int`. Index of the operation's endpoint that produces 396 this tensor. 397 dtype: A `DType`. Type of elements stored in this tensor. 398 399 Raises: 400 TypeError: If the op is not an `Operation`. 401 """ 402 if not isinstance(op, Operation): 403 raise TypeError("op needs to be an Operation: %s" % op) 404 self._op = op 405 self._value_index = value_index 406 self._dtype = dtypes.as_dtype(dtype) 407 # This will be set by self._as_tf_output(). 408 self._tf_output = None 409 # This will be set by self.shape(). 410 self._shape_val = None 411 # List of operations that use this Tensor as input. We maintain this list 412 # to easily navigate a computation graph. 413 self._consumers = [] 414 self._id = uid() 415 self._name = None 416 417 @staticmethod 418 def _create_with_tf_output(op, value_index, dtype, tf_output): 419 ret = Tensor(op, value_index, dtype) 420 ret._tf_output = tf_output 421 return ret 422 423 @property 424 def op(self): 425 """The `Operation` that produces this tensor as an output.""" 426 return self._op 427 428 @property 429 def dtype(self): 430 """The `DType` of elements in this tensor.""" 431 return self._dtype 432 433 @property 434 def graph(self): 435 """The `Graph` that contains this tensor.""" 436 return self._op.graph 437 438 @property 439 def name(self): 440 """The string name of this tensor.""" 441 if self._name is None: 442 if not self._op.name: 443 raise ValueError("Operation was not named: %s" % self._op) 444 self._name = "%s:%d" % (self._op.name, self._value_index) 445 return self._name 446 447 @property 448 def device(self): 449 """The name of the device on which this tensor will be produced, or None.""" 450 return self._op.device 451 452 @property 453 def shape(self): 454 """Returns the `TensorShape` that represents the shape of this tensor. 455 456 The shape is computed using shape inference functions that are 457 registered in the Op for each `Operation`. See 458 `tf.TensorShape` 459 for more details of what a shape represents. 460 461 The inferred shape of a tensor is used to provide shape 462 information without having to execute the underlying kernel. This 463 can be used for debugging and providing early error messages. For 464 example: 465 466 ```python 467 >>> c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 468 >>> print(c.shape) # will be TensorShape([2, 3]) 469 (2, 3) 470 471 >>> d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) 472 >>> print(d.shape) 473 (4, 2) 474 475 # Raises a ValueError, because `c` and `d` do not have compatible 476 # inner dimensions. 477 >>> e = tf.matmul(c, d) 478 Traceback (most recent call last): 479 ... 480 tensorflow.python.framework.errors_impl.InvalidArgumentError: Matrix 481 size-incompatible: In[0]: [2,3], In[1]: [4,2] [Op:MatMul] name: MatMul/ 482 483 # This works because we have compatible shapes. 484 >>> f = tf.matmul(c, d, transpose_a=True, transpose_b=True) 485 >>> print(f.shape) 486 (3, 4) 487 488 ``` 489 490 In some cases, the inferred shape may have unknown dimensions. If 491 the caller has additional information about the values of these 492 dimensions, `Tensor.set_shape()` can be used to augment the 493 inferred shape. 494 495 Returns: 496 A `tf.TensorShape` representing the shape of this tensor. 497 498 """ 499 if self._shape_val is None: 500 self._shape_val = self._c_api_shape() 501 return self._shape_val 502 503 def _c_api_shape(self): 504 """Returns the TensorShape of this tensor according to the C API.""" 505 c_graph = self._op._graph._c_graph # pylint: disable=protected-access 506 shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper( 507 c_graph, self._as_tf_output()) 508 if unknown_shape: 509 return tensor_shape.unknown_shape() 510 else: 511 shape_vec = [None if d == -1 else d for d in shape_vec] 512 return tensor_shape.TensorShape(shape_vec) 513 514 @property 515 def _shape(self): 516 logging.warning("Tensor._shape is private, use Tensor.shape " 517 "instead. Tensor._shape will eventually be removed.") 518 return self.shape 519 520 @_shape.setter 521 def _shape(self, value): 522 raise ValueError( 523 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 524 525 def _disallow_when_autograph_disabled(self, task): 526 raise errors.OperatorNotAllowedInGraphError( 527 "{} is not allowed: AutoGraph is disabled in this function." 528 " Try decorating it directly with @tf.function.".format(task)) 529 530 def _disallow_when_autograph_enabled(self, task): 531 raise errors.OperatorNotAllowedInGraphError( 532 "{} is not allowed: AutoGraph did not convert this function. Try" 533 " decorating it directly with @tf.function.".format(task)) 534 535 def _disallow_in_graph_mode(self, task): 536 raise errors.OperatorNotAllowedInGraphError( 537 "{} is not allowed in Graph execution. Use Eager execution or decorate" 538 " this function with @tf.function.".format(task)) 539 540 def _disallow_bool_casting(self): 541 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 542 self._disallow_when_autograph_disabled( 543 "using a `tf.Tensor` as a Python `bool`") 544 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 545 self._disallow_when_autograph_enabled( 546 "using a `tf.Tensor` as a Python `bool`") 547 else: 548 # Default: V1-style Graph execution. 549 self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`") 550 551 def _disallow_iteration(self): 552 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 553 self._disallow_when_autograph_disabled("iterating over `tf.Tensor`") 554 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 555 self._disallow_when_autograph_enabled("iterating over `tf.Tensor`") 556 else: 557 # Default: V1-style Graph execution. 558 self._disallow_in_graph_mode("iterating over `tf.Tensor`") 559 560 def __iter__(self): 561 if not context.executing_eagerly(): 562 self._disallow_iteration() 563 564 shape = self._shape_tuple() 565 if shape is None: 566 raise TypeError("Cannot iterate over a tensor with unknown shape.") 567 if not shape: 568 raise TypeError("Cannot iterate over a scalar tensor.") 569 if shape[0] is None: 570 raise TypeError( 571 "Cannot iterate over a tensor with unknown first dimension.") 572 for i in xrange(shape[0]): 573 yield self[i] 574 575 def _shape_as_list(self): 576 if self.shape.ndims is not None: 577 return [dim.value for dim in self.shape.dims] 578 else: 579 return None 580 581 def _shape_tuple(self): 582 shape = self._shape_as_list() 583 if shape is None: 584 return None 585 return tuple(shape) 586 587 def _rank(self): 588 """Integer rank of this Tensor, if known, else None. 589 590 Returns: 591 Integer rank or None 592 """ 593 return self.shape.ndims 594 595 def get_shape(self): 596 """Alias of `tf.Tensor.shape`.""" 597 return self.shape 598 599 def set_shape(self, shape): 600 """Updates the shape of this tensor. 601 602 This method can be called multiple times, and will merge the given 603 `shape` with the current shape of this tensor. It can be used to 604 provide additional information about the shape of this tensor that 605 cannot be inferred from the graph alone. For example, this can be used 606 to provide additional information about the shapes of images: 607 608 ```python 609 _, image_data = tf.compat.v1.TFRecordReader(...).read(...) 610 image = tf.image.decode_png(image_data, channels=3) 611 612 # The height and width dimensions of `image` are data dependent, and 613 # cannot be computed without executing the op. 614 print(image.shape) 615 ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) 616 617 # We know that each image in this dataset is 28 x 28 pixels. 618 image.set_shape([28, 28, 3]) 619 print(image.shape) 620 ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) 621 ``` 622 623 NOTE: This shape is not enforced at runtime. Setting incorrect shapes can 624 result in inconsistencies between the statically-known graph and the runtime 625 value of tensors. For runtime validation of the shape, use `tf.ensure_shape` 626 instead. 627 628 Args: 629 shape: A `TensorShape` representing the shape of this tensor, a 630 `TensorShapeProto`, a list, a tuple, or None. 631 632 Raises: 633 ValueError: If `shape` is not compatible with the current shape of 634 this tensor. 635 """ 636 # Reset cached shape. 637 self._shape_val = None 638 639 # We want set_shape to be reflected in the C API graph for when we run it. 640 if not isinstance(shape, tensor_shape.TensorShape): 641 shape = tensor_shape.TensorShape(shape) 642 dim_list = [] 643 if shape.dims is None: 644 unknown_shape = True 645 else: 646 unknown_shape = False 647 for dim in shape.dims: 648 if dim.value is None: 649 dim_list.append(-1) 650 else: 651 dim_list.append(dim.value) 652 try: 653 pywrap_tf_session.TF_GraphSetTensorShape_wrapper( 654 self._op._graph._c_graph, # pylint: disable=protected-access 655 self._as_tf_output(), 656 dim_list, 657 unknown_shape) 658 except errors.InvalidArgumentError as e: 659 # Convert to ValueError for backwards compatibility. 660 raise ValueError(str(e)) 661 662 @property 663 def value_index(self): 664 """The index of this tensor in the outputs of its `Operation`.""" 665 return self._value_index 666 667 def consumers(self): 668 """Returns a list of `Operation`s that consume this tensor. 669 670 Returns: 671 A list of `Operation`s. 672 """ 673 consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper( 674 self._as_tf_output()) 675 # pylint: disable=protected-access 676 return [ 677 self.graph._get_operation_by_name_unsafe(name) 678 for name in consumer_names 679 ] 680 # pylint: enable=protected-access 681 682 def _as_node_def_input(self): 683 """Return a value to use for the NodeDef "input" attribute. 684 685 The returned string can be used in a NodeDef "input" attribute 686 to indicate that the NodeDef uses this Tensor as input. 687 688 Raises: 689 ValueError: if this Tensor's Operation does not have a name. 690 691 Returns: 692 a string. 693 """ 694 if not self._op.name: 695 raise ValueError("Operation was not named: %s" % self._op) 696 if self._value_index == 0: 697 return self._op.name 698 else: 699 return "%s:%d" % (self._op.name, self._value_index) 700 701 def _as_tf_output(self): 702 # pylint: disable=protected-access 703 # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object 704 # also guarantees that a dictionary of tf_output objects will retain a 705 # deterministic (yet unsorted) order which prevents memory blowup in the 706 # cache of executor(s) stored for every session. 707 if self._tf_output is None: 708 self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index) 709 return self._tf_output 710 # pylint: enable=protected-access 711 712 def __str__(self): 713 return "Tensor(\"%s\"%s%s%s)" % ( 714 self.name, 715 (", shape=%s" % 716 self.get_shape()) if self.get_shape().ndims is not None else "", 717 (", dtype=%s" % self._dtype.name) if self._dtype else "", 718 (", device=%s" % self.device) if self.device else "") 719 720 def __repr__(self): 721 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 722 self._dtype.name) 723 724 def __hash__(self): 725 g = getattr(self, "graph", None) 726 if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and 727 (g is None or g._building_function)): # pylint: disable=protected-access 728 raise TypeError("Tensor is unhashable if Tensor equality is enabled. " 729 "Instead, use tensor.experimental_ref() as the key.") 730 else: 731 return id(self) 732 733 def __copy__(self): 734 # TODO(b/77597810): get rid of Tensor copies. 735 cls = self.__class__ 736 result = cls.__new__(cls) 737 result.__dict__.update(self.__dict__) 738 return result 739 740 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 741 # operators to run when the left operand is an ndarray, because it 742 # accords the Tensor class higher priority than an ndarray, or a 743 # numpy matrix. 744 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 745 # mechanism, which allows more control over how Tensors interact 746 # with ndarrays. 747 __array_priority__ = 100 748 749 def __array__(self): 750 raise NotImplementedError("Cannot convert a symbolic Tensor ({}) to a numpy" 751 " array.".format(self.name)) 752 753 def __len__(self): 754 raise TypeError("len is not well defined for symbolic Tensors. ({}) " 755 "Please call `x.shape` rather than `len(x)` for " 756 "shape information.".format(self.name)) 757 758 @staticmethod 759 def _override_operator(operator, func): 760 _override_helper(Tensor, operator, func) 761 762 def __bool__(self): 763 """Dummy method to prevent a tensor from being used as a Python `bool`. 764 765 This overload raises a `TypeError` when the user inadvertently 766 treats a `Tensor` as a boolean (most commonly in an `if` or `while` 767 statement), in code that was not converted by AutoGraph. For example: 768 769 ```python 770 if tf.constant(True): # Will raise. 771 # ... 772 773 if tf.constant(5) < tf.constant(7): # Will raise. 774 # ... 775 ``` 776 777 Raises: 778 `TypeError`. 779 """ 780 self._disallow_bool_casting() 781 782 def __nonzero__(self): 783 """Dummy method to prevent a tensor from being used as a Python `bool`. 784 785 This is the Python 2.x counterpart to `__bool__()` above. 786 787 Raises: 788 `TypeError`. 789 """ 790 self._disallow_bool_casting() 791 792 def eval(self, feed_dict=None, session=None): 793 """Evaluates this tensor in a `Session`. 794 795 Note: If you are not using `compat.v1` libraries, you should not need this, 796 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`) 797 you do not need to call `eval`. 798 799 Calling this method will execute all preceding operations that 800 produce the inputs needed for the operation that produces this 801 tensor. 802 803 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 804 launched in a session, and either a default session must be 805 available, or `session` must be specified explicitly. 806 807 Args: 808 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 809 `tf.Session.run` for a description of the valid feed values. 810 session: (Optional.) The `Session` to be used to evaluate this tensor. If 811 none, the default session will be used. 812 813 Returns: 814 A numpy array corresponding to the value of this tensor. 815 """ 816 return _eval_using_default_session(self, feed_dict, self.graph, session) 817 818 def experimental_ref(self): 819 # tf.Variable also has the same experimental_ref() API. If you update the 820 # documenation here, please update tf.Variable.experimental_ref() as well. 821 """Returns a hashable reference object to this Tensor. 822 823 Warning: Experimental API that could be changed or removed. 824 825 The primary usecase for this API is to put tensors in a set/dictionary. 826 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer 827 available starting Tensorflow 2.0. 828 829 ```python 830 import tensorflow as tf 831 832 x = tf.constant(5) 833 y = tf.constant(10) 834 z = tf.constant(10) 835 836 # The followings will raise an exception starting 2.0 837 # TypeError: Tensor is unhashable if Tensor equality is enabled. 838 tensor_set = {x, y, z} 839 tensor_dict = {x: 'five', y: 'ten', z: 'ten'} 840 ``` 841 842 Instead, we can use `tensor.experimental_ref()`. 843 844 ```python 845 tensor_set = {x.experimental_ref(), 846 y.experimental_ref(), 847 z.experimental_ref()} 848 849 print(x.experimental_ref() in tensor_set) 850 ==> True 851 852 tensor_dict = {x.experimental_ref(): 'five', 853 y.experimental_ref(): 'ten', 854 z.experimental_ref(): 'ten'} 855 856 print(tensor_dict[y.experimental_ref()]) 857 ==> ten 858 ``` 859 860 Also, the reference object provides `.deref()` function that returns the 861 original Tensor. 862 863 ```python 864 x = tf.constant(5) 865 print(x.experimental_ref().deref()) 866 ==> tf.Tensor(5, shape=(), dtype=int32) 867 ``` 868 """ 869 return object_identity.Reference(self) 870 871 872# TODO(agarwal): consider getting rid of this. 873class _EagerTensorBase(Tensor): 874 """Base class for EagerTensor.""" 875 876 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and 877 # only work for scalars; values are cast as per numpy. 878 def __complex__(self): 879 return complex(self._numpy()) 880 881 def __int__(self): 882 return int(self._numpy()) 883 884 def __long__(self): 885 return long(self._numpy()) 886 887 def __float__(self): 888 return float(self._numpy()) 889 890 def __index__(self): 891 return self._numpy().__index__() 892 893 def __bool__(self): 894 return bool(self._numpy()) 895 896 __nonzero__ = __bool__ 897 898 def __format__(self, format_spec): 899 return self._numpy().__format__(format_spec) 900 901 def __reduce__(self): 902 return convert_to_tensor, (self._numpy(),) 903 904 def __copy__(self): 905 # Eager Tensors are immutable so it's safe to return themselves as a copy. 906 return self 907 908 def __deepcopy__(self, memo): 909 # Eager Tensors are immutable so it's safe to return themselves as a copy. 910 del memo 911 return self 912 913 def __str__(self): 914 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape, 915 self.dtype.name) 916 917 def __repr__(self): 918 return "<tf.Tensor: shape=%s, dtype=%s, numpy=%s>" % ( 919 self.shape, self.dtype.name, numpy_text(self, is_repr=True)) 920 921 def __len__(self): 922 """Returns the length of the first dimension in the Tensor.""" 923 if not self.shape.ndims: 924 raise TypeError("Scalar tensor has no `len()`") 925 # pylint: disable=protected-access 926 try: 927 return self._shape_tuple()[0] 928 except core._NotOkStatusException as e: 929 six.raise_from(core._status_to_exception(e.code, e.message), None) 930 931 def _numpy_internal(self): 932 raise NotImplementedError() 933 934 def _numpy(self): 935 # pylint: disable=protected-access 936 try: 937 return self._numpy_internal() 938 except core._NotOkStatusException as e: 939 six.raise_from(core._status_to_exception(e.code, e.message), None) 940 941 @property 942 def dtype(self): 943 # Note: using the intern table directly here as this is 944 # performance-sensitive in some models. 945 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 946 947 def numpy(self): 948 """Copy of the contents of this Tensor into a NumPy array or scalar. 949 950 Unlike NumPy arrays, Tensors are immutable, so this method has to copy 951 the contents to ensure safety. Use `memoryview` to get a readonly 952 view of the contents without doing a copy: 953 954 >>> t = tf.constant([42]) 955 >>> np.array(memoryview(t)) 956 array([42], dtype=int32) 957 958 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor 959 is on GPU, it will have to be transferred to CPU first in order for 960 `memoryview` to work. 961 962 Returns: 963 A NumPy array of the same shape and dtype or a NumPy scalar, if this 964 Tensor has rank 0. 965 966 Raises: 967 ValueError: If the dtype of this Tensor does not have a compatible 968 NumPy dtype. 969 """ 970 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors. 971 maybe_arr = self._numpy() # pylint: disable=protected-access 972 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr 973 974 @property 975 def backing_device(self): 976 """Returns the name of the device holding this tensor's memory. 977 978 `.backing_device` is usually the same as `.device`, which returns 979 the device on which the kernel of the operation that produced this tensor 980 ran. However, some operations can produce tensors on a different device 981 (e.g., an operation that executes on the GPU but produces output tensors 982 in host memory). 983 """ 984 raise NotImplementedError() 985 986 def _datatype_enum(self): 987 raise NotImplementedError() 988 989 def _shape_tuple(self): 990 """The shape of this Tensor, as a tuple. 991 992 This is more performant than tuple(shape().as_list()) as it avoids 993 two list and one object creation. Marked private for now as from an API 994 perspective, it would be better to have a single performant way of 995 getting a shape rather than exposing shape() and shape_tuple() 996 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 997 but ideally one would work things out and remove the need for this method. 998 999 Returns: 1000 tuple with the shape. 1001 """ 1002 raise NotImplementedError() 1003 1004 def _rank(self): 1005 """Integer rank of this Tensor. 1006 1007 Unlike regular Tensors, the rank is always known for EagerTensors. 1008 1009 This is more performant than len(self._shape_tuple()) 1010 1011 Returns: 1012 Integer rank 1013 """ 1014 raise NotImplementedError() 1015 1016 def _num_elements(self): 1017 """Number of elements of this Tensor. 1018 1019 Unlike regular Tensors, the number of elements is always known for 1020 EagerTensors. 1021 1022 This is more performant than tensor.shape.num_elements 1023 1024 Returns: 1025 Long - num elements in the tensor 1026 """ 1027 raise NotImplementedError() 1028 1029 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name 1030 raise NotImplementedError() 1031 1032 @staticmethod 1033 def _override_operator(name, func): 1034 setattr(_EagerTensorBase, name, func) 1035 1036 def _copy_nograd(self, ctx=None, device_name=None): 1037 """Copies tensor to dest device, but doesn't record the operation.""" 1038 # Creates a new tensor on the dest device. 1039 if ctx is None: 1040 ctx = context.context() 1041 if device_name is None: 1042 device_name = ctx.device_name 1043 # pylint: disable=protected-access 1044 try: 1045 ctx.ensure_initialized() 1046 new_tensor = self._copy_to_device(device_name) 1047 except core._NotOkStatusException as e: 1048 six.raise_from(core._status_to_exception(e.code, e.message), None) 1049 return new_tensor 1050 1051 def _copy(self, ctx=None, device_name=None): 1052 """Copies tensor to dest device.""" 1053 new_tensor = self._copy_nograd(ctx, device_name) 1054 # Record the copy on tape and define backprop copy as well. 1055 if context.executing_eagerly(): 1056 self_device = self.device 1057 1058 def grad_fun(dresult): 1059 return [ 1060 dresult._copy(device_name=self_device) 1061 if hasattr(dresult, "_copy") else dresult 1062 ] 1063 1064 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 1065 return new_tensor 1066 # pylint: enable=protected-access 1067 1068 @property 1069 def shape(self): 1070 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 1071 # pylint: disable=protected-access 1072 try: 1073 # `_tensor_shape` is declared and defined in the definition of 1074 # `EagerTensor`, in C. 1075 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 1076 except core._NotOkStatusException as e: 1077 six.raise_from(core._status_to_exception(e.code, e.message), None) 1078 1079 return self._tensor_shape 1080 1081 def get_shape(self): 1082 """Alias of Tensor.shape.""" 1083 return self.shape 1084 1085 def _shape_as_list(self): 1086 """The shape of the tensor as a list.""" 1087 return list(self._shape_tuple()) 1088 1089 @property 1090 def ndim(self): 1091 """Returns the number of Tensor dimensions.""" 1092 return self.shape.ndims 1093 1094 @deprecation.deprecated(None, "Use tf.identity instead.") 1095 def cpu(self): 1096 """A copy of this Tensor with contents backed by host memory.""" 1097 return self._copy(context.context(), "CPU:0") 1098 1099 @deprecation.deprecated(None, "Use tf.identity instead.") 1100 def gpu(self, gpu_index=0): 1101 """A copy of this Tensor with contents backed by memory on the GPU. 1102 1103 Arguments: 1104 gpu_index: Identifies which GPU to place the contents on the returned 1105 Tensor in. 1106 1107 Returns: 1108 A GPU-memory backed Tensor object initialized with the same contents 1109 as this Tensor. 1110 """ 1111 return self._copy(context.context(), "GPU:" + str(gpu_index)) 1112 1113 def set_shape(self, shape): 1114 if not self.shape.is_compatible_with(shape): 1115 raise ValueError( 1116 "Tensor's shape %s is not compatible with supplied shape %s" % 1117 (self.shape, shape)) 1118 1119 # Methods not supported / implemented for Eager Tensors. 1120 @property 1121 def op(self): 1122 raise AttributeError( 1123 "Tensor.op is meaningless when eager execution is enabled.") 1124 1125 @property 1126 def graph(self): 1127 raise AttributeError( 1128 "Tensor.graph is meaningless when eager execution is enabled.") 1129 1130 @property 1131 def name(self): 1132 raise AttributeError( 1133 "Tensor.name is meaningless when eager execution is enabled.") 1134 1135 @property 1136 def value_index(self): 1137 raise AttributeError( 1138 "Tensor.value_index is meaningless when eager execution is enabled.") 1139 1140 def consumers(self): 1141 raise NotImplementedError( 1142 "Tensor.consumers is meaningless when eager execution is enabled.") 1143 1144 def _add_consumer(self, consumer): 1145 raise NotImplementedError( 1146 "_add_consumer not supported when eager execution is enabled.") 1147 1148 def _as_node_def_input(self): 1149 raise NotImplementedError( 1150 "_as_node_def_input not supported when eager execution is enabled.") 1151 1152 def _as_tf_output(self): 1153 raise NotImplementedError( 1154 "_as_tf_output not supported when eager execution is enabled.") 1155 1156 def eval(self, feed_dict=None, session=None): 1157 raise NotImplementedError( 1158 "eval is not supported when eager execution is enabled, " 1159 "is .numpy() what you're looking for?") 1160 1161 1162# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 1163# registers it with the current module. 1164EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase) 1165 1166 1167register_dense_tensor_like_type(Tensor) 1168 1169 1170@tf_export(v1=["convert_to_tensor"]) 1171def convert_to_tensor_v1(value, 1172 dtype=None, 1173 name=None, 1174 preferred_dtype=None, 1175 dtype_hint=None): 1176 """Converts the given `value` to a `Tensor`. 1177 1178 This function converts Python objects of various types to `Tensor` 1179 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1180 and Python scalars. For example: 1181 1182 ```python 1183 import numpy as np 1184 1185 def my_func(arg): 1186 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1187 return tf.matmul(arg, arg) + arg 1188 1189 # The following calls are equivalent. 1190 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1191 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1192 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1193 ``` 1194 1195 This function can be useful when composing a new operation in Python 1196 (such as `my_func` in the example above). All standard Python op 1197 constructors apply this function to each of their Tensor-valued 1198 inputs, which allows those ops to accept numpy arrays, Python lists, 1199 and scalars in addition to `Tensor` objects. 1200 1201 Note: This function diverges from default Numpy behavior for `float` and 1202 `string` types when `None` is present in a Python list or scalar. Rather 1203 than silently converting `None` values, an error will be thrown. 1204 1205 Args: 1206 value: An object whose type has a registered `Tensor` conversion function. 1207 dtype: Optional element type for the returned tensor. If missing, the type 1208 is inferred from the type of `value`. 1209 name: Optional name to use if a new `Tensor` is created. 1210 preferred_dtype: Optional element type for the returned tensor, used when 1211 dtype is None. In some cases, a caller may not have a dtype in mind when 1212 converting to a tensor, so preferred_dtype can be used as a soft 1213 preference. If the conversion to `preferred_dtype` is not possible, this 1214 argument has no effect. 1215 dtype_hint: same meaning as preferred_dtype, and overrides it. 1216 1217 Returns: 1218 A `Tensor` based on `value`. 1219 1220 Raises: 1221 TypeError: If no conversion function is registered for `value` to `dtype`. 1222 RuntimeError: If a registered conversion function returns an invalid value. 1223 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1224 """ 1225 preferred_dtype = deprecation.deprecated_argument_lookup( 1226 "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype) 1227 return convert_to_tensor_v2(value, dtype, preferred_dtype, name) 1228 1229 1230@tf_export("convert_to_tensor", v1=[]) 1231def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): 1232 """Converts the given `value` to a `Tensor`. 1233 1234 This function converts Python objects of various types to `Tensor` 1235 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1236 and Python scalars. For example: 1237 1238 >>> def my_func(arg): 1239 ... arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1240 ... return arg 1241 1242 >>> # The following calls are equivalent. 1243 >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1244 >>> print(value_1) 1245 tf.Tensor( 1246 [[1. 2.] 1247 [3. 4.]], shape=(2, 2), dtype=float32) 1248 >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1249 >>> print(value_2) 1250 tf.Tensor( 1251 [[1. 2.] 1252 [3. 4.]], shape=(2, 2), dtype=float32) 1253 >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1254 >>> print(value_3) 1255 tf.Tensor( 1256 [[1. 2.] 1257 [3. 4.]], shape=(2, 2), dtype=float32) 1258 1259 This function can be useful when composing a new operation in Python 1260 (such as `my_func` in the example above). All standard Python op 1261 constructors apply this function to each of their Tensor-valued 1262 inputs, which allows those ops to accept numpy arrays, Python lists, 1263 and scalars in addition to `Tensor` objects. 1264 1265 Note: This function diverges from default Numpy behavior for `float` and 1266 `string` types when `None` is present in a Python list or scalar. Rather 1267 than silently converting `None` values, an error will be thrown. 1268 1269 Args: 1270 value: An object whose type has a registered `Tensor` conversion function. 1271 dtype: Optional element type for the returned tensor. If missing, the type 1272 is inferred from the type of `value`. 1273 dtype_hint: Optional element type for the returned tensor, used when dtype 1274 is None. In some cases, a caller may not have a dtype in mind when 1275 converting to a tensor, so dtype_hint can be used as a soft preference. 1276 If the conversion to `dtype_hint` is not possible, this argument has no 1277 effect. 1278 name: Optional name to use if a new `Tensor` is created. 1279 1280 Returns: 1281 A `Tensor` based on `value`. 1282 1283 Raises: 1284 TypeError: If no conversion function is registered for `value` to `dtype`. 1285 RuntimeError: If a registered conversion function returns an invalid value. 1286 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1287 """ 1288 return convert_to_tensor( 1289 value=value, 1290 dtype=dtype, 1291 name=name, 1292 preferred_dtype=dtype_hint, 1293 as_ref=False) 1294 1295 1296def _error_prefix(name): 1297 return "" if name is None else "%s: " % name 1298 1299 1300def convert_to_tensor(value, 1301 dtype=None, 1302 name=None, 1303 as_ref=False, 1304 preferred_dtype=None, 1305 dtype_hint=None, 1306 ctx=None, 1307 accepted_result_types=(Tensor,)): 1308 """Implementation of the public convert_to_tensor.""" 1309 # TODO(b/142518781): Fix all call-sites and remove redundant arg 1310 preferred_dtype = preferred_dtype or dtype_hint 1311 if isinstance(value, EagerTensor): 1312 if ctx is None: 1313 ctx = context.context() 1314 if not ctx.executing_eagerly(): 1315 graph = get_default_graph() 1316 if not graph.building_function: 1317 raise RuntimeError("Attempting to capture an EagerTensor without " 1318 "building a function.") 1319 return graph.capture(value, name=name) 1320 1321 if dtype is not None: 1322 dtype = dtypes.as_dtype(dtype) 1323 if isinstance(value, Tensor): 1324 if dtype is not None and not dtype.is_compatible_with(value.dtype): 1325 raise ValueError( 1326 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1327 (dtype.name, value.dtype.name, value)) 1328 return value 1329 1330 if preferred_dtype is not None: 1331 preferred_dtype = dtypes.as_dtype(preferred_dtype) 1332 for base_type, conversion_func in tensor_conversion_registry.get(type(value)): 1333 # If dtype is None but preferred_dtype is not None, we try to 1334 # cast to preferred_dtype first. 1335 ret = None 1336 if dtype is None and preferred_dtype is not None: 1337 try: 1338 ret = conversion_func( 1339 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1340 except (TypeError, ValueError): 1341 # Could not coerce the conversion to use the preferred dtype. 1342 pass 1343 else: 1344 if (ret is not NotImplemented and 1345 ret.dtype.base_dtype != preferred_dtype.base_dtype): 1346 raise TypeError("convert_to_tensor did not convert to " 1347 "the preferred dtype: %s vs %s " % 1348 (ret.dtype.base_dtype, preferred_dtype.base_dtype)) 1349 1350 if ret is None: 1351 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1352 1353 if ret is NotImplemented: 1354 continue 1355 1356 if not isinstance(ret, accepted_result_types): 1357 raise RuntimeError( 1358 "%sConversion function %r for type %s returned non-Tensor: %r" % 1359 (_error_prefix(name), conversion_func, base_type, ret)) 1360 if dtype and not dtype.is_compatible_with(ret.dtype): 1361 raise RuntimeError( 1362 "%sConversion function %r for type %s returned incompatible " 1363 "dtype: requested = %s, actual = %s" % 1364 (_error_prefix(name), conversion_func, base_type, dtype.name, 1365 ret.dtype.name)) 1366 return ret 1367 raise TypeError("%sCannot convert %r with type %s to Tensor: " 1368 "no conversion function registered." % 1369 (_error_prefix(name), value, type(value))) 1370 1371 1372internal_convert_to_tensor = convert_to_tensor 1373 1374 1375def internal_convert_n_to_tensor(values, 1376 dtype=None, 1377 name=None, 1378 as_ref=False, 1379 preferred_dtype=None, 1380 ctx=None): 1381 """Converts `values` to a list of `Tensor` objects. 1382 1383 Args: 1384 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1385 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1386 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1387 which case element `i` will be given the name `name + '_' + i`. 1388 as_ref: True if the caller wants the results as ref tensors. 1389 preferred_dtype: Optional element type for the returned tensors, used when 1390 dtype is None. In some cases, a caller may not have a dtype in mind when 1391 converting to a tensor, so preferred_dtype can be used as a soft 1392 preference. If the conversion to `preferred_dtype` is not possible, this 1393 argument has no effect. 1394 ctx: The value of context.context(). 1395 1396 Returns: 1397 A list of `Tensor` and/or `IndexedSlices` objects. 1398 1399 Raises: 1400 TypeError: If no conversion function is registered for an element in 1401 `values`. 1402 RuntimeError: If a registered conversion function returns an invalid 1403 value. 1404 """ 1405 if not isinstance(values, collections_abc.Sequence): 1406 raise TypeError("values must be a sequence.") 1407 ret = [] 1408 if ctx is None: 1409 ctx = context.context() 1410 for i, value in enumerate(values): 1411 n = None if name is None else "%s_%d" % (name, i) 1412 ret.append( 1413 convert_to_tensor( 1414 value, 1415 dtype=dtype, 1416 name=n, 1417 as_ref=as_ref, 1418 preferred_dtype=preferred_dtype, 1419 ctx=ctx)) 1420 return ret 1421 1422 1423def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1424 """Converts `values` to a list of `Tensor` objects. 1425 1426 Args: 1427 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1428 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1429 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1430 which case element `i` will be given the name `name + '_' + i`. 1431 preferred_dtype: Optional element type for the returned tensors, used when 1432 dtype is None. In some cases, a caller may not have a dtype in mind when 1433 converting to a tensor, so preferred_dtype can be used as a soft 1434 preference. If the conversion to `preferred_dtype` is not possible, this 1435 argument has no effect. 1436 1437 Returns: 1438 A list of `Tensor` and/or `IndexedSlices` objects. 1439 1440 Raises: 1441 TypeError: If no conversion function is registered for an element in 1442 `values`. 1443 RuntimeError: If a registered conversion function returns an invalid 1444 value. 1445 """ 1446 return internal_convert_n_to_tensor( 1447 values=values, 1448 dtype=dtype, 1449 name=name, 1450 preferred_dtype=preferred_dtype, 1451 as_ref=False) 1452 1453 1454def convert_to_tensor_or_composite(value, dtype=None, name=None): 1455 """Converts the given object to a `Tensor` or `CompositeTensor`. 1456 1457 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1458 is converted to a `Tensor` using `convert_to_tensor()`. 1459 1460 Args: 1461 value: A `CompositeTensor` or an object that can be consumed by 1462 `convert_to_tensor()`. 1463 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1464 `CompositeTensor`. 1465 name: (Optional.) A name to use if a new `Tensor` is created. 1466 1467 Returns: 1468 A `Tensor` or `CompositeTensor`, based on `value`. 1469 1470 Raises: 1471 ValueError: If `dtype` does not match the element type of `value`. 1472 """ 1473 return internal_convert_to_tensor_or_composite( 1474 value=value, dtype=dtype, name=name, as_ref=False) 1475 1476 1477def internal_convert_to_tensor_or_composite(value, 1478 dtype=None, 1479 name=None, 1480 as_ref=False): 1481 """Converts the given object to a `Tensor` or `CompositeTensor`. 1482 1483 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1484 is converted to a `Tensor` using `convert_to_tensor()`. 1485 1486 Args: 1487 value: A `CompositeTensor`, or an object that can be consumed by 1488 `convert_to_tensor()`. 1489 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1490 `CompositeTensor`. 1491 name: (Optional.) A name to use if a new `Tensor` is created. 1492 as_ref: True if the caller wants the results as ref tensors. 1493 1494 Returns: 1495 A `Tensor` or `CompositeTensor`, based on `value`. 1496 1497 Raises: 1498 ValueError: If `dtype` does not match the element type of `value`. 1499 """ 1500 if isinstance(value, composite_tensor.CompositeTensor): 1501 value_dtype = getattr(value, "dtype", None) 1502 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 1503 raise ValueError( 1504 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1505 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1506 return value 1507 else: 1508 return convert_to_tensor( 1509 value, 1510 dtype=dtype, 1511 name=name, 1512 as_ref=as_ref, 1513 accepted_result_types=(Tensor, composite_tensor.CompositeTensor)) 1514 1515 1516def internal_convert_n_to_tensor_or_composite(values, 1517 dtype=None, 1518 name=None, 1519 as_ref=False): 1520 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 1521 1522 Any `CompositeTensor` objects in `values` are returned unmodified. 1523 1524 Args: 1525 values: A list of `None`, `CompositeTensor`, or objects that can be consumed 1526 by `convert_to_tensor()`. 1527 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1528 `CompositeTensor`s. 1529 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1530 which case element `i` will be given the name `name + '_' + i`. 1531 as_ref: True if the caller wants the results as ref tensors. 1532 1533 Returns: 1534 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 1535 1536 Raises: 1537 TypeError: If no conversion function is registered for an element in 1538 `values`. 1539 RuntimeError: If a registered conversion function returns an invalid 1540 value. 1541 """ 1542 if not isinstance(values, collections_abc.Sequence): 1543 raise TypeError("values must be a sequence.") 1544 ret = [] 1545 for i, value in enumerate(values): 1546 if value is None: 1547 ret.append(value) 1548 else: 1549 n = None if name is None else "%s_%d" % (name, i) 1550 ret.append( 1551 internal_convert_to_tensor_or_composite( 1552 value, dtype=dtype, name=n, as_ref=as_ref)) 1553 return ret 1554 1555 1556def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 1557 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 1558 1559 Any `CompositeTensor` objects in `values` are returned unmodified. 1560 1561 Args: 1562 values: A list of `None`, `CompositeTensor``, or objects that can be 1563 consumed by `convert_to_tensor()`. 1564 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1565 `CompositeTensor`s. 1566 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1567 which case element `i` will be given the name `name + '_' + i`. 1568 1569 Returns: 1570 A list of `Tensor` and/or `CompositeTensor` objects. 1571 1572 Raises: 1573 TypeError: If no conversion function is registered for an element in 1574 `values`. 1575 RuntimeError: If a registered conversion function returns an invalid 1576 value. 1577 """ 1578 return internal_convert_n_to_tensor_or_composite( 1579 values=values, dtype=dtype, name=name, as_ref=False) 1580 1581 1582def _device_string(dev_spec): 1583 if pydev.is_device_spec(dev_spec): 1584 return dev_spec.to_string() 1585 else: 1586 return dev_spec 1587 1588 1589def _NodeDef(op_type, name, attrs=None): 1590 """Create a NodeDef proto. 1591 1592 Args: 1593 op_type: Value for the "op" attribute of the NodeDef proto. 1594 name: Value for the "name" attribute of the NodeDef proto. 1595 attrs: Dictionary where the key is the attribute name (a string) 1596 and the value is the respective "attr" attribute of the NodeDef proto (an 1597 AttrValue). 1598 1599 Returns: 1600 A node_def_pb2.NodeDef protocol buffer. 1601 """ 1602 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type), 1603 name=compat.as_bytes(name)) 1604 if attrs: 1605 for k, v in six.iteritems(attrs): 1606 node_def.attr[k].CopyFrom(v) 1607 return node_def 1608 1609 1610# Copied from core/framework/node_def_util.cc 1611# TODO(mrry,josh11b): Consolidate this validation in C++ code. 1612_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/>]*$") 1613_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/>]*$") 1614 1615 1616def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None): 1617 """Creates a TF_Operation. 1618 1619 Args: 1620 graph: a `Graph`. 1621 node_def: `node_def_pb2.NodeDef` for the operation to create. 1622 inputs: A flattened list of `Tensor`s. This function handles grouping 1623 tensors into lists as per attributes in the `node_def`. 1624 control_inputs: A list of `Operation`s to set as control dependencies. 1625 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not 1626 specified, is looked up from the `graph` using `node_def.op`. 1627 1628 Returns: 1629 A wrapped TF_Operation*. 1630 """ 1631 if op_def is None: 1632 op_def = graph._get_op_def(node_def.op) # pylint: disable=protected-access 1633 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1634 # Refactor so we don't have to do this here. 1635 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr) 1636 # pylint: disable=protected-access 1637 op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph, 1638 compat.as_str(node_def.op), 1639 compat.as_str(node_def.name)) 1640 if node_def.device: 1641 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 1642 # Add inputs 1643 for op_input in inputs: 1644 if isinstance(op_input, (list, tuple)): 1645 pywrap_tf_session.TF_AddInputList(op_desc, 1646 [t._as_tf_output() for t in op_input]) 1647 else: 1648 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output()) 1649 1650 # Add control inputs 1651 for control_input in control_inputs: 1652 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op) 1653 # pylint: enable=protected-access 1654 1655 # Add attrs 1656 for name, attr_value in node_def.attr.items(): 1657 serialized = attr_value.SerializeToString() 1658 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1659 # It might be worth creating a convenient way to re-use the same status. 1660 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name), 1661 serialized) 1662 1663 try: 1664 c_op = pywrap_tf_session.TF_FinishOperation(op_desc) 1665 except errors.InvalidArgumentError as e: 1666 # Convert to ValueError for backwards compatibility. 1667 raise ValueError(str(e)) 1668 1669 return c_op 1670 1671 1672@tf_export("Operation") 1673class Operation(object): 1674 """Represents a graph node that performs computation on tensors. 1675 1676 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor` 1677 objects as input, and produces zero or more `Tensor` objects as output. 1678 Objects of type `Operation` are created by calling a Python op constructor 1679 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default` 1680 context manager. 1681 1682 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an 1683 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and 1684 produces `c` as output. 1685 1686 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be 1687 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for 1688 calling `tf.compat.v1.get_default_session().run(op)`. 1689 """ 1690 1691 def __init__(self, 1692 node_def, 1693 g, 1694 inputs=None, 1695 output_types=None, 1696 control_inputs=None, 1697 input_types=None, 1698 original_op=None, 1699 op_def=None): 1700 r"""Creates an `Operation`. 1701 1702 NOTE: This constructor validates the name of the `Operation` (passed 1703 as `node_def.name`). Valid `Operation` names match the following 1704 regular expression: 1705 1706 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 1707 1708 Args: 1709 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for 1710 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and 1711 `device`. The `input` attribute is irrelevant here as it will be 1712 computed when generating the model. 1713 g: `Graph`. The parent graph. 1714 inputs: list of `Tensor` objects. The inputs to this `Operation`. 1715 output_types: list of `DType` objects. List of the types of the `Tensors` 1716 computed by this operation. The length of this list indicates the 1717 number of output endpoints of the `Operation`. 1718 control_inputs: list of operations or tensors from which to have a control 1719 dependency. 1720 input_types: List of `DType` objects representing the types of the tensors 1721 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x 1722 in inputs]`. Operations that expect reference-typed inputs must specify 1723 these explicitly. 1724 original_op: Optional. Used to associate the new `Operation` with an 1725 existing `Operation` (for example, a replica with the op that was 1726 replicated). 1727 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type 1728 that this `Operation` represents. 1729 1730 Raises: 1731 TypeError: if control inputs are not Operations or Tensors, 1732 or if `node_def` is not a `NodeDef`, 1733 or if `g` is not a `Graph`, 1734 or if `inputs` are not tensors, 1735 or if `inputs` and `input_types` are incompatible. 1736 ValueError: if the `node_def` name is not valid. 1737 """ 1738 # For internal use only: `node_def` can be set to a TF_Operation to create 1739 # an Operation for that op. This is useful for creating Operations for ops 1740 # indirectly created by C API methods, e.g. the ops created by 1741 # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields 1742 # should be None. 1743 1744 if isinstance(node_def, node_def_pb2.NodeDef): 1745 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 1746 raise ValueError( 1747 "Cannot create a tensor proto whose content is larger than 2GB.") 1748 if not _VALID_OP_NAME_REGEX.match(node_def.name): 1749 raise ValueError("'%s' is not a valid node name" % node_def.name) 1750 c_op = None 1751 elif type(node_def).__name__ == "TF_Operation": 1752 assert inputs is None 1753 assert output_types is None 1754 assert control_inputs is None 1755 assert input_types is None 1756 assert original_op is None 1757 assert op_def is None 1758 c_op = node_def 1759 else: 1760 raise TypeError("node_def needs to be a NodeDef: %s" % node_def) 1761 1762 if not isinstance(g, Graph): 1763 raise TypeError("g needs to be a Graph: %s" % g) 1764 self._graph = g 1765 1766 if inputs is None: 1767 inputs = [] 1768 elif not isinstance(inputs, list): 1769 raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) 1770 for a in inputs: 1771 if not isinstance(a, Tensor): 1772 raise TypeError("input needs to be a Tensor: %s" % a) 1773 if input_types is None: 1774 input_types = [i.dtype.base_dtype for i in inputs] 1775 else: 1776 if not all( 1777 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)): 1778 raise TypeError("In op '%s', input types (%s) are not compatible " 1779 "with expected types (%s)" % 1780 (node_def.name, [i.dtype for i in inputs], input_types)) 1781 1782 # Build the list of control inputs. 1783 control_input_ops = [] 1784 if control_inputs: 1785 for c in control_inputs: 1786 control_op = None 1787 if isinstance(c, Operation): 1788 control_op = c 1789 elif isinstance(c, (Tensor, IndexedSlices)): 1790 control_op = c.op 1791 else: 1792 raise TypeError("Control input must be an Operation, " 1793 "a Tensor, or IndexedSlices: %s" % c) 1794 control_input_ops.append(control_op) 1795 1796 # This will be set by self.inputs. 1797 self._inputs_val = None 1798 1799 # pylint: disable=protected-access 1800 self._original_op = original_op 1801 self._traceback = tf_stack.extract_stack() 1802 1803 # List of _UserDevSpecs holding code location of device context manager 1804 # invocations and the users original argument to them. 1805 self._device_code_locations = None 1806 # Dict mapping op name to file and line information for op colocation 1807 # context managers. 1808 self._colocation_code_locations = None 1809 self._control_flow_context = self.graph._get_control_flow_context() 1810 1811 # Gradient function for this op. There are three ways to specify gradient 1812 # function, and first available gradient gets used, in the following order. 1813 # 1. self._gradient_function 1814 # 2. Gradient name registered by "_gradient_op_type" attribute. 1815 # 3. Gradient name registered by op.type. 1816 self._gradient_function = None 1817 1818 # Initialize self._c_op. 1819 if c_op: 1820 self._c_op = c_op 1821 op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op)) 1822 name = self.name 1823 else: 1824 if op_def is None: 1825 op_def = self._graph._get_op_def(node_def.op) 1826 self._c_op = _create_c_op(self._graph, node_def, inputs, 1827 control_input_ops, op_def) 1828 name = compat.as_str(node_def.name) 1829 # pylint: enable=protected-access 1830 1831 self._is_stateful = op_def.is_stateful 1832 1833 # Initialize self._outputs. 1834 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 1835 self._outputs = [] 1836 for i in range(num_outputs): 1837 tf_output = c_api_util.tf_output(self._c_op, i) 1838 output_type = pywrap_tf_session.TF_OperationOutputType(tf_output) 1839 tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access 1840 self._outputs.append(tensor) 1841 1842 self._id_value = self._graph._add_op(self, name) # pylint: disable=protected-access 1843 1844 if not c_op: 1845 self._control_flow_post_processing(input_tensors=inputs) 1846 1847 def _control_flow_post_processing(self, input_tensors=None): 1848 """Add this op to its control flow context. 1849 1850 This may add new ops and change this op's inputs. self.inputs must be 1851 available before calling this method. 1852 1853 Args: 1854 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs 1855 of this op, which should be equivalent to `self.inputs`. Pass this 1856 argument to avoid evaluating `self.inputs` unnecessarily. 1857 """ 1858 if input_tensors is None: 1859 input_tensors = self.inputs 1860 for input_tensor in input_tensors: 1861 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 1862 if self._control_flow_context is not None: 1863 self._control_flow_context.AddOp(self) 1864 1865 def colocation_groups(self): 1866 """Returns the list of colocation groups of the op.""" 1867 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)] 1868 try: 1869 class_attr = self.get_attr("_class") 1870 except ValueError: 1871 # This op has no explicit colocation group, so it is itself its 1872 # own root of a colocation group. 1873 return default_colocation_group 1874 1875 attr_groups = [ 1876 class_name for class_name in class_attr 1877 if class_name.startswith(b"loc:@") 1878 ] 1879 1880 # If there are no colocation groups in the explicit _class field, 1881 # return the default colocation group. 1882 return attr_groups if attr_groups else default_colocation_group 1883 1884 def values(self): 1885 """DEPRECATED: Use outputs.""" 1886 return tuple(self.outputs) 1887 1888 def _get_control_flow_context(self): 1889 """Returns the control flow context of this op. 1890 1891 Returns: 1892 A context object. 1893 """ 1894 return self._control_flow_context 1895 1896 def _set_control_flow_context(self, ctx): 1897 """Sets the current control flow context of this op. 1898 1899 Args: 1900 ctx: a context object. 1901 """ 1902 self._control_flow_context = ctx 1903 1904 @property 1905 def name(self): 1906 """The full name of this operation.""" 1907 return pywrap_tf_session.TF_OperationName(self._c_op) 1908 1909 @property 1910 def _id(self): 1911 """The unique integer id of this operation.""" 1912 return self._id_value 1913 1914 @property 1915 def device(self): 1916 """The name of the device to which this op has been assigned, if any. 1917 1918 Returns: 1919 The string name of the device to which this op has been 1920 assigned, or an empty string if it has not been assigned to a 1921 device. 1922 """ 1923 return pywrap_tf_session.TF_OperationDevice(self._c_op) 1924 1925 @property 1926 def _device_assignments(self): 1927 """Code locations for device context managers active at op creation. 1928 1929 This property will return a list of traceable_stack.TraceableObject 1930 instances where .obj is a string representing the assigned device 1931 (or information about the function that would be applied to this op 1932 to compute the desired device) and the filename and lineno members 1933 record the location of the relevant device context manager. 1934 1935 For example, suppose file_a contained these lines: 1936 1937 file_a.py: 1938 15: with tf.device('/gpu:0'): 1939 16: node_b = tf.constant(4, name='NODE_B') 1940 1941 Then a TraceableObject t_obj representing the device context manager 1942 would have these member values: 1943 1944 t_obj.obj -> '/gpu:0' 1945 t_obj.filename = 'file_a.py' 1946 t_obj.lineno = 15 1947 1948 and node_b.op._device_assignments would return the list [t_obj]. 1949 1950 Returns: 1951 [str: traceable_stack.TraceableObject, ...] as per this method's 1952 description, above. 1953 """ 1954 return self._device_code_locations or [] 1955 1956 @property 1957 def _colocation_dict(self): 1958 """Code locations for colocation context managers active at op creation. 1959 1960 This property will return a dictionary for which the keys are nodes with 1961 which this Operation is colocated, and for which the values are 1962 traceable_stack.TraceableObject instances. The TraceableObject instances 1963 record the location of the relevant colocation context manager but have the 1964 "obj" field set to None to prevent leaking private data. 1965 1966 For example, suppose file_a contained these lines: 1967 1968 file_a.py: 1969 14: node_a = tf.constant(3, name='NODE_A') 1970 15: with tf.compat.v1.colocate_with(node_a): 1971 16: node_b = tf.constant(4, name='NODE_B') 1972 1973 Then a TraceableObject t_obj representing the colocation context manager 1974 would have these member values: 1975 1976 t_obj.obj -> None 1977 t_obj.filename = 'file_a.py' 1978 t_obj.lineno = 15 1979 1980 and node_b.op._colocation_dict would return the dictionary 1981 1982 { 'NODE_A': t_obj } 1983 1984 Returns: 1985 {str: traceable_stack.TraceableObject} as per this method's description, 1986 above. 1987 """ 1988 locations_dict = self._colocation_code_locations or {} 1989 return locations_dict.copy() 1990 1991 @property 1992 def _output_types(self): 1993 """List this operation's output types. 1994 1995 Returns: 1996 List of the types of the Tensors computed by this operation. 1997 Each element in the list is an integer whose value is one of 1998 the TF_DataType enums defined in pywrap_tf_session.h 1999 The length of this list indicates the number of output endpoints 2000 of the operation. 2001 """ 2002 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2003 output_types = [ 2004 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i))) 2005 for i in xrange(num_outputs) 2006 ] 2007 2008 return output_types 2009 2010 def _tf_output(self, output_idx): 2011 """Create and return a new TF_Output for output_idx'th output of this op.""" 2012 tf_output = pywrap_tf_session.TF_Output() 2013 tf_output.oper = self._c_op 2014 tf_output.index = output_idx 2015 return tf_output 2016 2017 def _tf_input(self, input_idx): 2018 """Create and return a new TF_Input for input_idx'th input of this op.""" 2019 tf_input = pywrap_tf_session.TF_Input() 2020 tf_input.oper = self._c_op 2021 tf_input.index = input_idx 2022 return tf_input 2023 2024 def _set_device(self, device): # pylint: disable=redefined-outer-name 2025 """Set the device of this operation. 2026 2027 Args: 2028 device: string or device.. The device to set. 2029 """ 2030 self._set_device_from_string(compat.as_str(_device_string(device))) 2031 2032 def _set_device_from_string(self, device_str): 2033 """Fast path to set device if the type is known to be a string. 2034 2035 This function is called frequently enough during graph construction that 2036 there are non-trivial performance gains if the caller can guarantee that 2037 the specified device is already a string. 2038 2039 Args: 2040 device_str: A string specifying where to place this op. 2041 """ 2042 pywrap_tf_session.SetRequestedDevice( 2043 self._graph._c_graph, # pylint: disable=protected-access 2044 self._c_op, # pylint: disable=protected-access 2045 device_str) 2046 2047 def _update_input(self, index, tensor): 2048 """Update the input to this operation at the given index. 2049 2050 NOTE: This is for TF internal use only. Please don't use it. 2051 2052 Args: 2053 index: the index of the input to update. 2054 tensor: the Tensor to be used as the input at the given index. 2055 2056 Raises: 2057 TypeError: if tensor is not a Tensor, 2058 or if input tensor type is not convertible to dtype. 2059 ValueError: if the Tensor is from a different graph. 2060 """ 2061 if not isinstance(tensor, Tensor): 2062 raise TypeError("tensor must be a Tensor: %s" % tensor) 2063 _assert_same_graph(self, tensor) 2064 2065 # Reset cached inputs. 2066 self._inputs_val = None 2067 pywrap_tf_session.UpdateEdge( 2068 self._graph._c_graph, # pylint: disable=protected-access 2069 tensor._as_tf_output(), # pylint: disable=protected-access 2070 self._tf_input(index)) 2071 2072 def _add_while_inputs(self, tensors): 2073 """See AddWhileInputHack in python_api.h. 2074 2075 NOTE: This is for TF internal use only. Please don't use it. 2076 2077 Args: 2078 tensors: list of Tensors 2079 2080 Raises: 2081 TypeError: if tensor is not a Tensor, 2082 or if input tensor type is not convertible to dtype. 2083 ValueError: if the Tensor is from a different graph. 2084 """ 2085 for tensor in tensors: 2086 if not isinstance(tensor, Tensor): 2087 raise TypeError("tensor must be a Tensor: %s" % tensor) 2088 _assert_same_graph(self, tensor) 2089 2090 # Reset cached inputs. 2091 self._inputs_val = None 2092 pywrap_tf_session.AddWhileInputHack( 2093 self._graph._c_graph, # pylint: disable=protected-access 2094 tensor._as_tf_output(), # pylint: disable=protected-access 2095 self._c_op) 2096 2097 def _add_control_inputs(self, ops): 2098 """Add a list of new control inputs to this operation. 2099 2100 Args: 2101 ops: the list of Operations to add as control input. 2102 2103 Raises: 2104 TypeError: if ops is not a list of Operations. 2105 ValueError: if any op in ops is from a different graph. 2106 """ 2107 for op in ops: 2108 if not isinstance(op, Operation): 2109 raise TypeError("op must be an Operation: %s" % op) 2110 pywrap_tf_session.AddControlInput( 2111 self._graph._c_graph, # pylint: disable=protected-access 2112 self._c_op, # pylint: disable=protected-access 2113 op._c_op) # pylint: disable=protected-access 2114 2115 def _add_control_input(self, op): 2116 """Add a new control input to this operation. 2117 2118 Args: 2119 op: the Operation to add as control input. 2120 2121 Raises: 2122 TypeError: if op is not an Operation. 2123 ValueError: if op is from a different graph. 2124 """ 2125 if not isinstance(op, Operation): 2126 raise TypeError("op must be an Operation: %s" % op) 2127 pywrap_tf_session.AddControlInput( 2128 self._graph._c_graph, # pylint: disable=protected-access 2129 self._c_op, # pylint: disable=protected-access 2130 op._c_op) # pylint: disable=protected-access 2131 2132 def _remove_all_control_inputs(self): 2133 """Removes any control inputs to this operation.""" 2134 pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access 2135 2136 def _add_outputs(self, types, shapes): 2137 """Adds new Tensors to self.outputs. 2138 2139 Note: this is generally unsafe to use. This is used in certain situations in 2140 conjunction with _set_type_list_attr. 2141 2142 Arguments: 2143 types: list of DTypes 2144 shapes: list of TensorShapes 2145 """ 2146 assert len(types) == len(shapes) 2147 orig_num_outputs = len(self.outputs) 2148 for i in range(len(types)): 2149 t = Tensor(self, orig_num_outputs + i, types[i]) 2150 self._outputs.append(t) 2151 t.set_shape(shapes[i]) 2152 2153 def __str__(self): 2154 return str(self.node_def) 2155 2156 def __repr__(self): 2157 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 2158 2159 @property 2160 def outputs(self): 2161 """The list of `Tensor` objects representing the outputs of this op.""" 2162 return self._outputs 2163 2164 @property 2165 def inputs(self): 2166 """The sequence of `Tensor` objects representing the data inputs of this op.""" 2167 if self._inputs_val is None: 2168 # pylint: disable=protected-access 2169 self._inputs_val = tuple( 2170 map(self.graph._get_tensor_by_tf_output, 2171 pywrap_tf_session.GetOperationInputs(self._c_op))) 2172 # pylint: enable=protected-access 2173 return self._inputs_val 2174 2175 @property 2176 def _input_types(self): 2177 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op) 2178 input_types = [ 2179 dtypes.as_dtype( 2180 pywrap_tf_session.TF_OperationInputType(self._tf_input(i))) 2181 for i in xrange(num_inputs) 2182 ] 2183 return input_types 2184 2185 @property 2186 def control_inputs(self): 2187 """The `Operation` objects on which this op has a control dependency. 2188 2189 Before this op is executed, TensorFlow will ensure that the 2190 operations in `self.control_inputs` have finished executing. This 2191 mechanism can be used to run ops sequentially for performance 2192 reasons, or to ensure that the side effects of an op are observed 2193 in the correct order. 2194 2195 Returns: 2196 A list of `Operation` objects. 2197 2198 """ 2199 control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper( 2200 self._c_op) 2201 # pylint: disable=protected-access 2202 return [ 2203 self.graph._get_operation_by_name_unsafe( 2204 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2205 ] 2206 # pylint: enable=protected-access 2207 2208 @property 2209 def _control_outputs(self): 2210 """The `Operation` objects which have a control dependency on this op. 2211 2212 Before any of the ops in self._control_outputs can execute tensorflow will 2213 ensure self has finished executing. 2214 2215 Returns: 2216 A list of `Operation` objects. 2217 2218 """ 2219 control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper( 2220 self._c_op) 2221 # pylint: disable=protected-access 2222 return [ 2223 self.graph._get_operation_by_name_unsafe( 2224 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2225 ] 2226 # pylint: enable=protected-access 2227 2228 @property 2229 def type(self): 2230 """The type of the op (e.g. `"MatMul"`).""" 2231 return pywrap_tf_session.TF_OperationOpType(self._c_op) 2232 2233 @property 2234 def graph(self): 2235 """The `Graph` that contains this operation.""" 2236 return self._graph 2237 2238 @property 2239 def node_def(self): 2240 # pylint: disable=line-too-long 2241 """Returns the `NodeDef` representation of this operation. 2242 2243 Returns: 2244 A 2245 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2246 protocol buffer. 2247 """ 2248 # pylint: enable=line-too-long 2249 with c_api_util.tf_buffer() as buf: 2250 pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf) 2251 data = pywrap_tf_session.TF_GetBuffer(buf) 2252 node_def = node_def_pb2.NodeDef() 2253 node_def.ParseFromString(compat.as_bytes(data)) 2254 return node_def 2255 2256 @property 2257 def op_def(self): 2258 # pylint: disable=line-too-long 2259 """Returns the `OpDef` proto that represents the type of this op. 2260 2261 Returns: 2262 An 2263 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2264 protocol buffer. 2265 """ 2266 # pylint: enable=line-too-long 2267 return self._graph._get_op_def(self.type) 2268 2269 @property 2270 def traceback(self): 2271 """Returns the call stack from when this operation was constructed.""" 2272 return self._traceback 2273 2274 def _set_attr(self, attr_name, attr_value): 2275 """Private method used to set an attribute in the node_def.""" 2276 buf = pywrap_tf_session.TF_NewBufferFromString( 2277 compat.as_bytes(attr_value.SerializeToString())) 2278 try: 2279 self._set_attr_with_buf(attr_name, buf) 2280 finally: 2281 pywrap_tf_session.TF_DeleteBuffer(buf) 2282 2283 def _set_attr_with_buf(self, attr_name, attr_buf): 2284 """Set an attr in the node_def with a pre-allocated buffer.""" 2285 # pylint: disable=protected-access 2286 pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name, 2287 attr_buf) 2288 # pylint: enable=protected-access 2289 2290 def _set_func_attr(self, attr_name, func_name): 2291 """Private method used to set a function attribute in the node_def.""" 2292 func = attr_value_pb2.NameAttrList(name=func_name) 2293 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 2294 2295 def _set_func_list_attr(self, attr_name, func_names): 2296 """Private method used to set a list(function) attribute in the node_def.""" 2297 funcs = [attr_value_pb2.NameAttrList(name=func_name) 2298 for func_name in func_names] 2299 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) 2300 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) 2301 2302 def _set_type_list_attr(self, attr_name, types): 2303 """Private method used to set a list(type) attribute in the node_def.""" 2304 if not types: 2305 return 2306 if isinstance(types[0], dtypes.DType): 2307 types = [dt.as_datatype_enum for dt in types] 2308 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 2309 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 2310 2311 def _set_shape_list_attr(self, attr_name, shapes): 2312 """Private method used to set a list(shape) attribute in the node_def.""" 2313 shapes = [s.as_proto() for s in shapes] 2314 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 2315 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 2316 2317 def _clear_attr(self, attr_name): 2318 """Private method used to clear an attribute in the node_def.""" 2319 # pylint: disable=protected-access 2320 pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name) 2321 # pylint: enable=protected-access 2322 2323 def get_attr(self, name): 2324 """Returns the value of the attr of this op with the given `name`. 2325 2326 Args: 2327 name: The name of the attr to fetch. 2328 2329 Returns: 2330 The value of the attr, as a Python object. 2331 2332 Raises: 2333 ValueError: If this op does not have an attr with the given `name`. 2334 """ 2335 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 2336 try: 2337 with c_api_util.tf_buffer() as buf: 2338 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf) 2339 data = pywrap_tf_session.TF_GetBuffer(buf) 2340 except errors.InvalidArgumentError as e: 2341 # Convert to ValueError for backwards compatibility. 2342 raise ValueError(str(e)) 2343 x = attr_value_pb2.AttrValue() 2344 x.ParseFromString(data) 2345 2346 oneof_value = x.WhichOneof("value") 2347 if oneof_value is None: 2348 return [] 2349 if oneof_value == "list": 2350 for f in fields: 2351 if getattr(x.list, f): 2352 if f == "type": 2353 return [dtypes.as_dtype(t) for t in x.list.type] 2354 else: 2355 return list(getattr(x.list, f)) 2356 return [] 2357 if oneof_value == "type": 2358 return dtypes.as_dtype(x.type) 2359 assert oneof_value in fields, "Unsupported field type in " + str(x) 2360 return getattr(x, oneof_value) 2361 2362 def _get_attr_type(self, name): 2363 """Returns the `DType` value of the attr of this op with the given `name`.""" 2364 try: 2365 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name) 2366 return _DTYPES_INTERN_TABLE[dtype_enum] 2367 except errors.InvalidArgumentError as e: 2368 # Convert to ValueError for backwards compatibility. 2369 raise ValueError(str(e)) 2370 2371 def _get_attr_bool(self, name): 2372 """Returns the `bool` value of the attr of this op with the given `name`.""" 2373 try: 2374 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name) 2375 except errors.InvalidArgumentError as e: 2376 # Convert to ValueError for backwards compatibility. 2377 raise ValueError(str(e)) 2378 2379 def _get_attr_int(self, name): 2380 """Returns the `int` value of the attr of this op with the given `name`.""" 2381 try: 2382 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name) 2383 except errors.InvalidArgumentError as e: 2384 # Convert to ValueError for backwards compatibility. 2385 raise ValueError(str(e)) 2386 2387 def run(self, feed_dict=None, session=None): 2388 """Runs this operation in a `Session`. 2389 2390 Calling this method will execute all preceding operations that 2391 produce the inputs needed for this operation. 2392 2393 *N.B.* Before invoking `Operation.run()`, its graph must have been 2394 launched in a session, and either a default session must be 2395 available, or `session` must be specified explicitly. 2396 2397 Args: 2398 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 2399 `tf.Session.run` for a description of the valid feed values. 2400 session: (Optional.) The `Session` to be used to run to this operation. If 2401 none, the default session will be used. 2402 """ 2403 _run_using_default_session(self, feed_dict, self.graph, session) 2404 2405_gradient_registry = registry.Registry("gradient") 2406 2407 2408@tf_export("RegisterGradient") 2409class RegisterGradient(object): 2410 """A decorator for registering the gradient function for an op type. 2411 2412 This decorator is only used when defining a new op type. For an op 2413 with `m` inputs and `n` outputs, the gradient function is a function 2414 that takes the original `Operation` and `n` `Tensor` objects 2415 (representing the gradients with respect to each output of the op), 2416 and returns `m` `Tensor` objects (representing the partial gradients 2417 with respect to each input of the op). 2418 2419 For example, assuming that operations of type `"Sub"` take two 2420 inputs `x` and `y`, and return a single output `x - y`, the 2421 following gradient function would be registered: 2422 2423 ```python 2424 @tf.RegisterGradient("Sub") 2425 def _sub_grad(unused_op, grad): 2426 return grad, tf.negative(grad) 2427 ``` 2428 2429 The decorator argument `op_type` is the string type of an 2430 operation. This corresponds to the `OpDef.name` field for the proto 2431 that defines the operation. 2432 """ 2433 2434 def __init__(self, op_type): 2435 """Creates a new decorator with `op_type` as the Operation type. 2436 2437 Args: 2438 op_type: The string type of an operation. This corresponds to the 2439 `OpDef.name` field for the proto that defines the operation. 2440 2441 Raises: 2442 TypeError: If `op_type` is not string. 2443 """ 2444 if not isinstance(op_type, six.string_types): 2445 raise TypeError("op_type must be a string") 2446 self._op_type = op_type 2447 2448 def __call__(self, f): 2449 """Registers the function `f` as gradient function for `op_type`.""" 2450 _gradient_registry.register(f, self._op_type) 2451 return f 2452 2453 2454@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 2455@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 2456def no_gradient(op_type): 2457 """Specifies that ops of type `op_type` is not differentiable. 2458 2459 This function should *not* be used for operations that have a 2460 well-defined gradient that is not yet implemented. 2461 2462 This function is only used when defining a new op type. It may be 2463 used for ops such as `tf.size()` that are not differentiable. For 2464 example: 2465 2466 ```python 2467 tf.no_gradient("Size") 2468 ``` 2469 2470 The gradient computed for 'op_type' will then propagate zeros. 2471 2472 For ops that have a well-defined gradient but are not yet implemented, 2473 no declaration should be made, and an error *must* be thrown if 2474 an attempt to request its gradient is made. 2475 2476 Args: 2477 op_type: The string type of an operation. This corresponds to the 2478 `OpDef.name` field for the proto that defines the operation. 2479 2480 Raises: 2481 TypeError: If `op_type` is not a string. 2482 2483 """ 2484 if not isinstance(op_type, six.string_types): 2485 raise TypeError("op_type must be a string") 2486 _gradient_registry.register(None, op_type) 2487 2488 2489# Aliases for the old names, will be eventually removed. 2490NoGradient = no_gradient 2491NotDifferentiable = no_gradient 2492 2493 2494def get_gradient_function(op): 2495 """Returns the function that computes gradients for "op".""" 2496 if not op.inputs: 2497 return None 2498 2499 gradient_function = op._gradient_function # pylint: disable=protected-access 2500 if gradient_function: 2501 return gradient_function 2502 2503 try: 2504 op_type = op.get_attr("_gradient_op_type") 2505 except ValueError: 2506 op_type = op.type 2507 return _gradient_registry.lookup(op_type) 2508 2509 2510def set_shape_and_handle_data_for_outputs(_): 2511 """No op. TODO(b/74620627): Remove this.""" 2512 pass 2513 2514 2515class OpStats(object): 2516 """A holder for statistics about an operator. 2517 2518 This class holds information about the resource requirements for an op, 2519 including the size of its weight parameters on-disk and how many FLOPS it 2520 requires to execute forward inference. 2521 2522 If you define a new operation, you can create a function that will return a 2523 set of information about its usage of the CPU and disk space when serialized. 2524 The function itself takes a Graph object that's been set up so you can call 2525 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2526 argument. 2527 2528 """ 2529 2530 def __init__(self, statistic_type, value=None): 2531 """Sets up the initial placeholders for the statistics.""" 2532 self.statistic_type = statistic_type 2533 self.value = value 2534 2535 @property 2536 def statistic_type(self): 2537 return self._statistic_type 2538 2539 @statistic_type.setter 2540 def statistic_type(self, statistic_type): 2541 self._statistic_type = statistic_type 2542 2543 @property 2544 def value(self): 2545 return self._value 2546 2547 @value.setter 2548 def value(self, value): 2549 self._value = value 2550 2551 def __iadd__(self, other): 2552 if other.statistic_type != self.statistic_type: 2553 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2554 (self.statistic_type, other.statistic_type)) 2555 if self.value is None: 2556 self.value = other.value 2557 elif other.value is not None: 2558 self._value += other.value 2559 return self 2560 2561 2562_stats_registry = registry.Registry("statistical functions") 2563 2564 2565class RegisterStatistics(object): 2566 """A decorator for registering the statistics function for an op type. 2567 2568 This decorator can be defined for an op type so that it gives a 2569 report on the resources used by an instance of an operator, in the 2570 form of an OpStats object. 2571 2572 Well-known types of statistics include these so far: 2573 2574 - flops: When running a graph, the bulk of the computation happens doing 2575 numerical calculations like matrix multiplications. This type allows a node 2576 to return how many floating-point operations it takes to complete. The 2577 total number of FLOPs for a graph is a good guide to its expected latency. 2578 2579 You can add your own statistics just by picking a new type string, registering 2580 functions for the ops you care about, and then calling get_stats_for_node_def. 2581 2582 If a statistic for an op is registered multiple times, a KeyError will be 2583 raised. 2584 2585 Since the statistics is counted on a per-op basis. It is not suitable for 2586 model parameters (capacity), which is expected to be counted only once, even 2587 if it is shared by multiple ops. (e.g. RNN) 2588 2589 For example, you can define a new metric called doohickey for a Foo operation 2590 by placing this in your code: 2591 2592 ```python 2593 @ops.RegisterStatistics("Foo", "doohickey") 2594 def _calc_foo_bojangles(unused_graph, unused_node_def): 2595 return ops.OpStats("doohickey", 20) 2596 ``` 2597 2598 Then in client code you can retrieve the value by making this call: 2599 2600 ```python 2601 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2602 ``` 2603 2604 If the NodeDef is for an op with a registered doohickey function, you'll get 2605 back the calculated amount in doohickey.value, or None if it's not defined. 2606 2607 """ 2608 2609 def __init__(self, op_type, statistic_type): 2610 """Saves the `op_type` as the `Operation` type.""" 2611 if not isinstance(op_type, six.string_types): 2612 raise TypeError("op_type must be a string.") 2613 if "," in op_type: 2614 raise TypeError("op_type must not contain a comma.") 2615 self._op_type = op_type 2616 if not isinstance(statistic_type, six.string_types): 2617 raise TypeError("statistic_type must be a string.") 2618 if "," in statistic_type: 2619 raise TypeError("statistic_type must not contain a comma.") 2620 self._statistic_type = statistic_type 2621 2622 def __call__(self, f): 2623 """Registers "f" as the statistics function for "op_type".""" 2624 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 2625 return f 2626 2627 2628def get_stats_for_node_def(graph, node, statistic_type): 2629 """Looks up the node's statistics function in the registry and calls it. 2630 2631 This function takes a Graph object and a NodeDef from a GraphDef, and if 2632 there's an associated statistics method, calls it and returns a result. If no 2633 function has been registered for the particular node type, it returns an empty 2634 statistics object. 2635 2636 Args: 2637 graph: A Graph object that's been set up with the node's graph. 2638 node: A NodeDef describing the operator. 2639 statistic_type: A string identifying the statistic we're interested in. 2640 2641 Returns: 2642 An OpStats object containing information about resource usage. 2643 """ 2644 2645 try: 2646 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 2647 result = stats_func(graph, node) 2648 except LookupError: 2649 result = OpStats(statistic_type) 2650 return result 2651 2652 2653def name_from_scope_name(name): 2654 """Returns the name of an op given the name of its scope. 2655 2656 Args: 2657 name: the name of the scope. 2658 2659 Returns: 2660 the name of the op (equal to scope name minus any trailing slash). 2661 """ 2662 return name[:-1] if (name and name[-1] == "/") else name 2663 2664 2665_MUTATION_LOCK_GROUP = 0 2666_SESSION_RUN_LOCK_GROUP = 1 2667 2668 2669@tf_export("Graph") 2670class Graph(object): 2671 """A TensorFlow computation, represented as a dataflow graph. 2672 2673 Graphs are used by `tf.function`s to represent the function's computations. 2674 Each graph contains a set of `tf.Operation` objects, which represent units of 2675 computation; and `tf.Tensor` objects, which represent the units of data that 2676 flow between operations. 2677 2678 ### Using graphs directly (deprecated) 2679 2680 A `tf.Graph` can be constructed and used directly without a `tf.function`, as 2681 was required in TensorFlow 1, but this is deprecated and it is recommended to 2682 use a `tf.function` instead. If a graph is directly used, other deprecated 2683 TensorFlow 1 classes are also required to execute the graph, such as a 2684 `tf.compat.v1.Session`. 2685 2686 A default graph can be registered with the `tf.Graph.as_default` context 2687 manager. Then, operations will be added to the graph instead of being executed 2688 eagerly. For example: 2689 2690 ```python 2691 g = tf.Graph() 2692 with g.as_default(): 2693 # Define operations and tensors in `g`. 2694 c = tf.constant(30.0) 2695 assert c.graph is g 2696 ``` 2697 2698 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph. 2699 2700 Important note: This class *is not* thread-safe for graph construction. All 2701 operations should be created from a single thread, or external 2702 synchronization must be provided. Unless otherwise specified, all methods 2703 are not thread-safe. 2704 2705 A `Graph` instance supports an arbitrary number of "collections" 2706 that are identified by name. For convenience when building a large 2707 graph, collections can store groups of related objects: for 2708 example, the `tf.Variable` uses a collection (named 2709 `tf.GraphKeys.GLOBAL_VARIABLES`) for 2710 all variables that are created during the construction of a graph. The caller 2711 may define additional collections by specifying a new name. 2712 """ 2713 2714 def __init__(self): 2715 """Creates a new, empty Graph.""" 2716 # Protects core state that can be returned via public accessors. 2717 # Thread-safety is provided on a best-effort basis to support buggy 2718 # programs, and is not guaranteed by the public `tf.Graph` API. 2719 # 2720 # NOTE(mrry): This does not protect the various stacks. A warning will 2721 # be reported if these are used from multiple threads 2722 self._lock = threading.RLock() 2723 # The group lock synchronizes Session.run calls with methods that create 2724 # and mutate ops (e.g. Graph.create_op()). This synchronization is 2725 # necessary because it's illegal to modify an operation after it's been run. 2726 # The group lock allows any number of threads to mutate ops at the same time 2727 # but if any modification is going on, all Session.run calls have to wait. 2728 # Similarly, if one or more Session.run calls are going on, all mutate ops 2729 # have to wait until all Session.run calls have finished. 2730 self._group_lock = lock_util.GroupLock(num_groups=2) 2731 self._nodes_by_id = {} # GUARDED_BY(self._lock) 2732 self._next_id_counter = 0 # GUARDED_BY(self._lock) 2733 self._nodes_by_name = {} # GUARDED_BY(self._lock) 2734 self._version = 0 # GUARDED_BY(self._lock) 2735 # Maps a name used in the graph to the next id to use for that name. 2736 self._names_in_use = {} 2737 self._stack_state_is_thread_local = False 2738 self._thread_local = threading.local() 2739 # Functions that will be applied to choose a device if none is specified. 2740 # In TF2.x or after switch_to_thread_local(), 2741 # self._thread_local._device_function_stack is used instead. 2742 self._graph_device_function_stack = traceable_stack.TraceableStack() 2743 # Default original_op applied to new ops. 2744 self._default_original_op = None 2745 # Current control flow context. It could be either CondContext or 2746 # WhileContext defined in ops/control_flow_ops.py 2747 self._control_flow_context = None 2748 # A new node will depend of the union of all of the nodes in the stack. 2749 # In TF2.x or after switch_to_thread_local(), 2750 # self._thread_local._control_dependencies_stack is used instead. 2751 self._graph_control_dependencies_stack = [] 2752 # Arbitrary collections of objects. 2753 self._collections = {} 2754 # The graph-level random seed 2755 self._seed = None 2756 # A dictionary of attributes that should be applied to all ops. 2757 self._attr_scope_map = {} 2758 # A map from op type to the kernel label that should be used. 2759 self._op_to_kernel_label_map = {} 2760 # A map from op type to an alternative op type that should be used when 2761 # computing gradients. 2762 self._gradient_override_map = {} 2763 # A map from op type to a gradient function that should be used instead. 2764 self._gradient_function_map = {} 2765 # True if the graph is considered "finalized". In that case no 2766 # new operations can be added. 2767 self._finalized = False 2768 # Functions defined in the graph 2769 self._functions = collections.OrderedDict() 2770 # Default GraphDef versions 2771 self._graph_def_versions = versions_pb2.VersionDef( 2772 producer=versions.GRAPH_DEF_VERSION, 2773 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 2774 self._building_function = False 2775 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 2776 # self._thread_local._colocation_stack is used instead. 2777 self._graph_colocation_stack = traceable_stack.TraceableStack() 2778 # Set of tensors that are dangerous to feed! 2779 self._unfeedable_tensors = object_identity.ObjectIdentitySet() 2780 # Set of operations that are dangerous to fetch! 2781 self._unfetchable_ops = set() 2782 # A map of tensor handle placeholder to tensor dtype. 2783 self._handle_feeders = {} 2784 # A map from tensor handle to its read op. 2785 self._handle_readers = {} 2786 # A map from tensor handle to its move op. 2787 self._handle_movers = {} 2788 # A map from tensor handle to its delete op. 2789 self._handle_deleters = {} 2790 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 2791 # will be shared when defining function graphs, for example, so optimizers 2792 # being called inside function definitions behave as if they were seeing the 2793 # actual outside graph). 2794 self._graph_key = "grap-key-%d/" % (uid(),) 2795 # A string with the last reduction method passed to 2796 # losses.compute_weighted_loss(), or None. This is required only for 2797 # backward compatibility with Estimator and optimizer V1 use cases. 2798 self._last_loss_reduction = None 2799 # Flag that is used to indicate whether loss has been scaled by optimizer. 2800 # If this flag has been set, then estimator uses it to scale losss back 2801 # before reporting. This is required only for backward compatibility with 2802 # Estimator and optimizer V1 use cases. 2803 self._is_loss_scaled_by_optimizer = False 2804 self._container = "" 2805 # Set to True if this graph is being built in an 2806 # AutomaticControlDependencies context. 2807 self._add_control_dependencies = False 2808 # Cache for OpDef protobufs retrieved via the C API. 2809 self._op_def_cache = {} 2810 # Cache for constant results of `broadcast_gradient_args()`. The keys are 2811 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the 2812 # values are tuples of reduction indices: (rx, ry). 2813 self._bcast_grad_args_cache = {} 2814 # Cache for constant results of `reduced_shape()`. The keys are pairs of 2815 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values 2816 # are pairs of tuples: (output_shape_kept_dims, tile_scaling). 2817 self._reduced_shape_cache = {} 2818 2819 # TODO(skyewm): fold as much of the above as possible into the C 2820 # implementation 2821 self._scoped_c_graph = c_api_util.ScopedTFGraph() 2822 # The C API requires all ops to have shape functions. Disable this 2823 # requirement (many custom ops do not have shape functions, and we don't 2824 # want to break these existing cases). 2825 pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False) 2826 if tf2.enabled(): 2827 self.switch_to_thread_local() 2828 2829 # Note: this method is private because the API of tf.Graph() is public and 2830 # frozen, and this functionality is still not ready for public visibility. 2831 @tf_contextlib.contextmanager 2832 def _variable_creator_scope(self, creator, priority=100): 2833 """Scope which defines a variable creation function. 2834 2835 Args: 2836 creator: A callable taking `next_creator` and `kwargs`. See the 2837 `tf.variable_creator_scope` docstring. 2838 priority: Creators with a higher `priority` are called first. Within the 2839 same priority, creators are called inner-to-outer. 2840 2841 Yields: 2842 `_variable_creator_scope` is a context manager with a side effect, but 2843 doesn't return a value. 2844 2845 Raises: 2846 RuntimeError: If variable creator scopes are not properly nested. 2847 """ 2848 # This step keeps a reference to the existing stack, and it also initializes 2849 # self._thread_local._variable_creator_stack if it doesn't exist yet. 2850 old = self._variable_creator_stack 2851 new = list(old) 2852 new.append((priority, creator)) 2853 # Sorting is stable, so we'll put higher-priority creators later in the list 2854 # but otherwise maintain registration order. 2855 new.sort(key=lambda item: item[0]) 2856 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access 2857 try: 2858 yield 2859 finally: 2860 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access 2861 raise RuntimeError( 2862 "Exiting variable_creator_scope without proper nesting.") 2863 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 2864 2865 # Note: this method is private because the API of tf.Graph() is public and 2866 # frozen, and this functionality is still not ready for public visibility. 2867 @property 2868 def _variable_creator_stack(self): 2869 if not hasattr(self._thread_local, "_variable_creator_stack"): 2870 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 2871 2872 # This previously returned a copy of the stack instead of the stack itself, 2873 # to guard against accidental mutation. Consider, however, code that wants 2874 # to save and restore the variable creator stack: 2875 # def f(): 2876 # original_stack = graph._variable_creator_stack 2877 # graph._variable_creator_stack = new_stack 2878 # ... # Some code 2879 # graph._variable_creator_stack = original_stack 2880 # 2881 # And lets say you have some code that calls this function with some 2882 # variable_creator: 2883 # def g(): 2884 # with variable_scope.variable_creator_scope(creator): 2885 # f() 2886 # When exiting the variable creator scope, it would see a different stack 2887 # object than it expected leading to a "Exiting variable_creator_scope 2888 # without proper nesting" error. 2889 return self._thread_local._variable_creator_stack # pylint: disable=protected-access 2890 2891 @_variable_creator_stack.setter 2892 def _variable_creator_stack(self, variable_creator_stack): 2893 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 2894 2895 def _check_not_finalized(self): 2896 """Check if the graph is finalized. 2897 2898 Raises: 2899 RuntimeError: If the graph finalized. 2900 """ 2901 if self._finalized: 2902 raise RuntimeError("Graph is finalized and cannot be modified.") 2903 2904 def _add_op(self, op, op_name): 2905 """Adds 'op' to the graph and returns the unique ID for the added Operation. 2906 2907 Args: 2908 op: the Operation to add. 2909 op_name: the name of the Operation. 2910 2911 Returns: 2912 An integer that is a unique ID for the added Operation. 2913 """ 2914 self._check_not_finalized() 2915 with self._lock: 2916 self._next_id_counter += 1 2917 op_id = self._next_id_counter 2918 self._nodes_by_id[op_id] = op 2919 self._nodes_by_name[op_name] = op 2920 self._version = max(self._version, op_id) 2921 return op_id 2922 2923 @property 2924 def _c_graph(self): 2925 if self._scoped_c_graph: 2926 return self._scoped_c_graph.graph 2927 return None 2928 2929 @property 2930 def version(self): 2931 """Returns a version number that increases as ops are added to the graph. 2932 2933 Note that this is unrelated to the 2934 `tf.Graph.graph_def_versions`. 2935 2936 Returns: 2937 An integer version that increases as ops are added to the graph. 2938 """ 2939 if self._finalized: 2940 return self._version 2941 2942 with self._lock: 2943 return self._version 2944 2945 @property 2946 def graph_def_versions(self): 2947 # pylint: disable=line-too-long 2948 """The GraphDef version information of this graph. 2949 2950 For details on the meaning of each version, see 2951 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 2952 2953 Returns: 2954 A `VersionDef`. 2955 """ 2956 # pylint: enable=line-too-long 2957 with c_api_util.tf_buffer() as buf: 2958 pywrap_tf_session.TF_GraphVersions(self._c_graph, buf) 2959 data = pywrap_tf_session.TF_GetBuffer(buf) 2960 version_def = versions_pb2.VersionDef() 2961 version_def.ParseFromString(compat.as_bytes(data)) 2962 return version_def 2963 2964 @property 2965 def seed(self): 2966 """The graph-level random seed of this graph.""" 2967 return self._seed 2968 2969 @seed.setter 2970 def seed(self, seed): 2971 self._seed = seed 2972 2973 @property 2974 def finalized(self): 2975 """True if this graph has been finalized.""" 2976 return self._finalized 2977 2978 def finalize(self): 2979 """Finalizes this graph, making it read-only. 2980 2981 After calling `g.finalize()`, no new operations can be added to 2982 `g`. This method is used to ensure that no operations are added 2983 to a graph when it is shared between multiple threads, for example 2984 when using a `tf.compat.v1.train.QueueRunner`. 2985 """ 2986 self._finalized = True 2987 2988 def _unsafe_unfinalize(self): 2989 """Opposite of `finalize`. 2990 2991 Internal interface. 2992 2993 NOTE: Unfinalizing a graph could have negative impact on performance, 2994 especially in a multi-threaded environment. Unfinalizing a graph 2995 when it is in use by a Session may lead to undefined behavior. Ensure 2996 that all sessions using a graph are closed before calling this method. 2997 """ 2998 self._finalized = False 2999 3000 def _get_control_flow_context(self): 3001 """Returns the current control flow context. 3002 3003 Returns: 3004 A context object. 3005 """ 3006 return self._control_flow_context 3007 3008 def _set_control_flow_context(self, ctx): 3009 """Sets the current control flow context. 3010 3011 Args: 3012 ctx: a context object. 3013 """ 3014 self._control_flow_context = ctx 3015 3016 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 3017 """If this graph contains functions, copy them to `graph_def`.""" 3018 bytesize = starting_bytesize 3019 for f in self._functions.values(): 3020 bytesize += f.definition.ByteSize() 3021 if bytesize >= (1 << 31) or bytesize < 0: 3022 raise ValueError("GraphDef cannot be larger than 2GB.") 3023 graph_def.library.function.extend([f.definition]) 3024 if f.grad_func_name: 3025 grad_def = function_pb2.GradientDef() 3026 grad_def.function_name = f.name 3027 grad_def.gradient_func = f.grad_func_name 3028 graph_def.library.gradient.extend([grad_def]) 3029 3030 def _as_graph_def(self, from_version=None, add_shapes=False): 3031 # pylint: disable=line-too-long 3032 """Returns a serialized `GraphDef` representation of this graph. 3033 3034 The serialized `GraphDef` can be imported into another `Graph` 3035 (using `tf.import_graph_def`) or used with the 3036 [C++ Session API](../../../../api_docs/cc/index.md). 3037 3038 This method is thread-safe. 3039 3040 Args: 3041 from_version: Optional. If this is set, returns a `GraphDef` containing 3042 only the nodes that were added to this graph since its `version` 3043 property had the given value. 3044 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3045 the inferred shapes of each of its outputs. 3046 3047 Returns: 3048 A tuple containing a 3049 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3050 protocol buffer, and the version of the graph to which that 3051 `GraphDef` corresponds. 3052 3053 Raises: 3054 ValueError: If the `graph_def` would be too large. 3055 3056 """ 3057 # pylint: enable=line-too-long 3058 with self._lock: 3059 with c_api_util.tf_buffer() as buf: 3060 pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf) 3061 data = pywrap_tf_session.TF_GetBuffer(buf) 3062 graph = graph_pb2.GraphDef() 3063 graph.ParseFromString(compat.as_bytes(data)) 3064 # Strip the experimental library field iff it's empty. 3065 if not graph.library.function: 3066 graph.ClearField("library") 3067 3068 if add_shapes: 3069 for node in graph.node: 3070 op = self._nodes_by_name[node.name] 3071 if op.outputs: 3072 node.attr["_output_shapes"].list.shape.extend( 3073 [output.get_shape().as_proto() for output in op.outputs]) 3074 for function_def in graph.library.function: 3075 defined_function = self._functions[function_def.signature.name] 3076 try: 3077 func_graph = defined_function.graph 3078 except AttributeError: 3079 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction 3080 # does. Both rely on ops.py, so we can't really isinstance check 3081 # them. 3082 continue 3083 input_shapes = function_def.attr["_input_shapes"] 3084 try: 3085 func_graph_inputs = func_graph.inputs 3086 except AttributeError: 3087 continue 3088 # TODO(b/141471245): Fix the inconsistency when inputs of func graph 3089 # are appended during gradient computation of while/cond. 3090 for input_tensor, _ in zip(func_graph_inputs, 3091 function_def.signature.input_arg): 3092 if input_tensor.dtype == dtypes.resource: 3093 # TODO(allenl): Save and restore handle data, then save the 3094 # resource placeholder's shape. Right now some shape functions get 3095 # confused if we set the shape of the resource placeholder (to a 3096 # scalar of course) and there isn't any handle data. 3097 input_shapes.list.shape.add().CopyFrom( 3098 tensor_shape.TensorShape(None).as_proto()) 3099 else: 3100 input_shapes.list.shape.add().CopyFrom( 3101 input_tensor.get_shape().as_proto()) 3102 for node in function_def.node_def: 3103 try: 3104 op = func_graph.get_operation_by_name(node.name) 3105 except KeyError: 3106 continue 3107 outputs = op.outputs 3108 3109 if op.type == "StatefulPartitionedCall": 3110 # Filter out any extra outputs (possibly added by function 3111 # backpropagation rewriting). 3112 num_outputs = len(node.attr["Tout"].list.type) 3113 outputs = outputs[:num_outputs] 3114 3115 node.attr["_output_shapes"].list.shape.extend( 3116 [output.get_shape().as_proto() for output in outputs]) 3117 3118 return graph, self._version 3119 3120 def as_graph_def(self, from_version=None, add_shapes=False): 3121 # pylint: disable=line-too-long 3122 """Returns a serialized `GraphDef` representation of this graph. 3123 3124 The serialized `GraphDef` can be imported into another `Graph` 3125 (using `tf.import_graph_def`) or used with the 3126 [C++ Session API](../../api_docs/cc/index.md). 3127 3128 This method is thread-safe. 3129 3130 Args: 3131 from_version: Optional. If this is set, returns a `GraphDef` containing 3132 only the nodes that were added to this graph since its `version` 3133 property had the given value. 3134 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3135 the inferred shapes of each of its outputs. 3136 3137 Returns: 3138 A 3139 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3140 protocol buffer. 3141 3142 Raises: 3143 ValueError: If the `graph_def` would be too large. 3144 """ 3145 # pylint: enable=line-too-long 3146 result, _ = self._as_graph_def(from_version, add_shapes) 3147 return result 3148 3149 def _is_function(self, name): 3150 """Tests whether 'name' is registered in this graph's function library. 3151 3152 Args: 3153 name: string op name. 3154 3155 Returns: 3156 bool indicating whether or not 'name' is registered in function library. 3157 """ 3158 return compat.as_str(name) in self._functions 3159 3160 def _get_function(self, name): 3161 """Returns the function definition for 'name'. 3162 3163 Args: 3164 name: string function name. 3165 3166 Returns: 3167 The function def proto. 3168 """ 3169 return self._functions.get(compat.as_str(name), None) 3170 3171 def _add_function(self, function): 3172 """Adds a function to the graph. 3173 3174 After the function has been added, you can call to the function by 3175 passing the function name in place of an op name to 3176 `Graph.create_op()`. 3177 3178 Args: 3179 function: A `_DefinedFunction` object. 3180 3181 Raises: 3182 ValueError: if another function is defined with the same name. 3183 """ 3184 name = function.name 3185 # Sanity checks on gradient definition. 3186 if (function.grad_func_name is not None) and (function.python_grad_func is 3187 not None): 3188 raise ValueError("Gradient defined twice for function %s" % name) 3189 3190 # Add function to graph 3191 # pylint: disable=protected-access 3192 gradient = ( 3193 function._grad_func._c_func.func if function._grad_func else None) 3194 pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func, 3195 gradient) 3196 # pylint: enable=protected-access 3197 3198 self._functions[compat.as_str(name)] = function 3199 3200 # Need a new-enough consumer to support the functions we add to the graph. 3201 if self._graph_def_versions.min_consumer < 12: 3202 self._graph_def_versions.min_consumer = 12 3203 3204 @property 3205 def building_function(self): 3206 """Returns True iff this graph represents a function.""" 3207 return self._building_function 3208 3209 # Helper functions to create operations. 3210 @deprecated_args(None, 3211 "Shapes are always computed; don't use the compute_shapes " 3212 "as it has no effect.", "compute_shapes") 3213 def create_op( 3214 self, 3215 op_type, 3216 inputs, 3217 dtypes=None, # pylint: disable=redefined-outer-name 3218 input_types=None, 3219 name=None, 3220 attrs=None, 3221 op_def=None, 3222 compute_shapes=True, 3223 compute_device=True): 3224 """Creates an `Operation` in this graph. 3225 3226 This is a low-level interface for creating an `Operation`. Most 3227 programs will not call this method directly, and instead use the 3228 Python op constructors, such as `tf.constant()`, which add ops to 3229 the default graph. 3230 3231 Args: 3232 op_type: The `Operation` type to create. This corresponds to the 3233 `OpDef.name` field for the proto that defines the operation. 3234 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3235 dtypes: (Optional) A list of `DType` objects that will be the types of the 3236 tensors that the operation produces. 3237 input_types: (Optional.) A list of `DType`s that will be the types of the 3238 tensors that the operation consumes. By default, uses the base `DType` 3239 of each input in `inputs`. Operations that expect reference-typed inputs 3240 must specify `input_types` explicitly. 3241 name: (Optional.) A string name for the operation. If not specified, a 3242 name is generated based on `op_type`. 3243 attrs: (Optional.) A dictionary where the key is the attribute name (a 3244 string) and the value is the respective `attr` attribute of the 3245 `NodeDef` proto that will represent the operation (an `AttrValue` 3246 proto). 3247 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3248 the operation will have. 3249 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 3250 computed). 3251 compute_device: (Optional.) If True, device functions will be executed to 3252 compute the device property of the Operation. 3253 3254 Raises: 3255 TypeError: if any of the inputs is not a `Tensor`. 3256 ValueError: if colocation conflicts with existing device assignment. 3257 3258 Returns: 3259 An `Operation` object. 3260 """ 3261 del compute_shapes 3262 for idx, a in enumerate(inputs): 3263 if not isinstance(a, Tensor): 3264 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3265 return self._create_op_internal(op_type, inputs, dtypes, input_types, name, 3266 attrs, op_def, compute_device) 3267 3268 def _create_op_internal( 3269 self, 3270 op_type, 3271 inputs, 3272 dtypes=None, # pylint: disable=redefined-outer-name 3273 input_types=None, 3274 name=None, 3275 attrs=None, 3276 op_def=None, 3277 compute_device=True): 3278 """Creates an `Operation` in this graph. 3279 3280 Implements `Graph.create_op()` without the overhead of the deprecation 3281 wrapper. 3282 3283 Args: 3284 op_type: The `Operation` type to create. This corresponds to the 3285 `OpDef.name` field for the proto that defines the operation. 3286 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3287 dtypes: (Optional) A list of `DType` objects that will be the types of the 3288 tensors that the operation produces. 3289 input_types: (Optional.) A list of `DType`s that will be the types of the 3290 tensors that the operation consumes. By default, uses the base `DType` 3291 of each input in `inputs`. Operations that expect reference-typed inputs 3292 must specify `input_types` explicitly. 3293 name: (Optional.) A string name for the operation. If not specified, a 3294 name is generated based on `op_type`. 3295 attrs: (Optional.) A dictionary where the key is the attribute name (a 3296 string) and the value is the respective `attr` attribute of the 3297 `NodeDef` proto that will represent the operation (an `AttrValue` 3298 proto). 3299 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3300 the operation will have. 3301 compute_device: (Optional.) If True, device functions will be executed to 3302 compute the device property of the Operation. 3303 3304 Raises: 3305 ValueError: if colocation conflicts with existing device assignment. 3306 3307 Returns: 3308 An `Operation` object. 3309 """ 3310 self._check_not_finalized() 3311 if name is None: 3312 name = op_type 3313 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3314 # after removing the trailing '/'. 3315 if name and name[-1] == "/": 3316 name = name_from_scope_name(name) 3317 else: 3318 name = self.unique_name(name) 3319 3320 node_def = _NodeDef(op_type, name, attrs) 3321 3322 input_ops = set(t.op for t in inputs) 3323 control_inputs = self._control_dependencies_for_inputs(input_ops) 3324 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 3325 # Session.run call cannot occur between creating and mutating the op. 3326 with self._mutation_lock(): 3327 ret = Operation( 3328 node_def, 3329 self, 3330 inputs=inputs, 3331 output_types=dtypes, 3332 control_inputs=control_inputs, 3333 input_types=input_types, 3334 original_op=self._default_original_op, 3335 op_def=op_def) 3336 self._create_op_helper(ret, compute_device=compute_device) 3337 return ret 3338 3339 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3340 """Creates an `Operation` in this graph from the supplied TF_Operation. 3341 3342 This method is like create_op() except the new Operation is constructed 3343 using `c_op`. The returned Operation will have `c_op` as its _c_op 3344 field. This is used to create Operation objects around TF_Operations created 3345 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3346 3347 This function does not call Operation._control_flow_post_processing or 3348 Graph._control_dependencies_for_inputs (since the inputs may not be 3349 available yet). The caller is responsible for calling these methods. 3350 3351 Args: 3352 c_op: a wrapped TF_Operation 3353 compute_device: (Optional.) If True, device functions will be executed to 3354 compute the device property of the Operation. 3355 3356 Returns: 3357 An `Operation` object. 3358 """ 3359 self._check_not_finalized() 3360 ret = Operation(c_op, self) 3361 # If a name_scope was created with ret.name but no nodes were created in it, 3362 # the name will still appear in _names_in_use even though the name hasn't 3363 # been used. This is ok, just leave _names_in_use as-is in this case. 3364 # TODO(skyewm): make the C API guarantee no name conflicts. 3365 name_key = ret.name.lower() 3366 if name_key not in self._names_in_use: 3367 self._names_in_use[name_key] = 1 3368 self._create_op_helper(ret, compute_device=compute_device) 3369 return ret 3370 3371 def _create_op_helper(self, op, compute_device=True): 3372 """Common logic for creating an op in this graph.""" 3373 # Apply any additional attributes requested. Do not overwrite any existing 3374 # attributes. 3375 for key, value in self._attr_scope_map.items(): 3376 try: 3377 op.get_attr(key) 3378 except ValueError: 3379 if callable(value): 3380 value = value(op.node_def) 3381 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3382 raise TypeError( 3383 "Callable for scope map key '%s' must return either None or " 3384 "an AttrValue protocol buffer; but it returned: %s" % 3385 (key, value)) 3386 if value: 3387 op._set_attr(key, value) # pylint: disable=protected-access 3388 3389 # Apply a kernel label if one has been specified for this op type. 3390 try: 3391 kernel_label = self._op_to_kernel_label_map[op.type] 3392 op._set_attr("_kernel", # pylint: disable=protected-access 3393 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3394 except KeyError: 3395 pass 3396 3397 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access 3398 3399 # Apply the overriding op type for gradients if one has been specified for 3400 # this op type. 3401 try: 3402 mapped_op_type = self._gradient_override_map[op.type] 3403 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3404 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3405 except KeyError: 3406 pass 3407 3408 self._record_op_seen_by_control_dependencies(op) 3409 3410 if compute_device: 3411 self._apply_device_functions(op) 3412 3413 # Snapshot the colocation stack metadata before we might generate error 3414 # messages using it. Note that this snapshot depends on the actual stack 3415 # and is independent of the op's _class attribute. 3416 # pylint: disable=protected-access 3417 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 3418 # pylint: enable=protected-access 3419 3420 if self._colocation_stack: 3421 all_colocation_groups = [] 3422 for colocation_op in self._colocation_stack.peek_objs(): 3423 all_colocation_groups.extend(colocation_op.colocation_groups()) 3424 if colocation_op.device: 3425 # pylint: disable=protected-access 3426 op._set_device(colocation_op.device) 3427 # pylint: enable=protected-access 3428 3429 all_colocation_groups = sorted(set(all_colocation_groups)) 3430 # pylint: disable=protected-access 3431 op._set_attr( 3432 "_class", 3433 attr_value_pb2.AttrValue( 3434 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3435 # pylint: enable=protected-access 3436 3437 # Sets "container" attribute if 3438 # (1) self._container is not None 3439 # (2) "is_stateful" is set in OpDef 3440 # (3) "container" attribute is in OpDef 3441 # (4) "container" attribute is None 3442 if self._container and op._is_stateful: # pylint: disable=protected-access 3443 try: 3444 container_attr = op.get_attr("container") 3445 except ValueError: 3446 # "container" attribute is not in OpDef 3447 pass 3448 else: 3449 if not container_attr: 3450 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3451 s=compat.as_bytes(self._container))) 3452 3453 def _add_new_tf_operations(self, compute_devices=True): 3454 """Creates `Operations` in this graph for any new TF_Operations. 3455 3456 This is useful for when TF_Operations are indirectly created by the C API 3457 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3458 TF_FinishWhile). This ensures there are corresponding Operations for all 3459 TF_Operations in the underlying TF_Graph. 3460 3461 Args: 3462 compute_devices: (Optional.) If True, device functions will be executed to 3463 compute the device properties of each new Operation. 3464 3465 Returns: 3466 A list of the new `Operation` objects. 3467 """ 3468 # Create all Operation objects before accessing their inputs since an op may 3469 # be created before its inputs. 3470 new_ops = [ 3471 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3472 for c_op in c_api_util.new_tf_operations(self) 3473 ] 3474 3475 # pylint: disable=protected-access 3476 for op in new_ops: 3477 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3478 op._add_control_inputs(new_control_inputs) 3479 op._control_flow_post_processing() 3480 # pylint: enable=protected-access 3481 3482 return new_ops 3483 3484 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3485 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3486 3487 This function validates that `obj` represents an element of this 3488 graph, and gives an informative error message if it is not. 3489 3490 This function is the canonical way to get/validate an object of 3491 one of the allowed types from an external argument reference in the 3492 Session API. 3493 3494 This method may be called concurrently from multiple threads. 3495 3496 Args: 3497 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can 3498 also be any object with an `_as_graph_element()` method that returns a 3499 value of one of these types. Note: `_as_graph_element` will be called 3500 inside the graph's lock and so may not modify the graph. 3501 allow_tensor: If true, `obj` may refer to a `Tensor`. 3502 allow_operation: If true, `obj` may refer to an `Operation`. 3503 3504 Returns: 3505 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3506 3507 Raises: 3508 TypeError: If `obj` is not a type we support attempting to convert 3509 to types. 3510 ValueError: If `obj` is of an appropriate type but invalid. For 3511 example, an invalid string. 3512 KeyError: If `obj` is not an object in the graph. 3513 """ 3514 if self._finalized: 3515 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3516 3517 with self._lock: 3518 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3519 3520 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3521 """See `Graph.as_graph_element()` for details.""" 3522 # The vast majority of this function is figuring 3523 # out what an API user might be doing wrong, so 3524 # that we can give helpful error messages. 3525 # 3526 # Ideally, it would be nice to split it up, but we 3527 # need context to generate nice error messages. 3528 3529 if allow_tensor and allow_operation: 3530 types_str = "Tensor or Operation" 3531 elif allow_tensor: 3532 types_str = "Tensor" 3533 elif allow_operation: 3534 types_str = "Operation" 3535 else: 3536 raise ValueError("allow_tensor and allow_operation can't both be False.") 3537 3538 temp_obj = _as_graph_element(obj) 3539 if temp_obj is not None: 3540 obj = temp_obj 3541 3542 # If obj appears to be a name... 3543 if isinstance(obj, compat.bytes_or_text_types): 3544 name = compat.as_str(obj) 3545 3546 if ":" in name and allow_tensor: 3547 # Looks like a Tensor name and can be a Tensor. 3548 try: 3549 op_name, out_n = name.split(":") 3550 out_n = int(out_n) 3551 except: 3552 raise ValueError("The name %s looks a like a Tensor name, but is " 3553 "not a valid one. Tensor names must be of the " 3554 "form \"<op_name>:<output_index>\"." % repr(name)) 3555 if op_name in self._nodes_by_name: 3556 op = self._nodes_by_name[op_name] 3557 else: 3558 raise KeyError("The name %s refers to a Tensor which does not " 3559 "exist. The operation, %s, does not exist in the " 3560 "graph." % (repr(name), repr(op_name))) 3561 try: 3562 return op.outputs[out_n] 3563 except: 3564 raise KeyError("The name %s refers to a Tensor which does not " 3565 "exist. The operation, %s, exists but only has " 3566 "%s outputs." % 3567 (repr(name), repr(op_name), len(op.outputs))) 3568 3569 elif ":" in name and not allow_tensor: 3570 # Looks like a Tensor name but can't be a Tensor. 3571 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 3572 (repr(name), types_str)) 3573 3574 elif ":" not in name and allow_operation: 3575 # Looks like an Operation name and can be an Operation. 3576 if name not in self._nodes_by_name: 3577 raise KeyError("The name %s refers to an Operation not in the " 3578 "graph." % repr(name)) 3579 return self._nodes_by_name[name] 3580 3581 elif ":" not in name and not allow_operation: 3582 # Looks like an Operation name but can't be an Operation. 3583 if name in self._nodes_by_name: 3584 # Yep, it's an Operation name 3585 err_msg = ("The name %s refers to an Operation, not a %s." % 3586 (repr(name), types_str)) 3587 else: 3588 err_msg = ("The name %s looks like an (invalid) Operation name, " 3589 "not a %s." % (repr(name), types_str)) 3590 err_msg += (" Tensor names must be of the form " 3591 "\"<op_name>:<output_index>\".") 3592 raise ValueError(err_msg) 3593 3594 elif isinstance(obj, Tensor) and allow_tensor: 3595 # Actually obj is just the object it's referring to. 3596 if obj.graph is not self: 3597 raise ValueError("Tensor %s is not an element of this graph." % obj) 3598 return obj 3599 elif isinstance(obj, Operation) and allow_operation: 3600 # Actually obj is just the object it's referring to. 3601 if obj.graph is not self: 3602 raise ValueError("Operation %s is not an element of this graph." % obj) 3603 return obj 3604 else: 3605 # We give up! 3606 raise TypeError("Can not convert a %s into a %s." % 3607 (type(obj).__name__, types_str)) 3608 3609 def get_operations(self): 3610 """Return the list of operations in the graph. 3611 3612 You can modify the operations in place, but modifications 3613 to the list such as inserts/delete have no effect on the 3614 list of operations known to the graph. 3615 3616 This method may be called concurrently from multiple threads. 3617 3618 Returns: 3619 A list of Operations. 3620 """ 3621 if self._finalized: 3622 return list(self._nodes_by_id.values()) 3623 3624 with self._lock: 3625 return list(self._nodes_by_id.values()) 3626 3627 def get_operation_by_name(self, name): 3628 """Returns the `Operation` with the given `name`. 3629 3630 This method may be called concurrently from multiple threads. 3631 3632 Args: 3633 name: The name of the `Operation` to return. 3634 3635 Returns: 3636 The `Operation` with the given `name`. 3637 3638 Raises: 3639 TypeError: If `name` is not a string. 3640 KeyError: If `name` does not correspond to an operation in this graph. 3641 """ 3642 3643 if not isinstance(name, six.string_types): 3644 raise TypeError("Operation names are strings (or similar), not %s." % 3645 type(name).__name__) 3646 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 3647 3648 def _get_operation_by_name_unsafe(self, name): 3649 """Returns the `Operation` with the given `name`. 3650 3651 This is a internal unsafe version of get_operation_by_name. It skips many 3652 checks and does not have user friedly error messages but runs considerably 3653 faster. This method may be called concurrently from multiple threads. 3654 3655 Args: 3656 name: The name of the `Operation` to return. 3657 3658 Returns: 3659 The `Operation` with the given `name`. 3660 3661 Raises: 3662 KeyError: If `name` does not correspond to an operation in this graph. 3663 """ 3664 3665 if self._finalized: 3666 return self._nodes_by_name[name] 3667 3668 with self._lock: 3669 return self._nodes_by_name[name] 3670 3671 def _get_operation_by_tf_operation(self, tf_oper): 3672 op_name = pywrap_tf_session.TF_OperationName(tf_oper) 3673 return self._get_operation_by_name_unsafe(op_name) 3674 3675 def get_tensor_by_name(self, name): 3676 """Returns the `Tensor` with the given `name`. 3677 3678 This method may be called concurrently from multiple threads. 3679 3680 Args: 3681 name: The name of the `Tensor` to return. 3682 3683 Returns: 3684 The `Tensor` with the given `name`. 3685 3686 Raises: 3687 TypeError: If `name` is not a string. 3688 KeyError: If `name` does not correspond to a tensor in this graph. 3689 """ 3690 # Names should be strings. 3691 if not isinstance(name, six.string_types): 3692 raise TypeError("Tensor names are strings (or similar), not %s." % 3693 type(name).__name__) 3694 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 3695 3696 def _get_tensor_by_tf_output(self, tf_output): 3697 """Returns the `Tensor` representing `tf_output`. 3698 3699 Note that there is only one such `Tensor`, i.e. multiple calls to this 3700 function with the same TF_Output value will always return the same `Tensor` 3701 object. 3702 3703 Args: 3704 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 3705 3706 Returns: 3707 The `Tensor` that represents `tf_output`. 3708 """ 3709 op = self._get_operation_by_tf_operation(tf_output.oper) 3710 return op.outputs[tf_output.index] 3711 3712 @property 3713 def _last_id(self): 3714 return self._next_id_counter 3715 3716 def _get_op_def(self, type): # pylint: disable=redefined-builtin 3717 """Returns the `OpDef` proto for `type`. `type` is a string.""" 3718 # NOTE: No locking is required because the lookup and insertion operations 3719 # on Python dictionaries are atomic. 3720 try: 3721 return self._op_def_cache[type] 3722 except KeyError: 3723 with c_api_util.tf_buffer() as buf: 3724 # pylint: disable=protected-access 3725 pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), 3726 buf) 3727 # pylint: enable=protected-access 3728 data = pywrap_tf_session.TF_GetBuffer(buf) 3729 op_def = op_def_pb2.OpDef() 3730 op_def.ParseFromString(compat.as_bytes(data)) 3731 self._op_def_cache[type] = op_def 3732 return op_def 3733 3734 def as_default(self): 3735 """Returns a context manager that makes this `Graph` the default graph. 3736 3737 This method should be used if you want to create multiple graphs 3738 in the same process. For convenience, a global default graph is 3739 provided, and all ops will be added to this graph if you do not 3740 create a new graph explicitly. 3741 3742 Use this method with the `with` keyword to specify that ops created within 3743 the scope of a block should be added to this graph. In this case, once 3744 the scope of the `with` is exited, the previous default graph is set again 3745 as default. There is a stack, so it's ok to have multiple nested levels 3746 of `as_default` calls. 3747 3748 The default graph is a property of the current thread. If you 3749 create a new thread, and wish to use the default graph in that 3750 thread, you must explicitly add a `with g.as_default():` in that 3751 thread's function. 3752 3753 The following code examples are equivalent: 3754 3755 ```python 3756 # 1. Using Graph.as_default(): 3757 g = tf.Graph() 3758 with g.as_default(): 3759 c = tf.constant(5.0) 3760 assert c.graph is g 3761 3762 # 2. Constructing and making default: 3763 with tf.Graph().as_default() as g: 3764 c = tf.constant(5.0) 3765 assert c.graph is g 3766 ``` 3767 3768 If eager execution is enabled ops created under this context manager will be 3769 added to the graph instead of executed eagerly. 3770 3771 Returns: 3772 A context manager for using this graph as the default graph. 3773 """ 3774 return _default_graph_stack.get_controller(self) 3775 3776 @property 3777 def collections(self): 3778 """Returns the names of the collections known to this graph.""" 3779 return list(self._collections) 3780 3781 def add_to_collection(self, name, value): 3782 """Stores `value` in the collection with the given `name`. 3783 3784 Note that collections are not sets, so it is possible to add a value to 3785 a collection several times. 3786 3787 Args: 3788 name: The key for the collection. The `GraphKeys` class contains many 3789 standard names for collections. 3790 value: The value to add to the collection. 3791 """ # pylint: disable=g-doc-exception 3792 self._check_not_finalized() 3793 with self._lock: 3794 if name not in self._collections: 3795 self._collections[name] = [value] 3796 else: 3797 self._collections[name].append(value) 3798 3799 def add_to_collections(self, names, value): 3800 """Stores `value` in the collections given by `names`. 3801 3802 Note that collections are not sets, so it is possible to add a value to 3803 a collection several times. This function makes sure that duplicates in 3804 `names` are ignored, but it will not check for pre-existing membership of 3805 `value` in any of the collections in `names`. 3806 3807 `names` can be any iterable, but if `names` is a string, it is treated as a 3808 single collection name. 3809 3810 Args: 3811 names: The keys for the collections to add to. The `GraphKeys` class 3812 contains many standard names for collections. 3813 value: The value to add to the collections. 3814 """ 3815 # Make sure names are unique, but treat strings as a single collection name 3816 names = (names,) if isinstance(names, six.string_types) else set(names) 3817 for name in names: 3818 self.add_to_collection(name, value) 3819 3820 def get_collection_ref(self, name): 3821 """Returns a list of values in the collection with the given `name`. 3822 3823 If the collection exists, this returns the list itself, which can 3824 be modified in place to change the collection. If the collection does 3825 not exist, it is created as an empty list and the list is returned. 3826 3827 This is different from `get_collection()` which always returns a copy of 3828 the collection list if it exists and never creates an empty collection. 3829 3830 Args: 3831 name: The key for the collection. For example, the `GraphKeys` class 3832 contains many standard names for collections. 3833 3834 Returns: 3835 The list of values in the collection with the given `name`, or an empty 3836 list if no value has been added to that collection. 3837 """ # pylint: disable=g-doc-exception 3838 with self._lock: 3839 coll_list = self._collections.get(name, None) 3840 if coll_list is None: 3841 coll_list = [] 3842 self._collections[name] = coll_list 3843 return coll_list 3844 3845 def get_collection(self, name, scope=None): 3846 """Returns a list of values in the collection with the given `name`. 3847 3848 This is different from `get_collection_ref()` which always returns the 3849 actual collection list if it exists in that it returns a new list each time 3850 it is called. 3851 3852 Args: 3853 name: The key for the collection. For example, the `GraphKeys` class 3854 contains many standard names for collections. 3855 scope: (Optional.) A string. If supplied, the resulting list is filtered 3856 to include only items whose `name` attribute matches `scope` using 3857 `re.match`. Items without a `name` attribute are never returned if a 3858 scope is supplied. The choice of `re.match` means that a `scope` without 3859 special tokens filters by prefix. 3860 3861 Returns: 3862 The list of values in the collection with the given `name`, or 3863 an empty list if no value has been added to that collection. The 3864 list contains the values in the order under which they were 3865 collected. 3866 """ # pylint: disable=g-doc-exception 3867 with self._lock: 3868 collection = self._collections.get(name, None) 3869 if collection is None: 3870 return [] 3871 if scope is None: 3872 return list(collection) 3873 else: 3874 c = [] 3875 regex = re.compile(scope) 3876 for item in collection: 3877 try: 3878 if regex.match(item.name): 3879 c.append(item) 3880 except AttributeError: 3881 # Collection items with no name are ignored. 3882 pass 3883 return c 3884 3885 def get_all_collection_keys(self): 3886 """Returns a list of collections used in this graph.""" 3887 with self._lock: 3888 return [x for x in self._collections if isinstance(x, six.string_types)] 3889 3890 def clear_collection(self, name): 3891 """Clears all values in a collection. 3892 3893 Args: 3894 name: The key for the collection. The `GraphKeys` class contains many 3895 standard names for collections. 3896 """ 3897 self._check_not_finalized() 3898 with self._lock: 3899 if name in self._collections: 3900 del self._collections[name] 3901 3902 @tf_contextlib.contextmanager 3903 def _original_op(self, op): 3904 """Python 'with' handler to help annotate ops with their originator. 3905 3906 An op may have an 'original_op' property that indicates the op on which 3907 it was based. For example a replica op is based on the op that was 3908 replicated and a gradient op is based on the op that was differentiated. 3909 3910 All ops created in the scope of this 'with' handler will have 3911 the given 'op' as their original op. 3912 3913 Args: 3914 op: The Operation that all ops created in this scope will have as their 3915 original op. 3916 3917 Yields: 3918 Nothing. 3919 """ 3920 old_original_op = self._default_original_op 3921 self._default_original_op = op 3922 try: 3923 yield 3924 finally: 3925 self._default_original_op = old_original_op 3926 3927 @property 3928 def _name_stack(self): 3929 # This may be called from a thread where name_stack doesn't yet exist. 3930 if not hasattr(self._thread_local, "_name_stack"): 3931 self._thread_local._name_stack = "" 3932 return self._thread_local._name_stack 3933 3934 @_name_stack.setter 3935 def _name_stack(self, name_stack): 3936 self._thread_local._name_stack = name_stack 3937 3938 # pylint: disable=g-doc-return-or-yield,line-too-long 3939 @tf_contextlib.contextmanager 3940 def name_scope(self, name): 3941 """Returns a context manager that creates hierarchical names for operations. 3942 3943 A graph maintains a stack of name scopes. A `with name_scope(...):` 3944 statement pushes a new name onto the stack for the lifetime of the context. 3945 3946 The `name` argument will be interpreted as follows: 3947 3948 * A string (not ending with '/') will create a new name scope, in which 3949 `name` is appended to the prefix of all operations created in the 3950 context. If `name` has been used before, it will be made unique by 3951 calling `self.unique_name(name)`. 3952 * A scope previously captured from a `with g.name_scope(...) as 3953 scope:` statement will be treated as an "absolute" name scope, which 3954 makes it possible to re-enter existing scopes. 3955 * A value of `None` or the empty string will reset the current name scope 3956 to the top-level (empty) name scope. 3957 3958 For example: 3959 3960 ```python 3961 with tf.Graph().as_default() as g: 3962 c = tf.constant(5.0, name="c") 3963 assert c.op.name == "c" 3964 c_1 = tf.constant(6.0, name="c") 3965 assert c_1.op.name == "c_1" 3966 3967 # Creates a scope called "nested" 3968 with g.name_scope("nested") as scope: 3969 nested_c = tf.constant(10.0, name="c") 3970 assert nested_c.op.name == "nested/c" 3971 3972 # Creates a nested scope called "inner". 3973 with g.name_scope("inner"): 3974 nested_inner_c = tf.constant(20.0, name="c") 3975 assert nested_inner_c.op.name == "nested/inner/c" 3976 3977 # Create a nested scope called "inner_1". 3978 with g.name_scope("inner"): 3979 nested_inner_1_c = tf.constant(30.0, name="c") 3980 assert nested_inner_1_c.op.name == "nested/inner_1/c" 3981 3982 # Treats `scope` as an absolute name scope, and 3983 # switches to the "nested/" scope. 3984 with g.name_scope(scope): 3985 nested_d = tf.constant(40.0, name="d") 3986 assert nested_d.op.name == "nested/d" 3987 3988 with g.name_scope(""): 3989 e = tf.constant(50.0, name="e") 3990 assert e.op.name == "e" 3991 ``` 3992 3993 The name of the scope itself can be captured by `with 3994 g.name_scope(...) as scope:`, which stores the name of the scope 3995 in the variable `scope`. This value can be used to name an 3996 operation that represents the overall result of executing the ops 3997 in a scope. For example: 3998 3999 ```python 4000 inputs = tf.constant(...) 4001 with g.name_scope('my_layer') as scope: 4002 weights = tf.Variable(..., name="weights") 4003 biases = tf.Variable(..., name="biases") 4004 affine = tf.matmul(inputs, weights) + biases 4005 output = tf.nn.relu(affine, name=scope) 4006 ``` 4007 4008 NOTE: This constructor validates the given `name`. Valid scope 4009 names match one of the following regular expressions: 4010 4011 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 4012 [A-Za-z0-9_.\\-/]* (for other scopes) 4013 4014 Args: 4015 name: A name for the scope. 4016 4017 Returns: 4018 A context manager that installs `name` as a new name scope. 4019 4020 Raises: 4021 ValueError: If `name` is not a valid scope name, according to the rules 4022 above. 4023 """ 4024 if name: 4025 if isinstance(name, compat.bytes_or_text_types): 4026 name = compat.as_str(name) 4027 4028 if self._name_stack: 4029 # Scopes created in a nested scope may have initial characters 4030 # that are illegal as the initial character of an op name 4031 # (viz. '-', '\', '/', and '_'). 4032 if not _VALID_SCOPE_NAME_REGEX.match(name): 4033 raise ValueError("'%s' is not a valid scope name" % name) 4034 else: 4035 # Scopes created in the root must match the more restrictive 4036 # op name regex, which constrains the initial character. 4037 if not _VALID_OP_NAME_REGEX.match(name): 4038 raise ValueError("'%s' is not a valid scope name" % name) 4039 old_stack = self._name_stack 4040 if not name: # Both for name=None and name="" we re-set to empty scope. 4041 new_stack = None 4042 elif name[-1] == "/": 4043 new_stack = name_from_scope_name(name) 4044 else: 4045 new_stack = self.unique_name(name) 4046 self._name_stack = new_stack 4047 try: 4048 yield "" if new_stack is None else new_stack + "/" 4049 finally: 4050 self._name_stack = old_stack 4051 4052 # pylint: enable=g-doc-return-or-yield,line-too-long 4053 4054 def unique_name(self, name, mark_as_used=True): 4055 """Return a unique operation name for `name`. 4056 4057 Note: You rarely need to call `unique_name()` directly. Most of 4058 the time you just need to create `with g.name_scope()` blocks to 4059 generate structured names. 4060 4061 `unique_name` is used to generate structured names, separated by 4062 `"/"`, to help identify operations when debugging a graph. 4063 Operation names are displayed in error messages reported by the 4064 TensorFlow runtime, and in various visualization tools such as 4065 TensorBoard. 4066 4067 If `mark_as_used` is set to `True`, which is the default, a new 4068 unique name is created and marked as in use. If it's set to `False`, 4069 the unique name is returned without actually being marked as used. 4070 This is useful when the caller simply wants to know what the name 4071 to be created will be. 4072 4073 Args: 4074 name: The name for an operation. 4075 mark_as_used: Whether to mark this name as being used. 4076 4077 Returns: 4078 A string to be passed to `create_op()` that will be used 4079 to name the operation being created. 4080 """ 4081 if self._name_stack: 4082 name = self._name_stack + "/" + name 4083 4084 # For the sake of checking for names in use, we treat names as case 4085 # insensitive (e.g. foo = Foo). 4086 name_key = name.lower() 4087 i = self._names_in_use.get(name_key, 0) 4088 # Increment the number for "name_key". 4089 if mark_as_used: 4090 self._names_in_use[name_key] = i + 1 4091 if i > 0: 4092 base_name_key = name_key 4093 # Make sure the composed name key is not already used. 4094 while name_key in self._names_in_use: 4095 name_key = "%s_%d" % (base_name_key, i) 4096 i += 1 4097 # Mark the composed name_key as used in case someone wants 4098 # to call unique_name("name_1"). 4099 if mark_as_used: 4100 self._names_in_use[name_key] = 1 4101 4102 # Return the new name with the original capitalization of the given name. 4103 name = "%s_%d" % (name, i - 1) 4104 return name 4105 4106 def get_name_scope(self): 4107 """Returns the current name scope. 4108 4109 For example: 4110 4111 ```python 4112 with tf.name_scope('scope1'): 4113 with tf.name_scope('scope2'): 4114 print(tf.compat.v1.get_default_graph().get_name_scope()) 4115 ``` 4116 would print the string `scope1/scope2`. 4117 4118 Returns: 4119 A string representing the current name scope. 4120 """ 4121 return self._name_stack 4122 4123 @tf_contextlib.contextmanager 4124 def _colocate_with_for_gradient(self, op, gradient_uid, 4125 ignore_existing=False): 4126 with self.colocate_with(op, ignore_existing): 4127 if gradient_uid is not None and self._control_flow_context is not None: 4128 self._control_flow_context.EnterGradientColocation(op, gradient_uid) 4129 try: 4130 yield 4131 finally: 4132 self._control_flow_context.ExitGradientColocation(op, gradient_uid) 4133 else: 4134 yield 4135 4136 @tf_contextlib.contextmanager 4137 def colocate_with(self, op, ignore_existing=False): 4138 """Returns a context manager that specifies an op to colocate with. 4139 4140 Note: this function is not for public use, only for internal libraries. 4141 4142 For example: 4143 4144 ```python 4145 a = tf.Variable([1.0]) 4146 with g.colocate_with(a): 4147 b = tf.constant(1.0) 4148 c = tf.add(a, b) 4149 ``` 4150 4151 `b` and `c` will always be colocated with `a`, no matter where `a` 4152 is eventually placed. 4153 4154 **NOTE** Using a colocation scope resets any existing device constraints. 4155 4156 If `op` is `None` then `ignore_existing` must be `True` and the new 4157 scope resets all colocation and device constraints. 4158 4159 Args: 4160 op: The op to colocate all created ops with, or `None`. 4161 ignore_existing: If true, only applies colocation of this op within the 4162 context, rather than applying all colocation properties on the stack. 4163 If `op` is `None`, this value must be `True`. 4164 4165 Raises: 4166 ValueError: if op is None but ignore_existing is False. 4167 4168 Yields: 4169 A context manager that specifies the op with which to colocate 4170 newly created ops. 4171 """ 4172 if op is None and not ignore_existing: 4173 raise ValueError("Trying to reset colocation (op is None) but " 4174 "ignore_existing is not True") 4175 op = _op_to_colocate_with(op, self) 4176 4177 # By default, colocate_with resets the device function stack, 4178 # since colocate_with is typically used in specific internal 4179 # library functions where colocation is intended to be "stronger" 4180 # than device functions. 4181 # 4182 # In the future, a caller may specify that device_functions win 4183 # over colocation, in which case we can add support. 4184 device_fn_tmp = self._device_function_stack 4185 self._device_function_stack = traceable_stack.TraceableStack() 4186 4187 if ignore_existing: 4188 current_stack = self._colocation_stack 4189 self._colocation_stack = traceable_stack.TraceableStack() 4190 4191 if op is not None: 4192 # offset refers to the stack frame used for storing code location. 4193 # We use 4, the sum of 1 to use our caller's stack frame and 3 4194 # to jump over layers of context managers above us. 4195 self._colocation_stack.push_obj(op, offset=4) 4196 4197 try: 4198 yield 4199 finally: 4200 # Restore device function stack 4201 self._device_function_stack = device_fn_tmp 4202 if op is not None: 4203 self._colocation_stack.pop_obj() 4204 4205 # Reset the colocation stack if requested. 4206 if ignore_existing: 4207 self._colocation_stack = current_stack 4208 4209 def _add_device_to_stack(self, device_name_or_function, offset=0): 4210 """Add device to stack manually, separate from a context manager.""" 4211 total_offset = 1 + offset 4212 spec = _UserDeviceSpec(device_name_or_function) 4213 self._device_function_stack.push_obj(spec, offset=total_offset) 4214 return spec 4215 4216 @tf_contextlib.contextmanager 4217 def device(self, device_name_or_function): 4218 # pylint: disable=line-too-long 4219 """Returns a context manager that specifies the default device to use. 4220 4221 The `device_name_or_function` argument may either be a device name 4222 string, a device function, or None: 4223 4224 * If it is a device name string, all operations constructed in 4225 this context will be assigned to the device with that name, unless 4226 overridden by a nested `device()` context. 4227 * If it is a function, it will be treated as a function from 4228 Operation objects to device name strings, and invoked each time 4229 a new Operation is created. The Operation will be assigned to 4230 the device with the returned name. 4231 * If it is None, all `device()` invocations from the enclosing context 4232 will be ignored. 4233 4234 For information about the valid syntax of device name strings, see 4235 the documentation in 4236 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4237 4238 For example: 4239 4240 ```python 4241 with g.device('/device:GPU:0'): 4242 # All operations constructed in this context will be placed 4243 # on GPU 0. 4244 with g.device(None): 4245 # All operations constructed in this context will have no 4246 # assigned device. 4247 4248 # Defines a function from `Operation` to device string. 4249 def matmul_on_gpu(n): 4250 if n.type == "MatMul": 4251 return "/device:GPU:0" 4252 else: 4253 return "/cpu:0" 4254 4255 with g.device(matmul_on_gpu): 4256 # All operations of type "MatMul" constructed in this context 4257 # will be placed on GPU 0; all other operations will be placed 4258 # on CPU 0. 4259 ``` 4260 4261 **N.B.** The device scope may be overridden by op wrappers or 4262 other library code. For example, a variable assignment op 4263 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4264 incompatible device scopes will be ignored. 4265 4266 Args: 4267 device_name_or_function: The device name or function to use in the 4268 context. 4269 4270 Yields: 4271 A context manager that specifies the default device to use for newly 4272 created ops. 4273 4274 Raises: 4275 RuntimeError: If device scopes are not properly nested. 4276 """ 4277 self._add_device_to_stack(device_name_or_function, offset=2) 4278 old_top_of_stack = self._device_function_stack.peek_top_obj() 4279 try: 4280 yield 4281 finally: 4282 new_top_of_stack = self._device_function_stack.peek_top_obj() 4283 if old_top_of_stack is not new_top_of_stack: 4284 raise RuntimeError("Exiting device scope without proper scope nesting.") 4285 self._device_function_stack.pop_obj() 4286 4287 def _apply_device_functions(self, op): 4288 """Applies the current device function stack to the given operation.""" 4289 # Apply any device functions in LIFO order, so that the most recently 4290 # pushed function has the first chance to apply a device to the op. 4291 # We apply here because the result can depend on the Operation's 4292 # signature, which is computed in the Operation constructor. 4293 # pylint: disable=protected-access 4294 prior_device_string = None 4295 for device_spec in self._device_function_stack.peek_objs(): 4296 if device_spec.is_null_merge: 4297 continue 4298 4299 if device_spec.function is None: 4300 break 4301 4302 device_string = device_spec.string_merge(op) 4303 4304 # Take advantage of the fact that None is a singleton and Python interns 4305 # strings, since identity checks are faster than equality checks. 4306 if device_string is not prior_device_string: 4307 op._set_device_from_string(device_string) 4308 prior_device_string = device_string 4309 op._device_code_locations = self._snapshot_device_function_stack_metadata() 4310 # pylint: enable=protected-access 4311 4312 # pylint: disable=g-doc-return-or-yield 4313 @tf_contextlib.contextmanager 4314 def container(self, container_name): 4315 """Returns a context manager that specifies the resource container to use. 4316 4317 Stateful operations, such as variables and queues, can maintain their 4318 states on devices so that they can be shared by multiple processes. 4319 A resource container is a string name under which these stateful 4320 operations are tracked. These resources can be released or cleared 4321 with `tf.Session.reset()`. 4322 4323 For example: 4324 4325 ```python 4326 with g.container('experiment0'): 4327 # All stateful Operations constructed in this context will be placed 4328 # in resource container "experiment0". 4329 v1 = tf.Variable([1.0]) 4330 v2 = tf.Variable([2.0]) 4331 with g.container("experiment1"): 4332 # All stateful Operations constructed in this context will be 4333 # placed in resource container "experiment1". 4334 v3 = tf.Variable([3.0]) 4335 q1 = tf.queue.FIFOQueue(10, tf.float32) 4336 # All stateful Operations constructed in this context will be 4337 # be created in the "experiment0". 4338 v4 = tf.Variable([4.0]) 4339 q1 = tf.queue.FIFOQueue(20, tf.float32) 4340 with g.container(""): 4341 # All stateful Operations constructed in this context will be 4342 # be placed in the default resource container. 4343 v5 = tf.Variable([5.0]) 4344 q3 = tf.queue.FIFOQueue(30, tf.float32) 4345 4346 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4347 # will become undefined (such as uninitialized). 4348 tf.Session.reset(target, ["experiment0"]) 4349 ``` 4350 4351 Args: 4352 container_name: container name string. 4353 4354 Returns: 4355 A context manager for defining resource containers for stateful ops, 4356 yields the container name. 4357 """ 4358 original_container = self._container 4359 self._container = container_name 4360 try: 4361 yield self._container 4362 finally: 4363 self._container = original_container 4364 4365 # pylint: enable=g-doc-return-or-yield 4366 4367 class _ControlDependenciesController(object): 4368 """Context manager for `control_dependencies()`.""" 4369 4370 def __init__(self, graph, control_inputs): 4371 """Create a new `_ControlDependenciesController`. 4372 4373 A `_ControlDependenciesController` is the context manager for 4374 `with tf.control_dependencies()` blocks. These normally nest, 4375 as described in the documentation for `control_dependencies()`. 4376 4377 The `control_inputs` argument list control dependencies that must be 4378 added to the current set of control dependencies. Because of 4379 uniquification the set can be empty even if the caller passed a list of 4380 ops. The special value `None` indicates that we want to start a new 4381 empty set of control dependencies instead of extending the current set. 4382 4383 In that case we also clear the current control flow context, which is an 4384 additional mechanism to add control dependencies. 4385 4386 Args: 4387 graph: The graph that this controller is managing. 4388 control_inputs: List of ops to use as control inputs in addition to the 4389 current control dependencies. None to indicate that the dependencies 4390 should be cleared. 4391 """ 4392 self._graph = graph 4393 if control_inputs is None: 4394 self._control_inputs_val = [] 4395 self._new_stack = True 4396 else: 4397 self._control_inputs_val = control_inputs 4398 self._new_stack = False 4399 self._seen_nodes = set() 4400 self._old_stack = None 4401 self._old_control_flow_context = None 4402 4403# pylint: disable=protected-access 4404 4405 def __enter__(self): 4406 if self._new_stack: 4407 # Clear the control_dependencies graph. 4408 self._old_stack = self._graph._control_dependencies_stack 4409 self._graph._control_dependencies_stack = [] 4410 # Clear the control_flow_context too. 4411 self._old_control_flow_context = self._graph._get_control_flow_context() 4412 self._graph._set_control_flow_context(None) 4413 self._graph._push_control_dependencies_controller(self) 4414 4415 def __exit__(self, unused_type, unused_value, unused_traceback): 4416 self._graph._pop_control_dependencies_controller(self) 4417 if self._new_stack: 4418 self._graph._control_dependencies_stack = self._old_stack 4419 self._graph._set_control_flow_context(self._old_control_flow_context) 4420 4421# pylint: enable=protected-access 4422 4423 @property 4424 def control_inputs(self): 4425 return self._control_inputs_val 4426 4427 def add_op(self, op): 4428 if isinstance(op, Tensor): 4429 op = op.experimental_ref() 4430 self._seen_nodes.add(op) 4431 4432 def op_in_group(self, op): 4433 if isinstance(op, Tensor): 4434 op = op.experimental_ref() 4435 return op in self._seen_nodes 4436 4437 def _push_control_dependencies_controller(self, controller): 4438 self._control_dependencies_stack.append(controller) 4439 4440 def _pop_control_dependencies_controller(self, controller): 4441 assert self._control_dependencies_stack[-1] is controller 4442 self._control_dependencies_stack.pop() 4443 4444 def _current_control_dependencies(self): 4445 ret = set() 4446 for controller in self._control_dependencies_stack: 4447 for op in controller.control_inputs: 4448 ret.add(op) 4449 return ret 4450 4451 def _control_dependencies_for_inputs(self, input_ops): 4452 """For an op that takes `input_ops` as inputs, compute control inputs. 4453 4454 The returned control dependencies should yield an execution that 4455 is equivalent to adding all control inputs in 4456 self._control_dependencies_stack to a newly created op. However, 4457 this function attempts to prune the returned control dependencies 4458 by observing that nodes created within the same `with 4459 control_dependencies(...):` block may have data dependencies that make 4460 the explicit approach redundant. 4461 4462 Args: 4463 input_ops: The data input ops for an op to be created. 4464 4465 Returns: 4466 A list of control inputs for the op to be created. 4467 """ 4468 ret = [] 4469 for controller in self._control_dependencies_stack: 4470 # If any of the input_ops already depends on the inputs from controller, 4471 # we say that the new op is dominated (by that input), and we therefore 4472 # do not need to add control dependencies for this controller's inputs. 4473 dominated = False 4474 for op in input_ops: 4475 if controller.op_in_group(op): 4476 dominated = True 4477 break 4478 if not dominated: 4479 # Don't add a control input if we already have a data dependency on i. 4480 # NOTE(mrry): We do not currently track transitive data dependencies, 4481 # so we may add redundant control inputs. 4482 ret.extend(c for c in controller.control_inputs if c not in input_ops) 4483 return ret 4484 4485 def _record_op_seen_by_control_dependencies(self, op): 4486 """Record that the given op depends on all registered control dependencies. 4487 4488 Args: 4489 op: An Operation. 4490 """ 4491 for controller in self._control_dependencies_stack: 4492 controller.add_op(op) 4493 4494 def control_dependencies(self, control_inputs): 4495 """Returns a context manager that specifies control dependencies. 4496 4497 Use with the `with` keyword to specify that all operations constructed 4498 within the context should have control dependencies on 4499 `control_inputs`. For example: 4500 4501 ```python 4502 with g.control_dependencies([a, b, c]): 4503 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4504 d = ... 4505 e = ... 4506 ``` 4507 4508 Multiple calls to `control_dependencies()` can be nested, and in 4509 that case a new `Operation` will have control dependencies on the union 4510 of `control_inputs` from all active contexts. 4511 4512 ```python 4513 with g.control_dependencies([a, b]): 4514 # Ops constructed here run after `a` and `b`. 4515 with g.control_dependencies([c, d]): 4516 # Ops constructed here run after `a`, `b`, `c`, and `d`. 4517 ``` 4518 4519 You can pass None to clear the control dependencies: 4520 4521 ```python 4522 with g.control_dependencies([a, b]): 4523 # Ops constructed here run after `a` and `b`. 4524 with g.control_dependencies(None): 4525 # Ops constructed here run normally, not waiting for either `a` or `b`. 4526 with g.control_dependencies([c, d]): 4527 # Ops constructed here run after `c` and `d`, also not waiting 4528 # for either `a` or `b`. 4529 ``` 4530 4531 *N.B.* The control dependencies context applies *only* to ops that 4532 are constructed within the context. Merely using an op or tensor 4533 in the context does not add a control dependency. The following 4534 example illustrates this point: 4535 4536 ```python 4537 # WRONG 4538 def my_func(pred, tensor): 4539 t = tf.matmul(tensor, tensor) 4540 with tf.control_dependencies([pred]): 4541 # The matmul op is created outside the context, so no control 4542 # dependency will be added. 4543 return t 4544 4545 # RIGHT 4546 def my_func(pred, tensor): 4547 with tf.control_dependencies([pred]): 4548 # The matmul op is created in the context, so a control dependency 4549 # will be added. 4550 return tf.matmul(tensor, tensor) 4551 ``` 4552 4553 Also note that though execution of ops created under this scope will trigger 4554 execution of the dependencies, the ops created under this scope might still 4555 be pruned from a normal tensorflow graph. For example, in the following 4556 snippet of code the dependencies are never executed: 4557 4558 ```python 4559 loss = model.loss() 4560 with tf.control_dependencies(dependencies): 4561 loss = loss + tf.constant(1) # note: dependencies ignored in the 4562 # backward pass 4563 return tf.gradients(loss, model.variables) 4564 ``` 4565 4566 This is because evaluating the gradient graph does not require evaluating 4567 the constant(1) op created in the forward pass. 4568 4569 Args: 4570 control_inputs: A list of `Operation` or `Tensor` objects which must be 4571 executed or computed before running the operations defined in the 4572 context. Can also be `None` to clear the control dependencies. 4573 4574 Returns: 4575 A context manager that specifies control dependencies for all 4576 operations constructed within the context. 4577 4578 Raises: 4579 TypeError: If `control_inputs` is not a list of `Operation` or 4580 `Tensor` objects. 4581 """ 4582 if control_inputs is None: 4583 return self._ControlDependenciesController(self, None) 4584 # First convert the inputs to ops, and deduplicate them. 4585 # NOTE(mrry): Other than deduplication, we do not currently track direct 4586 # or indirect dependencies between control_inputs, which may result in 4587 # redundant control inputs. 4588 control_ops = [] 4589 current = self._current_control_dependencies() 4590 for c in control_inputs: 4591 # The hasattr(handle) is designed to match ResourceVariables. This is so 4592 # control dependencies on a variable or on an unread variable don't 4593 # trigger reads. 4594 if (isinstance(c, IndexedSlices) or 4595 (hasattr(c, "_handle") and hasattr(c, "op"))): 4596 c = c.op 4597 c = self.as_graph_element(c) 4598 if isinstance(c, Tensor): 4599 c = c.op 4600 elif not isinstance(c, Operation): 4601 raise TypeError("Control input must be Operation or Tensor: %s" % c) 4602 if c not in current: 4603 control_ops.append(c) 4604 current.add(c) 4605 return self._ControlDependenciesController(self, control_ops) 4606 4607 # pylint: disable=g-doc-return-or-yield 4608 @tf_contextlib.contextmanager 4609 def _attr_scope(self, attr_map): 4610 """EXPERIMENTAL: A context manager for setting attributes on operators. 4611 4612 This context manager can be used to add additional 4613 attributes to operators within the scope of the context. 4614 4615 For example: 4616 4617 with ops.Graph().as_default() as g: 4618 f_1 = Foo() # No extra attributes 4619 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 4620 f_2 = Foo() # Additional attribute _a=False 4621 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 4622 f_3 = Foo() # Additional attribute _a=False 4623 with g._attr_scope({"_a": None}): 4624 f_4 = Foo() # No additional attributes. 4625 4626 Args: 4627 attr_map: A dictionary mapping attr name strings to AttrValue protocol 4628 buffers or None. 4629 4630 Returns: 4631 A context manager that sets the kernel label to be used for one or more 4632 ops created in that context. 4633 4634 Raises: 4635 TypeError: If attr_map is not a dictionary mapping 4636 strings to AttrValue protobufs. 4637 """ 4638 if not isinstance(attr_map, dict): 4639 raise TypeError("attr_map must be a dictionary mapping " 4640 "strings to AttrValue protocol buffers") 4641 # The saved_attrs dictionary stores any currently-set labels that 4642 # will be overridden by this context manager. 4643 saved_attrs = {} 4644 # Install the given attribute 4645 for name, attr in attr_map.items(): 4646 if not (isinstance(name, six.string_types) and 4647 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 4648 callable(attr))): 4649 raise TypeError("attr_map must be a dictionary mapping " 4650 "strings to AttrValue protocol buffers or " 4651 "callables that emit AttrValue protocol buffers") 4652 try: 4653 saved_attrs[name] = self._attr_scope_map[name] 4654 except KeyError: 4655 pass 4656 if attr is None: 4657 del self._attr_scope_map[name] 4658 else: 4659 self._attr_scope_map[name] = attr 4660 try: 4661 yield # The code within the context runs here. 4662 finally: 4663 # Remove the attributes set for this context, and restore any saved 4664 # attributes. 4665 for name, attr in attr_map.items(): 4666 try: 4667 self._attr_scope_map[name] = saved_attrs[name] 4668 except KeyError: 4669 del self._attr_scope_map[name] 4670 4671 # pylint: enable=g-doc-return-or-yield 4672 4673 # pylint: disable=g-doc-return-or-yield 4674 @tf_contextlib.contextmanager 4675 def _kernel_label_map(self, op_to_kernel_label_map): 4676 """EXPERIMENTAL: A context manager for setting kernel labels. 4677 4678 This context manager can be used to select particular 4679 implementations of kernels within the scope of the context. 4680 4681 For example: 4682 4683 with ops.Graph().as_default() as g: 4684 f_1 = Foo() # Uses the default registered kernel for the Foo op. 4685 with g.kernel_label_map({"Foo": "v_2"}): 4686 f_2 = Foo() # Uses the registered kernel with label "v_2" 4687 # for the Foo op. 4688 with g.kernel_label_map({"Foo": "v_3"}): 4689 f_3 = Foo() # Uses the registered kernel with label "v_3" 4690 # for the Foo op. 4691 with g.kernel_label_map({"Foo": ""}): 4692 f_4 = Foo() # Uses the default registered kernel 4693 # for the Foo op. 4694 4695 Args: 4696 op_to_kernel_label_map: A dictionary mapping op type strings to kernel 4697 label strings. 4698 4699 Returns: 4700 A context manager that sets the kernel label to be used for one or more 4701 ops created in that context. 4702 4703 Raises: 4704 TypeError: If op_to_kernel_label_map is not a dictionary mapping 4705 strings to strings. 4706 """ 4707 if not isinstance(op_to_kernel_label_map, dict): 4708 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4709 "strings to strings") 4710 # The saved_labels dictionary stores any currently-set labels that 4711 # will be overridden by this context manager. 4712 saved_labels = {} 4713 # Install the given label 4714 for op_type, label in op_to_kernel_label_map.items(): 4715 if not (isinstance(op_type, six.string_types) and 4716 isinstance(label, six.string_types)): 4717 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4718 "strings to strings") 4719 try: 4720 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 4721 except KeyError: 4722 pass 4723 self._op_to_kernel_label_map[op_type] = label 4724 try: 4725 yield # The code within the context runs here. 4726 finally: 4727 # Remove the labels set for this context, and restore any saved labels. 4728 for op_type, label in op_to_kernel_label_map.items(): 4729 try: 4730 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 4731 except KeyError: 4732 del self._op_to_kernel_label_map[op_type] 4733 4734 # pylint: enable=g-doc-return-or-yield 4735 4736 @tf_contextlib.contextmanager 4737 def _override_gradient_function(self, gradient_function_map): 4738 """Specify gradient function for the given op type.""" 4739 4740 # This is an internal API and we don't need nested context for this. 4741 assert not self._gradient_function_map 4742 self._gradient_function_map = gradient_function_map 4743 yield 4744 self._gradient_function_map = {} 4745 4746 # pylint: disable=g-doc-return-or-yield 4747 @tf_contextlib.contextmanager 4748 def gradient_override_map(self, op_type_map): 4749 """EXPERIMENTAL: A context manager for overriding gradient functions. 4750 4751 This context manager can be used to override the gradient function 4752 that will be used for ops within the scope of the context. 4753 4754 For example: 4755 4756 ```python 4757 @tf.RegisterGradient("CustomSquare") 4758 def _custom_square_grad(op, grad): 4759 # ... 4760 4761 with tf.Graph().as_default() as g: 4762 c = tf.constant(5.0) 4763 s_1 = tf.square(c) # Uses the default gradient for tf.square. 4764 with g.gradient_override_map({"Square": "CustomSquare"}): 4765 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 4766 # gradient of s_2. 4767 ``` 4768 4769 Args: 4770 op_type_map: A dictionary mapping op type strings to alternative op type 4771 strings. 4772 4773 Returns: 4774 A context manager that sets the alternative op type to be used for one 4775 or more ops created in that context. 4776 4777 Raises: 4778 TypeError: If `op_type_map` is not a dictionary mapping strings to 4779 strings. 4780 """ 4781 if not isinstance(op_type_map, dict): 4782 raise TypeError("op_type_map must be a dictionary mapping " 4783 "strings to strings") 4784 # The saved_mappings dictionary stores any currently-set mappings that 4785 # will be overridden by this context manager. 4786 saved_mappings = {} 4787 # Install the given label 4788 for op_type, mapped_op_type in op_type_map.items(): 4789 if not (isinstance(op_type, six.string_types) and 4790 isinstance(mapped_op_type, six.string_types)): 4791 raise TypeError("op_type_map must be a dictionary mapping " 4792 "strings to strings") 4793 try: 4794 saved_mappings[op_type] = self._gradient_override_map[op_type] 4795 except KeyError: 4796 pass 4797 self._gradient_override_map[op_type] = mapped_op_type 4798 try: 4799 yield # The code within the context runs here. 4800 finally: 4801 # Remove the labels set for this context, and restore any saved labels. 4802 for op_type, mapped_op_type in op_type_map.items(): 4803 try: 4804 self._gradient_override_map[op_type] = saved_mappings[op_type] 4805 except KeyError: 4806 del self._gradient_override_map[op_type] 4807 4808 # pylint: enable=g-doc-return-or-yield 4809 4810 def prevent_feeding(self, tensor): 4811 """Marks the given `tensor` as unfeedable in this graph.""" 4812 self._unfeedable_tensors.add(tensor) 4813 4814 def is_feedable(self, tensor): 4815 """Returns `True` if and only if `tensor` is feedable.""" 4816 return tensor not in self._unfeedable_tensors 4817 4818 def prevent_fetching(self, op): 4819 """Marks the given `op` as unfetchable in this graph.""" 4820 self._unfetchable_ops.add(op) 4821 4822 def is_fetchable(self, tensor_or_op): 4823 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 4824 if isinstance(tensor_or_op, Tensor): 4825 return tensor_or_op.op not in self._unfetchable_ops 4826 else: 4827 return tensor_or_op not in self._unfetchable_ops 4828 4829 def switch_to_thread_local(self): 4830 """Make device, colocation and dependencies stacks thread-local. 4831 4832 Device, colocation and dependencies stacks are not thread-local be default. 4833 If multiple threads access them, then the state is shared. This means that 4834 one thread may affect the behavior of another thread. 4835 4836 After this method is called, the stacks become thread-local. If multiple 4837 threads access them, then the state is not shared. Each thread uses its own 4838 value; a thread doesn't affect other threads by mutating such a stack. 4839 4840 The initial value for every thread's stack is set to the current value 4841 of the stack when `switch_to_thread_local()` was first called. 4842 """ 4843 if not self._stack_state_is_thread_local: 4844 self._stack_state_is_thread_local = True 4845 4846 @property 4847 def _device_function_stack(self): 4848 if self._stack_state_is_thread_local: 4849 # This may be called from a thread where device_function_stack doesn't yet 4850 # exist. 4851 # pylint: disable=protected-access 4852 if not hasattr(self._thread_local, "_device_function_stack"): 4853 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 4854 self._thread_local._device_function_stack = stack_copy_for_this_thread 4855 return self._thread_local._device_function_stack 4856 # pylint: enable=protected-access 4857 else: 4858 return self._graph_device_function_stack 4859 4860 @property 4861 def _device_functions_outer_to_inner(self): 4862 user_device_specs = self._device_function_stack.peek_objs() 4863 device_functions = [spec.function for spec in user_device_specs] 4864 device_functions_outer_to_inner = list(reversed(device_functions)) 4865 return device_functions_outer_to_inner 4866 4867 def _snapshot_device_function_stack_metadata(self): 4868 """Return device function stack as a list of TraceableObjects. 4869 4870 Returns: 4871 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 4872 member is a displayable name for the user's argument to Graph.device, and 4873 the filename and lineno members point to the code location where 4874 Graph.device was called directly or indirectly by the user. 4875 """ 4876 snapshot = [] 4877 for obj in self._device_function_stack.peek_traceable_objs(): 4878 obj_copy = obj.copy_metadata() 4879 obj_copy.obj = obj.obj.display_name 4880 snapshot.append(obj_copy) 4881 return snapshot 4882 4883 @_device_function_stack.setter 4884 def _device_function_stack(self, device_function_stack): 4885 if self._stack_state_is_thread_local: 4886 # pylint: disable=protected-access 4887 self._thread_local._device_function_stack = device_function_stack 4888 # pylint: enable=protected-access 4889 else: 4890 self._graph_device_function_stack = device_function_stack 4891 4892 @property 4893 def _colocation_stack(self): 4894 """Return thread-local copy of colocation stack.""" 4895 if self._stack_state_is_thread_local: 4896 # This may be called from a thread where colocation_stack doesn't yet 4897 # exist. 4898 # pylint: disable=protected-access 4899 if not hasattr(self._thread_local, "_colocation_stack"): 4900 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 4901 self._thread_local._colocation_stack = stack_copy_for_this_thread 4902 return self._thread_local._colocation_stack 4903 # pylint: enable=protected-access 4904 else: 4905 return self._graph_colocation_stack 4906 4907 def _snapshot_colocation_stack_metadata(self): 4908 """Return colocation stack metadata as a dictionary.""" 4909 return { 4910 traceable_obj.obj.name: traceable_obj.copy_metadata() 4911 for traceable_obj in self._colocation_stack.peek_traceable_objs() 4912 } 4913 4914 @_colocation_stack.setter 4915 def _colocation_stack(self, colocation_stack): 4916 if self._stack_state_is_thread_local: 4917 # pylint: disable=protected-access 4918 self._thread_local._colocation_stack = colocation_stack 4919 # pylint: enable=protected-access 4920 else: 4921 self._graph_colocation_stack = colocation_stack 4922 4923 @property 4924 def _control_dependencies_stack(self): 4925 if self._stack_state_is_thread_local: 4926 # This may be called from a thread where control_dependencies_stack 4927 # doesn't yet exist. 4928 if not hasattr(self._thread_local, "_control_dependencies_stack"): 4929 self._thread_local._control_dependencies_stack = ( 4930 self._graph_control_dependencies_stack[:]) 4931 return self._thread_local._control_dependencies_stack 4932 else: 4933 return self._graph_control_dependencies_stack 4934 4935 @_control_dependencies_stack.setter 4936 def _control_dependencies_stack(self, control_dependencies): 4937 if self._stack_state_is_thread_local: 4938 self._thread_local._control_dependencies_stack = control_dependencies 4939 else: 4940 self._graph_control_dependencies_stack = control_dependencies 4941 4942 @property 4943 def _distribution_strategy_stack(self): 4944 """A stack to maintain distribution strategy context for each thread.""" 4945 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 4946 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 4947 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 4948 4949 @_distribution_strategy_stack.setter 4950 def _distribution_strategy_stack(self, _distribution_strategy_stack): 4951 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 4952 _distribution_strategy_stack) 4953 4954 @property 4955 def _global_distribute_strategy_scope(self): 4956 """For implementing `tf.distribute.set_strategy()`.""" 4957 if not hasattr(self._thread_local, "distribute_strategy_scope"): 4958 self._thread_local.distribute_strategy_scope = None 4959 return self._thread_local.distribute_strategy_scope 4960 4961 @_global_distribute_strategy_scope.setter 4962 def _global_distribute_strategy_scope(self, distribute_strategy_scope): 4963 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope) 4964 4965 @property 4966 def _auto_cast_variable_read_dtype(self): 4967 """The dtype that instances of `AutoCastVariable` will be casted to. 4968 4969 This is None if `AutoCastVariables` should not be casted. 4970 4971 See `AutoCastVariable` for more information. 4972 4973 Returns: 4974 The dtype that instances of `AutoCastVariable` will be casted to. 4975 """ 4976 if not hasattr(self._thread_local, "_auto_cast_variable_read_dtype"): 4977 self._thread_local._auto_cast_variable_read_dtype = None # pylint: disable=protected-access 4978 return self._thread_local._auto_cast_variable_read_dtype # pylint: disable=protected-access 4979 4980 @_auto_cast_variable_read_dtype.setter 4981 def _auto_cast_variable_read_dtype(self, dtype): 4982 if dtype: 4983 dtype = dtypes.as_dtype(dtype) 4984 self._thread_local._auto_cast_variable_read_dtype = dtype # pylint: disable=protected-access 4985 4986 @tf_contextlib.contextmanager 4987 def _enable_auto_casting_variables(self, dtype): 4988 """Context manager to automatically cast AutoCastVariables. 4989 4990 If an AutoCastVariable `var` is used under this context manager, it will be 4991 casted to `dtype` before being used. 4992 4993 See `AutoCastVariable` for more information. 4994 4995 Args: 4996 dtype: The dtype that AutoCastVariables should be casted to. 4997 4998 Yields: 4999 Nothing. 5000 """ 5001 prev_read_dtype = self._auto_cast_variable_read_dtype 5002 try: 5003 self._auto_cast_variable_read_dtype = dtype 5004 yield 5005 finally: 5006 self._auto_cast_variable_read_dtype = prev_read_dtype 5007 5008 def _mutation_lock(self): 5009 """Returns a lock to guard code that creates & mutates ops. 5010 5011 See the comment for self._group_lock for more info. 5012 """ 5013 return self._group_lock.group(_MUTATION_LOCK_GROUP) 5014 5015 def _session_run_lock(self): 5016 """Returns a lock to guard code for Session.run. 5017 5018 See the comment for self._group_lock for more info. 5019 """ 5020 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 5021 5022 5023# TODO(agarwal): currently device directives in an outer eager scope will not 5024# apply to inner graph mode code. Fix that. 5025 5026 5027@tf_export(v1=["device"]) 5028def device(device_name_or_function): 5029 """Wrapper for `Graph.device()` using the default graph. 5030 5031 See `tf.Graph.device` for more details. 5032 5033 Args: 5034 device_name_or_function: The device name or function to use in the context. 5035 5036 Returns: 5037 A context manager that specifies the default device to use for newly 5038 created ops. 5039 5040 Raises: 5041 RuntimeError: If eager execution is enabled and a function is passed in. 5042 """ 5043 if context.executing_eagerly(): 5044 if callable(device_name_or_function): 5045 raise RuntimeError( 5046 "tf.device does not support functions when eager execution " 5047 "is enabled.") 5048 return context.device(device_name_or_function) 5049 elif executing_eagerly_outside_functions(): 5050 @tf_contextlib.contextmanager 5051 def combined(device_name_or_function): 5052 with get_default_graph().device(device_name_or_function): 5053 if not callable(device_name_or_function): 5054 with context.device(device_name_or_function): 5055 yield 5056 else: 5057 yield 5058 return combined(device_name_or_function) 5059 else: 5060 return get_default_graph().device(device_name_or_function) 5061 5062 5063@tf_export("device", v1=[]) 5064def device_v2(device_name): 5065 """Specifies the device for ops created/executed in this context. 5066 5067 This function specifies the device to be used for ops created/executed in a 5068 particular context. Nested contexts will inherit and also create/execute 5069 their ops on the specified device. If a specific device is not required, 5070 consider not using this function so that a device can be automatically 5071 assigned. In general the use of this function is optional. `device_name` can 5072 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially 5073 specified, containing only a subset of the "/"-separated fields. Any fields 5074 which are specified will override device annotations from outer scopes. 5075 5076 For example: 5077 5078 ```python 5079 with tf.device('/job:foo'): 5080 # ops created here have devices with /job:foo 5081 with tf.device('/job:bar/task:0/device:gpu:2'): 5082 # ops created here have the fully specified device above 5083 with tf.device('/device:gpu:1'): 5084 # ops created here have the device '/job:foo/device:gpu:1' 5085 ``` 5086 5087 Args: 5088 device_name: The device name to use in the context. 5089 5090 Returns: 5091 A context manager that specifies the default device to use for newly 5092 created ops. 5093 5094 Raises: 5095 RuntimeError: If a function is passed in. 5096 """ 5097 if callable(device_name): 5098 raise RuntimeError("tf.device does not support functions.") 5099 return device(device_name) 5100 5101 5102@tf_export(v1=["container"]) 5103def container(container_name): 5104 """Wrapper for `Graph.container()` using the default graph. 5105 5106 Args: 5107 container_name: The container string to use in the context. 5108 5109 Returns: 5110 A context manager that specifies the default container to use for newly 5111 created stateful ops. 5112 """ 5113 return get_default_graph().container(container_name) 5114 5115 5116def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 5117 if context.executing_eagerly(): 5118 if op is not None: 5119 if not hasattr(op, "device"): 5120 op = internal_convert_to_tensor_or_indexed_slices(op) 5121 return device(op.device) 5122 else: 5123 return NullContextmanager() 5124 else: 5125 default_graph = get_default_graph() 5126 if isinstance(op, EagerTensor): 5127 if default_graph.building_function: 5128 return default_graph.device(op.device) 5129 else: 5130 raise ValueError("Encountered an Eager-defined Tensor during graph " 5131 "construction, but a function was not being built.") 5132 return default_graph._colocate_with_for_gradient( 5133 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 5134 5135 5136# Internal interface to colocate_with. colocate_with has been deprecated from 5137# public API. There are still a few internal uses of colocate_with. Add internal 5138# only API for those uses to avoid deprecation warning. 5139def colocate_with(op, ignore_existing=False): 5140 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 5141 5142 5143@deprecation.deprecated( 5144 date=None, instructions="Colocations handled automatically by placer.") 5145@tf_export(v1=["colocate_with"]) 5146def _colocate_with(op, ignore_existing=False): 5147 return colocate_with(op, ignore_existing) 5148 5149 5150@tf_export("control_dependencies") 5151def control_dependencies(control_inputs): 5152 """Wrapper for `Graph.control_dependencies()` using the default graph. 5153 5154 See `tf.Graph.control_dependencies` 5155 for more details. 5156 5157 When eager execution is enabled, any callable object in the `control_inputs` 5158 list will be called. 5159 5160 Args: 5161 control_inputs: A list of `Operation` or `Tensor` objects which must be 5162 executed or computed before running the operations defined in the context. 5163 Can also be `None` to clear the control dependencies. If eager execution 5164 is enabled, any callable object in the `control_inputs` list will be 5165 called. 5166 5167 Returns: 5168 A context manager that specifies control dependencies for all 5169 operations constructed within the context. 5170 """ 5171 if context.executing_eagerly(): 5172 if control_inputs: 5173 # Excute any pending callables. 5174 for control in control_inputs: 5175 if callable(control): 5176 control() 5177 return NullContextmanager() 5178 else: 5179 return get_default_graph().control_dependencies(control_inputs) 5180 5181 5182class _DefaultStack(threading.local): 5183 """A thread-local stack of objects for providing implicit defaults.""" 5184 5185 def __init__(self): 5186 super(_DefaultStack, self).__init__() 5187 self._enforce_nesting = True 5188 self.stack = [] 5189 5190 def get_default(self): 5191 return self.stack[-1] if len(self.stack) >= 1 else None 5192 5193 def reset(self): 5194 self.stack = [] 5195 5196 def is_cleared(self): 5197 return not self.stack 5198 5199 @property 5200 def enforce_nesting(self): 5201 return self._enforce_nesting 5202 5203 @enforce_nesting.setter 5204 def enforce_nesting(self, value): 5205 self._enforce_nesting = value 5206 5207 @tf_contextlib.contextmanager 5208 def get_controller(self, default): 5209 """A context manager for manipulating a default stack.""" 5210 self.stack.append(default) 5211 try: 5212 yield default 5213 finally: 5214 # stack may be empty if reset() was called 5215 if self.stack: 5216 if self._enforce_nesting: 5217 if self.stack[-1] is not default: 5218 raise AssertionError( 5219 "Nesting violated for default stack of %s objects" % 5220 type(default)) 5221 self.stack.pop() 5222 else: 5223 self.stack.remove(default) 5224 5225 5226_default_session_stack = _DefaultStack() # pylint: disable=protected-access 5227 5228 5229def default_session(session): 5230 """Python "with" handler for defining a default session. 5231 5232 This function provides a means of registering a session for handling 5233 Tensor.eval() and Operation.run() calls. It is primarily intended for use 5234 by session.Session, but can be used with any object that implements 5235 the Session.run() interface. 5236 5237 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 5238 invocations within the scope of a block should be executed by a particular 5239 session. 5240 5241 The default session applies to the current thread only, so it is always 5242 possible to inspect the call stack and determine the scope of a default 5243 session. If you create a new thread, and wish to use the default session 5244 in that thread, you must explicitly add a "with ops.default_session(sess):" 5245 block in that thread's function. 5246 5247 Example: 5248 The following code examples are equivalent: 5249 5250 # 1. Using the Session object directly: 5251 sess = ... 5252 c = tf.constant(5.0) 5253 sess.run(c) 5254 5255 # 2. Using default_session(): 5256 sess = ... 5257 with ops.default_session(sess): 5258 c = tf.constant(5.0) 5259 result = c.eval() 5260 5261 # 3. Overriding default_session(): 5262 sess = ... 5263 with ops.default_session(sess): 5264 c = tf.constant(5.0) 5265 with ops.default_session(...): 5266 c.eval(session=sess) 5267 5268 Args: 5269 session: The session to be installed as the default session. 5270 5271 Returns: 5272 A context manager for the default session. 5273 """ 5274 return _default_session_stack.get_controller(session) 5275 5276 5277@tf_export(v1=["get_default_session"]) 5278def get_default_session(): 5279 """Returns the default session for the current thread. 5280 5281 The returned `Session` will be the innermost session on which a 5282 `Session` or `Session.as_default()` context has been entered. 5283 5284 NOTE: The default session is a property of the current thread. If you 5285 create a new thread, and wish to use the default session in that 5286 thread, you must explicitly add a `with sess.as_default():` in that 5287 thread's function. 5288 5289 Returns: 5290 The default `Session` being used in the current thread. 5291 """ 5292 return _default_session_stack.get_default() 5293 5294 5295def _eval_using_default_session(tensors, feed_dict, graph, session=None): 5296 """Uses the default session to evaluate one or more tensors. 5297 5298 Args: 5299 tensors: A single Tensor, or a list of Tensor objects. 5300 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5301 numpy ndarrays, TensorProtos, or strings. 5302 graph: The graph in which the tensors are defined. 5303 session: (Optional) A different session to use to evaluate "tensors". 5304 5305 Returns: 5306 Either a single numpy ndarray if "tensors" is a single tensor; or a list 5307 of numpy ndarrays that each correspond to the respective element in 5308 "tensors". 5309 5310 Raises: 5311 ValueError: If no default session is available; the default session 5312 does not have "graph" as its graph; or if "session" is specified, 5313 and it does not have "graph" as its graph. 5314 """ 5315 if session is None: 5316 session = get_default_session() 5317 if session is None: 5318 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 5319 "session is registered. Use `with " 5320 "sess.as_default()` or pass an explicit session to " 5321 "`eval(session=sess)`") 5322 if session.graph is not graph: 5323 raise ValueError("Cannot use the default session to evaluate tensor: " 5324 "the tensor's graph is different from the session's " 5325 "graph. Pass an explicit session to " 5326 "`eval(session=sess)`.") 5327 else: 5328 if session.graph is not graph: 5329 raise ValueError("Cannot use the given session to evaluate tensor: " 5330 "the tensor's graph is different from the session's " 5331 "graph.") 5332 return session.run(tensors, feed_dict) 5333 5334 5335def _run_using_default_session(operation, feed_dict, graph, session=None): 5336 """Uses the default session to run "operation". 5337 5338 Args: 5339 operation: The Operation to be run. 5340 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5341 numpy ndarrays, TensorProtos, or strings. 5342 graph: The graph in which "operation" is defined. 5343 session: (Optional) A different session to use to run "operation". 5344 5345 Raises: 5346 ValueError: If no default session is available; the default session 5347 does not have "graph" as its graph; or if "session" is specified, 5348 and it does not have "graph" as its graph. 5349 """ 5350 if session is None: 5351 session = get_default_session() 5352 if session is None: 5353 raise ValueError("Cannot execute operation using `run()`: No default " 5354 "session is registered. Use `with " 5355 "sess.as_default():` or pass an explicit session to " 5356 "`run(session=sess)`") 5357 if session.graph is not graph: 5358 raise ValueError("Cannot use the default session to execute operation: " 5359 "the operation's graph is different from the " 5360 "session's graph. Pass an explicit session to " 5361 "run(session=sess).") 5362 else: 5363 if session.graph is not graph: 5364 raise ValueError("Cannot use the given session to execute operation: " 5365 "the operation's graph is different from the session's " 5366 "graph.") 5367 session.run(operation, feed_dict) 5368 5369 5370class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 5371 """A thread-local stack of objects for providing an implicit default graph.""" 5372 5373 def __init__(self): 5374 super(_DefaultGraphStack, self).__init__() 5375 self._global_default_graph = None 5376 5377 def get_default(self): 5378 """Override that returns a global default if the stack is empty.""" 5379 ret = super(_DefaultGraphStack, self).get_default() 5380 if ret is None: 5381 ret = self._GetGlobalDefaultGraph() 5382 return ret 5383 5384 def _GetGlobalDefaultGraph(self): 5385 if self._global_default_graph is None: 5386 # TODO(mrry): Perhaps log that the default graph is being used, or set 5387 # provide some other feedback to prevent confusion when a mixture of 5388 # the global default graph and an explicit graph are combined in the 5389 # same process. 5390 self._global_default_graph = Graph() 5391 return self._global_default_graph 5392 5393 def reset(self): 5394 super(_DefaultGraphStack, self).reset() 5395 self._global_default_graph = None 5396 5397 @tf_contextlib.contextmanager 5398 def get_controller(self, default): 5399 context.context().context_switches.push(default.building_function, 5400 default.as_default, 5401 default._device_function_stack) 5402 try: 5403 with super(_DefaultGraphStack, 5404 self).get_controller(default) as g, context.graph_mode(): 5405 yield g 5406 finally: 5407 # If an exception is raised here it may be hiding a related exception in 5408 # the try-block (just above). 5409 context.context().context_switches.pop() 5410 5411 5412_default_graph_stack = _DefaultGraphStack() 5413 5414 5415# Shared helper used in init_scope and executing_eagerly_outside_functions 5416# to obtain the outermost context that is not building a function, and the 5417# innermost non empty device stack. 5418def _get_outer_context_and_inner_device_stack(): 5419 """Get the outermost context not building a function.""" 5420 default_graph = get_default_graph() 5421 outer_context = None 5422 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 5423 5424 if not _default_graph_stack.stack: 5425 # If the default graph stack is empty, then we cannot be building a 5426 # function. Install the global graph (which, in this case, is also the 5427 # default graph) as the outer context. 5428 if default_graph.building_function: 5429 raise RuntimeError("The global graph is building a function.") 5430 outer_context = default_graph.as_default 5431 else: 5432 # Find a context that is not building a function. 5433 for stack_entry in reversed(context.context().context_switches.stack): 5434 if not innermost_nonempty_device_stack: 5435 innermost_nonempty_device_stack = stack_entry.device_stack 5436 if not stack_entry.is_building_function: 5437 outer_context = stack_entry.enter_context_fn 5438 break 5439 5440 if outer_context is None: 5441 # As a last resort, obtain the global default graph; this graph doesn't 5442 # necessarily live on the graph stack (and hence it doesn't necessarily 5443 # live on the context stack), but it is stored in the graph stack's 5444 # encapsulating object. 5445 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 5446 5447 if outer_context is None: 5448 # Sanity check; this shouldn't be triggered. 5449 raise RuntimeError("All graphs are building functions, and no " 5450 "eager context was previously active.") 5451 5452 return outer_context, innermost_nonempty_device_stack 5453 5454 5455# pylint: disable=g-doc-return-or-yield,line-too-long 5456@tf_export("init_scope") 5457@tf_contextlib.contextmanager 5458def init_scope(): 5459 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 5460 5461 There is often a need to lift variable initialization ops out of control-flow 5462 scopes, function-building graphs, and gradient tapes. Entering an 5463 `init_scope` is a mechanism for satisfying these desiderata. In particular, 5464 entering an `init_scope` has three effects: 5465 5466 (1) All control dependencies are cleared the moment the scope is entered; 5467 this is equivalent to entering the context manager returned from 5468 `control_dependencies(None)`, which has the side-effect of exiting 5469 control-flow scopes like `tf.cond` and `tf.while_loop`. 5470 5471 (2) All operations that are created while the scope is active are lifted 5472 into the lowest context on the `context_stack` that is not building a 5473 graph function. Here, a context is defined as either a graph or an eager 5474 context. Every context switch, i.e., every installation of a graph as 5475 the default graph and every switch into eager mode, is logged in a 5476 thread-local stack called `context_switches`; the log entry for a 5477 context switch is popped from the stack when the context is exited. 5478 Entering an `init_scope` is equivalent to crawling up 5479 `context_switches`, finding the first context that is not building a 5480 graph function, and entering it. A caveat is that if graph mode is 5481 enabled but the default graph stack is empty, then entering an 5482 `init_scope` will simply install a fresh graph as the default one. 5483 5484 (3) The gradient tape is paused while the scope is active. 5485 5486 When eager execution is enabled, code inside an init_scope block runs with 5487 eager execution enabled even when tracing a `tf.function`. For example: 5488 5489 ```python 5490 tf.compat.v1.enable_eager_execution() 5491 5492 @tf.function 5493 def func(): 5494 # A function constructs TensorFlow graphs, 5495 # it does not execute eagerly. 5496 assert not tf.executing_eagerly() 5497 with tf.init_scope(): 5498 # Initialization runs with eager execution enabled 5499 assert tf.executing_eagerly() 5500 ``` 5501 5502 Raises: 5503 RuntimeError: if graph state is incompatible with this initialization. 5504 """ 5505 # pylint: enable=g-doc-return-or-yield,line-too-long 5506 5507 if context.executing_eagerly(): 5508 # Fastpath. 5509 with tape.stop_recording(): 5510 yield 5511 else: 5512 # Retrieve the active name scope: entering an `init_scope` preserves 5513 # the name scope of the current context. 5514 scope = get_default_graph().get_name_scope() 5515 if scope and scope[-1] != "/": 5516 # Names that end with trailing slashes are treated by `name_scope` as 5517 # absolute. 5518 scope = scope + "/" 5519 5520 outer_context, innermost_nonempty_device_stack = ( 5521 _get_outer_context_and_inner_device_stack()) 5522 5523 outer_graph = None 5524 outer_device_stack = None 5525 try: 5526 with outer_context(), name_scope( 5527 scope, skip_on_eager=False), control_dependencies( 5528 None), tape.stop_recording(): 5529 context_manager = NullContextmanager 5530 context_manager_input = None 5531 if not context.executing_eagerly(): 5532 # The device stack is preserved when lifting into a graph. Eager 5533 # execution doesn't implement device stacks and in particular it 5534 # doesn't support device functions, so in general it's not possible 5535 # to do the same when lifting into the eager context. 5536 outer_graph = get_default_graph() 5537 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 5538 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 5539 elif innermost_nonempty_device_stack is not None: 5540 for device_spec in innermost_nonempty_device_stack.peek_objs(): 5541 if device_spec.function is None: 5542 break 5543 if device_spec.raw_string: 5544 context_manager = context.device 5545 context_manager_input = device_spec.raw_string 5546 break 5547 # It is currently not possible to have a device function in V2, 5548 # but in V1 we are unable to apply device functions in eager mode. 5549 # This means that we will silently skip some of the entries on the 5550 # device stack in V1 + eager mode. 5551 5552 with context_manager(context_manager_input): 5553 yield 5554 finally: 5555 # If an exception is raised here it may be hiding a related exception in 5556 # try-block (just above). 5557 if outer_graph is not None: 5558 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 5559 5560 5561def executing_eagerly_outside_functions(): 5562 """Returns True if executing eagerly, even if inside a graph function.""" 5563 if context.executing_eagerly(): 5564 return True 5565 else: 5566 outer_context, _ = _get_outer_context_and_inner_device_stack() 5567 with outer_context(): 5568 return context.executing_eagerly() 5569 5570 5571def inside_function(): 5572 return get_default_graph().building_function 5573 5574 5575@tf_export(v1=["enable_eager_execution"]) 5576def enable_eager_execution(config=None, device_policy=None, 5577 execution_mode=None): 5578 """Enables eager execution for the lifetime of this program. 5579 5580 Eager execution provides an imperative interface to TensorFlow. With eager 5581 execution enabled, TensorFlow functions execute operations immediately (as 5582 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`) 5583 and 5584 return concrete values (as opposed to symbolic references to a node in a 5585 computational graph). 5586 5587 For example: 5588 5589 ```python 5590 tf.compat.v1.enable_eager_execution() 5591 5592 # After eager execution is enabled, operations are executed as they are 5593 # defined and Tensor objects hold concrete values, which can be accessed as 5594 # numpy.ndarray`s through the numpy() method. 5595 assert tf.multiply(6, 7).numpy() == 42 5596 ``` 5597 5598 Eager execution cannot be enabled after TensorFlow APIs have been used to 5599 create or execute graphs. It is typically recommended to invoke this function 5600 at program startup and not in a library (as most libraries should be usable 5601 both with and without eager execution). 5602 5603 Args: 5604 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the 5605 environment in which operations are executed. Note that 5606 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via 5607 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto` 5608 are not implemented (or are irrelevant) when eager execution is enabled. 5609 device_policy: (Optional.) Policy controlling how operations requiring 5610 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 5611 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will 5612 be picked automatically. The value picked may change between TensorFlow 5613 releases. 5614 Valid values: 5615 - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the 5616 placement is not correct. 5617 - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not 5618 on the right device but logs a warning. 5619 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. 5620 Note that this may hide performance problems as there is no notification 5621 provided when operations are blocked on the tensor being copied between 5622 devices. 5623 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 5624 int32 tensors, raising errors on the other ones. 5625 execution_mode: (Optional.) Policy controlling how operations dispatched are 5626 actually executed. When set to None, an appropriate value will be picked 5627 automatically. The value picked may change between TensorFlow releases. 5628 Valid values: 5629 - tf.contrib.eager.SYNC: executes each operation synchronously. 5630 - tf.contrib.eager.ASYNC: executes each operation asynchronously. These 5631 operations may return "non-ready" handles. 5632 5633 Raises: 5634 ValueError: If eager execution is enabled after creating/executing a 5635 TensorFlow graph, or if options provided conflict with a previous call 5636 to this function. 5637 """ 5638 _api_usage_gauge.get_cell().set(True) 5639 if context.default_execution_mode != context.EAGER_MODE: 5640 return enable_eager_execution_internal( 5641 config=config, 5642 device_policy=device_policy, 5643 execution_mode=execution_mode, 5644 server_def=None) 5645 5646 5647@tf_export(v1=["disable_eager_execution"]) 5648def disable_eager_execution(): 5649 """Disables eager execution. 5650 5651 This function can only be called before any Graphs, Ops, or Tensors have been 5652 created. It can be used at the beginning of the program for complex migration 5653 projects from TensorFlow 1.x to 2.x. 5654 """ 5655 _api_usage_gauge.get_cell().set(False) 5656 context.default_execution_mode = context.GRAPH_MODE 5657 c = context.context_safe() 5658 if c is not None: 5659 c._thread_local_data.is_eager = False # pylint: disable=protected-access 5660 5661 5662def enable_eager_execution_internal(config=None, 5663 device_policy=None, 5664 execution_mode=None, 5665 server_def=None): 5666 """Enables eager execution for the lifetime of this program. 5667 5668 Most of the doc string for enable_eager_execution is relevant here as well. 5669 5670 Args: 5671 config: See enable_eager_execution doc string 5672 device_policy: See enable_eager_execution doc string 5673 execution_mode: See enable_eager_execution doc string 5674 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on 5675 remote devices. GrpcServers need to be started by creating an identical 5676 server_def to this, and setting the appropriate task_indexes, so that the 5677 servers can communicate. It will then be possible to execute operations on 5678 remote devices. 5679 5680 Raises: 5681 ValueError 5682 5683 """ 5684 if config is not None and not isinstance(config, config_pb2.ConfigProto): 5685 raise TypeError("config must be a tf.ConfigProto, but got %s" % 5686 type(config)) 5687 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 5688 context.DEVICE_PLACEMENT_WARN, 5689 context.DEVICE_PLACEMENT_SILENT, 5690 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 5691 raise ValueError( 5692 "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" 5693 ) 5694 if execution_mode not in (None, context.SYNC, context.ASYNC): 5695 raise ValueError( 5696 "execution_mode must be one of None, tf.contrib.eager.SYNC, " 5697 "tf.contrib.eager.ASYNC") 5698 if context.default_execution_mode == context.GRAPH_MODE: 5699 graph_mode_has_been_used = ( 5700 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 5701 if graph_mode_has_been_used: 5702 raise ValueError( 5703 "tf.enable_eager_execution must be called at program startup.") 5704 context.default_execution_mode = context.EAGER_MODE 5705 # pylint: disable=protected-access 5706 with context._context_lock: 5707 if context._context is None: 5708 context._set_context_locked(context.Context( 5709 config=config, 5710 device_policy=device_policy, 5711 execution_mode=execution_mode, 5712 server_def=server_def)) 5713 elif ((config is not None and config is not context._context._config) or 5714 (device_policy is not None and 5715 device_policy is not context._context._device_policy) or 5716 (execution_mode is not None and 5717 execution_mode is not context._context._execution_mode)): 5718 raise ValueError( 5719 "Trying to change the options of an active eager" 5720 " execution. Context config: %s, specified config:" 5721 " %s. Context device policy: %s, specified device" 5722 " policy: %s. Context execution mode: %s, " 5723 " specified execution mode %s." % 5724 (context._context._config, config, context._context._device_policy, 5725 device_policy, context._context._execution_mode, execution_mode)) 5726 else: 5727 # We already created everything, so update the thread local data. 5728 context._context._thread_local_data.is_eager = True 5729 5730 # Monkey patch to get rid of an unnecessary conditional since the context is 5731 # now initialized. 5732 context.context = context.context_safe 5733 5734 5735def eager_run(main=None, argv=None): 5736 """Runs the program with an optional main function and argv list. 5737 5738 The program will run with eager execution enabled. 5739 5740 Example: 5741 ```python 5742 import tensorflow as tf 5743 # Import subject to future changes: 5744 from tensorflow.contrib.eager.python import tfe 5745 5746 def main(_): 5747 u = tf.constant(6.0) 5748 v = tf.constant(7.0) 5749 print(u * v) 5750 5751 if __name__ == "__main__": 5752 tfe.run() 5753 ``` 5754 5755 Args: 5756 main: the main function to run. 5757 argv: the arguments to pass to it. 5758 """ 5759 enable_eager_execution() 5760 app.run(main, argv) 5761 5762 5763@tf_export(v1=["reset_default_graph"]) 5764def reset_default_graph(): 5765 """Clears the default graph stack and resets the global default graph. 5766 5767 NOTE: The default graph is a property of the current thread. This 5768 function applies only to the current thread. Calling this function while 5769 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will 5770 result in undefined 5771 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 5772 after calling this function will result in undefined behavior. 5773 Raises: 5774 AssertionError: If this function is called within a nested graph. 5775 """ 5776 if not _default_graph_stack.is_cleared(): 5777 raise AssertionError("Do not use tf.reset_default_graph() to clear " 5778 "nested graphs. If you need a cleared graph, " 5779 "exit the nesting and create a new graph.") 5780 _default_graph_stack.reset() 5781 5782 5783@tf_export(v1=["get_default_graph"]) 5784def get_default_graph(): 5785 """Returns the default graph for the current thread. 5786 5787 The returned graph will be the innermost graph on which a 5788 `Graph.as_default()` context has been entered, or a global default 5789 graph if none has been explicitly created. 5790 5791 NOTE: The default graph is a property of the current thread. If you 5792 create a new thread, and wish to use the default graph in that 5793 thread, you must explicitly add a `with g.as_default():` in that 5794 thread's function. 5795 5796 Returns: 5797 The default `Graph` being used in the current thread. 5798 """ 5799 return _default_graph_stack.get_default() 5800 5801 5802def has_default_graph(): 5803 """Returns True if there is a default graph.""" 5804 return len(_default_graph_stack.stack) >= 1 5805 5806 5807def get_name_scope(): 5808 """Returns the current name scope in the default_graph. 5809 5810 For example: 5811 5812 ```python 5813 with tf.name_scope('scope1'): 5814 with tf.name_scope('scope2'): 5815 print(tf.get_name_scope()) 5816 ``` 5817 would print the string `scope1/scope2`. 5818 5819 Returns: 5820 A string representing the current name scope. 5821 """ 5822 if context.executing_eagerly(): 5823 return context.context().scope_name.rstrip("/") 5824 return get_default_graph().get_name_scope() 5825 5826 5827def _assert_same_graph(original_item, item): 5828 """Fail if the 2 items are from different graphs. 5829 5830 Args: 5831 original_item: Original item to check against. 5832 item: Item to check. 5833 5834 Raises: 5835 ValueError: if graphs do not match. 5836 """ 5837 if original_item.graph is not item.graph: 5838 raise ValueError("%s must be from the same graph as %s." % 5839 (item, original_item)) 5840 5841 5842def _get_graph_from_inputs(op_input_list, graph=None): 5843 """Returns the appropriate graph to use for the given inputs. 5844 5845 This library method provides a consistent algorithm for choosing the graph 5846 in which an Operation should be constructed: 5847 5848 1. If the default graph is being used to construct a function, we 5849 use the default graph. 5850 2. If the "graph" is specified explicitly, we validate that all of the inputs 5851 in "op_input_list" are compatible with that graph. 5852 3. Otherwise, we attempt to select a graph from the first Operation- 5853 or Tensor-valued input in "op_input_list", and validate that all other 5854 such inputs are in the same graph. 5855 4. If the graph was not specified and it could not be inferred from 5856 "op_input_list", we attempt to use the default graph. 5857 5858 Args: 5859 op_input_list: A list of inputs to an operation, which may include `Tensor`, 5860 `Operation`, and other objects that may be converted to a graph element. 5861 graph: (Optional) The explicit graph to use. 5862 5863 Raises: 5864 TypeError: If op_input_list is not a list or tuple, or if graph is not a 5865 Graph. 5866 ValueError: If a graph is explicitly passed and not all inputs are from it, 5867 or if the inputs are from multiple graphs, or we could not find a graph 5868 and there was no default graph. 5869 5870 Returns: 5871 The appropriate graph to use for the given inputs. 5872 5873 """ 5874 current_default_graph = get_default_graph() 5875 if current_default_graph.building_function: 5876 return current_default_graph 5877 5878 op_input_list = tuple(op_input_list) # Handle generators correctly 5879 if graph and not isinstance(graph, Graph): 5880 raise TypeError("Input graph needs to be a Graph: %s" % graph) 5881 5882 # 1. We validate that all of the inputs are from the same graph. This is 5883 # either the supplied graph parameter, or the first one selected from one 5884 # the graph-element-valued inputs. In the latter case, we hold onto 5885 # that input in original_graph_element so we can provide a more 5886 # informative error if a mismatch is found. 5887 original_graph_element = None 5888 for op_input in op_input_list: 5889 # Determine if this is a valid graph_element. 5890 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 5891 # up. 5892 graph_element = None 5893 if (isinstance(op_input, (Operation, _TensorLike)) and 5894 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 5895 graph_element = op_input 5896 else: 5897 graph_element = _as_graph_element(op_input) 5898 5899 if graph_element is not None: 5900 if not graph: 5901 original_graph_element = graph_element 5902 graph = graph_element.graph 5903 elif original_graph_element is not None: 5904 _assert_same_graph(original_graph_element, graph_element) 5905 elif graph_element.graph is not graph: 5906 raise ValueError("%s is not from the passed-in graph." % graph_element) 5907 5908 # 2. If all else fails, we use the default graph, which is always there. 5909 return graph or current_default_graph 5910 5911 5912@tf_export(v1=["GraphKeys"]) 5913class GraphKeys(object): 5914 """Standard names to use for graph collections. 5915 5916 The standard library uses various well-known names to collect and 5917 retrieve values associated with a graph. For example, the 5918 `tf.Optimizer` subclasses default to optimizing the variables 5919 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 5920 specified, but it is also possible to pass an explicit list of 5921 variables. 5922 5923 The following standard keys are defined: 5924 5925 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 5926 across distributed environment (model variables are subset of these). See 5927 `tf.compat.v1.global_variables` 5928 for more details. 5929 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 5930 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 5931 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 5932 machine. Usually used for temporarily variables, like counters. 5933 Note: use `tf.contrib.framework.local_variable` to add to this collection. 5934 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 5935 model for inference (feed forward). Note: use 5936 `tf.contrib.framework.model_variable` to add to this collection. 5937 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 5938 be trained by an optimizer. See 5939 `tf.compat.v1.trainable_variables` 5940 for more details. 5941 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 5942 graph. See 5943 `tf.compat.v1.summary.merge_all` 5944 for more details. 5945 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 5946 produce input for a computation. See 5947 `tf.compat.v1.train.start_queue_runners` 5948 for more details. 5949 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 5950 keep moving averages. See 5951 `tf.compat.v1.moving_average_variables` 5952 for more details. 5953 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 5954 construction. 5955 5956 The following standard keys are _defined_, but their collections are **not** 5957 automatically populated as many of the others are: 5958 5959 * `WEIGHTS` 5960 * `BIASES` 5961 * `ACTIVATIONS` 5962 """ 5963 5964 # Key to collect Variable objects that are global (shared across machines). 5965 # Default collection for all variables, except local ones. 5966 GLOBAL_VARIABLES = "variables" 5967 # Key to collect local variables that are local to the machine and are not 5968 # saved/restored. 5969 LOCAL_VARIABLES = "local_variables" 5970 # Key to collect local variables which are used to accumulate interal state 5971 # to be used in tf.metrics.*. 5972 METRIC_VARIABLES = "metric_variables" 5973 # Key to collect model variables defined by layers. 5974 MODEL_VARIABLES = "model_variables" 5975 # Key to collect Variable objects that will be trained by the 5976 # optimizers. 5977 TRAINABLE_VARIABLES = "trainable_variables" 5978 # Key to collect summaries. 5979 SUMMARIES = "summaries" 5980 # Key to collect QueueRunners. 5981 QUEUE_RUNNERS = "queue_runners" 5982 # Key to collect table initializers. 5983 TABLE_INITIALIZERS = "table_initializer" 5984 # Key to collect asset filepaths. An asset represents an external resource 5985 # like a vocabulary file. 5986 ASSET_FILEPATHS = "asset_filepaths" 5987 # Key to collect Variable objects that keep moving averages. 5988 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 5989 # Key to collect regularization losses at graph construction. 5990 REGULARIZATION_LOSSES = "regularization_losses" 5991 # Key to collect concatenated sharded variables. 5992 CONCATENATED_VARIABLES = "concatenated_variables" 5993 # Key to collect savers. 5994 SAVERS = "savers" 5995 # Key to collect weights 5996 WEIGHTS = "weights" 5997 # Key to collect biases 5998 BIASES = "biases" 5999 # Key to collect activations 6000 ACTIVATIONS = "activations" 6001 # Key to collect update_ops 6002 UPDATE_OPS = "update_ops" 6003 # Key to collect losses 6004 LOSSES = "losses" 6005 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 6006 SAVEABLE_OBJECTS = "saveable_objects" 6007 # Key to collect all shared resources used by the graph which need to be 6008 # initialized once per cluster. 6009 RESOURCES = "resources" 6010 # Key to collect all shared resources used in this graph which need to be 6011 # initialized once per session. 6012 LOCAL_RESOURCES = "local_resources" 6013 # Trainable resource-style variables. 6014 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 6015 6016 # Key to indicate various ops. 6017 INIT_OP = "init_op" 6018 LOCAL_INIT_OP = "local_init_op" 6019 READY_OP = "ready_op" 6020 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 6021 SUMMARY_OP = "summary_op" 6022 GLOBAL_STEP = "global_step" 6023 6024 # Used to count the number of evaluations performed during a single evaluation 6025 # run. 6026 EVAL_STEP = "eval_step" 6027 TRAIN_OP = "train_op" 6028 6029 # Key for control flow context. 6030 COND_CONTEXT = "cond_context" 6031 WHILE_CONTEXT = "while_context" 6032 6033 # Used to store v2 summary names. 6034 _SUMMARY_COLLECTION = "_SUMMARY_V2" 6035 6036 # List of all collections that keep track of variables. 6037 _VARIABLE_COLLECTIONS = [ 6038 GLOBAL_VARIABLES, 6039 LOCAL_VARIABLES, 6040 METRIC_VARIABLES, 6041 MODEL_VARIABLES, 6042 TRAINABLE_VARIABLES, 6043 MOVING_AVERAGE_VARIABLES, 6044 CONCATENATED_VARIABLES, 6045 TRAINABLE_RESOURCE_VARIABLES, 6046 ] 6047 6048 # Key for streaming model ports. 6049 # NOTE(yuanbyu): internal and experimental. 6050 _STREAMING_MODEL_PORTS = "streaming_model_ports" 6051 6052 @decorator_utils.classproperty 6053 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 6054 def VARIABLES(cls): # pylint: disable=no-self-argument 6055 return cls.GLOBAL_VARIABLES 6056 6057 6058def dismantle_graph(graph): 6059 """Cleans up reference cycles from a `Graph`. 6060 6061 Helpful for making sure the garbage collector doesn't need to run after a 6062 temporary `Graph` is no longer needed. 6063 6064 Args: 6065 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 6066 after this function runs. 6067 """ 6068 memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access 6069 6070 # Now clean up Operation<->Graph reference cycles by clearing all of the 6071 # attributes for the Graph and its ops. 6072 graph_operations = graph.get_operations() 6073 for op in graph_operations: 6074 op.__dict__ = {} 6075 graph.__dict__ = {} 6076 6077 6078@tf_export(v1=["add_to_collection"]) 6079def add_to_collection(name, value): 6080 """Wrapper for `Graph.add_to_collection()` using the default graph. 6081 6082 See `tf.Graph.add_to_collection` 6083 for more details. 6084 6085 Args: 6086 name: The key for the collection. For example, the `GraphKeys` class 6087 contains many standard names for collections. 6088 value: The value to add to the collection. @compatibility(eager) 6089 Collections are only supported in eager when variables are created inside 6090 an EagerVariableStore (e.g. as part of a layer or template). 6091 @end_compatibility 6092 """ 6093 get_default_graph().add_to_collection(name, value) 6094 6095 6096@tf_export(v1=["add_to_collections"]) 6097def add_to_collections(names, value): 6098 """Wrapper for `Graph.add_to_collections()` using the default graph. 6099 6100 See `tf.Graph.add_to_collections` 6101 for more details. 6102 6103 Args: 6104 names: The key for the collections. The `GraphKeys` class contains many 6105 standard names for collections. 6106 value: The value to add to the collections. @compatibility(eager) 6107 Collections are only supported in eager when variables are created inside 6108 an EagerVariableStore (e.g. as part of a layer or template). 6109 @end_compatibility 6110 """ 6111 get_default_graph().add_to_collections(names, value) 6112 6113 6114@tf_export(v1=["get_collection_ref"]) 6115def get_collection_ref(key): 6116 """Wrapper for `Graph.get_collection_ref()` using the default graph. 6117 6118 See `tf.Graph.get_collection_ref` 6119 for more details. 6120 6121 Args: 6122 key: The key for the collection. For example, the `GraphKeys` class contains 6123 many standard names for collections. 6124 6125 Returns: 6126 The list of values in the collection with the given `name`, or an empty 6127 list if no value has been added to that collection. Note that this returns 6128 the collection list itself, which can be modified in place to change the 6129 collection. 6130 6131 @compatibility(eager) 6132 Collections are not supported when eager execution is enabled. 6133 @end_compatibility 6134 """ 6135 return get_default_graph().get_collection_ref(key) 6136 6137 6138@tf_export(v1=["get_collection"]) 6139def get_collection(key, scope=None): 6140 """Wrapper for `Graph.get_collection()` using the default graph. 6141 6142 See `tf.Graph.get_collection` 6143 for more details. 6144 6145 Args: 6146 key: The key for the collection. For example, the `GraphKeys` class contains 6147 many standard names for collections. 6148 scope: (Optional.) If supplied, the resulting list is filtered to include 6149 only items whose `name` attribute matches using `re.match`. Items without 6150 a `name` attribute are never returned if a scope is supplied and the 6151 choice or `re.match` means that a `scope` without special tokens filters 6152 by prefix. 6153 6154 Returns: 6155 The list of values in the collection with the given `name`, or 6156 an empty list if no value has been added to that collection. The 6157 list contains the values in the order under which they were 6158 collected. 6159 6160 @compatibility(eager) 6161 Collections are not supported when eager execution is enabled. 6162 @end_compatibility 6163 """ 6164 return get_default_graph().get_collection(key, scope) 6165 6166 6167def get_all_collection_keys(): 6168 """Returns a list of collections used in the default graph.""" 6169 return get_default_graph().get_all_collection_keys() 6170 6171 6172def name_scope(name, default_name=None, values=None, skip_on_eager=True): 6173 """Internal-only entry point for `name_scope*`. 6174 6175 Internal ops do not use the public API and instead rely on 6176 `ops.name_scope` regardless of the execution mode. This function 6177 dispatches to the correct `name_scope*` implementation based on 6178 the arguments provided and the current mode. Specifically, 6179 6180 * if `values` contains a graph tensor `Graph.name_scope` is used; 6181 * `name_scope_v1` is used in graph mode; 6182 * `name_scope_v2` -- in eager mode. 6183 6184 Args: 6185 name: The name argument that is passed to the op function. 6186 default_name: The default name to use if the `name` argument is `None`. 6187 values: The list of `Tensor` arguments that are passed to the op function. 6188 skip_on_eager: Indicates to return NullContextmanager if executing eagerly. 6189 By default this is True since naming tensors and operations in eager mode 6190 have little use and cause unecessary performance overhead. However, it is 6191 important to preseve variable names since they are often useful for 6192 debugging and saved models. 6193 6194 Returns: 6195 `name_scope*` context manager. 6196 """ 6197 ctx = context.context() 6198 in_eager_mode = ctx.executing_eagerly() 6199 if not in_eager_mode: 6200 return internal_name_scope_v1(name, default_name, values) 6201 6202 if skip_on_eager: 6203 return NullContextmanager() 6204 6205 name = default_name if name is None else name 6206 if values: 6207 # The presence of a graph tensor in `values` overrides the context. 6208 # TODO(slebedev): this is Keras-specific and should be removed. 6209 # pylint: disable=unidiomatic-typecheck 6210 graph_value = next((value for value in values if type(value) == Tensor), 6211 None) 6212 # pylint: enable=unidiomatic-typecheck 6213 if graph_value is not None: 6214 return graph_value.graph.name_scope(name) 6215 6216 return name_scope_v2(name or "") 6217 6218 6219class internal_name_scope_v1(object): # pylint: disable=invalid-name 6220 """Graph-only version of `name_scope_v1`.""" 6221 6222 @property 6223 def name(self): 6224 return self._name 6225 6226 def __init__(self, name, default_name=None, values=None): 6227 """Initialize the context manager. 6228 6229 Args: 6230 name: The name argument that is passed to the op function. 6231 default_name: The default name to use if the `name` argument is `None`. 6232 values: The list of `Tensor` arguments that are passed to the op function. 6233 6234 Raises: 6235 TypeError: if `default_name` is passed in but not a string. 6236 """ 6237 if not (default_name is None or isinstance(default_name, six.string_types)): 6238 raise TypeError( 6239 "`default_name` type (%s) is not a string type. You likely meant to " 6240 "pass this into the `values` kwarg." % type(default_name)) 6241 self._name = default_name if name is None else name 6242 self._default_name = default_name 6243 self._values = values 6244 6245 def __enter__(self): 6246 """Start the scope block. 6247 6248 Returns: 6249 The scope name. 6250 6251 Raises: 6252 ValueError: if neither `name` nor `default_name` is provided 6253 but `values` are. 6254 """ 6255 if self._name is None and self._values is not None: 6256 # We only raise an error if values is not None (provided) because 6257 # currently tf.name_scope(None) (values=None then) is sometimes used as 6258 # an idiom to reset to top scope. 6259 raise ValueError( 6260 "At least one of name (%s) and default_name (%s) must be provided." 6261 % (self._name, self._default_name)) 6262 6263 g = get_default_graph() 6264 if self._values and not g.building_function: 6265 # Specialize based on the knowledge that `_get_graph_from_inputs()` 6266 # ignores `inputs` when building a function. 6267 g_from_inputs = _get_graph_from_inputs(self._values) 6268 if g_from_inputs is not g: 6269 g = g_from_inputs 6270 self._g_manager = g.as_default() 6271 self._g_manager.__enter__() 6272 else: 6273 self._g_manager = None 6274 else: 6275 self._g_manager = None 6276 6277 try: 6278 self._name_scope = g.name_scope(self._name) 6279 return self._name_scope.__enter__() 6280 except: 6281 if self._g_manager is not None: 6282 self._g_manager.__exit__(*sys.exc_info()) 6283 raise 6284 6285 def __exit__(self, *exc_info): 6286 self._name_scope.__exit__(*exc_info) 6287 if self._g_manager is not None: 6288 self._g_manager.__exit__(*exc_info) 6289 6290 6291# Named like a function for backwards compatibility with the 6292# @tf_contextlib.contextmanager version, which was switched to a class to avoid 6293# some object creation overhead. 6294@tf_export(v1=["name_scope"]) 6295class name_scope_v1(object): # pylint: disable=invalid-name 6296 """A context manager for use when defining a Python op. 6297 6298 This context manager validates that the given `values` are from the 6299 same graph, makes that graph the default graph, and pushes a 6300 name scope in that graph (see 6301 `tf.Graph.name_scope` 6302 for more details on that). 6303 6304 For example, to define a new Python op called `my_op`: 6305 6306 ```python 6307 def my_op(a, b, c, name=None): 6308 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 6309 a = tf.convert_to_tensor(a, name="a") 6310 b = tf.convert_to_tensor(b, name="b") 6311 c = tf.convert_to_tensor(c, name="c") 6312 # Define some computation that uses `a`, `b`, and `c`. 6313 return foo_op(..., name=scope) 6314 ``` 6315 """ 6316 6317 @property 6318 def name(self): 6319 return self._name 6320 6321 def __init__(self, name, default_name=None, values=None): 6322 """Initialize the context manager. 6323 6324 Args: 6325 name: The name argument that is passed to the op function. 6326 default_name: The default name to use if the `name` argument is `None`. 6327 values: The list of `Tensor` arguments that are passed to the op function. 6328 6329 Raises: 6330 TypeError: if `default_name` is passed in but not a string. 6331 """ 6332 self._name_scope = name_scope( 6333 name, default_name, values, skip_on_eager=False) 6334 self._name = default_name if name is None else name 6335 6336 def __enter__(self): 6337 return self._name_scope.__enter__() 6338 6339 def __exit__(self, *exc_info): 6340 return self._name_scope.__exit__(*exc_info) 6341 6342 6343def enter_eager_name_scope(ctx, name): 6344 """Updates the eager context to enter the given name scope.""" 6345 old_name = ctx.scope_name 6346 if not name: 6347 scope_name = "" 6348 else: 6349 if name.endswith("/"): 6350 # A trailing slash breaks out of nested name scopes, indicating a 6351 # fully specified scope name, for compatibility with Graph.name_scope. 6352 scope_name = name 6353 else: 6354 scope_name = name + "/" 6355 if old_name: 6356 scope_name = old_name + scope_name 6357 ctx.scope_name = scope_name 6358 return scope_name, old_name 6359 6360 6361@tf_export("name_scope", v1=[]) 6362class name_scope_v2(object): 6363 """A context manager for use when defining a Python op. 6364 6365 This context manager pushes a name scope, which will make the name of all 6366 operations added within it have a prefix. 6367 6368 For example, to define a new Python op called `my_op`: 6369 6370 ```python 6371 def my_op(a, b, c, name=None): 6372 with tf.name_scope("MyOp") as scope: 6373 a = tf.convert_to_tensor(a, name="a") 6374 b = tf.convert_to_tensor(b, name="b") 6375 c = tf.convert_to_tensor(c, name="c") 6376 # Define some computation that uses `a`, `b`, and `c`. 6377 return foo_op(..., name=scope) 6378 ``` 6379 6380 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 6381 and `MyOp/c`. 6382 6383 If the scope name already exists, the name will be made unique by appending 6384 `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`, 6385 etc. 6386 """ 6387 6388 def __init__(self, name): 6389 """Initialize the context manager. 6390 6391 Args: 6392 name: The prefix to use on all names created within the name scope. 6393 6394 Raises: 6395 ValueError: If name is None, or not a string. 6396 """ 6397 if name is None or not isinstance(name, six.string_types): 6398 raise ValueError("name for name_scope must be a string.") 6399 self._name = name 6400 self._exit_fns = [] 6401 6402 @property 6403 def name(self): 6404 return self._name 6405 6406 def __enter__(self): 6407 """Start the scope block. 6408 6409 Returns: 6410 The scope name. 6411 6412 Raises: 6413 ValueError: if neither `name` nor `default_name` is provided 6414 but `values` are. 6415 """ 6416 ctx = context.context() 6417 if ctx.executing_eagerly(): 6418 scope_name, old_scope_name = enter_eager_name_scope(ctx, self._name) 6419 self._exit_fns.append( 6420 lambda *a: setattr(ctx, "scope_name", old_scope_name)) 6421 else: 6422 scope = get_default_graph().name_scope(self._name) 6423 scope_name = scope.__enter__() 6424 self._exit_fns.append(scope.__exit__) 6425 return scope_name 6426 6427 def __exit__(self, type_arg, value_arg, traceback_arg): 6428 exit_fn = self._exit_fns.pop() 6429 exit_fn(type_arg, value_arg, traceback_arg) 6430 return False # False values do not suppress exceptions 6431 6432 6433def strip_name_scope(name, export_scope): 6434 """Removes name scope from a name. 6435 6436 Args: 6437 name: A `string` name. 6438 export_scope: Optional `string`. Name scope to remove. 6439 6440 Returns: 6441 Name with name scope removed, or the original name if export_scope 6442 is None. 6443 """ 6444 if export_scope: 6445 if export_scope[-1] == "/": 6446 export_scope = export_scope[:-1] 6447 6448 try: 6449 # Strips export_scope/, export_scope///, 6450 # ^export_scope/, loc:@export_scope/. 6451 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 6452 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 6453 except TypeError as e: 6454 # If the name is not of a type we can process, simply return it. 6455 logging.warning(e) 6456 return name 6457 else: 6458 return name 6459 6460 6461def prepend_name_scope(name, import_scope): 6462 """Prepends name scope to a name. 6463 6464 Args: 6465 name: A `string` name. 6466 import_scope: Optional `string`. Name scope to add. 6467 6468 Returns: 6469 Name with name scope added, or the original name if import_scope 6470 is None. 6471 """ 6472 if import_scope: 6473 if import_scope[-1] == "/": 6474 import_scope = import_scope[:-1] 6475 6476 try: 6477 str_to_replace = r"([\^]|loc:@|^)(.*)" 6478 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 6479 compat.as_str(name)) 6480 except TypeError as e: 6481 # If the name is not of a type we can process, simply return it. 6482 logging.warning(e) 6483 return name 6484 else: 6485 return name 6486 6487 6488# pylint: disable=g-doc-return-or-yield 6489# pylint: disable=not-context-manager 6490@tf_export(v1=["op_scope"]) 6491@tf_contextlib.contextmanager 6492def op_scope(values, name, default_name=None): 6493 """DEPRECATED. Same as name_scope above, just different argument order.""" 6494 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 6495 " use tf.name_scope(name, default_name, values)") 6496 with name_scope(name, default_name=default_name, values=values) as scope: 6497 yield scope 6498 6499 6500_proto_function_registry = registry.Registry("proto functions") 6501 6502 6503def register_proto_function(collection_name, 6504 proto_type=None, 6505 to_proto=None, 6506 from_proto=None): 6507 """Registers `to_proto` and `from_proto` functions for collection_name. 6508 6509 `to_proto` function converts a Python object to the corresponding protocol 6510 buffer, and returns the protocol buffer. 6511 6512 `from_proto` function converts protocol buffer into a Python object, and 6513 returns the object.. 6514 6515 Args: 6516 collection_name: Name of the collection. 6517 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 6518 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 6519 to_proto: Function that implements Python object to protobuf conversion. 6520 from_proto: Function that implements protobuf to Python object conversion. 6521 """ 6522 if to_proto and not callable(to_proto): 6523 raise TypeError("to_proto must be callable.") 6524 if from_proto and not callable(from_proto): 6525 raise TypeError("from_proto must be callable.") 6526 6527 _proto_function_registry.register((proto_type, to_proto, from_proto), 6528 collection_name) 6529 6530 6531def get_collection_proto_type(collection_name): 6532 """Returns the proto_type for collection_name.""" 6533 try: 6534 return _proto_function_registry.lookup(collection_name)[0] 6535 except LookupError: 6536 return None 6537 6538 6539def get_to_proto_function(collection_name): 6540 """Returns the to_proto function for collection_name.""" 6541 try: 6542 return _proto_function_registry.lookup(collection_name)[1] 6543 except LookupError: 6544 return None 6545 6546 6547def get_from_proto_function(collection_name): 6548 """Returns the from_proto function for collection_name.""" 6549 try: 6550 return _proto_function_registry.lookup(collection_name)[2] 6551 except LookupError: 6552 return None 6553 6554 6555def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): 6556 """Produce a nice error if someone converts an Operation to a Tensor.""" 6557 raise TypeError(("Can't convert Operation '%s' to Tensor " 6558 "(target dtype=%r, name=%r, as_ref=%r)") % 6559 (op.name, dtype, name, as_ref)) 6560 6561 6562def _op_to_colocate_with(v, graph): 6563 """Operation object corresponding to v to use for colocation constraints.""" 6564 if v is None: 6565 return None 6566 if isinstance(v, Operation): 6567 return v 6568 # We always want to colocate with the reference op. 6569 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 6570 # 6571 # What this should be is: 6572 # if isinstance(v, ResourceVariable): 6573 # return v.handle.op 6574 # However, that would require a circular import dependency. 6575 # As of October 2018, there were attempts underway to remove 6576 # colocation constraints altogether. Assuming that will 6577 # happen soon, perhaps this hack to work around the circular 6578 # import dependency is acceptable. 6579 if hasattr(v, "handle") and isinstance(v.handle, Tensor): 6580 if graph.building_function: 6581 return graph.capture(v.handle).op 6582 else: 6583 return v.handle.op 6584 return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op 6585 6586 6587def _is_keras_symbolic_tensor(x): 6588 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" 6589 6590 6591tensor_conversion_registry.register_tensor_conversion_function( 6592 Operation, _operation_conversion_error) 6593 6594 6595# These symbols were originally defined in this module; import them for 6596# backwards compatibility until all references have been updated to access 6597# them from the indexed_slices.py module. 6598IndexedSlices = indexed_slices.IndexedSlices 6599IndexedSlicesValue = indexed_slices.IndexedSlicesValue 6600convert_to_tensor_or_indexed_slices = \ 6601 indexed_slices.convert_to_tensor_or_indexed_slices 6602convert_n_to_tensor_or_indexed_slices = \ 6603 indexed_slices.convert_n_to_tensor_or_indexed_slices 6604internal_convert_to_tensor_or_indexed_slices = \ 6605 indexed_slices.internal_convert_to_tensor_or_indexed_slices 6606internal_convert_n_to_tensor_or_indexed_slices = \ 6607 indexed_slices.internal_convert_n_to_tensor_or_indexed_slices 6608register_tensor_conversion_function = \ 6609 tensor_conversion_registry.register_tensor_conversion_function 6610 6611 6612# Helper functions for op wrapper modules generated by `python_op_gen`. 6613 6614 6615def to_raw_op(f): 6616 """Make a given op wrapper function `f` raw. 6617 6618 Raw op wrappers are not included in the docs, and can only be called 6619 with keyword arguments. 6620 6621 Args: 6622 f: An op wrapper function to make raw. 6623 6624 Returns: 6625 Raw `f`. 6626 """ 6627 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail 6628 # due to double-registration. 6629 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, 6630 f.__closure__) 6631 return kwarg_only(do_not_generate_docs(f)) 6632 6633 6634def raise_from_not_ok_status(e, name): 6635 message = e.message + (" name: " + name if name is not None else "") 6636 # pylint: disable=protected-access 6637 six.raise_from(core._status_to_exception(e.code, message), None) 6638 # pylint: enable=protected-access 6639 6640 6641def add_exit_callback_to_default_func_graph(fn): 6642 """Add a callback to run when the default function graph goes out of scope. 6643 6644 Usage: 6645 6646 ```python 6647 @tf.function 6648 def fn(x, v): 6649 expensive = expensive_object(v) 6650 add_exit_callback_to_default_func_graph(lambda: expensive.release()) 6651 return g(x, expensive) 6652 6653 fn(x=tf.constant(...), v=...) 6654 # `expensive` has been released. 6655 ``` 6656 6657 Args: 6658 fn: A callable that takes no arguments and whose output is ignored. 6659 To be executed when exiting func graph scope. 6660 6661 Raises: 6662 RuntimeError: If executed when the current defualt graph is not a FuncGraph, 6663 or not currently executing in function creation mode (e.g., if inside 6664 an init_scope). 6665 """ 6666 default_graph = get_default_graph() 6667 if not default_graph._building_function: # pylint: disable=protected-access 6668 raise RuntimeError( 6669 "Cannot add scope exit callbacks when not building a function. " 6670 "Default graph: {}".format(default_graph)) 6671 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access 6672 6673 6674def _reconstruct_sequence_inputs(op_def, inputs, attrs): 6675 """Regroups a flat list of input tensors into scalar and sequence inputs. 6676 6677 Args: 6678 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 6679 inputs: a list of input `Tensor`s to the op. 6680 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 6681 how long each sequence is) 6682 6683 Returns: 6684 A list of `Tensor`s (corresponding to scalar inputs) and lists of 6685 `Tensor`s (corresponding to sequence inputs). 6686 """ 6687 grouped_inputs = [] 6688 i = 0 6689 for input_arg in op_def.input_arg: 6690 if input_arg.number_attr: 6691 input_len = attrs[input_arg.number_attr].i 6692 is_sequence = True 6693 elif input_arg.type_list_attr: 6694 input_len = len(attrs[input_arg.type_list_attr].list.type) 6695 is_sequence = True 6696 else: 6697 input_len = 1 6698 is_sequence = False 6699 6700 if is_sequence: 6701 grouped_inputs.append(inputs[i:i + input_len]) 6702 else: 6703 grouped_inputs.append(inputs[i]) 6704 i += input_len 6705 6706 assert i == len(inputs) 6707 return grouped_inputs 6708