1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions used to construct graphs.""" 16# pylint: disable=g-bad-name 17import collections 18import copy 19import re 20import sys 21import threading 22import types 23from absl import app 24 25import numpy as np 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.framework import full_type_pb2 29from tensorflow.core.framework import function_pb2 30from tensorflow.core.framework import graph_pb2 31from tensorflow.core.framework import node_def_pb2 32from tensorflow.core.framework import op_def_pb2 33from tensorflow.core.framework import versions_pb2 34from tensorflow.core.protobuf import config_pb2 35# pywrap_tensorflow must be imported first to avoid protobuf issues. 36# (b/143110113) 37# pylint: disable=invalid-import-order,g-bad-import-order,unused-import 38from tensorflow.python import pywrap_tensorflow 39from tensorflow.python import pywrap_tfe 40# pylint: enable=invalid-import-order,g-bad-import-order,unused-import 41from tensorflow.python import tf2 42from tensorflow.python.client import pywrap_tf_session 43from tensorflow.python.eager import context 44from tensorflow.python.eager import core 45from tensorflow.python.eager import monitoring 46from tensorflow.python.eager import tape 47from tensorflow.python.framework import c_api_util 48from tensorflow.python.framework import composite_tensor 49from tensorflow.python.framework import cpp_shape_inference_pb2 50from tensorflow.python.framework import device as pydev 51from tensorflow.python.framework import dtypes 52from tensorflow.python.framework import errors 53from tensorflow.python.framework import indexed_slices 54from tensorflow.python.framework import registry 55from tensorflow.python.framework import tensor_conversion_registry 56from tensorflow.python.framework import tensor_shape 57from tensorflow.python.framework import traceable_stack 58from tensorflow.python.framework import versions 59from tensorflow.python.ops import control_flow_util 60from tensorflow.python.platform import tf_logging as logging 61from tensorflow.python.profiler import trace as profiler_trace 62from tensorflow.python.types import core as core_tf_types 63from tensorflow.python.types import internal 64from tensorflow.python.util import compat 65from tensorflow.python.util import decorator_utils 66from tensorflow.python.util import deprecation 67from tensorflow.python.util import dispatch 68from tensorflow.python.util import function_utils 69from tensorflow.python.util import lock_util 70from tensorflow.python.util import memory 71from tensorflow.python.util import object_identity 72from tensorflow.python.util import tf_contextlib 73from tensorflow.python.util import tf_stack 74from tensorflow.python.util import traceback_utils 75from tensorflow.python.util.compat import collections_abc 76from tensorflow.python.util.deprecation import deprecated_args 77from tensorflow.python.util.lazy_loader import LazyLoader 78from tensorflow.python.util.tf_export import kwarg_only 79from tensorflow.python.util.tf_export import tf_export 80 81# TODO(b/218887885): Loaded lazily due to a circular dependency with this file. 82tensor_spec = LazyLoader( 83 "tensor_spec", globals(), 84 "tensorflow.python.framework.tensor_spec") 85ag_ctx = LazyLoader( 86 "ag_ctx", globals(), 87 "tensorflow.python.autograph.core.ag_ctx") 88 89 90# Temporary global switches determining if we should enable the work-in-progress 91# calls to the C API. These will be removed once all functionality is supported. 92_USE_C_API = True 93_USE_C_SHAPES = True 94 95_api_usage_gauge = monitoring.BoolGauge( 96 "/tensorflow/api/ops_eager_execution", 97 "Whether ops.enable_eager_execution() is called.") 98 99_tensor_equality_api_usage_gauge = monitoring.BoolGauge( 100 "/tensorflow/api/enable_tensor_equality", 101 "Whether ops.enable_tensor_equality() is called.") 102 103_control_flow_api_gauge = monitoring.BoolGauge( 104 "/tensorflow/api/enable_control_flow_v2", 105 "Whether enable_control_flow_v2() is called.") 106 107_tf_function_api_guage = monitoring.BoolGauge( 108 "/tensorflow/api/tf_function", 109 "Whether tf.function() is used.") 110 111# pylint: disable=protected-access 112_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE 113# pylint: enable=protected-access 114 115 116def tensor_id(tensor): 117 """Returns a unique identifier for this Tensor.""" 118 return tensor._id # pylint: disable=protected-access 119 120 121class _UserDeviceSpec(object): 122 """Store user-specified device and provide computation of merged device.""" 123 124 def __init__(self, device_name_or_function): 125 self._device_name_or_function = device_name_or_function 126 self.display_name = str(self._device_name_or_function) 127 self.function = device_name_or_function 128 self.raw_string = None 129 130 if isinstance(device_name_or_function, pydev.MergeDevice): 131 self.is_null_merge = device_name_or_function.is_null_merge 132 133 elif callable(device_name_or_function): 134 self.is_null_merge = False 135 dev_func = self._device_name_or_function 136 func_name = function_utils.get_func_name(dev_func) 137 func_code = function_utils.get_func_code(dev_func) 138 if func_code: 139 fname = func_code.co_filename 140 lineno = func_code.co_firstlineno 141 else: 142 fname = "unknown" 143 lineno = -1 144 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 145 146 elif device_name_or_function is None: 147 # NOTE(taylorrobie): This MUST be False. None signals a break in the 148 # device stack, so `is_null_merge` must be False for such a case to 149 # allow callers to safely skip over null merges without missing a None. 150 self.is_null_merge = False 151 152 else: 153 self.raw_string = device_name_or_function 154 self.function = pydev.merge_device(device_name_or_function) 155 self.is_null_merge = self.function.is_null_merge 156 157 # We perform this check in __init__ because it is of non-trivial cost, 158 # and self.string_merge is typically called many times. 159 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) 160 161 def string_merge(self, node_def): 162 if self.fast_string_merge: 163 return self.function.shortcut_string_merge(node_def) 164 165 return compat.as_str(_device_string(self.function(node_def))) 166 167 168class NullContextmanager(object): 169 170 def __init__(self, *args, **kwargs): 171 pass 172 173 def __enter__(self): 174 pass 175 176 def __exit__(self, type_arg, value_arg, traceback_arg): 177 return False # False values do not suppress exceptions 178 179 180def _override_helper(clazz_object, operator, func): 181 """Overrides (string) operator on Tensors to call func. 182 183 Args: 184 clazz_object: the class to override for; either Tensor or SparseTensor. 185 operator: the string name of the operator to override. 186 func: the function that replaces the overridden operator. 187 188 Raises: 189 ValueError: If operator is not allowed to be overwritten. 190 """ 191 if operator not in Tensor.OVERLOADABLE_OPERATORS: 192 raise ValueError(f"Overriding {operator} is disallowed. " 193 f"Allowed operators are {Tensor.OVERLOADABLE_OPERATORS}.") 194 setattr(clazz_object, operator, func) 195 196 197def _as_graph_element(obj): 198 """Convert `obj` to a graph element if possible, otherwise return `None`. 199 200 Args: 201 obj: Object to convert. 202 203 Returns: 204 The result of `obj._as_graph_element()` if that method is available; 205 otherwise `None`. 206 """ 207 conv_fn = getattr(obj, "_as_graph_element", None) 208 if conv_fn and callable(conv_fn): 209 return conv_fn() 210 return None 211 212 213# Deprecated - do not use. 214# This API to avoid breaking estimator and tensorflow-mesh which depend on this 215# internal API. The stub should be safe to use after TF 2.3 is released. 216def is_dense_tensor_like(t): 217 return isinstance(t, core_tf_types.Tensor) 218 219 220def uid(): 221 """A unique (within this program execution) integer.""" 222 return pywrap_tfe.TFE_Py_UID() 223 224 225def numpy_text(tensor, is_repr=False): 226 """Human readable representation of a tensor's numpy value.""" 227 if tensor.dtype.is_numpy_compatible: 228 # pylint: disable=protected-access 229 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 230 # pylint: enable=protected-access 231 else: 232 text = "<unprintable>" 233 if "\n" in text: 234 text = "\n" + text 235 return text 236 237 238def value_text(tensor, is_repr=False): 239 """Either the NumPy value or a custom TensorFlow formatting of `tensor`. 240 241 Custom formatting is used for custom device tensors, e.g. parallel tensors 242 with multiple components on different devices. 243 244 Args: 245 tensor: The tensor to format. 246 is_repr: Controls the style/verbosity of formatting. 247 248 Returns: 249 The formatted tensor. 250 """ 251 # pylint: disable=protected-access # friend access 252 if tensor._prefer_custom_summarizer(): 253 text = tensor._summarize_value() 254 # pylint: enable=protected-access 255 if is_repr: 256 text = "value=" + text 257 else: 258 text = numpy_text(tensor, is_repr=is_repr) 259 if is_repr: 260 text = "numpy=" + text 261 return text 262 263 264@tf_export(v1=["enable_tensor_equality"]) 265def enable_tensor_equality(): 266 """Compare Tensors with element-wise comparison and thus be unhashable. 267 268 Comparing tensors with element-wise allows comparisons such as 269 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are 270 unhashable. Thus tensors can no longer be directly used in sets or as a key in 271 a dictionary. 272 """ 273 logging.vlog(1, "Enabling tensor equality") 274 _tensor_equality_api_usage_gauge.get_cell().set(True) 275 Tensor._USE_EQUALITY = True # pylint: disable=protected-access 276 277 278@tf_export(v1=["disable_tensor_equality"]) 279def disable_tensor_equality(): 280 """Compare Tensors by their id and be hashable. 281 282 This is a legacy behaviour of TensorFlow and is highly discouraged. 283 """ 284 logging.vlog(1, "Disabling tensor equality") 285 _tensor_equality_api_usage_gauge.get_cell().set(False) 286 Tensor._USE_EQUALITY = False # pylint: disable=protected-access 287 288 289# TODO(mdan): This object should subclass Symbol, not just Tensor. 290@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"]) 291class Tensor(internal.NativeObject, core_tf_types.Tensor): 292 """A `tf.Tensor` represents a multidimensional array of elements. 293 294 All elements are of a single known data type. 295 296 When writing a TensorFlow program, the main object that is 297 manipulated and passed around is the `tf.Tensor`. 298 299 A `tf.Tensor` has the following properties: 300 301 * a single data type (float32, int32, or string, for example) 302 * a shape 303 304 TensorFlow supports eager execution and graph execution. In eager 305 execution, operations are evaluated immediately. In graph 306 execution, a computational graph is constructed for later 307 evaluation. 308 309 TensorFlow defaults to eager execution. In the example below, the 310 matrix multiplication results are calculated immediately. 311 312 >>> # Compute some values using a Tensor 313 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 314 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 315 >>> e = tf.matmul(c, d) 316 >>> print(e) 317 tf.Tensor( 318 [[1. 3.] 319 [3. 7.]], shape=(2, 2), dtype=float32) 320 321 Note that during eager execution, you may discover your `Tensors` are actually 322 of type `EagerTensor`. This is an internal detail, but it does give you 323 access to a useful function, `numpy`: 324 325 >>> type(e) 326 <class '...ops.EagerTensor'> 327 >>> print(e.numpy()) 328 [[1. 3.] 329 [3. 7.]] 330 331 In TensorFlow, `tf.function`s are a common way to define graph execution. 332 333 A Tensor's shape (that is, the rank of the Tensor and the size of 334 each dimension) may not always be fully known. In `tf.function` 335 definitions, the shape may only be partially known. 336 337 Most operations produce tensors of fully-known shapes if the shapes of their 338 inputs are also fully known, but in some cases it's only possible to find the 339 shape of a tensor at execution time. 340 341 A number of specialized tensors are available: see `tf.Variable`, 342 `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and 343 `tf.RaggedTensor`. 344 345 Caution: when constructing a tensor from a numpy array or pandas dataframe 346 the underlying buffer may be re-used: 347 348 ```python 349 a = np.array([1, 2, 3]) 350 b = tf.constant(a) 351 a[0] = 4 352 print(b) # tf.Tensor([4 2 3], shape=(3,), dtype=int64) 353 ``` 354 355 Note: this is an implementation detail that is subject to change and users 356 should not rely on this behaviour. 357 358 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor). 359 360 """ 361 362 # List of Python operators that we allow to override. 363 OVERLOADABLE_OPERATORS = { 364 # Binary. 365 "__add__", 366 "__radd__", 367 "__sub__", 368 "__rsub__", 369 "__mul__", 370 "__rmul__", 371 "__div__", 372 "__rdiv__", 373 "__truediv__", 374 "__rtruediv__", 375 "__floordiv__", 376 "__rfloordiv__", 377 "__mod__", 378 "__rmod__", 379 "__lt__", 380 "__le__", 381 "__gt__", 382 "__ge__", 383 "__ne__", 384 "__eq__", 385 "__and__", 386 "__rand__", 387 "__or__", 388 "__ror__", 389 "__xor__", 390 "__rxor__", 391 "__getitem__", 392 "__pow__", 393 "__rpow__", 394 # Unary. 395 "__invert__", 396 "__neg__", 397 "__abs__", 398 "__matmul__", 399 "__rmatmul__" 400 } 401 402 # Whether to allow hashing or numpy-style equality 403 _USE_EQUALITY = tf2.enabled() 404 405 def __init__(self, op, value_index, dtype): 406 """Creates a new `Tensor`. 407 408 Args: 409 op: An `Operation`. `Operation` that computes this tensor. 410 value_index: An `int`. Index of the operation's endpoint that produces 411 this tensor. 412 dtype: A `DType`. Type of elements stored in this tensor. 413 414 Raises: 415 TypeError: If the op is not an `Operation`. 416 """ 417 if not isinstance(op, Operation): 418 raise TypeError(f"op needs to be an Operation. " 419 f"An instance of type {type(op).__name__} is provided.") 420 421 self._op = op 422 self._value_index = value_index 423 self._dtype = dtypes.as_dtype(dtype) 424 # This will be set by self._as_tf_output(). 425 self._tf_output = None 426 # This will be set by self.shape(). 427 self._shape_val = None 428 # List of operations that use this Tensor as input. We maintain this list 429 # to easily navigate a computation graph. 430 self._consumers = [] 431 self._id = uid() 432 self._name = None 433 434 def __getattr__(self, name): 435 if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size", 436 "tolist", "data"}: 437 # TODO(wangpeng): Export the enable_numpy_behavior knob 438 raise AttributeError( 439 f"{type(self).__name__} object has no attribute '{name}'. " + """ 440 If you are looking for numpy-related methods, please run the following: 441 from tensorflow.python.ops.numpy_ops import np_config 442 np_config.enable_numpy_behavior() 443 """) 444 self.__getattribute__(name) 445 446 @staticmethod 447 def _create_with_tf_output(op, value_index, dtype, tf_output): 448 ret = Tensor(op, value_index, dtype) 449 ret._tf_output = tf_output 450 return ret 451 452 @property 453 def op(self): 454 """The `Operation` that produces this tensor as an output.""" 455 return self._op 456 457 @property 458 def dtype(self): 459 """The `DType` of elements in this tensor.""" 460 return self._dtype 461 462 @property 463 def graph(self): 464 """The `Graph` that contains this tensor.""" 465 return self._op.graph 466 467 @property 468 def name(self): 469 """The string name of this tensor.""" 470 if self._name is None: 471 assert self._op.name 472 self._name = "%s:%d" % (self._op.name, self._value_index) 473 return self._name 474 475 @property 476 def device(self): 477 """The name of the device on which this tensor will be produced, or None.""" 478 return self._op.device 479 480 @property 481 def shape(self): 482 """Returns a `tf.TensorShape` that represents the shape of this tensor. 483 484 >>> t = tf.constant([1,2,3,4,5]) 485 >>> t.shape 486 TensorShape([5]) 487 488 `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`. 489 490 In a `tf.function` or when building a model using 491 `tf.keras.Input`, they return the build-time shape of the 492 tensor, which may be partially unknown. 493 494 A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor 495 containing the shape, calculated at runtime. 496 497 See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples. 498 """ 499 if self._shape_val is None: 500 self._shape_val = self._c_api_shape() 501 return self._shape_val 502 503 def _c_api_shape(self): 504 """Returns the TensorShape of this tensor according to the C API.""" 505 with self._op._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 506 shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper( 507 c_graph, self._as_tf_output()) 508 if unknown_shape: 509 return tensor_shape.unknown_shape() 510 else: 511 shape_vec = [None if d == -1 else d for d in shape_vec] 512 return tensor_shape.TensorShape(shape_vec) 513 514 @property 515 def _shape(self): 516 logging.warning("Tensor._shape is private, use Tensor.shape " 517 "instead. Tensor._shape will eventually be removed.") 518 return self.shape 519 520 @_shape.setter 521 def _shape(self, value): 522 raise ValueError( 523 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 524 525 def _disallow_when_autograph_unavailable(self, task): 526 raise errors.OperatorNotAllowedInGraphError( 527 f"{task} is not allowed: AutoGraph is unavailable in this runtime. See" 528 " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code" 529 " for more information.") 530 531 def _disallow_when_autograph_disabled(self, task): 532 raise errors.OperatorNotAllowedInGraphError( 533 f"{task} is not allowed: AutoGraph is disabled in this function." 534 " Try decorating it directly with @tf.function.") 535 536 def _disallow_when_autograph_enabled(self, task): 537 raise errors.OperatorNotAllowedInGraphError( 538 f"{task} is not allowed: AutoGraph did convert this function. This" 539 " might indicate you are trying to use an unsupported feature.") 540 541 def _disallow_in_graph_mode(self, task): 542 raise errors.OperatorNotAllowedInGraphError( 543 f"{task} is not allowed in Graph execution. Use Eager execution or" 544 " decorate this function with @tf.function.") 545 546 def _disallow_bool_casting(self): 547 if not ag_ctx.INSPECT_SOURCE_SUPPORTED: 548 self._disallow_when_autograph_unavailable( 549 "Using a symbolic `tf.Tensor` as a Python `bool`") 550 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 551 self._disallow_when_autograph_disabled( 552 "Using a symbolic `tf.Tensor` as a Python `bool`") 553 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 554 self._disallow_when_autograph_enabled( 555 "Using a symbolic `tf.Tensor` as a Python `bool`") 556 else: 557 # Default: V1-style Graph execution. 558 self._disallow_in_graph_mode( 559 "Using a symbolic `tf.Tensor` as a Python `bool`") 560 561 def _disallow_iteration(self): 562 if not ag_ctx.INSPECT_SOURCE_SUPPORTED: 563 self._disallow_when_autograph_unavailable( 564 "Iterating over a symbolic `tf.Tensor`") 565 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 566 self._disallow_when_autograph_disabled( 567 "Iterating over a symbolic `tf.Tensor`") 568 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 569 self._disallow_when_autograph_enabled( 570 "Iterating over a symbolic `tf.Tensor`") 571 else: 572 # Default: V1-style Graph execution. 573 self._disallow_in_graph_mode("Iterating over a symbolic `tf.Tensor`") 574 575 def __iter__(self): 576 if not context.executing_eagerly(): 577 self._disallow_iteration() 578 579 shape = self._shape_tuple() 580 if shape is None: 581 raise TypeError("Cannot iterate over a tensor with unknown shape.") 582 if not shape: 583 raise TypeError("Cannot iterate over a scalar tensor.") 584 if shape[0] is None: 585 raise TypeError( 586 "Cannot iterate over a tensor with unknown first dimension.") 587 return _TensorIterator(self, shape[0]) 588 589 def _shape_as_list(self): 590 if self.shape.ndims is not None: 591 return [dim.value for dim in self.shape.dims] 592 else: 593 return None 594 595 def _shape_tuple(self): 596 shape = self._shape_as_list() 597 if shape is None: 598 return None 599 return tuple(shape) 600 601 def _rank(self): 602 """Integer rank of this Tensor, if known, else None. 603 604 Returns: 605 Integer rank or None 606 """ 607 return self.shape.ndims 608 609 def get_shape(self): 610 """Returns a `tf.TensorShape` that represents the shape of this tensor. 611 612 In eager execution the shape is always fully-known. 613 614 >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 615 >>> print(a.shape) 616 (2, 3) 617 618 `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`. 619 620 621 When executing in a `tf.function` or building a model using 622 `tf.keras.Input`, `Tensor.shape` may return a partial shape (including 623 `None` for unknown dimensions). See `tf.TensorShape` for more details. 624 625 >>> inputs = tf.keras.Input(shape = [10]) 626 >>> # Unknown batch size 627 >>> print(inputs.shape) 628 (None, 10) 629 630 The shape is computed using shape inference functions that are 631 registered for each `tf.Operation`. 632 633 The returned `tf.TensorShape` is determined at *build* time, without 634 executing the underlying kernel. It is not a `tf.Tensor`. If you need a 635 shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or 636 use the `tf.shape(tensor)` function, which returns the tensor's shape at 637 *execution* time. 638 639 This is useful for debugging and providing early errors. For 640 example, when tracing a `tf.function`, no ops are being executed, shapes 641 may be unknown (See the [Concrete Functions 642 Guide](https://www.tensorflow.org/guide/concrete_function) for details). 643 644 >>> @tf.function 645 ... def my_matmul(a, b): 646 ... result = a@b 647 ... # the `print` executes during tracing. 648 ... print("Result shape: ", result.shape) 649 ... return result 650 651 The shape inference functions propagate shapes to the extent possible: 652 653 >>> f = my_matmul.get_concrete_function( 654 ... tf.TensorSpec([None,3]), 655 ... tf.TensorSpec([3,5])) 656 Result shape: (None, 5) 657 658 Tracing may fail if a shape missmatch can be detected: 659 660 >>> cf = my_matmul.get_concrete_function( 661 ... tf.TensorSpec([None,3]), 662 ... tf.TensorSpec([4,5])) 663 Traceback (most recent call last): 664 ... 665 ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op: 666 'MatMul') with input shapes: [?,3], [4,5]. 667 668 In some cases, the inferred shape may have unknown dimensions. If 669 the caller has additional information about the values of these 670 dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment 671 the inferred shape. 672 673 >>> @tf.function 674 ... def my_fun(a): 675 ... a = tf.ensure_shape(a, [5, 5]) 676 ... # the `print` executes during tracing. 677 ... print("Result shape: ", a.shape) 678 ... return a 679 680 >>> cf = my_fun.get_concrete_function( 681 ... tf.TensorSpec([None, None])) 682 Result shape: (5, 5) 683 684 Returns: 685 A `tf.TensorShape` representing the shape of this tensor. 686 687 """ 688 return self.shape 689 690 def set_shape(self, shape): 691 """Updates the shape of this tensor. 692 693 Note: It is recommended to use `tf.ensure_shape` instead of 694 `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for 695 programming errors and can create guarantees for compiler 696 optimization. 697 698 With eager execution this operates as a shape assertion. 699 Here the shapes match: 700 701 >>> t = tf.constant([[1,2,3]]) 702 >>> t.set_shape([1, 3]) 703 704 Passing a `None` in the new shape allows any value for that axis: 705 706 >>> t.set_shape([1,None]) 707 708 An error is raised if an incompatible shape is passed. 709 710 >>> t.set_shape([1,5]) 711 Traceback (most recent call last): 712 ... 713 ValueError: Tensor's shape (1, 3) is not compatible with supplied 714 shape [1, 5] 715 716 When executing in a `tf.function`, or building a model using 717 `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with 718 the current shape of this tensor, and set the tensor's shape to the 719 merged value (see `tf.TensorShape.merge_with` for details): 720 721 >>> t = tf.keras.Input(shape=[None, None, 3]) 722 >>> print(t.shape) 723 (None, None, None, 3) 724 725 Dimensions set to `None` are not updated: 726 727 >>> t.set_shape([None, 224, 224, None]) 728 >>> print(t.shape) 729 (None, 224, 224, 3) 730 731 The main use case for this is to provide additional shape information 732 that cannot be inferred from the graph alone. 733 734 For example if you know all the images in a dataset have shape [28,28,3] you 735 can set it with `tf.set_shape`: 736 737 >>> @tf.function 738 ... def load_image(filename): 739 ... raw = tf.io.read_file(filename) 740 ... image = tf.image.decode_png(raw, channels=3) 741 ... # the `print` executes during tracing. 742 ... print("Initial shape: ", image.shape) 743 ... image.set_shape([28, 28, 3]) 744 ... print("Final shape: ", image.shape) 745 ... return image 746 747 Trace the function, see the [Concrete Functions 748 Guide](https://www.tensorflow.org/guide/concrete_function) for details. 749 750 >>> cf = load_image.get_concrete_function( 751 ... tf.TensorSpec([], dtype=tf.string)) 752 Initial shape: (None, None, 3) 753 Final shape: (28, 28, 3) 754 755 Similarly the `tf.io.parse_tensor` function could return a tensor with 756 any shape, even the `tf.rank` is unknown. If you know that all your 757 serialized tensors will be 2d, set it with `set_shape`: 758 759 >>> @tf.function 760 ... def my_parse(string_tensor): 761 ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32) 762 ... # the `print` executes during tracing. 763 ... print("Initial shape: ", result.shape) 764 ... result.set_shape([None, None]) 765 ... print("Final shape: ", result.shape) 766 ... return result 767 768 Trace the function 769 770 >>> concrete_parse = my_parse.get_concrete_function( 771 ... tf.TensorSpec([], dtype=tf.string)) 772 Initial shape: <unknown> 773 Final shape: (None, None) 774 775 Make sure it works: 776 777 >>> t = tf.ones([5,3], dtype=tf.float32) 778 >>> serialized = tf.io.serialize_tensor(t) 779 >>> print(serialized.dtype) 780 <dtype: 'string'> 781 >>> print(serialized.shape) 782 () 783 >>> t2 = concrete_parse(serialized) 784 >>> print(t2.shape) 785 (5, 3) 786 787 Caution: `set_shape` ensures that the applied shape is compatible with 788 the existing shape, but it does not check at runtime. Setting 789 incorrect shapes can result in inconsistencies between the 790 statically-known graph and the runtime value of tensors. For runtime 791 validation of the shape, use `tf.ensure_shape` instead. It also modifies 792 the `shape` of the tensor. 793 794 >>> # Serialize a rank-3 tensor 795 >>> t = tf.ones([5,5,5], dtype=tf.float32) 796 >>> serialized = tf.io.serialize_tensor(t) 797 >>> # The function still runs, even though it `set_shape([None,None])` 798 >>> t2 = concrete_parse(serialized) 799 >>> print(t2.shape) 800 (5, 5, 5) 801 802 Args: 803 shape: A `TensorShape` representing the shape of this tensor, a 804 `TensorShapeProto`, a list, a tuple, or None. 805 806 Raises: 807 ValueError: If `shape` is not compatible with the current shape of 808 this tensor. 809 """ 810 # Reset cached shape. 811 self._shape_val = None 812 813 # We want set_shape to be reflected in the C API graph for when we run it. 814 if not isinstance(shape, tensor_shape.TensorShape): 815 shape = tensor_shape.TensorShape(shape) 816 dim_list = [] 817 if shape.dims is None: 818 unknown_shape = True 819 else: 820 unknown_shape = False 821 for dim in shape.dims: 822 if dim.value is None: 823 dim_list.append(-1) 824 else: 825 dim_list.append(dim.value) 826 try: 827 with self._op._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 828 pywrap_tf_session.TF_GraphSetTensorShape_wrapper( 829 c_graph, self._as_tf_output(), dim_list, unknown_shape) 830 except errors.InvalidArgumentError as e: 831 # Convert to ValueError for backwards compatibility. 832 raise ValueError(e.message) 833 834 @property 835 def value_index(self): 836 """The index of this tensor in the outputs of its `Operation`.""" 837 return self._value_index 838 839 def consumers(self): 840 """Returns a list of `Operation`s that consume this tensor. 841 842 Returns: 843 A list of `Operation`s. 844 """ 845 consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper( 846 self._as_tf_output()) 847 # pylint: disable=protected-access 848 return [ 849 self.graph._get_operation_by_name_unsafe(name) 850 for name in consumer_names 851 ] 852 # pylint: enable=protected-access 853 854 def _as_node_def_input(self): 855 """Return a value to use for the NodeDef "input" attribute. 856 857 The returned string can be used in a NodeDef "input" attribute 858 to indicate that the NodeDef uses this Tensor as input. 859 860 Raises: 861 ValueError: if this Tensor's Operation does not have a name. 862 863 Returns: 864 a string. 865 """ 866 assert self._op.name 867 if self._value_index == 0: 868 return self._op.name 869 else: 870 return "%s:%d" % (self._op.name, self._value_index) 871 872 def _as_tf_output(self): 873 # pylint: disable=protected-access 874 # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object 875 # also guarantees that a dictionary of tf_output objects will retain a 876 # deterministic (yet unsorted) order which prevents memory blowup in the 877 # cache of executor(s) stored for every session. 878 if self._tf_output is None: 879 self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index) 880 return self._tf_output 881 # pylint: enable=protected-access 882 883 def __str__(self): 884 return "Tensor(\"%s\"%s%s%s)" % ( 885 self.name, 886 (", shape=%s" % 887 self.get_shape()) if self.get_shape().ndims is not None else "", 888 (", dtype=%s" % self._dtype.name) if self._dtype else "", 889 (", device=%s" % self.device) if self.device else "") 890 891 def __repr__(self): 892 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 893 self._dtype.name) 894 895 def __hash__(self): 896 g = getattr(self, "graph", None) 897 if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and 898 (g is None or g.building_function)): 899 raise TypeError("Tensor is unhashable. " 900 "Instead, use tensor.ref() as the key.") 901 else: 902 return id(self) 903 904 def __copy__(self): 905 # TODO(b/77597810): get rid of Tensor copies. 906 cls = self.__class__ 907 result = cls.__new__(cls) 908 result.__dict__.update(self.__dict__) 909 return result 910 911 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 912 # operators to run when the left operand is an ndarray, because it 913 # accords the Tensor class higher priority than an ndarray, or a 914 # numpy matrix. 915 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 916 # mechanism, which allows more control over how Tensors interact 917 # with ndarrays. 918 __array_priority__ = 100 919 920 def __array__(self, dtype=None): 921 del dtype 922 raise NotImplementedError( 923 f"Cannot convert a symbolic tf.Tensor ({self.name}) to a numpy array." 924 f" This error may indicate that you're trying to pass a Tensor to" 925 f" a NumPy call, which is not supported.") 926 927 def __len__(self): 928 raise TypeError(f"len is not well defined for a symbolic Tensor " 929 f"({self.name}). Please call `x.shape` rather than " 930 f"`len(x)` for shape information.") 931 932 # TODO(mdan): This convoluted machinery is hard to maintain. Clean up. 933 @staticmethod 934 def _override_operator(operator, func): 935 _override_helper(Tensor, operator, func) 936 937 def __bool__(self): 938 """Dummy method to prevent a tensor from being used as a Python `bool`. 939 940 This overload raises a `TypeError` when the user inadvertently 941 treats a `Tensor` as a boolean (most commonly in an `if` or `while` 942 statement), in code that was not converted by AutoGraph. For example: 943 944 ```python 945 if tf.constant(True): # Will raise. 946 # ... 947 948 if tf.constant(5) < tf.constant(7): # Will raise. 949 # ... 950 ``` 951 952 Raises: 953 `TypeError`. 954 """ 955 self._disallow_bool_casting() 956 957 def __nonzero__(self): 958 """Dummy method to prevent a tensor from being used as a Python `bool`. 959 960 This is the Python 2.x counterpart to `__bool__()` above. 961 962 Raises: 963 `TypeError`. 964 """ 965 self._disallow_bool_casting() 966 967 def eval(self, feed_dict=None, session=None): 968 """Evaluates this tensor in a `Session`. 969 970 Note: If you are not using `compat.v1` libraries, you should not need this, 971 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`) 972 you do not need to call `eval`. 973 974 Calling this method will execute all preceding operations that 975 produce the inputs needed for the operation that produces this 976 tensor. 977 978 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 979 launched in a session, and either a default session must be 980 available, or `session` must be specified explicitly. 981 982 Args: 983 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 984 `tf.Session.run` for a description of the valid feed values. 985 session: (Optional.) The `Session` to be used to evaluate this tensor. If 986 none, the default session will be used. 987 988 Returns: 989 A numpy array corresponding to the value of this tensor. 990 """ 991 return _eval_using_default_session(self, feed_dict, self.graph, session) 992 993 @deprecation.deprecated(None, "Use ref() instead.") 994 def experimental_ref(self): 995 return self.ref() 996 997 def ref(self): 998 # tf.Variable also has the same ref() API. If you update the 999 # documentation here, please update tf.Variable.ref() as well. 1000 """Returns a hashable reference object to this Tensor. 1001 1002 The primary use case for this API is to put tensors in a set/dictionary. 1003 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer 1004 available starting Tensorflow 2.0. 1005 1006 The following will raise an exception starting 2.0 1007 1008 >>> x = tf.constant(5) 1009 >>> y = tf.constant(10) 1010 >>> z = tf.constant(10) 1011 >>> tensor_set = {x, y, z} 1012 Traceback (most recent call last): 1013 ... 1014 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 1015 >>> tensor_dict = {x: 'five', y: 'ten'} 1016 Traceback (most recent call last): 1017 ... 1018 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 1019 1020 Instead, we can use `tensor.ref()`. 1021 1022 >>> tensor_set = {x.ref(), y.ref(), z.ref()} 1023 >>> x.ref() in tensor_set 1024 True 1025 >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 1026 >>> tensor_dict[y.ref()] 1027 'ten' 1028 1029 Also, the reference object provides `.deref()` function that returns the 1030 original Tensor. 1031 1032 >>> x = tf.constant(5) 1033 >>> x.ref().deref() 1034 <tf.Tensor: shape=(), dtype=int32, numpy=5> 1035 """ 1036 return object_identity.Reference(self) 1037 1038 def __tf_tracing_type__(self, signature_context): 1039 return tensor_spec.TensorSpec( 1040 self.shape, self.dtype).__tf_tracing_type__(signature_context) 1041 1042 1043# TODO(agarwal): consider getting rid of this. 1044# TODO(mdan): This object should not subclass ops.Tensor. 1045class _EagerTensorBase(Tensor): 1046 """Base class for EagerTensor.""" 1047 1048 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and 1049 # only work for scalars; values are cast as per numpy. 1050 def __complex__(self): 1051 return complex(self._numpy()) 1052 1053 def __int__(self): 1054 return int(self._numpy()) 1055 1056 def __long__(self): 1057 return long(self._numpy()) 1058 1059 def __float__(self): 1060 return float(self._numpy()) 1061 1062 def __index__(self): 1063 return self._numpy().__index__() 1064 1065 def __bool__(self): 1066 return bool(self._numpy()) 1067 1068 __nonzero__ = __bool__ 1069 1070 def __format__(self, format_spec): 1071 if self._prefer_custom_summarizer(): 1072 return self._summarize_value().__format__(format_spec) 1073 elif self.dtype.is_numpy_compatible: 1074 # Not numpy_text here, otherwise the __format__ behaves differently. 1075 return self._numpy().__format__(format_spec) 1076 else: 1077 return "<unprintable>".__format__(format_spec) 1078 1079 def __reduce__(self): 1080 return convert_to_tensor, (self._numpy(),) 1081 1082 def __copy__(self): 1083 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1084 return self 1085 1086 def __deepcopy__(self, memo): 1087 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1088 del memo 1089 return self 1090 1091 def __str__(self): 1092 return "tf.Tensor(%s, shape=%s, dtype=%s)" % ( 1093 value_text(self, is_repr=False), self.shape, self.dtype.name) 1094 1095 def __repr__(self): 1096 return "<tf.Tensor: shape=%s, dtype=%s, %s>" % ( 1097 self.shape, self.dtype.name, value_text(self, is_repr=True)) 1098 1099 def __len__(self): 1100 """Returns the length of the first dimension in the Tensor.""" 1101 if not self.shape.ndims: 1102 raise TypeError("Scalar tensor has no `len()`") 1103 # pylint: disable=protected-access 1104 try: 1105 return self._shape_tuple()[0] 1106 except core._NotOkStatusException as e: 1107 raise core._status_to_exception(e) from None 1108 1109 def __array__(self, dtype=None): 1110 a = self._numpy() 1111 if not dtype: 1112 return a 1113 1114 return np.array(a, dtype=dtype) 1115 1116 def _numpy_internal(self): 1117 raise NotImplementedError() 1118 1119 def _numpy(self): 1120 try: 1121 return self._numpy_internal() 1122 except core._NotOkStatusException as e: # pylint: disable=protected-access 1123 raise core._status_to_exception(e) from None # pylint: disable=protected-access 1124 1125 @property 1126 def dtype(self): 1127 # Note: using the intern table directly here as this is 1128 # performance-sensitive in some models. 1129 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 1130 1131 def numpy(self): 1132 """Copy of the contents of this Tensor into a NumPy array or scalar. 1133 1134 Unlike NumPy arrays, Tensors are immutable, so this method has to copy 1135 the contents to ensure safety. Use `memoryview` to get a readonly 1136 view of the contents without doing a copy: 1137 1138 >>> t = tf.constant([42]) 1139 >>> np.array(memoryview(t)) 1140 array([42], dtype=int32) 1141 1142 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor 1143 is on GPU, it will have to be transferred to CPU first in order for 1144 `memoryview` to work. 1145 1146 Returns: 1147 A NumPy array of the same shape and dtype or a NumPy scalar, if this 1148 Tensor has rank 0. 1149 1150 Raises: 1151 ValueError: If the dtype of this Tensor does not have a compatible 1152 NumPy dtype. 1153 """ 1154 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors. 1155 maybe_arr = self._numpy() # pylint: disable=protected-access 1156 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr 1157 1158 @property 1159 def backing_device(self): 1160 """Returns the name of the device holding this tensor's memory. 1161 1162 `.backing_device` is usually the same as `.device`, which returns 1163 the device on which the kernel of the operation that produced this tensor 1164 ran. However, some operations can produce tensors on a different device 1165 (e.g., an operation that executes on the GPU but produces output tensors 1166 in host memory). 1167 """ 1168 raise NotImplementedError() 1169 1170 def _datatype_enum(self): 1171 raise NotImplementedError() 1172 1173 def _shape_tuple(self): 1174 """The shape of this Tensor, as a tuple. 1175 1176 This is more performant than tuple(shape().as_list()) as it avoids 1177 two list and one object creation. Marked private for now as from an API 1178 perspective, it would be better to have a single performant way of 1179 getting a shape rather than exposing shape() and shape_tuple() 1180 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 1181 but ideally one would work things out and remove the need for this method. 1182 1183 Returns: 1184 tuple with the shape. 1185 """ 1186 raise NotImplementedError() 1187 1188 def _rank(self): 1189 """Integer rank of this Tensor. 1190 1191 Unlike regular Tensors, the rank is always known for EagerTensors. 1192 1193 This is more performant than len(self._shape_tuple()) 1194 1195 Returns: 1196 Integer rank 1197 """ 1198 raise NotImplementedError() 1199 1200 def _num_elements(self): 1201 """Number of elements of this Tensor. 1202 1203 Unlike regular Tensors, the number of elements is always known for 1204 EagerTensors. 1205 1206 This is more performant than tensor.shape.num_elements 1207 1208 Returns: 1209 Long - num elements in the tensor 1210 """ 1211 raise NotImplementedError() 1212 1213 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name 1214 raise NotImplementedError() 1215 1216 @staticmethod 1217 def _override_operator(name, func): 1218 setattr(_EagerTensorBase, name, func) 1219 1220 def _copy_nograd(self, ctx=None, device_name=None): 1221 """Copies tensor to dest device, but doesn't record the operation.""" 1222 # Creates a new tensor on the dest device. 1223 if ctx is None: 1224 ctx = context.context() 1225 if device_name is None: 1226 device_name = ctx.device_name 1227 # pylint: disable=protected-access 1228 try: 1229 ctx.ensure_initialized() 1230 new_tensor = self._copy_to_device(device_name) 1231 except core._NotOkStatusException as e: 1232 raise core._status_to_exception(e) from None 1233 return new_tensor 1234 1235 def _copy(self, ctx=None, device_name=None): 1236 """Copies tensor to dest device.""" 1237 new_tensor = self._copy_nograd(ctx, device_name) 1238 # Record the copy on tape and define backprop copy as well. 1239 if context.executing_eagerly(): 1240 self_device = self.device 1241 1242 def grad_fun(dresult): 1243 return [ 1244 dresult._copy(device_name=self_device) 1245 if hasattr(dresult, "_copy") else dresult 1246 ] 1247 1248 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 1249 return new_tensor 1250 # pylint: enable=protected-access 1251 1252 @property 1253 def shape(self): 1254 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 1255 # pylint: disable=protected-access 1256 try: 1257 # `_tensor_shape` is declared and defined in the definition of 1258 # `EagerTensor`, in C. 1259 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 1260 except core._NotOkStatusException as e: 1261 raise core._status_to_exception(e) from None 1262 1263 return self._tensor_shape 1264 1265 def get_shape(self): 1266 """Alias of Tensor.shape.""" 1267 return self.shape 1268 1269 def _shape_as_list(self): 1270 """The shape of the tensor as a list.""" 1271 return list(self._shape_tuple()) 1272 1273 @property 1274 def ndim(self): 1275 """Returns the number of Tensor dimensions.""" 1276 return self.shape.ndims 1277 1278 @deprecation.deprecated(None, "Use tf.identity instead.") 1279 def cpu(self): 1280 """A copy of this Tensor with contents backed by host memory.""" 1281 return self._copy(context.context(), "CPU:0") 1282 1283 @deprecation.deprecated(None, "Use tf.identity instead.") 1284 def gpu(self, gpu_index=0): 1285 """A copy of this Tensor with contents backed by memory on the GPU. 1286 1287 Args: 1288 gpu_index: Identifies which GPU to place the contents on the returned 1289 Tensor in. 1290 1291 Returns: 1292 A GPU-memory backed Tensor object initialized with the same contents 1293 as this Tensor. 1294 """ 1295 return self._copy(context.context(), "GPU:" + str(gpu_index)) 1296 1297 def set_shape(self, shape): 1298 if not self.shape.is_compatible_with(shape): 1299 raise ValueError(f"Tensor's shape {self.shape} is not compatible " 1300 f"with supplied shape {shape}.") 1301 1302 # Methods not supported / implemented for Eager Tensors. 1303 @property 1304 def op(self): 1305 raise AttributeError( 1306 "Tensor.op is undefined when eager execution is enabled.") 1307 1308 @property 1309 def graph(self): 1310 raise AttributeError( 1311 "Tensor.graph is undefined when eager execution is enabled.") 1312 1313 @property 1314 def name(self): 1315 raise AttributeError( 1316 "Tensor.name is undefined when eager execution is enabled.") 1317 1318 @property 1319 def value_index(self): 1320 raise AttributeError( 1321 "Tensor.value_index is undefined when eager execution is enabled.") 1322 1323 def consumers(self): 1324 raise NotImplementedError( 1325 "Tensor.consumers is undefined when eager execution is enabled.") 1326 1327 def _add_consumer(self, consumer): 1328 raise NotImplementedError( 1329 "_add_consumer not supported when eager execution is enabled.") 1330 1331 def _as_node_def_input(self): 1332 raise NotImplementedError( 1333 "_as_node_def_input not supported when eager execution is enabled.") 1334 1335 def _as_tf_output(self): 1336 raise NotImplementedError( 1337 "_as_tf_output not supported when eager execution is enabled.") 1338 1339 def eval(self, feed_dict=None, session=None): 1340 raise NotImplementedError( 1341 "eval is not supported when eager execution is enabled, " 1342 "is .numpy() what you're looking for?") 1343 1344 1345# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 1346# registers it with the current module. 1347# It is exposed as an __internal__ api for now (b/171081052), though we 1348# expect it to be eventually covered by tf Tensor types and typing. 1349EagerTensor = tf_export("__internal__.EagerTensor", v1=[])( 1350 pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)) 1351 1352 1353@tf_export(v1=["convert_to_tensor"]) 1354@dispatch.add_dispatch_support 1355def convert_to_tensor_v1_with_dispatch( 1356 value, 1357 dtype=None, 1358 name=None, 1359 preferred_dtype=None, 1360 dtype_hint=None): 1361 """Converts the given `value` to a `Tensor`. 1362 1363 This function converts Python objects of various types to `Tensor` 1364 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1365 and Python scalars. For example: 1366 1367 ```python 1368 import numpy as np 1369 1370 def my_func(arg): 1371 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1372 return tf.matmul(arg, arg) + arg 1373 1374 # The following calls are equivalent. 1375 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1376 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1377 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1378 ``` 1379 1380 This function can be useful when composing a new operation in Python 1381 (such as `my_func` in the example above). All standard Python op 1382 constructors apply this function to each of their Tensor-valued 1383 inputs, which allows those ops to accept numpy arrays, Python lists, 1384 and scalars in addition to `Tensor` objects. 1385 1386 Note: This function diverges from default Numpy behavior for `float` and 1387 `string` types when `None` is present in a Python list or scalar. Rather 1388 than silently converting `None` values, an error will be thrown. 1389 1390 Args: 1391 value: An object whose type has a registered `Tensor` conversion function. 1392 dtype: Optional element type for the returned tensor. If missing, the type 1393 is inferred from the type of `value`. 1394 name: Optional name to use if a new `Tensor` is created. 1395 preferred_dtype: Optional element type for the returned tensor, used when 1396 dtype is None. In some cases, a caller may not have a dtype in mind when 1397 converting to a tensor, so preferred_dtype can be used as a soft 1398 preference. If the conversion to `preferred_dtype` is not possible, this 1399 argument has no effect. 1400 dtype_hint: same meaning as preferred_dtype, and overrides it. 1401 1402 Returns: 1403 A `Tensor` based on `value`. 1404 1405 Raises: 1406 TypeError: If no conversion function is registered for `value` to `dtype`. 1407 RuntimeError: If a registered conversion function returns an invalid value. 1408 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1409 """ 1410 return convert_to_tensor_v1(value, dtype=dtype, name=name, 1411 preferred_dtype=preferred_dtype, 1412 dtype_hint=dtype_hint) 1413 1414 1415def convert_to_tensor_v1(value, 1416 dtype=None, 1417 name=None, 1418 preferred_dtype=None, 1419 dtype_hint=None): 1420 """Converts the given `value` to a `Tensor` (with the TF1 API).""" 1421 preferred_dtype = deprecation.deprecated_argument_lookup( 1422 "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype) 1423 return convert_to_tensor_v2(value, dtype, preferred_dtype, name) 1424 1425 1426@tf_export("convert_to_tensor", v1=[]) 1427@dispatch.add_dispatch_support 1428def convert_to_tensor_v2_with_dispatch( 1429 value, dtype=None, dtype_hint=None, name=None): 1430 """Converts the given `value` to a `Tensor`. 1431 1432 This function converts Python objects of various types to `Tensor` 1433 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1434 and Python scalars. 1435 1436 For example: 1437 1438 >>> import numpy as np 1439 >>> def my_func(arg): 1440 ... arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1441 ... return arg 1442 1443 >>> # The following calls are equivalent. 1444 ... 1445 >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1446 >>> print(value_1) 1447 tf.Tensor( 1448 [[1. 2.] 1449 [3. 4.]], shape=(2, 2), dtype=float32) 1450 >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1451 >>> print(value_2) 1452 tf.Tensor( 1453 [[1. 2.] 1454 [3. 4.]], shape=(2, 2), dtype=float32) 1455 >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1456 >>> print(value_3) 1457 tf.Tensor( 1458 [[1. 2.] 1459 [3. 4.]], shape=(2, 2), dtype=float32) 1460 1461 This function can be useful when composing a new operation in Python 1462 (such as `my_func` in the example above). All standard Python op 1463 constructors apply this function to each of their Tensor-valued 1464 inputs, which allows those ops to accept numpy arrays, Python lists, 1465 and scalars in addition to `Tensor` objects. 1466 1467 Note: This function diverges from default Numpy behavior for `float` and 1468 `string` types when `None` is present in a Python list or scalar. Rather 1469 than silently converting `None` values, an error will be thrown. 1470 1471 Args: 1472 value: An object whose type has a registered `Tensor` conversion function. 1473 dtype: Optional element type for the returned tensor. If missing, the type 1474 is inferred from the type of `value`. 1475 dtype_hint: Optional element type for the returned tensor, used when dtype 1476 is None. In some cases, a caller may not have a dtype in mind when 1477 converting to a tensor, so dtype_hint can be used as a soft preference. 1478 If the conversion to `dtype_hint` is not possible, this argument has no 1479 effect. 1480 name: Optional name to use if a new `Tensor` is created. 1481 1482 Returns: 1483 A `Tensor` based on `value`. 1484 1485 Raises: 1486 TypeError: If no conversion function is registered for `value` to `dtype`. 1487 RuntimeError: If a registered conversion function returns an invalid value. 1488 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1489 """ 1490 return convert_to_tensor_v2( 1491 value, dtype=dtype, dtype_hint=dtype_hint, name=name) 1492 1493 1494def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): 1495 """Converts the given `value` to a `Tensor`.""" 1496 return convert_to_tensor( 1497 value=value, 1498 dtype=dtype, 1499 name=name, 1500 preferred_dtype=dtype_hint, 1501 as_ref=False) 1502 1503 1504def _add_error_prefix(msg, *, name=None): 1505 return msg if name is None else f"{name}: {msg}" 1506 1507 1508def pack_eager_tensors(tensors, ctx=None): 1509 """Pack multiple `EagerTensor`s of the same dtype and shape. 1510 1511 Args: 1512 tensors: a list of EagerTensors to pack. 1513 ctx: context.context(). 1514 1515 Returns: 1516 A packed EagerTensor. 1517 """ 1518 if not isinstance(tensors, list): 1519 raise TypeError(f"tensors must be a list, but got a {type(tensors)}") 1520 1521 if not tensors: 1522 raise ValueError("Cannot pack an empty list of tensors.") 1523 1524 dtype = tensors[0].dtype 1525 shape = tensors[0].shape 1526 handle_data = tensors[0]._handle_data # pylint: disable=protected-access 1527 is_resource = dtype == dtypes.resource 1528 for i in range(len(tensors)): 1529 t = tensors[i] 1530 if not isinstance(t, EagerTensor): 1531 raise TypeError(f"All tensors being packed must be EagerTensor. " 1532 f"Found an item of type {type(t)}.") 1533 1534 if t.dtype != dtype: 1535 raise ValueError( 1536 f"All tensors being packed should have the same dtype {dtype}, " 1537 f"but the {i}-th tensor is of dtype {t.dtype}") 1538 if t.shape != shape: 1539 raise ValueError( 1540 f"All tensors being packed should have the same shape {shape}, " 1541 f"but the {i}-th tensor is of shape {t.shape}") 1542 # pylint: disable=protected-access 1543 if is_resource and t._handle_data != handle_data: 1544 raise ValueError( 1545 f"All tensors being packed should have the same handle data " 1546 f"{handle_data}, " 1547 f"but the {i}-th tensor is of handle data {t._handle_data}") 1548 # pylint: enable=protected-access 1549 1550 if ctx is None: 1551 ctx = context.context() 1552 1553 # Propogate handle data for resource variables 1554 packed_tensor = ctx.pack_eager_tensors(tensors) 1555 if handle_data is not None: 1556 packed_tensor._handle_data = handle_data # pylint: disable=protected-access 1557 1558 def grad_fun(_): 1559 raise ValueError( 1560 "Computing gradients through pack_eager_tensors is not supported.") 1561 1562 tape.record_operation("pack_eager_tensors", [packed_tensor], tensors, 1563 grad_fun) 1564 1565 return packed_tensor 1566 1567 1568@profiler_trace.trace_wrapper("convert_to_tensor") 1569def convert_to_tensor(value, 1570 dtype=None, 1571 name=None, 1572 as_ref=False, 1573 preferred_dtype=None, 1574 dtype_hint=None, 1575 ctx=None, 1576 accepted_result_types=(Tensor,)): 1577 """Implementation of the public convert_to_tensor.""" 1578 # TODO(b/142518781): Fix all call-sites and remove redundant arg 1579 preferred_dtype = preferred_dtype or dtype_hint 1580 if isinstance(value, EagerTensor): 1581 if ctx is None: 1582 ctx = context.context() 1583 if not ctx.executing_eagerly(): 1584 graph = get_default_graph() 1585 if not graph.building_function: 1586 raise RuntimeError( 1587 _add_error_prefix( 1588 "Attempting to capture an EagerTensor without " 1589 "building a function.", 1590 name=name)) 1591 return graph.capture(value, name=name) 1592 1593 if dtype is not None: 1594 dtype = dtypes.as_dtype(dtype) 1595 if isinstance(value, Tensor): 1596 if dtype is not None and not dtype.is_compatible_with(value.dtype): 1597 raise ValueError( 1598 _add_error_prefix( 1599 f"Tensor conversion requested dtype {dtype.name} " 1600 f"for Tensor with dtype {value.dtype.name}: {value!r}", 1601 name=name)) 1602 return value 1603 1604 if preferred_dtype is not None: 1605 preferred_dtype = dtypes.as_dtype(preferred_dtype) 1606 1607 # See below for the reason why it's `type(value)` and not just `value`. 1608 # https://docs.python.org/3.8/reference/datamodel.html#special-lookup 1609 overload = getattr(type(value), "__tf_tensor__", None) 1610 if overload is not None: 1611 return overload(value, dtype, name) # pylint: disable=not-callable 1612 1613 for base_type, conversion_func in tensor_conversion_registry.get(type(value)): 1614 # If dtype is None but preferred_dtype is not None, we try to 1615 # cast to preferred_dtype first. 1616 ret = None 1617 if dtype is None and preferred_dtype is not None: 1618 try: 1619 ret = conversion_func( 1620 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1621 except (TypeError, ValueError): 1622 # Could not coerce the conversion to use the preferred dtype. 1623 pass 1624 else: 1625 if (ret is not NotImplemented and 1626 ret.dtype.base_dtype != preferred_dtype.base_dtype): 1627 raise RuntimeError( 1628 _add_error_prefix( 1629 f"Conversion function {conversion_func!r} for type " 1630 f"{base_type} returned incompatible dtype: requested = " 1631 f"{preferred_dtype.base_dtype.name}, " 1632 f"actual = {ret.dtype.base_dtype.name}", 1633 name=name)) 1634 1635 if ret is None: 1636 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1637 1638 if ret is NotImplemented: 1639 continue 1640 1641 if not isinstance(ret, accepted_result_types): 1642 raise RuntimeError( 1643 _add_error_prefix( 1644 f"Conversion function {conversion_func!r} for type " 1645 f"{base_type} returned non-Tensor: {ret!r}", 1646 name=name)) 1647 if dtype and not dtype.is_compatible_with(ret.dtype): 1648 raise RuntimeError( 1649 _add_error_prefix( 1650 f"Conversion function {conversion_func} for type {base_type} " 1651 f"returned incompatible dtype: requested = {dtype.name}, " 1652 f"actual = {ret.dtype.name}", 1653 name=name)) 1654 return ret 1655 raise TypeError( 1656 _add_error_prefix( 1657 f"Cannot convert {value!r} with type {type(value)} to Tensor: " 1658 f"no conversion function registered.", 1659 name=name)) 1660 1661 1662internal_convert_to_tensor = convert_to_tensor 1663 1664 1665def internal_convert_n_to_tensor(values, 1666 dtype=None, 1667 name=None, 1668 as_ref=False, 1669 preferred_dtype=None, 1670 ctx=None): 1671 """Converts `values` to a list of `Tensor` objects. 1672 1673 Args: 1674 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1675 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1676 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1677 which case element `i` will be given the name `name + '_' + i`. 1678 as_ref: True if the caller wants the results as ref tensors. 1679 preferred_dtype: Optional element type for the returned tensors, used when 1680 dtype is None. In some cases, a caller may not have a dtype in mind when 1681 converting to a tensor, so preferred_dtype can be used as a soft 1682 preference. If the conversion to `preferred_dtype` is not possible, this 1683 argument has no effect. 1684 ctx: The value of context.context(). 1685 1686 Returns: 1687 A list of `Tensor` and/or `IndexedSlices` objects. 1688 1689 Raises: 1690 TypeError: If no conversion function is registered for an element in 1691 `values`. 1692 RuntimeError: If a registered conversion function returns an invalid 1693 value. 1694 """ 1695 if not isinstance(values, collections_abc.Sequence): 1696 raise TypeError("values must be a sequence.") 1697 ret = [] 1698 if ctx is None: 1699 ctx = context.context() 1700 for i, value in enumerate(values): 1701 n = None if name is None else "%s_%d" % (name, i) 1702 ret.append( 1703 convert_to_tensor( 1704 value, 1705 dtype=dtype, 1706 name=n, 1707 as_ref=as_ref, 1708 preferred_dtype=preferred_dtype, 1709 ctx=ctx)) 1710 return ret 1711 1712 1713def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1714 """Converts `values` to a list of `Tensor` objects. 1715 1716 Args: 1717 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1718 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1719 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1720 which case element `i` will be given the name `name + '_' + i`. 1721 preferred_dtype: Optional element type for the returned tensors, used when 1722 dtype is None. In some cases, a caller may not have a dtype in mind when 1723 converting to a tensor, so preferred_dtype can be used as a soft 1724 preference. If the conversion to `preferred_dtype` is not possible, this 1725 argument has no effect. 1726 1727 Returns: 1728 A list of `Tensor` and/or `IndexedSlices` objects. 1729 1730 Raises: 1731 TypeError: If no conversion function is registered for an element in 1732 `values`. 1733 RuntimeError: If a registered conversion function returns an invalid 1734 value. 1735 """ 1736 return internal_convert_n_to_tensor( 1737 values=values, 1738 dtype=dtype, 1739 name=name, 1740 preferred_dtype=preferred_dtype, 1741 as_ref=False) 1742 1743 1744def convert_to_tensor_or_composite(value, dtype=None, name=None): 1745 """Converts the given object to a `Tensor` or `CompositeTensor`. 1746 1747 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1748 is converted to a `Tensor` using `convert_to_tensor()`. 1749 1750 Args: 1751 value: A `CompositeTensor` or an object that can be consumed by 1752 `convert_to_tensor()`. 1753 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1754 `CompositeTensor`. 1755 name: (Optional.) A name to use if a new `Tensor` is created. 1756 1757 Returns: 1758 A `Tensor` or `CompositeTensor`, based on `value`. 1759 1760 Raises: 1761 ValueError: If `dtype` does not match the element type of `value`. 1762 """ 1763 return internal_convert_to_tensor_or_composite( 1764 value=value, dtype=dtype, name=name, as_ref=False) 1765 1766 1767def internal_convert_to_tensor_or_composite(value, 1768 dtype=None, 1769 name=None, 1770 as_ref=False): 1771 """Converts the given object to a `Tensor` or `CompositeTensor`. 1772 1773 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1774 is converted to a `Tensor` using `convert_to_tensor()`. 1775 1776 Args: 1777 value: A `CompositeTensor`, or an object that can be consumed by 1778 `convert_to_tensor()`. 1779 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1780 `CompositeTensor`. 1781 name: (Optional.) A name to use if a new `Tensor` is created. 1782 as_ref: True if the caller wants the results as ref tensors. 1783 1784 Returns: 1785 A `Tensor` or `CompositeTensor`, based on `value`. 1786 1787 Raises: 1788 ValueError: If `dtype` does not match the element type of `value`. 1789 """ 1790 if isinstance(value, composite_tensor.CompositeTensor): 1791 value_dtype = getattr(value, "dtype", None) 1792 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 1793 raise ValueError(f"Tensor conversion dtype mismatch. " 1794 f"Requested dtype is {dtypes.as_dtype(dtype).name}, " 1795 f"Tensor has dtype {value.dtype.name}: {value!r}") 1796 return value 1797 else: 1798 return convert_to_tensor( 1799 value, 1800 dtype=dtype, 1801 name=name, 1802 as_ref=as_ref, 1803 accepted_result_types=(Tensor, composite_tensor.CompositeTensor)) 1804 1805 1806def internal_convert_n_to_tensor_or_composite(values, 1807 dtype=None, 1808 name=None, 1809 as_ref=False): 1810 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 1811 1812 Any `CompositeTensor` objects in `values` are returned unmodified. 1813 1814 Args: 1815 values: A list of `None`, `CompositeTensor`, or objects that can be consumed 1816 by `convert_to_tensor()`. 1817 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1818 `CompositeTensor`s. 1819 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1820 which case element `i` will be given the name `name + '_' + i`. 1821 as_ref: True if the caller wants the results as ref tensors. 1822 1823 Returns: 1824 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 1825 1826 Raises: 1827 TypeError: If no conversion function is registered for an element in 1828 `values`. 1829 RuntimeError: If a registered conversion function returns an invalid 1830 value. 1831 """ 1832 if not isinstance(values, collections_abc.Sequence): 1833 raise TypeError("values must be a sequence.") 1834 ret = [] 1835 for i, value in enumerate(values): 1836 if value is None: 1837 ret.append(value) 1838 else: 1839 n = None if name is None else "%s_%d" % (name, i) 1840 ret.append( 1841 internal_convert_to_tensor_or_composite( 1842 value, dtype=dtype, name=n, as_ref=as_ref)) 1843 return ret 1844 1845 1846def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 1847 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 1848 1849 Any `CompositeTensor` objects in `values` are returned unmodified. 1850 1851 Args: 1852 values: A list of `None`, `CompositeTensor``, or objects that can be 1853 consumed by `convert_to_tensor()`. 1854 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1855 `CompositeTensor`s. 1856 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1857 which case element `i` will be given the name `name + '_' + i`. 1858 1859 Returns: 1860 A list of `Tensor` and/or `CompositeTensor` objects. 1861 1862 Raises: 1863 TypeError: If no conversion function is registered for an element in 1864 `values`. 1865 RuntimeError: If a registered conversion function returns an invalid 1866 value. 1867 """ 1868 return internal_convert_n_to_tensor_or_composite( 1869 values=values, dtype=dtype, name=name, as_ref=False) 1870 1871 1872def _device_string(dev_spec): 1873 if pydev.is_device_spec(dev_spec): 1874 return dev_spec.to_string() 1875 else: 1876 return dev_spec 1877 1878 1879def _NodeDef(op_type, name, attrs=None): 1880 """Create a NodeDef proto. 1881 1882 Args: 1883 op_type: Value for the "op" attribute of the NodeDef proto. 1884 name: Value for the "name" attribute of the NodeDef proto. 1885 attrs: Dictionary where the key is the attribute name (a string) 1886 and the value is the respective "attr" attribute of the NodeDef proto (an 1887 AttrValue). 1888 1889 Returns: 1890 A node_def_pb2.NodeDef protocol buffer. 1891 """ 1892 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type), 1893 name=compat.as_bytes(name)) 1894 if attrs: 1895 for k, v in attrs.items(): 1896 node_def.attr[k].CopyFrom(v) 1897 return node_def 1898 1899 1900# Copied from core/framework/node_def_util.cc 1901# TODO(mrry,josh11b): Consolidate this validation in C++ code. 1902_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") 1903_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$") 1904 1905 1906@tf_export("__internal__.create_c_op", v1=[]) 1907@traceback_utils.filter_traceback 1908def _create_c_op(graph, 1909 node_def, 1910 inputs, 1911 control_inputs, 1912 op_def=None, 1913 extract_traceback=True): 1914 """Creates a TF_Operation. 1915 1916 Args: 1917 graph: a `Graph`. 1918 node_def: `node_def_pb2.NodeDef` for the operation to create. 1919 inputs: A flattened list of `Tensor`s. This function handles grouping 1920 tensors into lists as per attributes in the `node_def`. 1921 control_inputs: A list of `Operation`s to set as control dependencies. 1922 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not 1923 specified, is looked up from the `graph` using `node_def.op`. 1924 extract_traceback: if True, extract the current Python traceback to the 1925 TF_Operation. 1926 1927 Returns: 1928 A wrapped TF_Operation*. 1929 """ 1930 if op_def is None: 1931 op_def = graph._get_op_def(node_def.op) # pylint: disable=protected-access 1932 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1933 # Refactor so we don't have to do this here. 1934 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr) 1935 # pylint: disable=protected-access 1936 with graph._c_graph.get() as c_graph: 1937 op_desc = pywrap_tf_session.TF_NewOperation(c_graph, 1938 compat.as_str(node_def.op), 1939 compat.as_str(node_def.name)) 1940 if node_def.device: 1941 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 1942 # Add inputs 1943 for op_input in inputs: 1944 if isinstance(op_input, (list, tuple)): 1945 pywrap_tf_session.TF_AddInputList(op_desc, 1946 [t._as_tf_output() for t in op_input]) 1947 else: 1948 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output()) 1949 1950 # Add control inputs 1951 for control_input in control_inputs: 1952 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op) 1953 # pylint: enable=protected-access 1954 1955 # Add attrs 1956 for name, attr_value in node_def.attr.items(): 1957 serialized = attr_value.SerializeToString() 1958 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1959 # It might be worth creating a convenient way to re-use the same status. 1960 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name), 1961 serialized) 1962 1963 try: 1964 c_op = pywrap_tf_session.TF_FinishOperation(op_desc) 1965 except errors.InvalidArgumentError as e: 1966 # Convert to ValueError for backwards compatibility. 1967 raise ValueError(e.message) 1968 1969 # Record the current Python stack trace as the creating stacktrace of this 1970 # TF_Operation. 1971 if extract_traceback: 1972 tf_stack.extract_stack_for_op(c_op, stacklevel=3) 1973 1974 return c_op 1975 1976 1977@tf_export("Operation") 1978class Operation(object): 1979 """Represents a graph node that performs computation on tensors. 1980 1981 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor` 1982 objects as input, and produces zero or more `Tensor` objects as output. 1983 Objects of type `Operation` are created by calling a Python op constructor 1984 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default` 1985 context manager. 1986 1987 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an 1988 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and 1989 produces `c` as output. 1990 1991 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be 1992 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for 1993 calling `tf.compat.v1.get_default_session().run(op)`. 1994 """ 1995 1996 def __init__(self, 1997 node_def, 1998 g, 1999 inputs=None, 2000 output_types=None, 2001 control_inputs=None, 2002 input_types=None, 2003 original_op=None, 2004 op_def=None): 2005 r"""Creates an `Operation`. 2006 2007 NOTE: This constructor validates the name of the `Operation` (passed 2008 as `node_def.name`). Valid `Operation` names match the following 2009 regular expression: 2010 2011 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 2012 2013 Args: 2014 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for 2015 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and 2016 `device`. The `input` attribute is irrelevant here as it will be 2017 computed when generating the model. 2018 g: `Graph`. The parent graph. 2019 inputs: list of `Tensor` objects. The inputs to this `Operation`. 2020 output_types: list of `DType` objects. List of the types of the `Tensors` 2021 computed by this operation. The length of this list indicates the 2022 number of output endpoints of the `Operation`. 2023 control_inputs: list of operations or tensors from which to have a control 2024 dependency. 2025 input_types: List of `DType` objects representing the types of the tensors 2026 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x 2027 in inputs]`. Operations that expect reference-typed inputs must specify 2028 these explicitly. 2029 original_op: Optional. Used to associate the new `Operation` with an 2030 existing `Operation` (for example, a replica with the op that was 2031 replicated). 2032 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type 2033 that this `Operation` represents. 2034 2035 Raises: 2036 TypeError: if control inputs are not Operations or Tensors, 2037 or if `node_def` is not a `NodeDef`, 2038 or if `g` is not a `Graph`, 2039 or if `inputs` are not tensors, 2040 or if `inputs` and `input_types` are incompatible. 2041 ValueError: if the `node_def` name is not valid. 2042 """ 2043 if not isinstance(g, Graph): 2044 raise TypeError(f"Argument g must be a Graph. " 2045 f"Received an instance of type {type(g)}") 2046 2047 # TODO(feyu): This message is redundant with the check below. We raise it 2048 # to help users to migrate. Remove this after 07/01/2022. 2049 if isinstance(node_def, pywrap_tf_session.TF_Operation): 2050 raise ValueError( 2051 "Calling Operation() with node_def of a TF_Operation is deprecated. " 2052 "Please switch to Operation.from_c_op.") 2053 2054 if not isinstance(node_def, node_def_pb2.NodeDef): 2055 raise TypeError(f"Argument node_def must be a NodeDef. " 2056 f"Received an instance of type: {type(node_def)}.") 2057 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 2058 raise ValueError( 2059 f"Cannot create a tensor proto whose content is larger than 2GB. " 2060 f"Size of tensor is {node_def.ByteSize()} bytes.") 2061 2062 # TODO(mdan): This does not belong here. Graph::AddNode should handle it. 2063 if not _VALID_OP_NAME_REGEX.match(node_def.name): 2064 raise ValueError( 2065 f"`{node_def.name}` is not a valid node name. " 2066 f"Accepted names conform to Regex /{_VALID_OP_NAME_REGEX}/") 2067 2068 # FIXME(b/225400189): output_types is unused. Consider remove it from 2069 # the argument list. 2070 del output_types 2071 2072 if inputs is None: 2073 inputs = [] 2074 elif not isinstance(inputs, list): 2075 raise TypeError(f"Argument inputs shall be a list of Tensors. " 2076 f"Received an instance of type {type(inputs)}") 2077 for a in inputs: 2078 if not isinstance(a, Tensor): 2079 raise TypeError(f"Items of argument inputs shall be Tensor. " 2080 f"Received an instance of type {type(a)}.") 2081 if input_types is None: 2082 input_types = [i.dtype.base_dtype for i in inputs] 2083 else: 2084 if not all( 2085 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)): 2086 raise TypeError("In op '%s', input types (%s) are not compatible " 2087 "with expected types (%s)" % 2088 (node_def.name, [i.dtype for i in inputs], input_types)) 2089 2090 # Build the list of control inputs. 2091 control_input_ops = [] 2092 if control_inputs: 2093 for c in control_inputs: 2094 control_op = None 2095 if isinstance(c, Operation): 2096 control_op = c 2097 elif isinstance(c, (Tensor, IndexedSlices)): 2098 control_op = c.op 2099 else: 2100 raise TypeError(f"Control input must be an Operation, " 2101 f"a Tensor, or IndexedSlices. " 2102 f"Received an instance of type {type(c)}.") 2103 control_input_ops.append(control_op) 2104 2105 # Initialize c_op from node_def and other inputs 2106 c_op = _create_c_op(g, node_def, inputs, control_input_ops, op_def=op_def) 2107 self._init_from_c_op(c_op=c_op, g=g) 2108 2109 self._original_op = original_op 2110 2111 # Post process for control flows. 2112 self._control_flow_post_processing(input_tensors=inputs) 2113 2114 # Removes this frame from the Python traceback. 2115 # We adjust stacklevel directly to avoid triggering serialization. 2116 self.traceback._stacklevel += 1 # pylint: disable=protected-access 2117 2118 @classmethod 2119 def _from_c_op(cls, c_op, g): 2120 """Create an Operation from a TF_Operation. 2121 2122 For internal use only: This is useful for creating Operation for ops 2123 indirectly created by C API methods, e.g. the ops created by 2124 TF_ImportGraphDef. 2125 2126 Args: 2127 c_op: a TF_Operation. 2128 g: A Graph. 2129 2130 Returns: 2131 an Operation object. 2132 """ 2133 self = object.__new__(cls) 2134 2135 self._init_from_c_op(c_op=c_op, g=g) # pylint: disable=protected-access 2136 return self 2137 2138 def _init_from_c_op(self, c_op, g): 2139 """Initializes Operation from a TF_Operation.""" 2140 2141 if not isinstance(g, Graph): 2142 raise TypeError(f"Operation initialization requires a Graph, " 2143 f"got {type(g)} for argument g.") 2144 2145 if not isinstance(c_op, pywrap_tf_session.TF_Operation): 2146 raise TypeError(f"Operation initialization requires a TF_Operation, " 2147 f"got {type(c_op)} for argument c_op.") 2148 2149 self._original_op = None 2150 2151 self._graph = g 2152 self._c_op = c_op 2153 2154 # This will be set by self.inputs. 2155 self._inputs_val = None 2156 2157 # List of _UserDevSpecs holding code location of device context manager 2158 # invocations and the users original argument to them. 2159 self._device_code_locations = None 2160 # Dict mapping op name to file and line information for op colocation 2161 # context managers. 2162 self._colocation_code_locations = None 2163 self._control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access 2164 2165 # Gradient function for this op. There are three ways to specify gradient 2166 # function, and first available gradient gets used, in the following order. 2167 # 1. self._gradient_function 2168 # 2. Gradient name registered by "_gradient_op_type" attribute. 2169 # 3. Gradient name registered by op.type. 2170 self._gradient_function = None 2171 2172 op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op)) # pylint: disable=protected-access 2173 2174 self._is_stateful = op_def.is_stateful 2175 2176 # Initialize self._outputs. 2177 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2178 self._outputs = [] 2179 for i in range(num_outputs): 2180 tf_output = c_api_util.tf_output(self._c_op, i) 2181 output_type = pywrap_tf_session.TF_OperationOutputType(tf_output) 2182 tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access 2183 self._outputs.append(tensor) 2184 2185 self._id_value = g._add_op(self, self.name) # pylint: disable=protected-access 2186 2187 def _control_flow_post_processing(self, input_tensors=None): 2188 """Add this op to its control flow context. 2189 2190 This may add new ops and change this op's inputs. self.inputs must be 2191 available before calling this method. 2192 2193 Args: 2194 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs 2195 of this op, which should be equivalent to `self.inputs`. Pass this 2196 argument to avoid evaluating `self.inputs` unnecessarily. 2197 """ 2198 if input_tensors is None: 2199 input_tensors = self.inputs 2200 for input_tensor in input_tensors: 2201 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 2202 if self._control_flow_context is not None: 2203 self._control_flow_context.AddOp(self) 2204 2205 def colocation_groups(self): 2206 """Returns the list of colocation groups of the op.""" 2207 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)] 2208 try: 2209 class_attr = self.get_attr("_class") 2210 except ValueError: 2211 # This op has no explicit colocation group, so it is itself its 2212 # own root of a colocation group. 2213 return default_colocation_group 2214 2215 attr_groups = [ 2216 class_name for class_name in class_attr 2217 if class_name.startswith(b"loc:@") 2218 ] 2219 2220 # If there are no colocation groups in the explicit _class field, 2221 # return the default colocation group. 2222 return attr_groups if attr_groups else default_colocation_group 2223 2224 def values(self): 2225 """DEPRECATED: Use outputs.""" 2226 return tuple(self.outputs) 2227 2228 def _get_control_flow_context(self): 2229 """Returns the control flow context of this op. 2230 2231 Returns: 2232 A context object. 2233 """ 2234 return self._control_flow_context 2235 2236 def _set_control_flow_context(self, ctx): 2237 """Sets the current control flow context of this op. 2238 2239 Args: 2240 ctx: a context object. 2241 """ 2242 self._control_flow_context = ctx 2243 2244 @property 2245 def name(self): 2246 """The full name of this operation.""" 2247 return pywrap_tf_session.TF_OperationName(self._c_op) 2248 2249 @property 2250 def _id(self): 2251 """The unique integer id of this operation.""" 2252 return self._id_value 2253 2254 @property 2255 def device(self): 2256 """The name of the device to which this op has been assigned, if any. 2257 2258 Returns: 2259 The string name of the device to which this op has been 2260 assigned, or an empty string if it has not been assigned to a 2261 device. 2262 """ 2263 return pywrap_tf_session.TF_OperationDevice(self._c_op) 2264 2265 @property 2266 def _device_assignments(self): 2267 """Code locations for device context managers active at op creation. 2268 2269 This property will return a list of traceable_stack.TraceableObject 2270 instances where .obj is a string representing the assigned device 2271 (or information about the function that would be applied to this op 2272 to compute the desired device) and the filename and lineno members 2273 record the location of the relevant device context manager. 2274 2275 For example, suppose file_a contained these lines: 2276 2277 file_a.py: 2278 15: with tf.device('/gpu:0'): 2279 16: node_b = tf.constant(4, name='NODE_B') 2280 2281 Then a TraceableObject t_obj representing the device context manager 2282 would have these member values: 2283 2284 t_obj.obj -> '/gpu:0' 2285 t_obj.filename = 'file_a.py' 2286 t_obj.lineno = 15 2287 2288 and node_b.op._device_assignments would return the list [t_obj]. 2289 2290 Returns: 2291 [str: traceable_stack.TraceableObject, ...] as per this method's 2292 description, above. 2293 """ 2294 return self._device_code_locations or [] 2295 2296 @property 2297 def _colocation_dict(self): 2298 """Code locations for colocation context managers active at op creation. 2299 2300 This property will return a dictionary for which the keys are nodes with 2301 which this Operation is colocated, and for which the values are 2302 traceable_stack.TraceableObject instances. The TraceableObject instances 2303 record the location of the relevant colocation context manager but have the 2304 "obj" field set to None to prevent leaking private data. 2305 2306 For example, suppose file_a contained these lines: 2307 2308 file_a.py: 2309 14: node_a = tf.constant(3, name='NODE_A') 2310 15: with tf.compat.v1.colocate_with(node_a): 2311 16: node_b = tf.constant(4, name='NODE_B') 2312 2313 Then a TraceableObject t_obj representing the colocation context manager 2314 would have these member values: 2315 2316 t_obj.obj -> None 2317 t_obj.filename = 'file_a.py' 2318 t_obj.lineno = 15 2319 2320 and node_b.op._colocation_dict would return the dictionary 2321 2322 { 'NODE_A': t_obj } 2323 2324 Returns: 2325 {str: traceable_stack.TraceableObject} as per this method's description, 2326 above. 2327 """ 2328 locations_dict = self._colocation_code_locations or {} 2329 return locations_dict.copy() 2330 2331 @property 2332 def _output_types(self): 2333 """List this operation's output types. 2334 2335 Returns: 2336 List of the types of the Tensors computed by this operation. 2337 Each element in the list is an integer whose value is one of 2338 the TF_DataType enums defined in pywrap_tf_session.h 2339 The length of this list indicates the number of output endpoints 2340 of the operation. 2341 """ 2342 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2343 output_types = [ 2344 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i))) 2345 for i in range(num_outputs) 2346 ] 2347 2348 return output_types 2349 2350 def _tf_output(self, output_idx): 2351 """Create and return a new TF_Output for output_idx'th output of this op.""" 2352 tf_output = pywrap_tf_session.TF_Output() 2353 tf_output.oper = self._c_op 2354 tf_output.index = output_idx 2355 return tf_output 2356 2357 def _tf_input(self, input_idx): 2358 """Create and return a new TF_Input for input_idx'th input of this op.""" 2359 tf_input = pywrap_tf_session.TF_Input() 2360 tf_input.oper = self._c_op 2361 tf_input.index = input_idx 2362 return tf_input 2363 2364 def _set_device(self, device): # pylint: disable=redefined-outer-name 2365 """Set the device of this operation. 2366 2367 Args: 2368 device: string or device.. The device to set. 2369 """ 2370 self._set_device_from_string(compat.as_str(_device_string(device))) 2371 2372 def _set_device_from_string(self, device_str): 2373 """Fast path to set device if the type is known to be a string. 2374 2375 This function is called frequently enough during graph construction that 2376 there are non-trivial performance gains if the caller can guarantee that 2377 the specified device is already a string. 2378 2379 Args: 2380 device_str: A string specifying where to place this op. 2381 """ 2382 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2383 pywrap_tf_session.SetRequestedDevice( 2384 c_graph, 2385 self._c_op, # pylint: disable=protected-access 2386 device_str) 2387 2388 def _update_input(self, index, tensor): 2389 """Update the input to this operation at the given index. 2390 2391 NOTE: This is for TF internal use only. Please don't use it. 2392 2393 Args: 2394 index: the index of the input to update. 2395 tensor: the Tensor to be used as the input at the given index. 2396 2397 Raises: 2398 TypeError: if tensor is not a Tensor, 2399 or if input tensor type is not convertible to dtype. 2400 ValueError: if the Tensor is from a different graph. 2401 """ 2402 if not isinstance(tensor, Tensor): 2403 raise TypeError("tensor must be a Tensor: %s" % tensor) 2404 _assert_same_graph(self, tensor) 2405 2406 # Reset cached inputs. 2407 self._inputs_val = None 2408 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2409 pywrap_tf_session.UpdateEdge( 2410 c_graph, 2411 tensor._as_tf_output(), # pylint: disable=protected-access 2412 self._tf_input(index)) 2413 2414 def _add_while_inputs(self, tensors): 2415 """See AddWhileInputHack in python_api.h. 2416 2417 NOTE: This is for TF internal use only. Please don't use it. 2418 2419 Args: 2420 tensors: list of Tensors 2421 2422 Raises: 2423 TypeError: if tensor is not a Tensor, 2424 or if input tensor type is not convertible to dtype. 2425 ValueError: if the Tensor is from a different graph. 2426 """ 2427 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2428 for tensor in tensors: 2429 if not isinstance(tensor, Tensor): 2430 raise TypeError("tensor must be a Tensor: %s" % tensor) 2431 _assert_same_graph(self, tensor) 2432 2433 # Reset cached inputs. 2434 self._inputs_val = None 2435 pywrap_tf_session.AddWhileInputHack( 2436 c_graph, # pylint: disable=protected-access 2437 tensor._as_tf_output(), # pylint: disable=protected-access 2438 self._c_op) 2439 2440 def _add_control_inputs(self, ops): 2441 """Add a list of new control inputs to this operation. 2442 2443 Args: 2444 ops: the list of Operations to add as control input. 2445 2446 Raises: 2447 TypeError: if ops is not a list of Operations. 2448 ValueError: if any op in ops is from a different graph. 2449 """ 2450 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2451 for op in ops: 2452 if not isinstance(op, Operation): 2453 raise TypeError("op must be an Operation: %s" % op) 2454 pywrap_tf_session.AddControlInput( 2455 c_graph, 2456 self._c_op, # pylint: disable=protected-access 2457 op._c_op) # pylint: disable=protected-access 2458 2459 def _add_control_input(self, op): 2460 """Add a new control input to this operation. 2461 2462 Args: 2463 op: the Operation to add as control input. 2464 2465 Raises: 2466 TypeError: if op is not an Operation. 2467 ValueError: if op is from a different graph. 2468 """ 2469 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2470 if not isinstance(op, Operation): 2471 raise TypeError("op must be an Operation: %s" % op) 2472 pywrap_tf_session.AddControlInput( 2473 c_graph, 2474 self._c_op, # pylint: disable=protected-access 2475 op._c_op) # pylint: disable=protected-access 2476 2477 def _remove_all_control_inputs(self): 2478 """Removes any control inputs to this operation.""" 2479 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2480 pywrap_tf_session.RemoveAllControlInputs(c_graph, self._c_op) # pylint: disable=protected-access 2481 2482 def _add_outputs(self, types, shapes): 2483 """Adds new Tensors to self.outputs. 2484 2485 Note: this is generally unsafe to use. This is used in certain situations in 2486 conjunction with _set_type_list_attr. 2487 2488 Args: 2489 types: list of DTypes 2490 shapes: list of TensorShapes 2491 """ 2492 assert len(types) == len(shapes) 2493 orig_num_outputs = len(self.outputs) 2494 for i in range(len(types)): 2495 t = Tensor(self, orig_num_outputs + i, types[i]) 2496 self._outputs.append(t) 2497 t.set_shape(shapes[i]) 2498 2499 def __str__(self): 2500 return str(self.node_def) 2501 2502 def __repr__(self): 2503 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 2504 2505 def __tf_tensor__(self, dtype=None, name=None): 2506 """Raises a helpful error.""" 2507 raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) 2508 2509 @property 2510 def outputs(self): 2511 """The list of `Tensor` objects representing the outputs of this op.""" 2512 return self._outputs 2513 2514 @property 2515 def inputs(self): 2516 """The sequence of `Tensor` objects representing the data inputs of this op.""" 2517 if self._inputs_val is None: 2518 # pylint: disable=protected-access 2519 self._inputs_val = tuple( 2520 self.graph._get_tensor_by_tf_output(i) 2521 for i in pywrap_tf_session.GetOperationInputs(self._c_op)) 2522 # pylint: enable=protected-access 2523 return self._inputs_val 2524 2525 @property 2526 def _input_types(self): 2527 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op) 2528 input_types = [ 2529 dtypes.as_dtype( 2530 pywrap_tf_session.TF_OperationInputType(self._tf_input(i))) 2531 for i in range(num_inputs) 2532 ] 2533 return input_types 2534 2535 @property 2536 def control_inputs(self): 2537 """The `Operation` objects on which this op has a control dependency. 2538 2539 Before this op is executed, TensorFlow will ensure that the 2540 operations in `self.control_inputs` have finished executing. This 2541 mechanism can be used to run ops sequentially for performance 2542 reasons, or to ensure that the side effects of an op are observed 2543 in the correct order. 2544 2545 Returns: 2546 A list of `Operation` objects. 2547 2548 """ 2549 control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper( 2550 self._c_op) 2551 # pylint: disable=protected-access 2552 return [ 2553 self.graph._get_operation_by_name_unsafe( 2554 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2555 ] 2556 # pylint: enable=protected-access 2557 2558 @property 2559 def _control_outputs(self): 2560 """The `Operation` objects which have a control dependency on this op. 2561 2562 Before any of the ops in self._control_outputs can execute tensorflow will 2563 ensure self has finished executing. 2564 2565 Returns: 2566 A list of `Operation` objects. 2567 2568 """ 2569 control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper( 2570 self._c_op) 2571 # pylint: disable=protected-access 2572 return [ 2573 self.graph._get_operation_by_name_unsafe( 2574 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2575 ] 2576 # pylint: enable=protected-access 2577 2578 @property 2579 def type(self): 2580 """The type of the op (e.g. `"MatMul"`).""" 2581 return pywrap_tf_session.TF_OperationOpType(self._c_op) 2582 2583 @property 2584 def graph(self): 2585 """The `Graph` that contains this operation.""" 2586 return self._graph 2587 2588 @property 2589 def node_def(self): 2590 # pylint: disable=line-too-long 2591 """Returns the `NodeDef` representation of this operation. 2592 2593 Returns: 2594 A 2595 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2596 protocol buffer. 2597 """ 2598 # pylint: enable=line-too-long 2599 with c_api_util.tf_buffer() as buf: 2600 pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf) 2601 data = pywrap_tf_session.TF_GetBuffer(buf) 2602 node_def = node_def_pb2.NodeDef() 2603 node_def.ParseFromString(compat.as_bytes(data)) 2604 return node_def 2605 2606 @property 2607 def op_def(self): 2608 # pylint: disable=line-too-long 2609 """Returns the `OpDef` proto that represents the type of this op. 2610 2611 Returns: 2612 An 2613 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2614 protocol buffer. 2615 """ 2616 # pylint: enable=line-too-long 2617 return self._graph._get_op_def(self.type) 2618 2619 @property 2620 def traceback(self): 2621 """Returns the call stack from when this operation was constructed.""" 2622 # FIXME(b/225423591): This object contains a dangling reference if _c_op 2623 # goes out of scope. 2624 return pywrap_tf_session.TF_OperationGetStackTrace(self._c_op) 2625 2626 def _set_attr(self, attr_name, attr_value): 2627 """Private method used to set an attribute in the node_def.""" 2628 buf = pywrap_tf_session.TF_NewBufferFromString( 2629 compat.as_bytes(attr_value.SerializeToString())) 2630 try: 2631 self._set_attr_with_buf(attr_name, buf) 2632 finally: 2633 pywrap_tf_session.TF_DeleteBuffer(buf) 2634 2635 def _set_attr_with_buf(self, attr_name, attr_buf): 2636 """Set an attr in the node_def with a pre-allocated buffer.""" 2637 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2638 # pylint: disable=protected-access 2639 pywrap_tf_session.SetAttr(c_graph, self._c_op, attr_name, attr_buf) 2640 # pylint: enable=protected-access 2641 2642 def _set_func_attr(self, attr_name, func_name): 2643 """Private method used to set a function attribute in the node_def.""" 2644 func = attr_value_pb2.NameAttrList(name=func_name) 2645 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 2646 2647 def _set_func_list_attr(self, attr_name, func_names): 2648 """Private method used to set a list(function) attribute in the node_def.""" 2649 funcs = [attr_value_pb2.NameAttrList(name=func_name) 2650 for func_name in func_names] 2651 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) 2652 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) 2653 2654 def _set_type_list_attr(self, attr_name, types): 2655 """Private method used to set a list(type) attribute in the node_def.""" 2656 if not types: 2657 return 2658 if isinstance(types[0], dtypes.DType): 2659 types = [dt.as_datatype_enum for dt in types] 2660 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 2661 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 2662 2663 def _set_shape_list_attr(self, attr_name, shapes): 2664 """Private method used to set a list(shape) attribute in the node_def.""" 2665 shapes = [s.as_proto() for s in shapes] 2666 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 2667 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 2668 2669 def _clear_attr(self, attr_name): 2670 """Private method used to clear an attribute in the node_def.""" 2671 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2672 # pylint: disable=protected-access 2673 pywrap_tf_session.ClearAttr(c_graph, self._c_op, attr_name) 2674 # pylint: enable=protected-access 2675 2676 def get_attr(self, name): 2677 """Returns the value of the attr of this op with the given `name`. 2678 2679 Args: 2680 name: The name of the attr to fetch. 2681 2682 Returns: 2683 The value of the attr, as a Python object. 2684 2685 Raises: 2686 ValueError: If this op does not have an attr with the given `name`. 2687 """ 2688 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 2689 try: 2690 with c_api_util.tf_buffer() as buf: 2691 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf) 2692 data = pywrap_tf_session.TF_GetBuffer(buf) 2693 except errors.InvalidArgumentError as e: 2694 # Convert to ValueError for backwards compatibility. 2695 raise ValueError(e.message) 2696 x = attr_value_pb2.AttrValue() 2697 x.ParseFromString(data) 2698 2699 oneof_value = x.WhichOneof("value") 2700 if oneof_value is None: 2701 return [] 2702 if oneof_value == "list": 2703 for f in fields: 2704 if getattr(x.list, f): 2705 if f == "type": 2706 return [dtypes.as_dtype(t) for t in x.list.type] 2707 else: 2708 return list(getattr(x.list, f)) 2709 return [] 2710 if oneof_value == "type": 2711 return dtypes.as_dtype(x.type) 2712 assert oneof_value in fields, "Unsupported field type in " + str(x) 2713 return getattr(x, oneof_value) 2714 2715 def _get_attr_type(self, name): 2716 """Returns the `DType` value of the attr of this op with the given `name`.""" 2717 try: 2718 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name) 2719 return _DTYPES_INTERN_TABLE[dtype_enum] 2720 except errors.InvalidArgumentError as e: 2721 # Convert to ValueError for backwards compatibility. 2722 raise ValueError(e.message) 2723 2724 def _get_attr_bool(self, name): 2725 """Returns the `bool` value of the attr of this op with the given `name`.""" 2726 try: 2727 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name) 2728 except errors.InvalidArgumentError as e: 2729 # Convert to ValueError for backwards compatibility. 2730 raise ValueError(e.message) 2731 2732 def _get_attr_int(self, name): 2733 """Returns the `int` value of the attr of this op with the given `name`.""" 2734 try: 2735 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name) 2736 except errors.InvalidArgumentError as e: 2737 # Convert to ValueError for backwards compatibility. 2738 raise ValueError(e.message) 2739 2740 def experimental_set_type(self, type_proto): 2741 """Sets the corresponding node's `experimental_type` field. 2742 2743 See the description of `NodeDef.experimental_type` for more info. 2744 2745 Args: 2746 type_proto: A FullTypeDef proto message. The root type_if of this object 2747 must be `TFT_PRODUCT`, even for ops which only have a singlre return 2748 value. 2749 """ 2750 with self._graph._c_graph.get() as c_graph: # pylint: disable=protected-access 2751 if (type_proto.type_id 2752 not in (full_type_pb2.TFT_UNSET, full_type_pb2.TFT_PRODUCT)): 2753 raise ValueError("error setting the type of ", self.name, 2754 ": expected TFT_UNSET or TFT_PRODUCT, got ", 2755 type_proto.type_id) 2756 pywrap_tf_session.SetFullType(c_graph, self._c_op, 2757 type_proto.SerializeToString()) # pylint:disable=protected-access 2758 2759 def run(self, feed_dict=None, session=None): 2760 """Runs this operation in a `Session`. 2761 2762 Calling this method will execute all preceding operations that 2763 produce the inputs needed for this operation. 2764 2765 *N.B.* Before invoking `Operation.run()`, its graph must have been 2766 launched in a session, and either a default session must be 2767 available, or `session` must be specified explicitly. 2768 2769 Args: 2770 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 2771 `tf.Session.run` for a description of the valid feed values. 2772 session: (Optional.) The `Session` to be used to run to this operation. If 2773 none, the default session will be used. 2774 """ 2775 _run_using_default_session(self, feed_dict, self.graph, session) 2776 2777# TODO(b/185395742): Clean up usages of _gradient_registry 2778gradient_registry = _gradient_registry = registry.Registry("gradient") 2779 2780 2781@tf_export("RegisterGradient") 2782class RegisterGradient(object): 2783 """A decorator for registering the gradient function for an op type. 2784 2785 This decorator is only used when defining a new op type. For an op 2786 with `m` inputs and `n` outputs, the gradient function is a function 2787 that takes the original `Operation` and `n` `Tensor` objects 2788 (representing the gradients with respect to each output of the op), 2789 and returns `m` `Tensor` objects (representing the partial gradients 2790 with respect to each input of the op). 2791 2792 For example, assuming that operations of type `"Sub"` take two 2793 inputs `x` and `y`, and return a single output `x - y`, the 2794 following gradient function would be registered: 2795 2796 ```python 2797 @tf.RegisterGradient("Sub") 2798 def _sub_grad(unused_op, grad): 2799 return grad, tf.negative(grad) 2800 ``` 2801 2802 The decorator argument `op_type` is the string type of an 2803 operation. This corresponds to the `OpDef.name` field for the proto 2804 that defines the operation. 2805 """ 2806 2807 __slots__ = ["_op_type"] 2808 2809 def __init__(self, op_type): 2810 """Creates a new decorator with `op_type` as the Operation type. 2811 2812 Args: 2813 op_type: The string type of an operation. This corresponds to the 2814 `OpDef.name` field for the proto that defines the operation. 2815 2816 Raises: 2817 TypeError: If `op_type` is not string. 2818 """ 2819 if not isinstance(op_type, str): 2820 raise TypeError("op_type must be a string") 2821 self._op_type = op_type 2822 2823 def __call__(self, f): 2824 """Registers the function `f` as gradient function for `op_type`.""" 2825 gradient_registry.register(f, self._op_type) 2826 return f 2827 2828 2829@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 2830@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 2831def no_gradient(op_type): 2832 """Specifies that ops of type `op_type` is not differentiable. 2833 2834 This function should *not* be used for operations that have a 2835 well-defined gradient that is not yet implemented. 2836 2837 This function is only used when defining a new op type. It may be 2838 used for ops such as `tf.size()` that are not differentiable. For 2839 example: 2840 2841 ```python 2842 tf.no_gradient("Size") 2843 ``` 2844 2845 The gradient computed for 'op_type' will then propagate zeros. 2846 2847 For ops that have a well-defined gradient but are not yet implemented, 2848 no declaration should be made, and an error *must* be thrown if 2849 an attempt to request its gradient is made. 2850 2851 Args: 2852 op_type: The string type of an operation. This corresponds to the 2853 `OpDef.name` field for the proto that defines the operation. 2854 2855 Raises: 2856 TypeError: If `op_type` is not a string. 2857 2858 """ 2859 if not isinstance(op_type, str): 2860 raise TypeError("op_type must be a string") 2861 gradient_registry.register(None, op_type) 2862 2863 2864# Aliases for the old names, will be eventually removed. 2865NoGradient = no_gradient 2866NotDifferentiable = no_gradient 2867 2868 2869def get_gradient_function(op): 2870 """Returns the function that computes gradients for "op".""" 2871 if not op.inputs: 2872 return None 2873 2874 gradient_function = op._gradient_function # pylint: disable=protected-access 2875 if gradient_function: 2876 return gradient_function 2877 2878 try: 2879 op_type = op.get_attr("_gradient_op_type") 2880 except ValueError: 2881 op_type = op.type 2882 return gradient_registry.lookup(op_type) 2883 2884 2885def set_shape_and_handle_data_for_outputs(_): 2886 """No op. TODO(b/74620627): Remove this.""" 2887 pass 2888 2889 2890class OpStats(object): 2891 """A holder for statistics about an operator. 2892 2893 This class holds information about the resource requirements for an op, 2894 including the size of its weight parameters on-disk and how many FLOPS it 2895 requires to execute forward inference. 2896 2897 If you define a new operation, you can create a function that will return a 2898 set of information about its usage of the CPU and disk space when serialized. 2899 The function itself takes a Graph object that's been set up so you can call 2900 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2901 argument. 2902 2903 """ 2904 2905 __slots__ = ["_statistic_type", "_value"] 2906 2907 def __init__(self, statistic_type, value=None): 2908 """Sets up the initial placeholders for the statistics.""" 2909 self.statistic_type = statistic_type 2910 self.value = value 2911 2912 @property 2913 def statistic_type(self): 2914 return self._statistic_type 2915 2916 @statistic_type.setter 2917 def statistic_type(self, statistic_type): 2918 self._statistic_type = statistic_type 2919 2920 @property 2921 def value(self): 2922 return self._value 2923 2924 @value.setter 2925 def value(self, value): 2926 self._value = value 2927 2928 def __iadd__(self, other): 2929 if other.statistic_type != self.statistic_type: 2930 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2931 (self.statistic_type, other.statistic_type)) 2932 if self.value is None: 2933 self.value = other.value 2934 elif other.value is not None: 2935 self._value += other.value 2936 return self 2937 2938 2939_stats_registry = registry.Registry("statistical functions") 2940 2941 2942class RegisterStatistics(object): 2943 """A decorator for registering the statistics function for an op type. 2944 2945 This decorator can be defined for an op type so that it gives a 2946 report on the resources used by an instance of an operator, in the 2947 form of an OpStats object. 2948 2949 Well-known types of statistics include these so far: 2950 2951 - flops: When running a graph, the bulk of the computation happens doing 2952 numerical calculations like matrix multiplications. This type allows a node 2953 to return how many floating-point operations it takes to complete. The 2954 total number of FLOPs for a graph is a good guide to its expected latency. 2955 2956 You can add your own statistics just by picking a new type string, registering 2957 functions for the ops you care about, and then calling get_stats_for_node_def. 2958 2959 If a statistic for an op is registered multiple times, a KeyError will be 2960 raised. 2961 2962 Since the statistics is counted on a per-op basis. It is not suitable for 2963 model parameters (capacity), which is expected to be counted only once, even 2964 if it is shared by multiple ops. (e.g. RNN) 2965 2966 For example, you can define a new metric called doohickey for a Foo operation 2967 by placing this in your code: 2968 2969 ```python 2970 @ops.RegisterStatistics("Foo", "doohickey") 2971 def _calc_foo_bojangles(unused_graph, unused_node_def): 2972 return ops.OpStats("doohickey", 20) 2973 ``` 2974 2975 Then in client code you can retrieve the value by making this call: 2976 2977 ```python 2978 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2979 ``` 2980 2981 If the NodeDef is for an op with a registered doohickey function, you'll get 2982 back the calculated amount in doohickey.value, or None if it's not defined. 2983 2984 """ 2985 2986 __slots__ = ["_op_type", "_statistic_type"] 2987 2988 def __init__(self, op_type, statistic_type): 2989 """Saves the `op_type` as the `Operation` type.""" 2990 if not isinstance(op_type, str): 2991 raise TypeError("op_type must be a string.") 2992 if "," in op_type: 2993 raise TypeError("op_type must not contain a comma.") 2994 self._op_type = op_type 2995 if not isinstance(statistic_type, str): 2996 raise TypeError("statistic_type must be a string.") 2997 if "," in statistic_type: 2998 raise TypeError("statistic_type must not contain a comma.") 2999 self._statistic_type = statistic_type 3000 3001 def __call__(self, f): 3002 """Registers "f" as the statistics function for "op_type".""" 3003 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 3004 return f 3005 3006 3007def get_stats_for_node_def(graph, node, statistic_type): 3008 """Looks up the node's statistics function in the registry and calls it. 3009 3010 This function takes a Graph object and a NodeDef from a GraphDef, and if 3011 there's an associated statistics method, calls it and returns a result. If no 3012 function has been registered for the particular node type, it returns an empty 3013 statistics object. 3014 3015 Args: 3016 graph: A Graph object that's been set up with the node's graph. 3017 node: A NodeDef describing the operator. 3018 statistic_type: A string identifying the statistic we're interested in. 3019 3020 Returns: 3021 An OpStats object containing information about resource usage. 3022 """ 3023 3024 try: 3025 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 3026 result = stats_func(graph, node) 3027 except LookupError: 3028 result = OpStats(statistic_type) 3029 return result 3030 3031 3032def name_from_scope_name(name): 3033 """Returns the name of an op given the name of its scope. 3034 3035 Args: 3036 name: the name of the scope. 3037 3038 Returns: 3039 the name of the op (equal to scope name minus any trailing slash). 3040 """ 3041 return name[:-1] if (name and name[-1] == "/") else name 3042 3043 3044_MUTATION_LOCK_GROUP = 0 3045_SESSION_RUN_LOCK_GROUP = 1 3046 3047 3048@tf_contextlib.contextmanager 3049def resource_creator_scope(resource_type, resource_creator): 3050 with get_default_graph()._resource_creator_scope(resource_type, # pylint: disable=protected-access 3051 resource_creator): 3052 yield 3053 3054 3055@tf_export("Graph") 3056class Graph(object): 3057 """A TensorFlow computation, represented as a dataflow graph. 3058 3059 Graphs are used by `tf.function`s to represent the function's computations. 3060 Each graph contains a set of `tf.Operation` objects, which represent units of 3061 computation; and `tf.Tensor` objects, which represent the units of data that 3062 flow between operations. 3063 3064 ### Using graphs directly (deprecated) 3065 3066 A `tf.Graph` can be constructed and used directly without a `tf.function`, as 3067 was required in TensorFlow 1, but this is deprecated and it is recommended to 3068 use a `tf.function` instead. If a graph is directly used, other deprecated 3069 TensorFlow 1 classes are also required to execute the graph, such as a 3070 `tf.compat.v1.Session`. 3071 3072 A default graph can be registered with the `tf.Graph.as_default` context 3073 manager. Then, operations will be added to the graph instead of being executed 3074 eagerly. For example: 3075 3076 ```python 3077 g = tf.Graph() 3078 with g.as_default(): 3079 # Define operations and tensors in `g`. 3080 c = tf.constant(30.0) 3081 assert c.graph is g 3082 ``` 3083 3084 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph. 3085 3086 Important note: This class *is not* thread-safe for graph construction. All 3087 operations should be created from a single thread, or external 3088 synchronization must be provided. Unless otherwise specified, all methods 3089 are not thread-safe. 3090 3091 A `Graph` instance supports an arbitrary number of "collections" 3092 that are identified by name. For convenience when building a large 3093 graph, collections can store groups of related objects: for 3094 example, the `tf.Variable` uses a collection (named 3095 `tf.GraphKeys.GLOBAL_VARIABLES`) for 3096 all variables that are created during the construction of a graph. The caller 3097 may define additional collections by specifying a new name. 3098 """ 3099 3100 def __init__(self): 3101 """Creates a new, empty Graph.""" 3102 # Protects core state that can be returned via public accessors. 3103 # Thread-safety is provided on a best-effort basis to support buggy 3104 # programs, and is not guaranteed by the public `tf.Graph` API. 3105 # 3106 # NOTE(mrry): This does not protect the various stacks. A warning will 3107 # be reported if these are used from multiple threads 3108 self._lock = threading.RLock() 3109 # The group lock synchronizes Session.run calls with methods that create 3110 # and mutate ops (e.g. Graph.create_op()). This synchronization is 3111 # necessary because it's illegal to modify an operation after it's been run. 3112 # The group lock allows any number of threads to mutate ops at the same time 3113 # but if any modification is going on, all Session.run calls have to wait. 3114 # Similarly, if one or more Session.run calls are going on, all mutate ops 3115 # have to wait until all Session.run calls have finished. 3116 self._group_lock = lock_util.GroupLock(num_groups=2) 3117 self._nodes_by_id = {} # GUARDED_BY(self._lock) 3118 self._next_id_counter = 0 # GUARDED_BY(self._lock) 3119 self._nodes_by_name = {} # GUARDED_BY(self._lock) 3120 self._version = 0 # GUARDED_BY(self._lock) 3121 # Maps a name used in the graph to the next id to use for that name. 3122 self._names_in_use = {} 3123 self._stack_state_is_thread_local = False 3124 self._thread_local = threading.local() 3125 # Functions that will be applied to choose a device if none is specified. 3126 # In TF2.x or after switch_to_thread_local(), 3127 # self._thread_local._device_function_stack is used instead. 3128 self._graph_device_function_stack = traceable_stack.TraceableStack() 3129 # Default original_op applied to new ops. 3130 self._default_original_op = None 3131 # Current control flow context. It could be either CondContext or 3132 # WhileContext defined in ops/control_flow_ops.py 3133 self._control_flow_context = None 3134 # A new node will depend of the union of all of the nodes in the stack. 3135 # In TF2.x or after switch_to_thread_local(), 3136 # self._thread_local._control_dependencies_stack is used instead. 3137 self._graph_control_dependencies_stack = [] 3138 # Arbitrary collections of objects. 3139 self._collections = {} 3140 # The graph-level random seed 3141 self._seed = None 3142 # A dictionary of attributes that should be applied to all ops. 3143 self._attr_scope_map = {} 3144 # A map from op type to the kernel label that should be used. 3145 self._op_to_kernel_label_map = {} 3146 # A map from op type to an alternative op type that should be used when 3147 # computing gradients. 3148 self._gradient_override_map = {} 3149 # A map from op type to a gradient function that should be used instead. 3150 self._gradient_function_map = {} 3151 # True if the graph is considered "finalized". In that case no 3152 # new operations can be added. 3153 self._finalized = False 3154 # Functions defined in the graph 3155 self._functions = collections.OrderedDict() 3156 # Default GraphDef versions 3157 self._graph_def_versions = versions_pb2.VersionDef( 3158 producer=versions.GRAPH_DEF_VERSION, 3159 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 3160 self._building_function = False 3161 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 3162 # self._thread_local._colocation_stack is used instead. 3163 self._graph_colocation_stack = traceable_stack.TraceableStack() 3164 # Set of tensors that are dangerous to feed! 3165 self._unfeedable_tensors = object_identity.ObjectIdentitySet() 3166 # Set of operations that are dangerous to fetch! 3167 self._unfetchable_ops = set() 3168 # A map of tensor handle placeholder to tensor dtype. 3169 self._handle_feeders = {} 3170 # A map from tensor handle to its read op. 3171 self._handle_readers = {} 3172 # A map from tensor handle to its move op. 3173 self._handle_movers = {} 3174 # A map from tensor handle to its delete op. 3175 self._handle_deleters = {} 3176 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 3177 # will be shared when defining function graphs, for example, so optimizers 3178 # being called inside function definitions behave as if they were seeing the 3179 # actual outside graph). 3180 self._graph_key = "graph-key-%d/" % (uid(),) 3181 # A string with the last reduction method passed to 3182 # losses.compute_weighted_loss(), or None. This is required only for 3183 # backward compatibility with Estimator and optimizer V1 use cases. 3184 self._last_loss_reduction = None 3185 # Flag that is used to indicate whether loss has been scaled by optimizer. 3186 # If this flag has been set, then estimator uses it to scale losss back 3187 # before reporting. This is required only for backward compatibility with 3188 # Estimator and optimizer V1 use cases. 3189 self._is_loss_scaled_by_optimizer = False 3190 self._container = "" 3191 3192 # The current AutomaticControlDependencies context manager. 3193 self.experimental_acd_manager = None 3194 # Set to True if this graph is being built in an 3195 # AutomaticControlDependencies context. 3196 # Deprecated: use acd_manager instead. 3197 self._add_control_dependencies = False 3198 3199 # Cache for OpDef protobufs retrieved via the C API. 3200 self._op_def_cache = {} 3201 # Cache for constant results of `broadcast_gradient_args()`. The keys are 3202 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the 3203 # values are tuples of reduction indices: (rx, ry). 3204 self._bcast_grad_args_cache = {} 3205 # Cache for constant results of `reduced_shape()`. The keys are pairs of 3206 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values 3207 # are pairs of tuples: (output_shape_kept_dims, tile_scaling). 3208 self._reduced_shape_cache = {} 3209 3210 # TODO(skyewm): fold as much of the above as possible into the C 3211 # implementation 3212 self._c_graph = c_api_util.ScopedTFGraph(self._graph_key) 3213 # The C API requires all ops to have shape functions. Disable this 3214 # requirement (many custom ops do not have shape functions, and we don't 3215 # want to break these existing cases). 3216 with self._c_graph.get() as c_graph: 3217 pywrap_tf_session.SetRequireShapeInferenceFns(c_graph, False) 3218 if tf2.enabled(): 3219 self.switch_to_thread_local() 3220 3221 # Note: this method is private because the API of tf.Graph() is public and 3222 # frozen, and this functionality is still not ready for public visibility. 3223 @tf_contextlib.contextmanager 3224 def _variable_creator_scope(self, creator, priority=100): 3225 """Scope which defines a variable creation function. 3226 3227 Args: 3228 creator: A callable taking `next_creator` and `kwargs`. See the 3229 `tf.variable_creator_scope` docstring. 3230 priority: Creators with a higher `priority` are called first. Within the 3231 same priority, creators are called inner-to-outer. 3232 3233 Yields: 3234 `_variable_creator_scope` is a context manager with a side effect, but 3235 doesn't return a value. 3236 3237 Raises: 3238 RuntimeError: If variable creator scopes are not properly nested. 3239 """ 3240 # This step keeps a reference to the existing stack, and it also initializes 3241 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3242 old = self._variable_creator_stack 3243 new = list(old) 3244 new.append((priority, creator)) 3245 # Sorting is stable, so we'll put higher-priority creators later in the list 3246 # but otherwise maintain registration order. 3247 new.sort(key=lambda item: item[0]) 3248 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access 3249 try: 3250 yield 3251 finally: 3252 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access 3253 raise RuntimeError( 3254 "Exiting variable_creator_scope without proper nesting.") 3255 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 3256 3257 # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope. 3258 # pylint: disable=protected-access 3259 @tf_contextlib.contextmanager 3260 def _resource_creator_scope(self, resource_type, creator): 3261 """Scope which defines a resource creation function used by some resource. 3262 3263 The resource should be a subclass of CapturableResource with a class method 3264 `cls._resource_type`, the output of which is what the `resource_type` 3265 argument should be. By default, `cls._resource_type` returns the class name, 3266 `cls.__name__`. Given a scope, creators being added with the same 3267 `resource_type` argument will be composed together to apply to all classes 3268 with this `_resource_type`. 3269 3270 3271 `creator` is expected to be a function with the following signature: 3272 3273 ``` 3274 def resource_creator(next_creator, *a, **kwargs) 3275 ``` 3276 3277 The creator is supposed to eventually call the next_creator to create an 3278 instance if it does want to create an instance and not call 3279 the class initialization method directly. This helps make creators 3280 composable. A creator may choose to create multiple instances, return 3281 already existing instances, or simply register that an instance was created 3282 and defer to the next creator in line. Creators can also modify keyword 3283 arguments seen by the next creators. 3284 3285 Valid keyword arguments in `kwargs` depends on the specific resource 3286 class. For StaticHashTable, this may be: 3287 * initializer: The table initializer to use. 3288 * default_value: The value to use if a key is missing in the table. 3289 * name: Optional name for the table, default to None. 3290 3291 3292 Args: 3293 resource_type: the output of the resource class's `_resource_type` method. 3294 creator: the passed creator for the resource. 3295 3296 Yields: 3297 A scope in which the creator is active 3298 3299 Raises: 3300 RuntimeError: If resource_creator_scope is existed without proper nesting. 3301 """ 3302 # This step keeps a reference to the existing stack, and it also initializes 3303 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3304 old = self._resource_creator_stack 3305 new = copy.deepcopy(old) 3306 if isinstance(resource_type, (list, tuple)): 3307 for r in resource_type: 3308 new[r].append(creator) 3309 else: 3310 new[resource_type].append(creator) 3311 self._thread_local._resource_creator_stack = new 3312 try: 3313 yield 3314 finally: 3315 if self._thread_local._resource_creator_stack is not new: 3316 raise RuntimeError( 3317 "Exiting resource_creator_scope without proper nesting.") 3318 self._thread_local._resource_creator_stack = old 3319 3320 @property 3321 def _resource_creator_stack(self): 3322 if not hasattr(self._thread_local, "_resource_creator_stack"): 3323 self._thread_local._resource_creator_stack = collections.defaultdict(list) 3324 return self._thread_local._resource_creator_stack 3325 3326 @_resource_creator_stack.setter 3327 def _resource_creator_stack(self, resource_creator_stack): 3328 self._thread_local._resource_creator_stack = resource_creator_stack 3329 # pylint: enable=protected-access 3330 3331 # Note: this method is private because the API of tf.Graph() is public and 3332 # frozen, and this functionality is still not ready for public visibility. 3333 @property 3334 def _variable_creator_stack(self): 3335 if not hasattr(self._thread_local, "_variable_creator_stack"): 3336 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 3337 3338 # This previously returned a copy of the stack instead of the stack itself, 3339 # to guard against accidental mutation. Consider, however, code that wants 3340 # to save and restore the variable creator stack: 3341 # def f(): 3342 # original_stack = graph._variable_creator_stack 3343 # graph._variable_creator_stack = new_stack 3344 # ... # Some code 3345 # graph._variable_creator_stack = original_stack 3346 # 3347 # And lets say you have some code that calls this function with some 3348 # variable_creator: 3349 # def g(): 3350 # with variable_scope.variable_creator_scope(creator): 3351 # f() 3352 # When exiting the variable creator scope, it would see a different stack 3353 # object than it expected leading to a "Exiting variable_creator_scope 3354 # without proper nesting" error. 3355 return self._thread_local._variable_creator_stack # pylint: disable=protected-access 3356 3357 @_variable_creator_stack.setter 3358 def _variable_creator_stack(self, variable_creator_stack): 3359 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 3360 3361 def _check_not_finalized(self): 3362 """Check if the graph is finalized. 3363 3364 Raises: 3365 RuntimeError: If the graph finalized. 3366 """ 3367 if self._finalized: 3368 raise RuntimeError("Graph is finalized and cannot be modified.") 3369 3370 def _add_op(self, op, op_name): 3371 """Adds 'op' to the graph and returns the unique ID for the added Operation. 3372 3373 Args: 3374 op: the Operation to add. 3375 op_name: the name of the Operation. 3376 3377 Returns: 3378 An integer that is a unique ID for the added Operation. 3379 """ 3380 self._check_not_finalized() 3381 with self._lock: 3382 self._next_id_counter += 1 3383 op_id = self._next_id_counter 3384 self._nodes_by_id[op_id] = op 3385 self._nodes_by_name[op_name] = op 3386 self._version = max(self._version, op_id) 3387 return op_id 3388 3389 @property 3390 def version(self): 3391 """Returns a version number that increases as ops are added to the graph. 3392 3393 Note that this is unrelated to the 3394 `tf.Graph.graph_def_versions`. 3395 3396 Returns: 3397 An integer version that increases as ops are added to the graph. 3398 """ 3399 if self._finalized: 3400 return self._version 3401 3402 with self._lock: 3403 return self._version 3404 3405 @property 3406 def graph_def_versions(self): 3407 # pylint: disable=line-too-long 3408 """The GraphDef version information of this graph. 3409 3410 For details on the meaning of each version, see 3411 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 3412 3413 Returns: 3414 A `VersionDef`. 3415 """ 3416 # pylint: enable=line-too-long 3417 with c_api_util.tf_buffer() as buf: 3418 with self._c_graph.get() as c_graph: 3419 pywrap_tf_session.TF_GraphVersions(c_graph, buf) 3420 data = pywrap_tf_session.TF_GetBuffer(buf) 3421 version_def = versions_pb2.VersionDef() 3422 version_def.ParseFromString(compat.as_bytes(data)) 3423 return version_def 3424 3425 @property 3426 def seed(self): 3427 """The graph-level random seed of this graph.""" 3428 return self._seed 3429 3430 @seed.setter 3431 def seed(self, seed): 3432 self._seed = seed 3433 3434 @property 3435 def finalized(self): 3436 """True if this graph has been finalized.""" 3437 return self._finalized 3438 3439 def finalize(self): 3440 """Finalizes this graph, making it read-only. 3441 3442 After calling `g.finalize()`, no new operations can be added to 3443 `g`. This method is used to ensure that no operations are added 3444 to a graph when it is shared between multiple threads, for example 3445 when using a `tf.compat.v1.train.QueueRunner`. 3446 """ 3447 self._finalized = True 3448 3449 def _unsafe_unfinalize(self): 3450 """Opposite of `finalize`. 3451 3452 Internal interface. 3453 3454 NOTE: Unfinalizing a graph could have negative impact on performance, 3455 especially in a multi-threaded environment. Unfinalizing a graph 3456 when it is in use by a Session may lead to undefined behavior. Ensure 3457 that all sessions using a graph are closed before calling this method. 3458 """ 3459 self._finalized = False 3460 3461 def _get_control_flow_context(self): 3462 """Returns the current control flow context. 3463 3464 Returns: 3465 A context object. 3466 """ 3467 return self._control_flow_context 3468 3469 def _set_control_flow_context(self, ctx): 3470 """Sets the current control flow context. 3471 3472 Args: 3473 ctx: a context object. 3474 """ 3475 self._control_flow_context = ctx 3476 3477 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 3478 """If this graph contains functions, copy them to `graph_def`.""" 3479 bytesize = starting_bytesize 3480 for f in self._functions.values(): 3481 bytesize += f.definition.ByteSize() 3482 if bytesize >= (1 << 31) or bytesize < 0: 3483 raise ValueError("GraphDef cannot be larger than 2GB.") 3484 graph_def.library.function.extend([f.definition]) 3485 if f.grad_func_name: 3486 grad_def = function_pb2.GradientDef() 3487 grad_def.function_name = f.name 3488 grad_def.gradient_func = f.grad_func_name 3489 graph_def.library.gradient.extend([grad_def]) 3490 3491 def _as_graph_def(self, from_version=None, add_shapes=False): 3492 # pylint: disable=line-too-long 3493 """Returns a serialized `GraphDef` representation of this graph. 3494 3495 The serialized `GraphDef` can be imported into another `Graph` 3496 (using `tf.import_graph_def`) or used with the 3497 [C++ Session API](https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/r0.10/tensorflow/g3doc/api_docs/cc/index.md). 3498 3499 This method is thread-safe. 3500 3501 Args: 3502 from_version: Optional. If this is set, returns a `GraphDef` containing 3503 only the nodes that were added to this graph since its `version` 3504 property had the given value. 3505 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3506 the inferred shapes of each of its outputs. 3507 3508 Returns: 3509 A tuple containing a 3510 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3511 protocol buffer, and the version of the graph to which that 3512 `GraphDef` corresponds. 3513 3514 Raises: 3515 ValueError: If the `graph_def` would be too large. 3516 3517 """ 3518 # pylint: enable=line-too-long 3519 with self._lock: 3520 with c_api_util.tf_buffer() as buf: 3521 with self._c_graph.get() as c_graph: 3522 pywrap_tf_session.TF_GraphToGraphDef(c_graph, buf) 3523 data = pywrap_tf_session.TF_GetBuffer(buf) 3524 graph = graph_pb2.GraphDef() 3525 graph.ParseFromString(compat.as_bytes(data)) 3526 # Strip the experimental library field iff it's empty. 3527 if not graph.library.function: 3528 graph.ClearField("library") 3529 3530 if add_shapes: 3531 for node in graph.node: 3532 op = self._nodes_by_name[node.name] 3533 if op.outputs: 3534 node.attr["_output_shapes"].list.shape.extend( 3535 [output.get_shape().as_proto() for output in op.outputs]) 3536 for function_def in graph.library.function: 3537 defined_function = self._functions[function_def.signature.name] 3538 try: 3539 func_graph = defined_function.graph 3540 except AttributeError: 3541 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction 3542 # does. Both rely on ops.py, so we can't really isinstance check 3543 # them. 3544 continue 3545 input_shapes = function_def.attr["_input_shapes"] 3546 try: 3547 func_graph_inputs = func_graph.inputs 3548 except AttributeError: 3549 continue 3550 # TODO(b/141471245): Fix the inconsistency when inputs of func graph 3551 # are appended during gradient computation of while/cond. 3552 assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)] 3553 # If the function_def has inputs already filled out, skip this step. 3554 if not input_shapes.list.shape: 3555 for input_tensor, arg_def in zip(func_graph_inputs, 3556 function_def.signature.input_arg): 3557 input_shapes.list.shape.add().CopyFrom( 3558 input_tensor.get_shape().as_proto()) 3559 if input_tensor.dtype == dtypes.resource: 3560 _copy_handle_data_to_arg_def(input_tensor, arg_def) 3561 3562 for output_tensor, arg_def in zip(func_graph.outputs, 3563 function_def.signature.output_arg): 3564 if output_tensor.dtype == dtypes.resource: 3565 _copy_handle_data_to_arg_def(output_tensor, arg_def) 3566 3567 for node in function_def.node_def: 3568 try: 3569 op = func_graph.get_operation_by_name(node.name) 3570 except KeyError: 3571 continue 3572 outputs = op.outputs 3573 3574 if op.type == "StatefulPartitionedCall": 3575 # Filter out any extra outputs (possibly added by function 3576 # backpropagation rewriting). 3577 num_outputs = len(node.attr["Tout"].list.type) 3578 outputs = outputs[:num_outputs] 3579 3580 node.attr["_output_shapes"].list.shape.extend( 3581 [output.get_shape().as_proto() for output in outputs]) 3582 3583 return graph, self._version 3584 3585 def as_graph_def(self, from_version=None, add_shapes=False): 3586 # pylint: disable=line-too-long 3587 """Returns a serialized `GraphDef` representation of this graph. 3588 3589 The serialized `GraphDef` can be imported into another `Graph` 3590 (using `tf.import_graph_def`) or used with the 3591 [C++ Session API](../../api_docs/cc/index.md). 3592 3593 This method is thread-safe. 3594 3595 Args: 3596 from_version: Optional. If this is set, returns a `GraphDef` containing 3597 only the nodes that were added to this graph since its `version` 3598 property had the given value. 3599 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3600 the inferred shapes of each of its outputs. 3601 3602 Returns: 3603 A 3604 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3605 protocol buffer. 3606 3607 Raises: 3608 ValueError: If the `graph_def` would be too large. 3609 """ 3610 # pylint: enable=line-too-long 3611 result, _ = self._as_graph_def(from_version, add_shapes) 3612 return result 3613 3614 def _is_function(self, name): 3615 """Tests whether 'name' is registered in this graph's function library. 3616 3617 Args: 3618 name: string op name. 3619 3620 Returns: 3621 bool indicating whether or not 'name' is registered in function library. 3622 """ 3623 return compat.as_str(name) in self._functions 3624 3625 def _get_function(self, name): 3626 """Returns the function definition for 'name'. 3627 3628 Args: 3629 name: string function name. 3630 3631 Returns: 3632 The function def proto. 3633 """ 3634 return self._functions.get(compat.as_str(name), None) 3635 3636 def _add_function(self, function): 3637 """Adds a function to the graph. 3638 3639 After the function has been added, you can call to the function by 3640 passing the function name in place of an op name to 3641 `Graph.create_op()`. 3642 3643 Args: 3644 function: A `_DefinedFunction` object. 3645 3646 Raises: 3647 ValueError: if another function is defined with the same name. 3648 """ 3649 self._check_not_finalized() 3650 3651 name = function.name 3652 # Sanity checks on gradient definition. 3653 if (function.grad_func_name is not None) and (function.python_grad_func is 3654 not None): 3655 raise ValueError("Gradient defined twice for function %s" % name) 3656 3657 # Add function to graph 3658 # pylint: disable=protected-access 3659 with self._c_graph.get() as c_graph: 3660 with function._c_func.get() as func: 3661 if function._grad_func: 3662 with function._grad_func._c_func.get() as gradient: 3663 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) 3664 else: 3665 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) 3666 # pylint: enable=protected-access 3667 3668 self._functions[compat.as_str(name)] = function 3669 3670 # Need a new-enough consumer to support the functions we add to the graph. 3671 if self._graph_def_versions.min_consumer < 12: 3672 self._graph_def_versions.min_consumer = 12 3673 3674 @property 3675 def building_function(self): 3676 """Returns True iff this graph represents a function.""" 3677 return self._building_function 3678 3679 # Helper functions to create operations. 3680 @deprecated_args(None, 3681 "Shapes are always computed; don't use the compute_shapes " 3682 "as it has no effect.", "compute_shapes") 3683 @traceback_utils.filter_traceback 3684 def create_op( 3685 self, 3686 op_type, 3687 inputs, 3688 dtypes=None, # pylint: disable=redefined-outer-name 3689 input_types=None, 3690 name=None, 3691 attrs=None, 3692 op_def=None, 3693 compute_shapes=True, 3694 compute_device=True): 3695 """Creates an `Operation` in this graph. 3696 3697 This is a low-level interface for creating an `Operation`. Most 3698 programs will not call this method directly, and instead use the 3699 Python op constructors, such as `tf.constant()`, which add ops to 3700 the default graph. 3701 3702 Args: 3703 op_type: The `Operation` type to create. This corresponds to the 3704 `OpDef.name` field for the proto that defines the operation. 3705 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3706 dtypes: (Optional) A list of `DType` objects that will be the types of the 3707 tensors that the operation produces. 3708 input_types: (Optional.) A list of `DType`s that will be the types of the 3709 tensors that the operation consumes. By default, uses the base `DType` 3710 of each input in `inputs`. Operations that expect reference-typed inputs 3711 must specify `input_types` explicitly. 3712 name: (Optional.) A string name for the operation. If not specified, a 3713 name is generated based on `op_type`. 3714 attrs: (Optional.) A dictionary where the key is the attribute name (a 3715 string) and the value is the respective `attr` attribute of the 3716 `NodeDef` proto that will represent the operation (an `AttrValue` 3717 proto). 3718 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3719 the operation will have. 3720 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 3721 computed). 3722 compute_device: (Optional.) If True, device functions will be executed to 3723 compute the device property of the Operation. 3724 3725 Raises: 3726 TypeError: if any of the inputs is not a `Tensor`. 3727 ValueError: if colocation conflicts with existing device assignment. 3728 3729 Returns: 3730 An `Operation` object. 3731 """ 3732 del compute_shapes 3733 for idx, a in enumerate(inputs): 3734 if not isinstance(a, Tensor): 3735 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3736 return self._create_op_internal(op_type, inputs, dtypes, input_types, name, 3737 attrs, op_def, compute_device) 3738 3739 def _create_op_internal( 3740 self, 3741 op_type, 3742 inputs, 3743 dtypes=None, # pylint: disable=redefined-outer-name 3744 input_types=None, 3745 name=None, 3746 attrs=None, 3747 op_def=None, 3748 compute_device=True): 3749 """Creates an `Operation` in this graph. 3750 3751 Implements `Graph.create_op()` without the overhead of the deprecation 3752 wrapper. 3753 3754 Args: 3755 op_type: The `Operation` type to create. This corresponds to the 3756 `OpDef.name` field for the proto that defines the operation. 3757 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3758 dtypes: (Optional) A list of `DType` objects that will be the types of the 3759 tensors that the operation produces. 3760 input_types: (Optional.) A list of `DType`s that will be the types of the 3761 tensors that the operation consumes. By default, uses the base `DType` 3762 of each input in `inputs`. Operations that expect reference-typed inputs 3763 must specify `input_types` explicitly. 3764 name: (Optional.) A string name for the operation. If not specified, a 3765 name is generated based on `op_type`. 3766 attrs: (Optional.) A dictionary where the key is the attribute name (a 3767 string) and the value is the respective `attr` attribute of the 3768 `NodeDef` proto that will represent the operation (an `AttrValue` 3769 proto). 3770 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3771 the operation will have. 3772 compute_device: (Optional.) If True, device functions will be executed to 3773 compute the device property of the Operation. 3774 3775 Raises: 3776 ValueError: if colocation conflicts with existing device assignment. 3777 3778 Returns: 3779 An `Operation` object. 3780 """ 3781 self._check_not_finalized() 3782 if name is None: 3783 name = op_type 3784 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3785 # after removing the trailing '/'. 3786 if name and name[-1] == "/": 3787 name = name_from_scope_name(name) 3788 else: 3789 name = self.unique_name(name) 3790 3791 node_def = _NodeDef(op_type, name, attrs) 3792 3793 input_ops = set(t.op for t in inputs) 3794 control_inputs = self._control_dependencies_for_inputs(input_ops) 3795 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 3796 # Session.run call cannot occur between creating and mutating the op. 3797 with self._mutation_lock(): 3798 ret = Operation( 3799 node_def, 3800 self, 3801 inputs=inputs, 3802 output_types=dtypes, 3803 control_inputs=control_inputs, 3804 input_types=input_types, 3805 original_op=self._default_original_op, 3806 op_def=op_def) 3807 self._create_op_helper(ret, compute_device=compute_device) 3808 return ret 3809 3810 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3811 """Creates an `Operation` in this graph from the supplied TF_Operation. 3812 3813 This method is like create_op() except the new Operation is constructed 3814 using `c_op`. The returned Operation will have `c_op` as its _c_op 3815 field. This is used to create Operation objects around TF_Operations created 3816 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3817 3818 This function does not call Operation._control_flow_post_processing or 3819 Graph._control_dependencies_for_inputs (since the inputs may not be 3820 available yet). The caller is responsible for calling these methods. 3821 3822 Args: 3823 c_op: a wrapped TF_Operation 3824 compute_device: (Optional.) If True, device functions will be executed to 3825 compute the device property of the Operation. 3826 3827 Returns: 3828 An `Operation` object. 3829 """ 3830 self._check_not_finalized() 3831 ret = Operation._from_c_op(c_op=c_op, g=self) # pylint: disable=protected-access 3832 # If a name_scope was created with ret.name but no nodes were created in it, 3833 # the name will still appear in _names_in_use even though the name hasn't 3834 # been used. This is ok, just leave _names_in_use as-is in this case. 3835 # TODO(skyewm): make the C API guarantee no name conflicts. 3836 name_key = ret.name.lower() 3837 if name_key not in self._names_in_use: 3838 self._names_in_use[name_key] = 1 3839 self._create_op_helper(ret, compute_device=compute_device) 3840 return ret 3841 3842 def _create_op_helper(self, op, compute_device=True): 3843 """Common logic for creating an op in this graph.""" 3844 # Apply any additional attributes requested. Do not overwrite any existing 3845 # attributes. 3846 for key, value in self._attr_scope_map.items(): 3847 try: 3848 op.get_attr(key) 3849 except ValueError: 3850 if callable(value): 3851 value = value(op.node_def) 3852 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3853 raise TypeError( 3854 "Callable for scope map key '%s' must return either None or " 3855 "an AttrValue protocol buffer; but it returned: %s" % 3856 (key, value)) 3857 if value: 3858 op._set_attr(key, value) # pylint: disable=protected-access 3859 3860 # Apply a kernel label if one has been specified for this op type. 3861 try: 3862 kernel_label = self._op_to_kernel_label_map[op.type] 3863 op._set_attr("_kernel", # pylint: disable=protected-access 3864 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3865 except KeyError: 3866 pass 3867 3868 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access 3869 3870 # Apply the overriding op type for gradients if one has been specified for 3871 # this op type. 3872 try: 3873 mapped_op_type = self._gradient_override_map[op.type] 3874 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3875 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3876 except KeyError: 3877 pass 3878 3879 self._record_op_seen_by_control_dependencies(op) 3880 3881 if compute_device: 3882 self._apply_device_functions(op) 3883 3884 # Snapshot the colocation stack metadata before we might generate error 3885 # messages using it. Note that this snapshot depends on the actual stack 3886 # and is independent of the op's _class attribute. 3887 # pylint: disable=protected-access 3888 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 3889 # pylint: enable=protected-access 3890 3891 if self._colocation_stack: 3892 all_colocation_groups = [] 3893 is_device_set = False 3894 for colocation_op in self._colocation_stack.peek_objs(): 3895 try: 3896 all_colocation_groups.extend(colocation_op.colocation_groups()) 3897 except AttributeError: 3898 pass 3899 if colocation_op.device and not is_device_set: 3900 # pylint: disable=protected-access 3901 op._set_device(colocation_op.device) 3902 # pylint: enable=protected-access 3903 is_device_set = True 3904 3905 all_colocation_groups = sorted(set(all_colocation_groups)) 3906 # pylint: disable=protected-access 3907 op._set_attr( 3908 "_class", 3909 attr_value_pb2.AttrValue( 3910 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3911 # pylint: enable=protected-access 3912 3913 # Sets "container" attribute if 3914 # (1) self._container is not None 3915 # (2) "is_stateful" is set in OpDef 3916 # (3) "container" attribute is in OpDef 3917 # (4) "container" attribute is None 3918 if self._container and op._is_stateful: # pylint: disable=protected-access 3919 try: 3920 container_attr = op.get_attr("container") 3921 except ValueError: 3922 # "container" attribute is not in OpDef 3923 pass 3924 else: 3925 if not container_attr: 3926 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3927 s=compat.as_bytes(self._container))) 3928 3929 def _add_new_tf_operations(self, compute_devices=True): 3930 """Creates `Operations` in this graph for any new TF_Operations. 3931 3932 This is useful for when TF_Operations are indirectly created by the C API 3933 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3934 TF_FinishWhile). This ensures there are corresponding Operations for all 3935 TF_Operations in the underlying TF_Graph. 3936 3937 Args: 3938 compute_devices: (Optional.) If True, device functions will be executed to 3939 compute the device properties of each new Operation. 3940 3941 Returns: 3942 A list of the new `Operation` objects. 3943 """ 3944 self._check_not_finalized() 3945 3946 # Create all Operation objects before accessing their inputs since an op may 3947 # be created before its inputs. 3948 new_ops = [ 3949 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3950 for c_op in c_api_util.new_tf_operations(self) 3951 ] 3952 3953 # pylint: disable=protected-access 3954 for op in new_ops: 3955 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3956 op._add_control_inputs(new_control_inputs) 3957 op._control_flow_post_processing() 3958 # pylint: enable=protected-access 3959 3960 return new_ops 3961 3962 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3963 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3964 3965 This function validates that `obj` represents an element of this 3966 graph, and gives an informative error message if it is not. 3967 3968 This function is the canonical way to get/validate an object of 3969 one of the allowed types from an external argument reference in the 3970 Session API. 3971 3972 This method may be called concurrently from multiple threads. 3973 3974 Args: 3975 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can 3976 also be any object with an `_as_graph_element()` method that returns a 3977 value of one of these types. Note: `_as_graph_element` will be called 3978 inside the graph's lock and so may not modify the graph. 3979 allow_tensor: If true, `obj` may refer to a `Tensor`. 3980 allow_operation: If true, `obj` may refer to an `Operation`. 3981 3982 Returns: 3983 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3984 3985 Raises: 3986 TypeError: If `obj` is not a type we support attempting to convert 3987 to types. 3988 ValueError: If `obj` is of an appropriate type but invalid. For 3989 example, an invalid string. 3990 KeyError: If `obj` is not an object in the graph. 3991 """ 3992 if self._finalized: 3993 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3994 3995 with self._lock: 3996 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3997 3998 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3999 """See `Graph.as_graph_element()` for details.""" 4000 # The vast majority of this function is figuring 4001 # out what an API user might be doing wrong, so 4002 # that we can give helpful error messages. 4003 # 4004 # Ideally, it would be nice to split it up, but we 4005 # need context to generate nice error messages. 4006 4007 if allow_tensor and allow_operation: 4008 types_str = "Tensor or Operation" 4009 elif allow_tensor: 4010 types_str = "Tensor" 4011 elif allow_operation: 4012 types_str = "Operation" 4013 else: 4014 raise ValueError("allow_tensor and allow_operation can't both be False.") 4015 4016 temp_obj = _as_graph_element(obj) 4017 if temp_obj is not None: 4018 obj = temp_obj 4019 4020 # If obj appears to be a name... 4021 if isinstance(obj, compat.bytes_or_text_types): 4022 name = compat.as_str(obj) 4023 4024 if ":" in name and allow_tensor: 4025 # Looks like a Tensor name and can be a Tensor. 4026 try: 4027 op_name, out_n = name.split(":") 4028 out_n = int(out_n) 4029 except: 4030 raise ValueError("The name %s looks a like a Tensor name, but is " 4031 "not a valid one. Tensor names must be of the " 4032 "form \"<op_name>:<output_index>\"." % repr(name)) 4033 if op_name in self._nodes_by_name: 4034 op = self._nodes_by_name[op_name] 4035 else: 4036 raise KeyError("The name %s refers to a Tensor which does not " 4037 "exist. The operation, %s, does not exist in the " 4038 "graph." % (repr(name), repr(op_name))) 4039 try: 4040 return op.outputs[out_n] 4041 except: 4042 raise KeyError("The name %s refers to a Tensor which does not " 4043 "exist. The operation, %s, exists but only has " 4044 "%s outputs." % 4045 (repr(name), repr(op_name), len(op.outputs))) 4046 4047 elif ":" in name and not allow_tensor: 4048 # Looks like a Tensor name but can't be a Tensor. 4049 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 4050 (repr(name), types_str)) 4051 4052 elif ":" not in name and allow_operation: 4053 # Looks like an Operation name and can be an Operation. 4054 if name not in self._nodes_by_name: 4055 raise KeyError("The name %s refers to an Operation not in the " 4056 "graph." % repr(name)) 4057 return self._nodes_by_name[name] 4058 4059 elif ":" not in name and not allow_operation: 4060 # Looks like an Operation name but can't be an Operation. 4061 if name in self._nodes_by_name: 4062 # Yep, it's an Operation name 4063 err_msg = ("The name %s refers to an Operation, not a %s." % 4064 (repr(name), types_str)) 4065 else: 4066 err_msg = ("The name %s looks like an (invalid) Operation name, " 4067 "not a %s." % (repr(name), types_str)) 4068 err_msg += (" Tensor names must be of the form " 4069 "\"<op_name>:<output_index>\".") 4070 raise ValueError(err_msg) 4071 4072 elif isinstance(obj, Tensor) and allow_tensor: 4073 # Actually obj is just the object it's referring to. 4074 if obj.graph is not self: 4075 raise ValueError("Tensor %s is not an element of this graph." % obj) 4076 return obj 4077 elif isinstance(obj, Operation) and allow_operation: 4078 # Actually obj is just the object it's referring to. 4079 if obj.graph is not self: 4080 raise ValueError("Operation %s is not an element of this graph." % obj) 4081 return obj 4082 else: 4083 # We give up! 4084 raise TypeError("Can not convert a %s into a %s." % 4085 (type(obj).__name__, types_str)) 4086 4087 def get_operations(self): 4088 """Return the list of operations in the graph. 4089 4090 You can modify the operations in place, but modifications 4091 to the list such as inserts/delete have no effect on the 4092 list of operations known to the graph. 4093 4094 This method may be called concurrently from multiple threads. 4095 4096 Returns: 4097 A list of Operations. 4098 """ 4099 if self._finalized: 4100 return list(self._nodes_by_id.values()) 4101 4102 with self._lock: 4103 return list(self._nodes_by_id.values()) 4104 4105 def get_operation_by_name(self, name): 4106 """Returns the `Operation` with the given `name`. 4107 4108 This method may be called concurrently from multiple threads. 4109 4110 Args: 4111 name: The name of the `Operation` to return. 4112 4113 Returns: 4114 The `Operation` with the given `name`. 4115 4116 Raises: 4117 TypeError: If `name` is not a string. 4118 KeyError: If `name` does not correspond to an operation in this graph. 4119 """ 4120 4121 if not isinstance(name, str): 4122 raise TypeError("Operation names are strings (or similar), not %s." % 4123 type(name).__name__) 4124 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 4125 4126 def _get_operation_by_name_unsafe(self, name): 4127 """Returns the `Operation` with the given `name`. 4128 4129 This is a internal unsafe version of get_operation_by_name. It skips many 4130 checks and does not have user friendly error messages but runs considerably 4131 faster. This method may be called concurrently from multiple threads. 4132 4133 Args: 4134 name: The name of the `Operation` to return. 4135 4136 Returns: 4137 The `Operation` with the given `name`. 4138 4139 Raises: 4140 KeyError: If `name` does not correspond to an operation in this graph. 4141 """ 4142 4143 if self._finalized: 4144 return self._nodes_by_name[name] 4145 4146 with self._lock: 4147 return self._nodes_by_name[name] 4148 4149 def _get_operation_by_tf_operation(self, tf_oper): 4150 op_name = pywrap_tf_session.TF_OperationName(tf_oper) 4151 return self._get_operation_by_name_unsafe(op_name) 4152 4153 def get_tensor_by_name(self, name): 4154 """Returns the `Tensor` with the given `name`. 4155 4156 This method may be called concurrently from multiple threads. 4157 4158 Args: 4159 name: The name of the `Tensor` to return. 4160 4161 Returns: 4162 The `Tensor` with the given `name`. 4163 4164 Raises: 4165 TypeError: If `name` is not a string. 4166 KeyError: If `name` does not correspond to a tensor in this graph. 4167 """ 4168 # Names should be strings. 4169 if not isinstance(name, str): 4170 raise TypeError("Tensor names are strings (or similar), not %s." % 4171 type(name).__name__) 4172 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 4173 4174 def _get_tensor_by_tf_output(self, tf_output): 4175 """Returns the `Tensor` representing `tf_output`. 4176 4177 Note that there is only one such `Tensor`, i.e. multiple calls to this 4178 function with the same TF_Output value will always return the same `Tensor` 4179 object. 4180 4181 Args: 4182 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 4183 4184 Returns: 4185 The `Tensor` that represents `tf_output`. 4186 """ 4187 op = self._get_operation_by_tf_operation(tf_output.oper) 4188 return op.outputs[tf_output.index] 4189 4190 @property 4191 def _last_id(self): 4192 return self._next_id_counter 4193 4194 def _get_op_def(self, type): # pylint: disable=redefined-builtin 4195 """Returns the `OpDef` proto for `type`. `type` is a string.""" 4196 # NOTE: No locking is required because the lookup and insertion operations 4197 # on Python dictionaries are atomic. 4198 try: 4199 return self._op_def_cache[type] 4200 except KeyError: 4201 with c_api_util.tf_buffer() as buf: 4202 with self._c_graph.get() as c_graph: 4203 pywrap_tf_session.TF_GraphGetOpDef(c_graph, compat.as_bytes(type), 4204 buf) 4205 data = pywrap_tf_session.TF_GetBuffer(buf) 4206 op_def = op_def_pb2.OpDef() 4207 op_def.ParseFromString(compat.as_bytes(data)) 4208 self._op_def_cache[type] = op_def 4209 return op_def 4210 4211 def as_default(self): 4212 """Returns a context manager that makes this `Graph` the default graph. 4213 4214 This method should be used if you want to create multiple graphs 4215 in the same process. For convenience, a global default graph is 4216 provided, and all ops will be added to this graph if you do not 4217 create a new graph explicitly. 4218 4219 Use this method with the `with` keyword to specify that ops created within 4220 the scope of a block should be added to this graph. In this case, once 4221 the scope of the `with` is exited, the previous default graph is set again 4222 as default. There is a stack, so it's ok to have multiple nested levels 4223 of `as_default` calls. 4224 4225 The default graph is a property of the current thread. If you 4226 create a new thread, and wish to use the default graph in that 4227 thread, you must explicitly add a `with g.as_default():` in that 4228 thread's function. 4229 4230 The following code examples are equivalent: 4231 4232 ```python 4233 # 1. Using Graph.as_default(): 4234 g = tf.Graph() 4235 with g.as_default(): 4236 c = tf.constant(5.0) 4237 assert c.graph is g 4238 4239 # 2. Constructing and making default: 4240 with tf.Graph().as_default() as g: 4241 c = tf.constant(5.0) 4242 assert c.graph is g 4243 ``` 4244 4245 If eager execution is enabled ops created under this context manager will be 4246 added to the graph instead of executed eagerly. 4247 4248 Returns: 4249 A context manager for using this graph as the default graph. 4250 """ 4251 return _default_graph_stack.get_controller(self) 4252 4253 @property 4254 def collections(self): 4255 """Returns the names of the collections known to this graph.""" 4256 return list(self._collections) 4257 4258 def add_to_collection(self, name, value): 4259 """Stores `value` in the collection with the given `name`. 4260 4261 Note that collections are not sets, so it is possible to add a value to 4262 a collection several times. 4263 4264 Args: 4265 name: The key for the collection. The `GraphKeys` class contains many 4266 standard names for collections. 4267 value: The value to add to the collection. 4268 """ # pylint: disable=g-doc-exception 4269 self._check_not_finalized() 4270 with self._lock: 4271 if name not in self._collections: 4272 self._collections[name] = [value] 4273 else: 4274 self._collections[name].append(value) 4275 4276 def add_to_collections(self, names, value): 4277 """Stores `value` in the collections given by `names`. 4278 4279 Note that collections are not sets, so it is possible to add a value to 4280 a collection several times. This function makes sure that duplicates in 4281 `names` are ignored, but it will not check for pre-existing membership of 4282 `value` in any of the collections in `names`. 4283 4284 `names` can be any iterable, but if `names` is a string, it is treated as a 4285 single collection name. 4286 4287 Args: 4288 names: The keys for the collections to add to. The `GraphKeys` class 4289 contains many standard names for collections. 4290 value: The value to add to the collections. 4291 """ 4292 # Make sure names are unique, but treat strings as a single collection name 4293 names = (names,) if isinstance(names, str) else set(names) 4294 for name in names: 4295 self.add_to_collection(name, value) 4296 4297 def get_collection_ref(self, name): 4298 """Returns a list of values in the collection with the given `name`. 4299 4300 If the collection exists, this returns the list itself, which can 4301 be modified in place to change the collection. If the collection does 4302 not exist, it is created as an empty list and the list is returned. 4303 4304 This is different from `get_collection()` which always returns a copy of 4305 the collection list if it exists and never creates an empty collection. 4306 4307 Args: 4308 name: The key for the collection. For example, the `GraphKeys` class 4309 contains many standard names for collections. 4310 4311 Returns: 4312 The list of values in the collection with the given `name`, or an empty 4313 list if no value has been added to that collection. 4314 """ # pylint: disable=g-doc-exception 4315 with self._lock: 4316 coll_list = self._collections.get(name, None) 4317 if coll_list is None: 4318 coll_list = [] 4319 self._collections[name] = coll_list 4320 return coll_list 4321 4322 def get_collection(self, name, scope=None): 4323 """Returns a list of values in the collection with the given `name`. 4324 4325 This is different from `get_collection_ref()` which always returns the 4326 actual collection list if it exists in that it returns a new list each time 4327 it is called. 4328 4329 Args: 4330 name: The key for the collection. For example, the `GraphKeys` class 4331 contains many standard names for collections. 4332 scope: (Optional.) A string. If supplied, the resulting list is filtered 4333 to include only items whose `name` attribute matches `scope` using 4334 `re.match`. Items without a `name` attribute are never returned if a 4335 scope is supplied. The choice of `re.match` means that a `scope` without 4336 special tokens filters by prefix. 4337 4338 Returns: 4339 The list of values in the collection with the given `name`, or 4340 an empty list if no value has been added to that collection. The 4341 list contains the values in the order under which they were 4342 collected. 4343 """ # pylint: disable=g-doc-exception 4344 with self._lock: 4345 collection = self._collections.get(name, None) 4346 if collection is None: 4347 return [] 4348 if scope is None: 4349 return list(collection) 4350 else: 4351 c = [] 4352 regex = re.compile(scope) 4353 for item in collection: 4354 try: 4355 if regex.match(item.name): 4356 c.append(item) 4357 except AttributeError: 4358 # Collection items with no name are ignored. 4359 pass 4360 return c 4361 4362 def get_all_collection_keys(self): 4363 """Returns a list of collections used in this graph.""" 4364 with self._lock: 4365 return [x for x in self._collections if isinstance(x, str)] 4366 4367 def clear_collection(self, name): 4368 """Clears all values in a collection. 4369 4370 Args: 4371 name: The key for the collection. The `GraphKeys` class contains many 4372 standard names for collections. 4373 """ 4374 self._check_not_finalized() 4375 with self._lock: 4376 if name in self._collections: 4377 del self._collections[name] 4378 4379 @tf_contextlib.contextmanager 4380 def _original_op(self, op): 4381 """Python 'with' handler to help annotate ops with their originator. 4382 4383 An op may have an 'original_op' property that indicates the op on which 4384 it was based. For example a replica op is based on the op that was 4385 replicated and a gradient op is based on the op that was differentiated. 4386 4387 All ops created in the scope of this 'with' handler will have 4388 the given 'op' as their original op. 4389 4390 Args: 4391 op: The Operation that all ops created in this scope will have as their 4392 original op. 4393 4394 Yields: 4395 Nothing. 4396 """ 4397 old_original_op = self._default_original_op 4398 self._default_original_op = op 4399 try: 4400 yield 4401 finally: 4402 self._default_original_op = old_original_op 4403 4404 @property 4405 def _name_stack(self): 4406 # This may be called from a thread where name_stack doesn't yet exist. 4407 if not hasattr(self._thread_local, "_name_stack"): 4408 self._thread_local._name_stack = "" 4409 return self._thread_local._name_stack 4410 4411 @_name_stack.setter 4412 def _name_stack(self, name_stack): 4413 self._thread_local._name_stack = name_stack 4414 4415 # pylint: disable=g-doc-return-or-yield,line-too-long 4416 @tf_contextlib.contextmanager 4417 def name_scope(self, name): 4418 """Returns a context manager that creates hierarchical names for operations. 4419 4420 A graph maintains a stack of name scopes. A `with name_scope(...):` 4421 statement pushes a new name onto the stack for the lifetime of the context. 4422 4423 The `name` argument will be interpreted as follows: 4424 4425 * A string (not ending with '/') will create a new name scope, in which 4426 `name` is appended to the prefix of all operations created in the 4427 context. If `name` has been used before, it will be made unique by 4428 calling `self.unique_name(name)`. 4429 * A scope previously captured from a `with g.name_scope(...) as 4430 scope:` statement will be treated as an "absolute" name scope, which 4431 makes it possible to re-enter existing scopes. 4432 * A value of `None` or the empty string will reset the current name scope 4433 to the top-level (empty) name scope. 4434 4435 For example: 4436 4437 ```python 4438 with tf.Graph().as_default() as g: 4439 c = tf.constant(5.0, name="c") 4440 assert c.op.name == "c" 4441 c_1 = tf.constant(6.0, name="c") 4442 assert c_1.op.name == "c_1" 4443 4444 # Creates a scope called "nested" 4445 with g.name_scope("nested") as scope: 4446 nested_c = tf.constant(10.0, name="c") 4447 assert nested_c.op.name == "nested/c" 4448 4449 # Creates a nested scope called "inner". 4450 with g.name_scope("inner"): 4451 nested_inner_c = tf.constant(20.0, name="c") 4452 assert nested_inner_c.op.name == "nested/inner/c" 4453 4454 # Create a nested scope called "inner_1". 4455 with g.name_scope("inner"): 4456 nested_inner_1_c = tf.constant(30.0, name="c") 4457 assert nested_inner_1_c.op.name == "nested/inner_1/c" 4458 4459 # Treats `scope` as an absolute name scope, and 4460 # switches to the "nested/" scope. 4461 with g.name_scope(scope): 4462 nested_d = tf.constant(40.0, name="d") 4463 assert nested_d.op.name == "nested/d" 4464 4465 with g.name_scope(""): 4466 e = tf.constant(50.0, name="e") 4467 assert e.op.name == "e" 4468 ``` 4469 4470 The name of the scope itself can be captured by `with 4471 g.name_scope(...) as scope:`, which stores the name of the scope 4472 in the variable `scope`. This value can be used to name an 4473 operation that represents the overall result of executing the ops 4474 in a scope. For example: 4475 4476 ```python 4477 inputs = tf.constant(...) 4478 with g.name_scope('my_layer') as scope: 4479 weights = tf.Variable(..., name="weights") 4480 biases = tf.Variable(..., name="biases") 4481 affine = tf.matmul(inputs, weights) + biases 4482 output = tf.nn.relu(affine, name=scope) 4483 ``` 4484 4485 NOTE: This constructor validates the given `name`. Valid scope 4486 names match one of the following regular expressions: 4487 4488 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 4489 [A-Za-z0-9_.\\-/]* (for other scopes) 4490 4491 Args: 4492 name: A name for the scope. 4493 4494 Returns: 4495 A context manager that installs `name` as a new name scope. 4496 4497 Raises: 4498 ValueError: If `name` is not a valid scope name, according to the rules 4499 above. 4500 """ 4501 if name: 4502 if isinstance(name, compat.bytes_or_text_types): 4503 name = compat.as_str(name) 4504 4505 if self._name_stack: 4506 # Scopes created in a nested scope may have initial characters 4507 # that are illegal as the initial character of an op name 4508 # (viz. '-', '\', '/', and '_'). 4509 if not _VALID_SCOPE_NAME_REGEX.match(name): 4510 raise ValueError( 4511 f"'{name}' is not a valid scope name. A scope name has to match " 4512 f"the following pattern: {_VALID_SCOPE_NAME_REGEX.pattern}") 4513 else: 4514 # Scopes created in the root must match the more restrictive 4515 # op name regex, which constrains the initial character. 4516 if not _VALID_OP_NAME_REGEX.match(name): 4517 raise ValueError( 4518 f"'{name}' is not a valid root scope name. A root scope name has " 4519 f"to match the following pattern: {_VALID_OP_NAME_REGEX.pattern}") 4520 old_stack = self._name_stack 4521 if not name: # Both for name=None and name="" we re-set to empty scope. 4522 new_stack = "" 4523 returned_scope = "" 4524 elif name[-1] == "/": 4525 new_stack = name_from_scope_name(name) 4526 returned_scope = name 4527 else: 4528 new_stack = self.unique_name(name) 4529 returned_scope = new_stack + "/" 4530 self._name_stack = new_stack 4531 try: 4532 yield returned_scope 4533 finally: 4534 self._name_stack = old_stack 4535 4536 # pylint: enable=g-doc-return-or-yield,line-too-long 4537 4538 def unique_name(self, name, mark_as_used=True): 4539 """Return a unique operation name for `name`. 4540 4541 Note: You rarely need to call `unique_name()` directly. Most of 4542 the time you just need to create `with g.name_scope()` blocks to 4543 generate structured names. 4544 4545 `unique_name` is used to generate structured names, separated by 4546 `"/"`, to help identify operations when debugging a graph. 4547 Operation names are displayed in error messages reported by the 4548 TensorFlow runtime, and in various visualization tools such as 4549 TensorBoard. 4550 4551 If `mark_as_used` is set to `True`, which is the default, a new 4552 unique name is created and marked as in use. If it's set to `False`, 4553 the unique name is returned without actually being marked as used. 4554 This is useful when the caller simply wants to know what the name 4555 to be created will be. 4556 4557 Args: 4558 name: The name for an operation. 4559 mark_as_used: Whether to mark this name as being used. 4560 4561 Returns: 4562 A string to be passed to `create_op()` that will be used 4563 to name the operation being created. 4564 """ 4565 if self._name_stack: 4566 name = self._name_stack + "/" + name 4567 4568 # For the sake of checking for names in use, we treat names as case 4569 # insensitive (e.g. foo = Foo). 4570 name_key = name.lower() 4571 i = self._names_in_use.get(name_key, 0) 4572 # Increment the number for "name_key". 4573 if mark_as_used: 4574 self._names_in_use[name_key] = i + 1 4575 if i > 0: 4576 base_name_key = name_key 4577 # Make sure the composed name key is not already used. 4578 while name_key in self._names_in_use: 4579 name_key = "%s_%d" % (base_name_key, i) 4580 i += 1 4581 # Mark the composed name_key as used in case someone wants 4582 # to call unique_name("name_1"). 4583 if mark_as_used: 4584 self._names_in_use[name_key] = 1 4585 4586 # Return the new name with the original capitalization of the given name. 4587 name = "%s_%d" % (name, i - 1) 4588 return name 4589 4590 def get_name_scope(self): 4591 """Returns the current name scope. 4592 4593 For example: 4594 4595 ```python 4596 with tf.name_scope('scope1'): 4597 with tf.name_scope('scope2'): 4598 print(tf.compat.v1.get_default_graph().get_name_scope()) 4599 ``` 4600 would print the string `scope1/scope2`. 4601 4602 Returns: 4603 A string representing the current name scope. 4604 """ 4605 return self._name_stack 4606 4607 @tf_contextlib.contextmanager 4608 def _colocate_with_for_gradient(self, op, gradient_uid, 4609 ignore_existing=False): 4610 with self.colocate_with(op, ignore_existing): 4611 if gradient_uid is not None: 4612 ctx = _get_enclosing_context(self) 4613 if ctx is not None: 4614 ctx.EnterGradientColocation(op, gradient_uid) 4615 try: 4616 yield 4617 finally: 4618 ctx.ExitGradientColocation(op, gradient_uid) 4619 else: 4620 yield 4621 else: 4622 yield 4623 4624 @tf_contextlib.contextmanager 4625 def colocate_with(self, op, ignore_existing=False): 4626 """Returns a context manager that specifies an op to colocate with. 4627 4628 Note: this function is not for public use, only for internal libraries. 4629 4630 For example: 4631 4632 ```python 4633 a = tf.Variable([1.0]) 4634 with g.colocate_with(a): 4635 b = tf.constant(1.0) 4636 c = tf.add(a, b) 4637 ``` 4638 4639 `b` and `c` will always be colocated with `a`, no matter where `a` 4640 is eventually placed. 4641 4642 **NOTE** Using a colocation scope resets any existing device constraints. 4643 4644 If `op` is `None` then `ignore_existing` must be `True` and the new 4645 scope resets all colocation and device constraints. 4646 4647 Args: 4648 op: The op to colocate all created ops with, or `None`. 4649 ignore_existing: If true, only applies colocation of this op within the 4650 context, rather than applying all colocation properties on the stack. 4651 If `op` is `None`, this value must be `True`. 4652 4653 Raises: 4654 ValueError: if op is None but ignore_existing is False. 4655 4656 Yields: 4657 A context manager that specifies the op with which to colocate 4658 newly created ops. 4659 """ 4660 if op is None and not ignore_existing: 4661 raise ValueError("Trying to reset colocation (op is None) but " 4662 "ignore_existing is not True") 4663 op, device_only_candidate = _op_to_colocate_with(op, self) 4664 4665 # By default, colocate_with resets the device function stack, 4666 # since colocate_with is typically used in specific internal 4667 # library functions where colocation is intended to be "stronger" 4668 # than device functions. 4669 # 4670 # In the future, a caller may specify that device_functions win 4671 # over colocation, in which case we can add support. 4672 device_fn_tmp = self._device_function_stack 4673 self._device_function_stack = traceable_stack.TraceableStack() 4674 4675 if ignore_existing: 4676 current_stack = self._colocation_stack 4677 self._colocation_stack = traceable_stack.TraceableStack() 4678 4679 if op is not None: 4680 # offset refers to the stack frame used for storing code location. 4681 # We use 4, the sum of 1 to use our caller's stack frame and 3 4682 # to jump over layers of context managers above us. 4683 if device_only_candidate is not None: 4684 self._colocation_stack.push_obj(device_only_candidate, offset=4) 4685 self._colocation_stack.push_obj(op, offset=4) 4686 elif not ignore_existing: 4687 raise ValueError("Trying to reset colocation (op is None) but " 4688 "ignore_existing is not True") 4689 try: 4690 yield 4691 finally: 4692 # Restore device function stack 4693 self._device_function_stack = device_fn_tmp 4694 if op is not None: 4695 self._colocation_stack.pop_obj() 4696 if device_only_candidate is not None: 4697 self._colocation_stack.pop_obj() 4698 4699 # Reset the colocation stack if requested. 4700 if ignore_existing: 4701 self._colocation_stack = current_stack 4702 4703 def _add_device_to_stack(self, device_name_or_function, offset=0): 4704 """Add device to stack manually, separate from a context manager.""" 4705 total_offset = 1 + offset 4706 spec = _UserDeviceSpec(device_name_or_function) 4707 self._device_function_stack.push_obj(spec, offset=total_offset) 4708 return spec 4709 4710 @tf_contextlib.contextmanager 4711 def device(self, device_name_or_function): 4712 # pylint: disable=line-too-long 4713 """Returns a context manager that specifies the default device to use. 4714 4715 The `device_name_or_function` argument may either be a device name 4716 string, a device function, or None: 4717 4718 * If it is a device name string, all operations constructed in 4719 this context will be assigned to the device with that name, unless 4720 overridden by a nested `device()` context. 4721 * If it is a function, it will be treated as a function from 4722 Operation objects to device name strings, and invoked each time 4723 a new Operation is created. The Operation will be assigned to 4724 the device with the returned name. 4725 * If it is None, all `device()` invocations from the enclosing context 4726 will be ignored. 4727 4728 For information about the valid syntax of device name strings, see 4729 the documentation in 4730 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4731 4732 For example: 4733 4734 ```python 4735 with g.device('/device:GPU:0'): 4736 # All operations constructed in this context will be placed 4737 # on GPU 0. 4738 with g.device(None): 4739 # All operations constructed in this context will have no 4740 # assigned device. 4741 4742 # Defines a function from `Operation` to device string. 4743 def matmul_on_gpu(n): 4744 if n.type == "MatMul": 4745 return "/device:GPU:0" 4746 else: 4747 return "/cpu:0" 4748 4749 with g.device(matmul_on_gpu): 4750 # All operations of type "MatMul" constructed in this context 4751 # will be placed on GPU 0; all other operations will be placed 4752 # on CPU 0. 4753 ``` 4754 4755 **N.B.** The device scope may be overridden by op wrappers or 4756 other library code. For example, a variable assignment op 4757 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4758 incompatible device scopes will be ignored. 4759 4760 Args: 4761 device_name_or_function: The device name or function to use in the 4762 context. 4763 4764 Yields: 4765 A context manager that specifies the default device to use for newly 4766 created ops. 4767 4768 Raises: 4769 RuntimeError: If device scopes are not properly nested. 4770 """ 4771 self._add_device_to_stack(device_name_or_function, offset=2) 4772 old_top_of_stack = self._device_function_stack.peek_top_obj() 4773 try: 4774 yield 4775 finally: 4776 new_top_of_stack = self._device_function_stack.peek_top_obj() 4777 if old_top_of_stack is not new_top_of_stack: 4778 raise RuntimeError("Exiting device scope without proper scope nesting.") 4779 self._device_function_stack.pop_obj() 4780 4781 def _apply_device_functions(self, op): 4782 """Applies the current device function stack to the given operation.""" 4783 # Apply any device functions in LIFO order, so that the most recently 4784 # pushed function has the first chance to apply a device to the op. 4785 # We apply here because the result can depend on the Operation's 4786 # signature, which is computed in the Operation constructor. 4787 # pylint: disable=protected-access 4788 prior_device_string = None 4789 for device_spec in self._device_function_stack.peek_objs(): 4790 if device_spec.is_null_merge: 4791 continue 4792 4793 if device_spec.function is None: 4794 break 4795 4796 device_string = device_spec.string_merge(op) 4797 4798 # Take advantage of the fact that None is a singleton and Python interns 4799 # strings, since identity checks are faster than equality checks. 4800 if device_string is not prior_device_string: 4801 op._set_device_from_string(device_string) 4802 prior_device_string = device_string 4803 op._device_code_locations = self._snapshot_device_function_stack_metadata() 4804 # pylint: enable=protected-access 4805 4806 # pylint: disable=g-doc-return-or-yield 4807 @tf_contextlib.contextmanager 4808 def container(self, container_name): 4809 """Returns a context manager that specifies the resource container to use. 4810 4811 Stateful operations, such as variables and queues, can maintain their 4812 states on devices so that they can be shared by multiple processes. 4813 A resource container is a string name under which these stateful 4814 operations are tracked. These resources can be released or cleared 4815 with `tf.Session.reset()`. 4816 4817 For example: 4818 4819 ```python 4820 with g.container('experiment0'): 4821 # All stateful Operations constructed in this context will be placed 4822 # in resource container "experiment0". 4823 v1 = tf.Variable([1.0]) 4824 v2 = tf.Variable([2.0]) 4825 with g.container("experiment1"): 4826 # All stateful Operations constructed in this context will be 4827 # placed in resource container "experiment1". 4828 v3 = tf.Variable([3.0]) 4829 q1 = tf.queue.FIFOQueue(10, tf.float32) 4830 # All stateful Operations constructed in this context will be 4831 # be created in the "experiment0". 4832 v4 = tf.Variable([4.0]) 4833 q1 = tf.queue.FIFOQueue(20, tf.float32) 4834 with g.container(""): 4835 # All stateful Operations constructed in this context will be 4836 # be placed in the default resource container. 4837 v5 = tf.Variable([5.0]) 4838 q3 = tf.queue.FIFOQueue(30, tf.float32) 4839 4840 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4841 # will become undefined (such as uninitialized). 4842 tf.Session.reset(target, ["experiment0"]) 4843 ``` 4844 4845 Args: 4846 container_name: container name string. 4847 4848 Returns: 4849 A context manager for defining resource containers for stateful ops, 4850 yields the container name. 4851 """ 4852 original_container = self._container 4853 self._container = container_name 4854 try: 4855 yield self._container 4856 finally: 4857 self._container = original_container 4858 4859 # pylint: enable=g-doc-return-or-yield 4860 4861 class _ControlDependenciesController(object): 4862 """Context manager for `control_dependencies()`.""" 4863 4864 def __init__(self, graph, control_inputs): 4865 """Create a new `_ControlDependenciesController`. 4866 4867 A `_ControlDependenciesController` is the context manager for 4868 `with tf.control_dependencies()` blocks. These normally nest, 4869 as described in the documentation for `control_dependencies()`. 4870 4871 The `control_inputs` argument list control dependencies that must be 4872 added to the current set of control dependencies. Because of 4873 uniquification the set can be empty even if the caller passed a list of 4874 ops. The special value `None` indicates that we want to start a new 4875 empty set of control dependencies instead of extending the current set. 4876 4877 In that case we also clear the current control flow context, which is an 4878 additional mechanism to add control dependencies. 4879 4880 Args: 4881 graph: The graph that this controller is managing. 4882 control_inputs: List of ops to use as control inputs in addition to the 4883 current control dependencies. None to indicate that the dependencies 4884 should be cleared. 4885 """ 4886 self._graph = graph 4887 if control_inputs is None: 4888 self._control_inputs_val = [] 4889 self._new_stack = True 4890 else: 4891 self._control_inputs_val = control_inputs 4892 self._new_stack = False 4893 self._seen_nodes = set() 4894 self._old_stack = None 4895 self._old_control_flow_context = None 4896 4897# pylint: disable=protected-access 4898 4899 def __enter__(self): 4900 if self._new_stack: 4901 # Clear the control_dependencies graph. 4902 self._old_stack = self._graph._control_dependencies_stack 4903 self._graph._control_dependencies_stack = [] 4904 # Clear the control_flow_context too. 4905 self._old_control_flow_context = self._graph._get_control_flow_context() 4906 self._graph._set_control_flow_context(None) 4907 self._graph._push_control_dependencies_controller(self) 4908 4909 def __exit__(self, unused_type, unused_value, unused_traceback): 4910 self._graph._pop_control_dependencies_controller(self) 4911 if self._new_stack: 4912 self._graph._control_dependencies_stack = self._old_stack 4913 self._graph._set_control_flow_context(self._old_control_flow_context) 4914 4915# pylint: enable=protected-access 4916 4917 @property 4918 def control_inputs(self): 4919 return self._control_inputs_val 4920 4921 def add_op(self, op): 4922 if isinstance(op, Tensor): 4923 op = op.ref() 4924 self._seen_nodes.add(op) 4925 4926 def op_in_group(self, op): 4927 if isinstance(op, Tensor): 4928 op = op.ref() 4929 return op in self._seen_nodes 4930 4931 def _push_control_dependencies_controller(self, controller): 4932 self._control_dependencies_stack.append(controller) 4933 4934 def _pop_control_dependencies_controller(self, controller): 4935 assert self._control_dependencies_stack[-1] is controller 4936 self._control_dependencies_stack.pop() 4937 4938 def _current_control_dependencies(self): 4939 ret = set() 4940 for controller in self._control_dependencies_stack: 4941 for op in controller.control_inputs: 4942 ret.add(op) 4943 return ret 4944 4945 def _control_dependencies_for_inputs(self, input_ops): 4946 """For an op that takes `input_ops` as inputs, compute control inputs. 4947 4948 The returned control dependencies should yield an execution that 4949 is equivalent to adding all control inputs in 4950 self._control_dependencies_stack to a newly created op. However, 4951 this function attempts to prune the returned control dependencies 4952 by observing that nodes created within the same `with 4953 control_dependencies(...):` block may have data dependencies that make 4954 the explicit approach redundant. 4955 4956 Args: 4957 input_ops: The data input ops for an op to be created. 4958 4959 Returns: 4960 A list of control inputs for the op to be created. 4961 """ 4962 ret = [] 4963 for controller in self._control_dependencies_stack: 4964 # If any of the input_ops already depends on the inputs from controller, 4965 # we say that the new op is dominated (by that input), and we therefore 4966 # do not need to add control dependencies for this controller's inputs. 4967 dominated = False 4968 for op in input_ops: 4969 if controller.op_in_group(op): 4970 dominated = True 4971 break 4972 if not dominated: 4973 # Don't add a control input if we already have a data dependency on i. 4974 # NOTE(mrry): We do not currently track transitive data dependencies, 4975 # so we may add redundant control inputs. 4976 ret.extend(c for c in controller.control_inputs if c not in input_ops) 4977 return ret 4978 4979 def _record_op_seen_by_control_dependencies(self, op): 4980 """Record that the given op depends on all registered control dependencies. 4981 4982 Args: 4983 op: An Operation. 4984 """ 4985 for controller in self._control_dependencies_stack: 4986 controller.add_op(op) 4987 4988 def control_dependencies(self, control_inputs): 4989 """Returns a context manager that specifies control dependencies. 4990 4991 Use with the `with` keyword to specify that all operations constructed 4992 within the context should have control dependencies on 4993 `control_inputs`. For example: 4994 4995 ```python 4996 with g.control_dependencies([a, b, c]): 4997 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4998 d = ... 4999 e = ... 5000 ``` 5001 5002 Multiple calls to `control_dependencies()` can be nested, and in 5003 that case a new `Operation` will have control dependencies on the union 5004 of `control_inputs` from all active contexts. 5005 5006 ```python 5007 with g.control_dependencies([a, b]): 5008 # Ops constructed here run after `a` and `b`. 5009 with g.control_dependencies([c, d]): 5010 # Ops constructed here run after `a`, `b`, `c`, and `d`. 5011 ``` 5012 5013 You can pass None to clear the control dependencies: 5014 5015 ```python 5016 with g.control_dependencies([a, b]): 5017 # Ops constructed here run after `a` and `b`. 5018 with g.control_dependencies(None): 5019 # Ops constructed here run normally, not waiting for either `a` or `b`. 5020 with g.control_dependencies([c, d]): 5021 # Ops constructed here run after `c` and `d`, also not waiting 5022 # for either `a` or `b`. 5023 ``` 5024 5025 *N.B.* The control dependencies context applies *only* to ops that 5026 are constructed within the context. Merely using an op or tensor 5027 in the context does not add a control dependency. The following 5028 example illustrates this point: 5029 5030 ```python 5031 # WRONG 5032 def my_func(pred, tensor): 5033 t = tf.matmul(tensor, tensor) 5034 with tf.control_dependencies([pred]): 5035 # The matmul op is created outside the context, so no control 5036 # dependency will be added. 5037 return t 5038 5039 # RIGHT 5040 def my_func(pred, tensor): 5041 with tf.control_dependencies([pred]): 5042 # The matmul op is created in the context, so a control dependency 5043 # will be added. 5044 return tf.matmul(tensor, tensor) 5045 ``` 5046 5047 Also note that though execution of ops created under this scope will trigger 5048 execution of the dependencies, the ops created under this scope might still 5049 be pruned from a normal tensorflow graph. For example, in the following 5050 snippet of code the dependencies are never executed: 5051 5052 ```python 5053 loss = model.loss() 5054 with tf.control_dependencies(dependencies): 5055 loss = loss + tf.constant(1) # note: dependencies ignored in the 5056 # backward pass 5057 return tf.gradients(loss, model.variables) 5058 ``` 5059 5060 This is because evaluating the gradient graph does not require evaluating 5061 the constant(1) op created in the forward pass. 5062 5063 Args: 5064 control_inputs: A list of `Operation` or `Tensor` objects which must be 5065 executed or computed before running the operations defined in the 5066 context. Can also be `None` to clear the control dependencies. 5067 5068 Returns: 5069 A context manager that specifies control dependencies for all 5070 operations constructed within the context. 5071 5072 Raises: 5073 TypeError: If `control_inputs` is not a list of `Operation` or 5074 `Tensor` objects. 5075 """ 5076 if control_inputs is None: 5077 return self._ControlDependenciesController(self, None) 5078 # First convert the inputs to ops, and deduplicate them. 5079 # NOTE(mrry): Other than deduplication, we do not currently track direct 5080 # or indirect dependencies between control_inputs, which may result in 5081 # redundant control inputs. 5082 control_ops = [] 5083 current = self._current_control_dependencies() 5084 for c in control_inputs: 5085 # The hasattr(handle) is designed to match ResourceVariables. This is so 5086 # control dependencies on a variable or on an unread variable don't 5087 # trigger reads. 5088 if (isinstance(c, IndexedSlices) or 5089 (hasattr(c, "_handle") and hasattr(c, "op"))): 5090 c = c.op 5091 c = self.as_graph_element(c) 5092 if isinstance(c, Tensor): 5093 c = c.op 5094 elif not isinstance(c, Operation): 5095 raise TypeError("Control input must be Operation or Tensor: %s" % c) 5096 if c not in current: 5097 control_ops.append(c) 5098 current.add(c) 5099 # Mark this op with an attribute indicating that it is used as a manual 5100 # control dep in order to allow tracking how common utilization of 5101 # manual control deps in graphs run through the MLIR Bridge are. See 5102 # go/manual-control-dependencies-bridge for details. 5103 # pylint: disable=protected-access 5104 c._set_attr("_has_manual_control_dependencies", 5105 attr_value_pb2.AttrValue(b=True)) 5106 # pylint: enable=protected-access 5107 return self._ControlDependenciesController(self, control_ops) 5108 5109 # pylint: disable=g-doc-return-or-yield 5110 @tf_contextlib.contextmanager 5111 def _attr_scope(self, attr_map): 5112 """EXPERIMENTAL: A context manager for setting attributes on operators. 5113 5114 This context manager can be used to add additional 5115 attributes to operators within the scope of the context. 5116 5117 For example: 5118 5119 with ops.Graph().as_default() as g: 5120 f_1 = Foo() # No extra attributes 5121 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 5122 f_2 = Foo() # Additional attribute _a=False 5123 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 5124 f_3 = Foo() # Additional attribute _a=False 5125 with g._attr_scope({"_a": None}): 5126 f_4 = Foo() # No additional attributes. 5127 5128 Args: 5129 attr_map: A dictionary mapping attr name strings to AttrValue protocol 5130 buffers or None. 5131 5132 Returns: 5133 A context manager that sets the kernel label to be used for one or more 5134 ops created in that context. 5135 5136 Raises: 5137 TypeError: If attr_map is not a dictionary mapping 5138 strings to AttrValue protobufs. 5139 """ 5140 if not isinstance(attr_map, dict): 5141 raise TypeError("attr_map must be a dictionary mapping " 5142 "strings to AttrValue protocol buffers") 5143 # The saved_attrs dictionary stores any currently-set labels that 5144 # will be overridden by this context manager. 5145 saved_attrs = {} 5146 # Install the given attribute 5147 for name, attr in attr_map.items(): 5148 if not (isinstance(name, str) and 5149 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 5150 callable(attr))): 5151 raise TypeError("attr_map must be a dictionary mapping " 5152 "strings to AttrValue protocol buffers or " 5153 "callables that emit AttrValue protocol buffers") 5154 try: 5155 saved_attrs[name] = self._attr_scope_map[name] 5156 except KeyError: 5157 pass 5158 if attr is None: 5159 del self._attr_scope_map[name] 5160 else: 5161 self._attr_scope_map[name] = attr 5162 try: 5163 yield # The code within the context runs here. 5164 finally: 5165 # Remove the attributes set for this context, and restore any saved 5166 # attributes. 5167 for name, attr in attr_map.items(): 5168 try: 5169 self._attr_scope_map[name] = saved_attrs[name] 5170 except KeyError: 5171 del self._attr_scope_map[name] 5172 5173 # pylint: enable=g-doc-return-or-yield 5174 5175 # pylint: disable=g-doc-return-or-yield 5176 @tf_contextlib.contextmanager 5177 def _kernel_label_map(self, op_to_kernel_label_map): 5178 """EXPERIMENTAL: A context manager for setting kernel labels. 5179 5180 This context manager can be used to select particular 5181 implementations of kernels within the scope of the context. 5182 5183 For example: 5184 5185 with ops.Graph().as_default() as g: 5186 f_1 = Foo() # Uses the default registered kernel for the Foo op. 5187 with g.kernel_label_map({"Foo": "v_2"}): 5188 f_2 = Foo() # Uses the registered kernel with label "v_2" 5189 # for the Foo op. 5190 with g.kernel_label_map({"Foo": "v_3"}): 5191 f_3 = Foo() # Uses the registered kernel with label "v_3" 5192 # for the Foo op. 5193 with g.kernel_label_map({"Foo": ""}): 5194 f_4 = Foo() # Uses the default registered kernel 5195 # for the Foo op. 5196 5197 Args: 5198 op_to_kernel_label_map: A dictionary mapping op type strings to kernel 5199 label strings. 5200 5201 Returns: 5202 A context manager that sets the kernel label to be used for one or more 5203 ops created in that context. 5204 5205 Raises: 5206 TypeError: If op_to_kernel_label_map is not a dictionary mapping 5207 strings to strings. 5208 """ 5209 if not isinstance(op_to_kernel_label_map, dict): 5210 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 5211 "strings to strings") 5212 # The saved_labels dictionary stores any currently-set labels that 5213 # will be overridden by this context manager. 5214 saved_labels = {} 5215 # Install the given label 5216 for op_type, label in op_to_kernel_label_map.items(): 5217 if not (isinstance(op_type, str) and 5218 isinstance(label, str)): 5219 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 5220 "strings to strings") 5221 try: 5222 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 5223 except KeyError: 5224 pass 5225 self._op_to_kernel_label_map[op_type] = label 5226 try: 5227 yield # The code within the context runs here. 5228 finally: 5229 # Remove the labels set for this context, and restore any saved labels. 5230 for op_type, label in op_to_kernel_label_map.items(): 5231 try: 5232 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 5233 except KeyError: 5234 del self._op_to_kernel_label_map[op_type] 5235 5236 # pylint: enable=g-doc-return-or-yield 5237 5238 @tf_contextlib.contextmanager 5239 def _override_gradient_function(self, gradient_function_map): 5240 """Specify gradient function for the given op type.""" 5241 5242 # This is an internal API and we don't need nested context for this. 5243 # TODO(mdan): make it a proper context manager. 5244 assert not self._gradient_function_map 5245 self._gradient_function_map = gradient_function_map 5246 try: 5247 yield 5248 finally: 5249 self._gradient_function_map = {} 5250 5251 # pylint: disable=g-doc-return-or-yield 5252 @tf_contextlib.contextmanager 5253 def gradient_override_map(self, op_type_map): 5254 """EXPERIMENTAL: A context manager for overriding gradient functions. 5255 5256 This context manager can be used to override the gradient function 5257 that will be used for ops within the scope of the context. 5258 5259 For example: 5260 5261 ```python 5262 @tf.RegisterGradient("CustomSquare") 5263 def _custom_square_grad(op, grad): 5264 # ... 5265 5266 with tf.Graph().as_default() as g: 5267 c = tf.constant(5.0) 5268 s_1 = tf.square(c) # Uses the default gradient for tf.square. 5269 with g.gradient_override_map({"Square": "CustomSquare"}): 5270 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 5271 # gradient of s_2. 5272 ``` 5273 5274 Args: 5275 op_type_map: A dictionary mapping op type strings to alternative op type 5276 strings. 5277 5278 Returns: 5279 A context manager that sets the alternative op type to be used for one 5280 or more ops created in that context. 5281 5282 Raises: 5283 TypeError: If `op_type_map` is not a dictionary mapping strings to 5284 strings. 5285 """ 5286 if not isinstance(op_type_map, dict): 5287 raise TypeError("op_type_map must be a dictionary mapping " 5288 "strings to strings") 5289 # The saved_mappings dictionary stores any currently-set mappings that 5290 # will be overridden by this context manager. 5291 saved_mappings = {} 5292 # Install the given label 5293 for op_type, mapped_op_type in op_type_map.items(): 5294 if not (isinstance(op_type, str) and 5295 isinstance(mapped_op_type, str)): 5296 raise TypeError("op_type_map must be a dictionary mapping " 5297 "strings to strings") 5298 try: 5299 saved_mappings[op_type] = self._gradient_override_map[op_type] 5300 except KeyError: 5301 pass 5302 self._gradient_override_map[op_type] = mapped_op_type 5303 try: 5304 yield # The code within the context runs here. 5305 finally: 5306 # Remove the labels set for this context, and restore any saved labels. 5307 for op_type, mapped_op_type in op_type_map.items(): 5308 try: 5309 self._gradient_override_map[op_type] = saved_mappings[op_type] 5310 except KeyError: 5311 del self._gradient_override_map[op_type] 5312 5313 # pylint: enable=g-doc-return-or-yield 5314 5315 def prevent_feeding(self, tensor): 5316 """Marks the given `tensor` as unfeedable in this graph.""" 5317 self._unfeedable_tensors.add(tensor) 5318 5319 def is_feedable(self, tensor): 5320 """Returns `True` if and only if `tensor` is feedable.""" 5321 return tensor not in self._unfeedable_tensors 5322 5323 def prevent_fetching(self, op): 5324 """Marks the given `op` as unfetchable in this graph.""" 5325 self._unfetchable_ops.add(op) 5326 5327 def is_fetchable(self, tensor_or_op): 5328 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 5329 if isinstance(tensor_or_op, Tensor): 5330 return tensor_or_op.op not in self._unfetchable_ops 5331 else: 5332 return tensor_or_op not in self._unfetchable_ops 5333 5334 def switch_to_thread_local(self): 5335 """Make device, colocation and dependencies stacks thread-local. 5336 5337 Device, colocation and dependencies stacks are not thread-local be default. 5338 If multiple threads access them, then the state is shared. This means that 5339 one thread may affect the behavior of another thread. 5340 5341 After this method is called, the stacks become thread-local. If multiple 5342 threads access them, then the state is not shared. Each thread uses its own 5343 value; a thread doesn't affect other threads by mutating such a stack. 5344 5345 The initial value for every thread's stack is set to the current value 5346 of the stack when `switch_to_thread_local()` was first called. 5347 """ 5348 if not self._stack_state_is_thread_local: 5349 self._stack_state_is_thread_local = True 5350 5351 @property 5352 def _device_function_stack(self): 5353 if self._stack_state_is_thread_local: 5354 # This may be called from a thread where device_function_stack doesn't yet 5355 # exist. 5356 # pylint: disable=protected-access 5357 if not hasattr(self._thread_local, "_device_function_stack"): 5358 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 5359 self._thread_local._device_function_stack = stack_copy_for_this_thread 5360 return self._thread_local._device_function_stack 5361 # pylint: enable=protected-access 5362 else: 5363 return self._graph_device_function_stack 5364 5365 @property 5366 def _device_functions_outer_to_inner(self): 5367 user_device_specs = self._device_function_stack.peek_objs() 5368 device_functions = [spec.function for spec in user_device_specs] 5369 device_functions_outer_to_inner = list(reversed(device_functions)) 5370 return device_functions_outer_to_inner 5371 5372 def _snapshot_device_function_stack_metadata(self): 5373 """Return device function stack as a list of TraceableObjects. 5374 5375 Returns: 5376 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 5377 member is a displayable name for the user's argument to Graph.device, and 5378 the filename and lineno members point to the code location where 5379 Graph.device was called directly or indirectly by the user. 5380 """ 5381 snapshot = [] 5382 for obj in self._device_function_stack.peek_traceable_objs(): 5383 obj_copy = obj.copy_metadata() 5384 obj_copy.obj = obj.obj.display_name 5385 snapshot.append(obj_copy) 5386 return snapshot 5387 5388 @_device_function_stack.setter 5389 def _device_function_stack(self, device_function_stack): 5390 if self._stack_state_is_thread_local: 5391 # pylint: disable=protected-access 5392 self._thread_local._device_function_stack = device_function_stack 5393 # pylint: enable=protected-access 5394 else: 5395 self._graph_device_function_stack = device_function_stack 5396 5397 @property 5398 def _colocation_stack(self): 5399 """Return thread-local copy of colocation stack.""" 5400 if self._stack_state_is_thread_local: 5401 # This may be called from a thread where colocation_stack doesn't yet 5402 # exist. 5403 # pylint: disable=protected-access 5404 if not hasattr(self._thread_local, "_colocation_stack"): 5405 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 5406 self._thread_local._colocation_stack = stack_copy_for_this_thread 5407 return self._thread_local._colocation_stack 5408 # pylint: enable=protected-access 5409 else: 5410 return self._graph_colocation_stack 5411 5412 def _snapshot_colocation_stack_metadata(self): 5413 """Return colocation stack metadata as a dictionary.""" 5414 return { 5415 traceable_obj.obj.name: traceable_obj.copy_metadata() 5416 for traceable_obj in self._colocation_stack.peek_traceable_objs() 5417 } 5418 5419 @_colocation_stack.setter 5420 def _colocation_stack(self, colocation_stack): 5421 if self._stack_state_is_thread_local: 5422 # pylint: disable=protected-access 5423 self._thread_local._colocation_stack = colocation_stack 5424 # pylint: enable=protected-access 5425 else: 5426 self._graph_colocation_stack = colocation_stack 5427 5428 @property 5429 def _control_dependencies_stack(self): 5430 if self._stack_state_is_thread_local: 5431 # This may be called from a thread where control_dependencies_stack 5432 # doesn't yet exist. 5433 if not hasattr(self._thread_local, "_control_dependencies_stack"): 5434 self._thread_local._control_dependencies_stack = ( 5435 self._graph_control_dependencies_stack[:]) 5436 return self._thread_local._control_dependencies_stack 5437 else: 5438 return self._graph_control_dependencies_stack 5439 5440 @_control_dependencies_stack.setter 5441 def _control_dependencies_stack(self, control_dependencies): 5442 if self._stack_state_is_thread_local: 5443 self._thread_local._control_dependencies_stack = control_dependencies 5444 else: 5445 self._graph_control_dependencies_stack = control_dependencies 5446 5447 @property 5448 def _distribution_strategy_stack(self): 5449 """A stack to maintain distribution strategy context for each thread.""" 5450 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 5451 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 5452 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 5453 5454 @_distribution_strategy_stack.setter 5455 def _distribution_strategy_stack(self, _distribution_strategy_stack): 5456 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 5457 _distribution_strategy_stack) 5458 5459 @property 5460 def _global_distribute_strategy_scope(self): 5461 """For implementing `tf.distribute.set_strategy()`.""" 5462 if not hasattr(self._thread_local, "distribute_strategy_scope"): 5463 self._thread_local.distribute_strategy_scope = None 5464 return self._thread_local.distribute_strategy_scope 5465 5466 @_global_distribute_strategy_scope.setter 5467 def _global_distribute_strategy_scope(self, distribute_strategy_scope): 5468 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope) 5469 5470 def _mutation_lock(self): 5471 """Returns a lock to guard code that creates & mutates ops. 5472 5473 See the comment for self._group_lock for more info. 5474 """ 5475 return self._group_lock.group(_MUTATION_LOCK_GROUP) 5476 5477 def _session_run_lock(self): 5478 """Returns a lock to guard code for Session.run. 5479 5480 See the comment for self._group_lock for more info. 5481 """ 5482 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 5483 5484 5485# TODO(agarwal): currently device directives in an outer eager scope will not 5486# apply to inner graph mode code. Fix that. 5487 5488 5489@tf_export(v1=["device"]) 5490def device(device_name_or_function): 5491 """Wrapper for `Graph.device()` using the default graph. 5492 5493 See `tf.Graph.device` for more details. 5494 5495 Args: 5496 device_name_or_function: The device name or function to use in the context. 5497 5498 Returns: 5499 A context manager that specifies the default device to use for newly 5500 created ops. 5501 5502 Raises: 5503 RuntimeError: If eager execution is enabled and a function is passed in. 5504 """ 5505 if context.executing_eagerly(): 5506 if callable(device_name_or_function): 5507 raise RuntimeError( 5508 "tf.device does not support functions when eager execution " 5509 "is enabled.") 5510 return context.device(device_name_or_function) 5511 elif executing_eagerly_outside_functions(): 5512 @tf_contextlib.contextmanager 5513 def combined(device_name_or_function): 5514 with get_default_graph().device(device_name_or_function): 5515 if not callable(device_name_or_function): 5516 with context.device(device_name_or_function): 5517 yield 5518 else: 5519 yield 5520 return combined(device_name_or_function) 5521 else: 5522 return get_default_graph().device(device_name_or_function) 5523 5524 5525@tf_export("device", v1=[]) 5526def device_v2(device_name): 5527 """Specifies the device for ops created/executed in this context. 5528 5529 This function specifies the device to be used for ops created/executed in a 5530 particular context. Nested contexts will inherit and also create/execute 5531 their ops on the specified device. If a specific device is not required, 5532 consider not using this function so that a device can be automatically 5533 assigned. In general the use of this function is optional. `device_name` can 5534 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially 5535 specified, containing only a subset of the "/"-separated fields. Any fields 5536 which are specified will override device annotations from outer scopes. 5537 5538 For example: 5539 5540 ```python 5541 with tf.device('/job:foo'): 5542 # ops created here have devices with /job:foo 5543 with tf.device('/job:bar/task:0/device:gpu:2'): 5544 # ops created here have the fully specified device above 5545 with tf.device('/device:gpu:1'): 5546 # ops created here have the device '/job:foo/device:gpu:1' 5547 ``` 5548 5549 Args: 5550 device_name: The device name to use in the context. 5551 5552 Returns: 5553 A context manager that specifies the default device to use for newly 5554 created ops. 5555 5556 Raises: 5557 RuntimeError: If a function is passed in. 5558 """ 5559 if callable(device_name): 5560 raise RuntimeError("tf.device does not support functions.") 5561 return device(device_name) 5562 5563 5564@tf_export(v1=["container"]) 5565def container(container_name): 5566 """Wrapper for `Graph.container()` using the default graph. 5567 5568 Args: 5569 container_name: The container string to use in the context. 5570 5571 Returns: 5572 A context manager that specifies the default container to use for newly 5573 created stateful ops. 5574 """ 5575 return get_default_graph().container(container_name) 5576 5577 5578def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 5579 if context.executing_eagerly(): 5580 if op is not None: 5581 if not hasattr(op, "device"): 5582 op = internal_convert_to_tensor_or_indexed_slices(op) 5583 return device(op.device) 5584 else: 5585 return NullContextmanager() 5586 else: 5587 default_graph = get_default_graph() 5588 if isinstance(op, EagerTensor): 5589 if default_graph.building_function: 5590 return default_graph.device(op.device) 5591 else: 5592 raise ValueError("Encountered an Eager-defined Tensor during graph " 5593 "construction, but a function was not being built.") 5594 return default_graph._colocate_with_for_gradient( 5595 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 5596 5597 5598# Internal interface to colocate_with. colocate_with has been deprecated from 5599# public API. There are still a few internal uses of colocate_with. Add internal 5600# only API for those uses to avoid deprecation warning. 5601def colocate_with(op, ignore_existing=False): 5602 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 5603 5604 5605@deprecation.deprecated( 5606 date=None, instructions="Colocations handled automatically by placer.") 5607@tf_export(v1=["colocate_with"]) 5608def _colocate_with(op, ignore_existing=False): 5609 return colocate_with(op, ignore_existing) 5610 5611 5612@tf_export("control_dependencies") 5613def control_dependencies(control_inputs): 5614 """Wrapper for `Graph.control_dependencies()` using the default graph. 5615 5616 See `tf.Graph.control_dependencies` for more details. 5617 5618 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 5619 this method, as ops execute in the expected order thanks to automatic control 5620 dependencies.* Only use `tf.control_dependencies` when working with v1 5621 `tf.Graph` code. 5622 5623 When eager execution is enabled, any callable object in the `control_inputs` 5624 list will be called. 5625 5626 Args: 5627 control_inputs: A list of `Operation` or `Tensor` objects which must be 5628 executed or computed before running the operations defined in the context. 5629 Can also be `None` to clear the control dependencies. If eager execution 5630 is enabled, any callable object in the `control_inputs` list will be 5631 called. 5632 5633 Returns: 5634 A context manager that specifies control dependencies for all 5635 operations constructed within the context. 5636 """ 5637 if context.executing_eagerly(): 5638 if control_inputs: 5639 # Execute any pending callables. 5640 for control in control_inputs: 5641 if callable(control): 5642 control() 5643 return NullContextmanager() 5644 else: 5645 return get_default_graph().control_dependencies(control_inputs) 5646 5647 5648class _DefaultStack(threading.local): 5649 """A thread-local stack of objects for providing implicit defaults.""" 5650 5651 def __init__(self): 5652 super(_DefaultStack, self).__init__() 5653 self._enforce_nesting = True 5654 self.stack = [] 5655 5656 def get_default(self): 5657 return self.stack[-1] if self.stack else None 5658 5659 def reset(self): 5660 self.stack = [] 5661 5662 def is_cleared(self): 5663 return not self.stack 5664 5665 @property 5666 def enforce_nesting(self): 5667 return self._enforce_nesting 5668 5669 @enforce_nesting.setter 5670 def enforce_nesting(self, value): 5671 self._enforce_nesting = value 5672 5673 @tf_contextlib.contextmanager 5674 def get_controller(self, default): 5675 """A context manager for manipulating a default stack.""" 5676 self.stack.append(default) 5677 try: 5678 yield default 5679 finally: 5680 # stack may be empty if reset() was called 5681 if self.stack: 5682 if self._enforce_nesting: 5683 if self.stack[-1] is not default: 5684 raise AssertionError( 5685 "Nesting violated for default stack of %s objects" % 5686 type(default)) 5687 self.stack.pop() 5688 else: 5689 self.stack.remove(default) 5690 5691 5692_default_session_stack = _DefaultStack() # pylint: disable=protected-access 5693 5694 5695def default_session(session): 5696 """Python "with" handler for defining a default session. 5697 5698 This function provides a means of registering a session for handling 5699 Tensor.eval() and Operation.run() calls. It is primarily intended for use 5700 by session.Session, but can be used with any object that implements 5701 the Session.run() interface. 5702 5703 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 5704 invocations within the scope of a block should be executed by a particular 5705 session. 5706 5707 The default session applies to the current thread only, so it is always 5708 possible to inspect the call stack and determine the scope of a default 5709 session. If you create a new thread, and wish to use the default session 5710 in that thread, you must explicitly add a "with ops.default_session(sess):" 5711 block in that thread's function. 5712 5713 Example: 5714 The following code examples are equivalent: 5715 5716 # 1. Using the Session object directly: 5717 sess = ... 5718 c = tf.constant(5.0) 5719 sess.run(c) 5720 5721 # 2. Using default_session(): 5722 sess = ... 5723 with ops.default_session(sess): 5724 c = tf.constant(5.0) 5725 result = c.eval() 5726 5727 # 3. Overriding default_session(): 5728 sess = ... 5729 with ops.default_session(sess): 5730 c = tf.constant(5.0) 5731 with ops.default_session(...): 5732 c.eval(session=sess) 5733 5734 Args: 5735 session: The session to be installed as the default session. 5736 5737 Returns: 5738 A context manager for the default session. 5739 """ 5740 return _default_session_stack.get_controller(session) 5741 5742 5743@tf_export(v1=["get_default_session"]) 5744def get_default_session(): 5745 """Returns the default session for the current thread. 5746 5747 The returned `Session` will be the innermost session on which a 5748 `Session` or `Session.as_default()` context has been entered. 5749 5750 NOTE: The default session is a property of the current thread. If you 5751 create a new thread, and wish to use the default session in that 5752 thread, you must explicitly add a `with sess.as_default():` in that 5753 thread's function. 5754 5755 Returns: 5756 The default `Session` being used in the current thread. 5757 """ 5758 return _default_session_stack.get_default() 5759 5760 5761def _eval_using_default_session(tensors, feed_dict, graph, session=None): 5762 """Uses the default session to evaluate one or more tensors. 5763 5764 Args: 5765 tensors: A single Tensor, or a list of Tensor objects. 5766 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5767 numpy ndarrays, TensorProtos, or strings. 5768 graph: The graph in which the tensors are defined. 5769 session: (Optional) A different session to use to evaluate "tensors". 5770 5771 Returns: 5772 Either a single numpy ndarray if "tensors" is a single tensor; or a list 5773 of numpy ndarrays that each correspond to the respective element in 5774 "tensors". 5775 5776 Raises: 5777 ValueError: If no default session is available; the default session 5778 does not have "graph" as its graph; or if "session" is specified, 5779 and it does not have "graph" as its graph. 5780 """ 5781 if session is None: 5782 session = get_default_session() 5783 if session is None: 5784 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 5785 "session is registered. Use `with " 5786 "sess.as_default()` or pass an explicit session to " 5787 "`eval(session=sess)`") 5788 if session.graph is not graph: 5789 raise ValueError("Cannot use the default session to evaluate tensor: " 5790 "the tensor's graph is different from the session's " 5791 "graph. Pass an explicit session to " 5792 "`eval(session=sess)`.") 5793 else: 5794 if session.graph is not graph: 5795 raise ValueError("Cannot use the given session to evaluate tensor: " 5796 "the tensor's graph is different from the session's " 5797 "graph.") 5798 return session.run(tensors, feed_dict) 5799 5800 5801def _run_using_default_session(operation, feed_dict, graph, session=None): 5802 """Uses the default session to run "operation". 5803 5804 Args: 5805 operation: The Operation to be run. 5806 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5807 numpy ndarrays, TensorProtos, or strings. 5808 graph: The graph in which "operation" is defined. 5809 session: (Optional) A different session to use to run "operation". 5810 5811 Raises: 5812 ValueError: If no default session is available; the default session 5813 does not have "graph" as its graph; or if "session" is specified, 5814 and it does not have "graph" as its graph. 5815 """ 5816 if session is None: 5817 session = get_default_session() 5818 if session is None: 5819 raise ValueError("Cannot execute operation using `run()`: No default " 5820 "session is registered. Use `with " 5821 "sess.as_default():` or pass an explicit session to " 5822 "`run(session=sess)`") 5823 if session.graph is not graph: 5824 raise ValueError("Cannot use the default session to execute operation: " 5825 "the operation's graph is different from the " 5826 "session's graph. Pass an explicit session to " 5827 "run(session=sess).") 5828 else: 5829 if session.graph is not graph: 5830 raise ValueError("Cannot use the given session to execute operation: " 5831 "the operation's graph is different from the session's " 5832 "graph.") 5833 session.run(operation, feed_dict) 5834 5835 5836class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 5837 """A thread-local stack of objects for providing an implicit default graph.""" 5838 5839 def __init__(self): 5840 super(_DefaultGraphStack, self).__init__() 5841 self._global_default_graph = None 5842 5843 def get_default(self): 5844 """Override that returns a global default if the stack is empty.""" 5845 if self.stack: 5846 return self.stack[-1] 5847 elif self._global_default_graph: 5848 return self._global_default_graph 5849 else: 5850 self._global_default_graph = Graph() 5851 return self._global_default_graph 5852 5853 def _GetGlobalDefaultGraph(self): 5854 if self._global_default_graph is None: 5855 # TODO(mrry): Perhaps log that the default graph is being used, or set 5856 # provide some other feedback to prevent confusion when a mixture of 5857 # the global default graph and an explicit graph are combined in the 5858 # same process. 5859 self._global_default_graph = Graph() 5860 return self._global_default_graph 5861 5862 def reset(self): 5863 super(_DefaultGraphStack, self).reset() 5864 self._global_default_graph = None 5865 5866 @tf_contextlib.contextmanager 5867 def get_controller(self, default): 5868 context.context().context_switches.push(default.building_function, 5869 default.as_default, 5870 default._device_function_stack) 5871 try: 5872 with super(_DefaultGraphStack, 5873 self).get_controller(default) as g, context.graph_mode(): 5874 yield g 5875 finally: 5876 # If an exception is raised here it may be hiding a related exception in 5877 # the try-block (just above). 5878 context.context().context_switches.pop() 5879 5880 5881_default_graph_stack = _DefaultGraphStack() 5882 5883 5884# Shared helper used in init_scope and executing_eagerly_outside_functions 5885# to obtain the outermost context that is not building a function, and the 5886# innermost non empty device stack. 5887def _get_outer_context_and_inner_device_stack(): 5888 """Get the outermost context not building a function.""" 5889 default_graph = get_default_graph() 5890 outer_context = None 5891 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 5892 5893 if not _default_graph_stack.stack: 5894 # If the default graph stack is empty, then we cannot be building a 5895 # function. Install the global graph (which, in this case, is also the 5896 # default graph) as the outer context. 5897 if default_graph.building_function: 5898 raise RuntimeError("The global graph is building a function.") 5899 outer_context = default_graph.as_default 5900 else: 5901 # Find a context that is not building a function. 5902 for stack_entry in reversed(context.context().context_switches.stack): 5903 if not innermost_nonempty_device_stack: 5904 innermost_nonempty_device_stack = stack_entry.device_stack 5905 if not stack_entry.is_building_function: 5906 outer_context = stack_entry.enter_context_fn 5907 break 5908 5909 if outer_context is None: 5910 # As a last resort, obtain the global default graph; this graph doesn't 5911 # necessarily live on the graph stack (and hence it doesn't necessarily 5912 # live on the context stack), but it is stored in the graph stack's 5913 # encapsulating object. 5914 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 5915 5916 if outer_context is None: 5917 # Sanity check; this shouldn't be triggered. 5918 raise RuntimeError("All graphs are building functions, and no " 5919 "eager context was previously active.") 5920 5921 return outer_context, innermost_nonempty_device_stack 5922 5923 5924# pylint: disable=g-doc-return-or-yield,line-too-long 5925@tf_export("init_scope") 5926@tf_contextlib.contextmanager 5927def init_scope(): 5928 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 5929 5930 There is often a need to lift variable initialization ops out of control-flow 5931 scopes, function-building graphs, and gradient tapes. Entering an 5932 `init_scope` is a mechanism for satisfying these desiderata. In particular, 5933 entering an `init_scope` has three effects: 5934 5935 (1) All control dependencies are cleared the moment the scope is entered; 5936 this is equivalent to entering the context manager returned from 5937 `control_dependencies(None)`, which has the side-effect of exiting 5938 control-flow scopes like `tf.cond` and `tf.while_loop`. 5939 5940 (2) All operations that are created while the scope is active are lifted 5941 into the lowest context on the `context_stack` that is not building a 5942 graph function. Here, a context is defined as either a graph or an eager 5943 context. Every context switch, i.e., every installation of a graph as 5944 the default graph and every switch into eager mode, is logged in a 5945 thread-local stack called `context_switches`; the log entry for a 5946 context switch is popped from the stack when the context is exited. 5947 Entering an `init_scope` is equivalent to crawling up 5948 `context_switches`, finding the first context that is not building a 5949 graph function, and entering it. A caveat is that if graph mode is 5950 enabled but the default graph stack is empty, then entering an 5951 `init_scope` will simply install a fresh graph as the default one. 5952 5953 (3) The gradient tape is paused while the scope is active. 5954 5955 When eager execution is enabled, code inside an init_scope block runs with 5956 eager execution enabled even when tracing a `tf.function`. For example: 5957 5958 ```python 5959 tf.compat.v1.enable_eager_execution() 5960 5961 @tf.function 5962 def func(): 5963 # A function constructs TensorFlow graphs, 5964 # it does not execute eagerly. 5965 assert not tf.executing_eagerly() 5966 with tf.init_scope(): 5967 # Initialization runs with eager execution enabled 5968 assert tf.executing_eagerly() 5969 ``` 5970 5971 Raises: 5972 RuntimeError: if graph state is incompatible with this initialization. 5973 """ 5974 # pylint: enable=g-doc-return-or-yield,line-too-long 5975 5976 if context.executing_eagerly(): 5977 # Fastpath. 5978 with tape.stop_recording(): 5979 yield 5980 else: 5981 # Retrieve the active name scope: entering an `init_scope` preserves 5982 # the name scope of the current context. 5983 scope = get_default_graph().get_name_scope() 5984 if scope and scope[-1] != "/": 5985 # Names that end with trailing slashes are treated by `name_scope` as 5986 # absolute. 5987 scope = scope + "/" 5988 5989 outer_context, innermost_nonempty_device_stack = ( 5990 _get_outer_context_and_inner_device_stack()) 5991 5992 outer_graph = None 5993 outer_device_stack = None 5994 try: 5995 with outer_context(), name_scope( 5996 scope, skip_on_eager=False), control_dependencies( 5997 None), tape.stop_recording(): 5998 context_manager = NullContextmanager 5999 context_manager_input = None 6000 if not context.executing_eagerly(): 6001 # The device stack is preserved when lifting into a graph. Eager 6002 # execution doesn't implement device stacks and in particular it 6003 # doesn't support device functions, so in general it's not possible 6004 # to do the same when lifting into the eager context. 6005 outer_graph = get_default_graph() 6006 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 6007 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 6008 elif innermost_nonempty_device_stack is not None: 6009 for device_spec in innermost_nonempty_device_stack.peek_objs(): 6010 if device_spec.function is None: 6011 break 6012 if device_spec.raw_string: 6013 context_manager = context.device 6014 context_manager_input = device_spec.raw_string 6015 break 6016 # It is currently not possible to have a device function in V2, 6017 # but in V1 we are unable to apply device functions in eager mode. 6018 # This means that we will silently skip some of the entries on the 6019 # device stack in V1 + eager mode. 6020 6021 with context_manager(context_manager_input): 6022 yield 6023 finally: 6024 # If an exception is raised here it may be hiding a related exception in 6025 # try-block (just above). 6026 if outer_graph is not None: 6027 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 6028 6029 6030@tf_export(v1=["executing_eagerly_outside_functions"]) 6031def executing_eagerly_outside_functions(): 6032 """Returns True if executing eagerly, even if inside a graph function. 6033 6034 This function will check the outermost context for the program and see if 6035 it is in eager mode. It is useful comparing to `tf.executing_eagerly()`, 6036 which checks the current context and will return `False` within a 6037 `tf.function` body. It can be used to build library that behave differently 6038 in eager runtime and v1 session runtime (deprecated). 6039 6040 Example: 6041 6042 >>> tf.compat.v1.enable_eager_execution() 6043 >>> @tf.function 6044 ... def func(): 6045 ... # A function constructs TensorFlow graphs, it does not execute eagerly, 6046 ... # but the outer most context is still eager. 6047 ... assert not tf.executing_eagerly() 6048 ... return tf.compat.v1.executing_eagerly_outside_functions() 6049 >>> func() 6050 <tf.Tensor: shape=(), dtype=bool, numpy=True> 6051 6052 Returns: 6053 boolean, whether the outermost context is in eager mode. 6054 """ 6055 if context.executing_eagerly(): 6056 return True 6057 else: 6058 outer_context, _ = _get_outer_context_and_inner_device_stack() 6059 with outer_context(): 6060 return context.executing_eagerly() 6061 6062 6063@tf_export("inside_function", v1=[]) 6064def inside_function(): 6065 """Indicates whether the caller code is executing inside a `tf.function`. 6066 6067 Returns: 6068 Boolean, True if the caller code is executing inside a `tf.function` 6069 rather than eagerly. 6070 6071 Example: 6072 6073 >>> tf.inside_function() 6074 False 6075 >>> @tf.function 6076 ... def f(): 6077 ... print(tf.inside_function()) 6078 >>> f() 6079 True 6080 """ 6081 return get_default_graph().building_function 6082 6083 6084@tf_export(v1=["enable_eager_execution"]) 6085def enable_eager_execution(config=None, device_policy=None, 6086 execution_mode=None): 6087 """Enables eager execution for the lifetime of this program. 6088 6089 Eager execution provides an imperative interface to TensorFlow. With eager 6090 execution enabled, TensorFlow functions execute operations immediately (as 6091 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`) 6092 and 6093 return concrete values (as opposed to symbolic references to a node in a 6094 computational graph). 6095 6096 For example: 6097 6098 ```python 6099 tf.compat.v1.enable_eager_execution() 6100 6101 # After eager execution is enabled, operations are executed as they are 6102 # defined and Tensor objects hold concrete values, which can be accessed as 6103 # numpy.ndarray`s through the numpy() method. 6104 assert tf.multiply(6, 7).numpy() == 42 6105 ``` 6106 6107 Eager execution cannot be enabled after TensorFlow APIs have been used to 6108 create or execute graphs. It is typically recommended to invoke this function 6109 at program startup and not in a library (as most libraries should be usable 6110 both with and without eager execution). 6111 6112 @compatibility(TF2) 6113 This function is not necessary if you are using TF2. Eager execution is 6114 enabled by default. 6115 @end_compatibility 6116 6117 Args: 6118 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the 6119 environment in which operations are executed. Note that 6120 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via 6121 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto` 6122 are not implemented (or are irrelevant) when eager execution is enabled. 6123 device_policy: (Optional.) Policy controlling how operations requiring 6124 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 6125 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will 6126 be picked automatically. The value picked may change between TensorFlow 6127 releases. 6128 Valid values: 6129 - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the 6130 placement is not correct. 6131 - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not 6132 on the right device but logs a warning. 6133 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. 6134 Note that this may hide performance problems as there is no notification 6135 provided when operations are blocked on the tensor being copied between 6136 devices. 6137 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 6138 int32 tensors, raising errors on the other ones. 6139 execution_mode: (Optional.) Policy controlling how operations dispatched are 6140 actually executed. When set to None, an appropriate value will be picked 6141 automatically. The value picked may change between TensorFlow releases. 6142 Valid values: 6143 - tf.contrib.eager.SYNC: executes each operation synchronously. 6144 - tf.contrib.eager.ASYNC: executes each operation asynchronously. These 6145 operations may return "non-ready" handles. 6146 6147 Raises: 6148 ValueError: If eager execution is enabled after creating/executing a 6149 TensorFlow graph, or if options provided conflict with a previous call 6150 to this function. 6151 """ 6152 _api_usage_gauge.get_cell().set(True) 6153 logging.vlog(1, "Enabling eager execution") 6154 if context.default_execution_mode != context.EAGER_MODE: 6155 return enable_eager_execution_internal( 6156 config=config, 6157 device_policy=device_policy, 6158 execution_mode=execution_mode, 6159 server_def=None) 6160 6161 6162@tf_export(v1=["disable_eager_execution"]) 6163def disable_eager_execution(): 6164 """Disables eager execution. 6165 6166 This function can only be called before any Graphs, Ops, or Tensors have been 6167 created. 6168 6169 @compatibility(TF2) 6170 This function is not necessary if you are using TF2. Eager execution is 6171 enabled by default. If you want to use Graph mode please consider 6172 [tf.function](https://www.tensorflow.org/api_docs/python/tf/function). 6173 @end_compatibility 6174 """ 6175 _api_usage_gauge.get_cell().set(False) 6176 logging.vlog(1, "Disabling eager execution") 6177 context.default_execution_mode = context.GRAPH_MODE 6178 c = context.context_safe() 6179 if c is not None: 6180 c._thread_local_data.is_eager = False # pylint: disable=protected-access 6181 6182 6183def enable_eager_execution_internal(config=None, 6184 device_policy=None, 6185 execution_mode=None, 6186 server_def=None): 6187 """Enables eager execution for the lifetime of this program. 6188 6189 Most of the doc string for enable_eager_execution is relevant here as well. 6190 6191 Args: 6192 config: See enable_eager_execution doc string 6193 device_policy: See enable_eager_execution doc string 6194 execution_mode: See enable_eager_execution doc string 6195 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on 6196 remote devices. GrpcServers need to be started by creating an identical 6197 server_def to this, and setting the appropriate task_indexes, so that the 6198 servers can communicate. It will then be possible to execute operations on 6199 remote devices. 6200 6201 Raises: 6202 ValueError 6203 6204 """ 6205 if config is not None and not isinstance(config, config_pb2.ConfigProto): 6206 raise TypeError("config must be a tf.ConfigProto, but got %s" % 6207 type(config)) 6208 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 6209 context.DEVICE_PLACEMENT_WARN, 6210 context.DEVICE_PLACEMENT_SILENT, 6211 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 6212 raise ValueError( 6213 "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" 6214 ) 6215 if execution_mode not in (None, context.SYNC, context.ASYNC): 6216 raise ValueError( 6217 "execution_mode must be one of None, tf.contrib.eager.SYNC, " 6218 "tf.contrib.eager.ASYNC") 6219 if context.default_execution_mode == context.GRAPH_MODE: 6220 graph_mode_has_been_used = ( 6221 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 6222 if graph_mode_has_been_used: 6223 raise ValueError( 6224 "tf.enable_eager_execution must be called at program startup.") 6225 context.default_execution_mode = context.EAGER_MODE 6226 # pylint: disable=protected-access 6227 with context._context_lock: 6228 if context._context is None: 6229 context._set_context_locked(context.Context( 6230 config=config, 6231 device_policy=device_policy, 6232 execution_mode=execution_mode, 6233 server_def=server_def)) 6234 elif ((config is not None and config is not context._context._config) or 6235 (device_policy is not None and 6236 device_policy is not context._context._device_policy) or 6237 (execution_mode is not None and 6238 execution_mode is not context._context._execution_mode)): 6239 raise ValueError( 6240 "Trying to change the options of an active eager" 6241 " execution. Context config: %s, specified config:" 6242 " %s. Context device policy: %s, specified device" 6243 " policy: %s. Context execution mode: %s, " 6244 " specified execution mode %s." % 6245 (context._context._config, config, context._context._device_policy, 6246 device_policy, context._context._execution_mode, execution_mode)) 6247 else: 6248 # We already created everything, so update the thread local data. 6249 context._context._thread_local_data.is_eager = True 6250 6251 # Monkey patch to get rid of an unnecessary conditional since the context is 6252 # now initialized. 6253 context.context = context.context_safe 6254 6255 6256def eager_run(main=None, argv=None): 6257 """Runs the program with an optional main function and argv list. 6258 6259 The program will run with eager execution enabled. 6260 6261 Example: 6262 ```python 6263 import tensorflow as tf 6264 # Import subject to future changes: 6265 from tensorflow.contrib.eager.python import tfe 6266 6267 def main(_): 6268 u = tf.constant(6.0) 6269 v = tf.constant(7.0) 6270 print(u * v) 6271 6272 if __name__ == "__main__": 6273 tfe.run() 6274 ``` 6275 6276 Args: 6277 main: the main function to run. 6278 argv: the arguments to pass to it. 6279 """ 6280 enable_eager_execution() 6281 app.run(main, argv) 6282 6283 6284@tf_export(v1=["reset_default_graph"]) 6285def reset_default_graph(): 6286 """Clears the default graph stack and resets the global default graph. 6287 6288 NOTE: The default graph is a property of the current thread. This 6289 function applies only to the current thread. Calling this function while 6290 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will 6291 result in undefined 6292 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 6293 after calling this function will result in undefined behavior. 6294 6295 @compatibility(TF2) 6296 `reset_default_graph` does not work with either eager execution or 6297 `tf.function`, and you should not invoke it directly. To migrate code that 6298 uses Graph-related functions to TF2, rewrite the code without them. See the 6299 [migration guide](https://www.tensorflow.org/guide/migrate) for more 6300 description about the behavior and semantic changes between Tensorflow 1 and 6301 Tensorflow 2. 6302 @end_compatibility 6303 6304 Raises: 6305 AssertionError: If this function is called within a nested graph. 6306 """ 6307 if not _default_graph_stack.is_cleared(): 6308 raise AssertionError("Do not use tf.reset_default_graph() to clear " 6309 "nested graphs. If you need a cleared graph, " 6310 "exit the nesting and create a new graph.") 6311 _default_graph_stack.reset() 6312 6313 6314@tf_export(v1=["get_default_graph"]) 6315def get_default_graph(): 6316 """Returns the default graph for the current thread. 6317 6318 The returned graph will be the innermost graph on which a 6319 `Graph.as_default()` context has been entered, or a global default 6320 graph if none has been explicitly created. 6321 6322 NOTE: The default graph is a property of the current thread. If you 6323 create a new thread, and wish to use the default graph in that 6324 thread, you must explicitly add a `with g.as_default():` in that 6325 thread's function. 6326 6327 @compatibility(TF2) 6328 `get_default_graph` does not work with either eager execution or 6329 `tf.function`, and you should not invoke it directly. To migrate code that 6330 uses Graph-related functions to TF2, rewrite the code without them. See the 6331 [migration guide](https://www.tensorflow.org/guide/migrate) for more 6332 description about the behavior and semantic changes between Tensorflow 1 and 6333 Tensorflow 2. 6334 @end_compatibility 6335 6336 Returns: 6337 The default `Graph` being used in the current thread. 6338 """ 6339 return _default_graph_stack.get_default() 6340 6341 6342def has_default_graph(): 6343 """Returns True if there is a default graph.""" 6344 return len(_default_graph_stack.stack) >= 1 6345 6346 6347# Exported due to b/171079555 6348@tf_export("__internal__.get_name_scope", v1=[]) 6349def get_name_scope(): 6350 """Returns the current name scope in the default_graph. 6351 6352 For example: 6353 6354 ```python 6355 with tf.name_scope('scope1'): 6356 with tf.name_scope('scope2'): 6357 print(tf.get_name_scope()) 6358 ``` 6359 would print the string `scope1/scope2`. 6360 6361 Returns: 6362 A string representing the current name scope. 6363 """ 6364 if context.executing_eagerly(): 6365 return context.context().scope_name.rstrip("/") 6366 return get_default_graph().get_name_scope() 6367 6368 6369def _assert_same_graph(original_item, item): 6370 """Fail if the 2 items are from different graphs. 6371 6372 Args: 6373 original_item: Original item to check against. 6374 item: Item to check. 6375 6376 Raises: 6377 ValueError: if graphs do not match. 6378 """ 6379 original_graph = getattr(original_item, "graph", None) 6380 graph = getattr(item, "graph", None) 6381 if original_graph and graph and original_graph is not graph: 6382 raise ValueError( 6383 "%s must be from the same graph as %s (graphs are %s and %s)." % 6384 (item, original_item, graph, original_graph)) 6385 6386 6387def _get_graph_from_inputs(op_input_list, graph=None): 6388 """Returns the appropriate graph to use for the given inputs. 6389 6390 This library method provides a consistent algorithm for choosing the graph 6391 in which an Operation should be constructed: 6392 6393 1. If the default graph is being used to construct a function, we 6394 use the default graph. 6395 2. If the "graph" is specified explicitly, we validate that all of the inputs 6396 in "op_input_list" are compatible with that graph. 6397 3. Otherwise, we attempt to select a graph from the first Operation- 6398 or Tensor-valued input in "op_input_list", and validate that all other 6399 such inputs are in the same graph. 6400 4. If the graph was not specified and it could not be inferred from 6401 "op_input_list", we attempt to use the default graph. 6402 6403 Args: 6404 op_input_list: A list of inputs to an operation, which may include `Tensor`, 6405 `Operation`, and other objects that may be converted to a graph element. 6406 graph: (Optional) The explicit graph to use. 6407 6408 Raises: 6409 TypeError: If op_input_list is not a list or tuple, or if graph is not a 6410 Graph. 6411 ValueError: If a graph is explicitly passed and not all inputs are from it, 6412 or if the inputs are from multiple graphs, or we could not find a graph 6413 and there was no default graph. 6414 6415 Returns: 6416 The appropriate graph to use for the given inputs. 6417 6418 """ 6419 current_default_graph = get_default_graph() 6420 if current_default_graph.building_function: 6421 return current_default_graph 6422 6423 op_input_list = tuple(op_input_list) # Handle generators correctly 6424 if graph and not isinstance(graph, Graph): 6425 raise TypeError("Input graph needs to be a Graph: %s" % (graph,)) 6426 6427 # 1. We validate that all of the inputs are from the same graph. This is 6428 # either the supplied graph parameter, or the first one selected from one 6429 # the graph-element-valued inputs. In the latter case, we hold onto 6430 # that input in original_graph_element so we can provide a more 6431 # informative error if a mismatch is found. 6432 original_graph_element = None 6433 for op_input in op_input_list: 6434 # Determine if this is a valid graph_element. 6435 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 6436 # up. 6437 graph_element = None 6438 if (isinstance(op_input, (Operation, internal.NativeObject)) and 6439 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 6440 graph_element = op_input 6441 else: 6442 graph_element = _as_graph_element(op_input) 6443 6444 if graph_element is not None: 6445 if not graph: 6446 original_graph_element = graph_element 6447 graph = getattr(graph_element, "graph", None) 6448 elif original_graph_element is not None: 6449 _assert_same_graph(original_graph_element, graph_element) 6450 elif graph_element.graph is not graph: 6451 raise ValueError("%s is not from the passed-in graph." % graph_element) 6452 6453 # 2. If all else fails, we use the default graph, which is always there. 6454 return graph or current_default_graph 6455 6456 6457@tf_export(v1=["GraphKeys"]) 6458class GraphKeys(object): 6459 """Standard names to use for graph collections. 6460 6461 The standard library uses various well-known names to collect and 6462 retrieve values associated with a graph. For example, the 6463 `tf.Optimizer` subclasses default to optimizing the variables 6464 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 6465 specified, but it is also possible to pass an explicit list of 6466 variables. 6467 6468 The following standard keys are defined: 6469 6470 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 6471 across distributed environment (model variables are subset of these). See 6472 `tf.compat.v1.global_variables` 6473 for more details. 6474 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 6475 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 6476 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 6477 machine. Usually used for temporarily variables, like counters. 6478 Note: use `tf.contrib.framework.local_variable` to add to this collection. 6479 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 6480 model for inference (feed forward). Note: use 6481 `tf.contrib.framework.model_variable` to add to this collection. 6482 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 6483 be trained by an optimizer. See 6484 `tf.compat.v1.trainable_variables` 6485 for more details. 6486 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 6487 graph. See 6488 `tf.compat.v1.summary.merge_all` 6489 for more details. 6490 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 6491 produce input for a computation. See 6492 `tf.compat.v1.train.start_queue_runners` 6493 for more details. 6494 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 6495 keep moving averages. See 6496 `tf.compat.v1.moving_average_variables` 6497 for more details. 6498 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 6499 construction. 6500 6501 The following standard keys are _defined_, but their collections are **not** 6502 automatically populated as many of the others are: 6503 6504 * `WEIGHTS` 6505 * `BIASES` 6506 * `ACTIVATIONS` 6507 """ 6508 6509 # Key to collect Variable objects that are global (shared across machines). 6510 # Default collection for all variables, except local ones. 6511 GLOBAL_VARIABLES = "variables" 6512 # Key to collect local variables that are local to the machine and are not 6513 # saved/restored. 6514 LOCAL_VARIABLES = "local_variables" 6515 # Key to collect local variables which are used to accumulate interal state 6516 # to be used in tf.metrics.*. 6517 METRIC_VARIABLES = "metric_variables" 6518 # Key to collect model variables defined by layers. 6519 MODEL_VARIABLES = "model_variables" 6520 # Key to collect Variable objects that will be trained by the 6521 # optimizers. 6522 TRAINABLE_VARIABLES = "trainable_variables" 6523 # Key to collect summaries. 6524 SUMMARIES = "summaries" 6525 # Key to collect QueueRunners. 6526 QUEUE_RUNNERS = "queue_runners" 6527 # Key to collect table initializers. 6528 TABLE_INITIALIZERS = "table_initializer" 6529 # Key to collect asset filepaths. An asset represents an external resource 6530 # like a vocabulary file. 6531 ASSET_FILEPATHS = "asset_filepaths" 6532 # Key to collect Variable objects that keep moving averages. 6533 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 6534 # Key to collect regularization losses at graph construction. 6535 REGULARIZATION_LOSSES = "regularization_losses" 6536 # Key to collect concatenated sharded variables. 6537 CONCATENATED_VARIABLES = "concatenated_variables" 6538 # Key to collect savers. 6539 SAVERS = "savers" 6540 # Key to collect weights 6541 WEIGHTS = "weights" 6542 # Key to collect biases 6543 BIASES = "biases" 6544 # Key to collect activations 6545 ACTIVATIONS = "activations" 6546 # Key to collect update_ops 6547 UPDATE_OPS = "update_ops" 6548 # Key to collect losses 6549 LOSSES = "losses" 6550 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 6551 SAVEABLE_OBJECTS = "saveable_objects" 6552 # Key to collect all shared resources used by the graph which need to be 6553 # initialized once per cluster. 6554 RESOURCES = "resources" 6555 # Key to collect all shared resources used in this graph which need to be 6556 # initialized once per session. 6557 LOCAL_RESOURCES = "local_resources" 6558 # Trainable resource-style variables. 6559 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 6560 6561 # Key to indicate various ops. 6562 INIT_OP = "init_op" 6563 LOCAL_INIT_OP = "local_init_op" 6564 READY_OP = "ready_op" 6565 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 6566 SUMMARY_OP = "summary_op" 6567 GLOBAL_STEP = "global_step" 6568 6569 # Used to count the number of evaluations performed during a single evaluation 6570 # run. 6571 EVAL_STEP = "eval_step" 6572 TRAIN_OP = "train_op" 6573 6574 # Key for control flow context. 6575 COND_CONTEXT = "cond_context" 6576 WHILE_CONTEXT = "while_context" 6577 6578 # Used to store v2 summary names. 6579 _SUMMARY_COLLECTION = "_SUMMARY_V2" 6580 6581 # List of all collections that keep track of variables. 6582 _VARIABLE_COLLECTIONS = [ 6583 GLOBAL_VARIABLES, 6584 LOCAL_VARIABLES, 6585 METRIC_VARIABLES, 6586 MODEL_VARIABLES, 6587 TRAINABLE_VARIABLES, 6588 MOVING_AVERAGE_VARIABLES, 6589 CONCATENATED_VARIABLES, 6590 TRAINABLE_RESOURCE_VARIABLES, 6591 ] 6592 6593 # Key for streaming model ports. 6594 # NOTE(yuanbyu): internal and experimental. 6595 _STREAMING_MODEL_PORTS = "streaming_model_ports" 6596 6597 @decorator_utils.classproperty 6598 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 6599 def VARIABLES(cls): # pylint: disable=no-self-argument 6600 return cls.GLOBAL_VARIABLES 6601 6602 6603def dismantle_graph(graph): 6604 """Cleans up reference cycles from a `Graph`. 6605 6606 Helpful for making sure the garbage collector doesn't need to run after a 6607 temporary `Graph` is no longer needed. 6608 6609 Args: 6610 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 6611 after this function runs. 6612 """ 6613 memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access 6614 6615 # Now clean up Operation<->Graph reference cycles by clearing all of the 6616 # attributes for the Graph and its ops. 6617 graph_operations = graph.get_operations() 6618 for op in graph_operations: 6619 op.__dict__ = {} 6620 graph.__dict__ = {} 6621 6622 6623@tf_export(v1=["add_to_collection"]) 6624def add_to_collection(name, value): 6625 """Wrapper for `Graph.add_to_collection()` using the default graph. 6626 6627 See `tf.Graph.add_to_collection` 6628 for more details. 6629 6630 Args: 6631 name: The key for the collection. For example, the `GraphKeys` class 6632 contains many standard names for collections. 6633 value: The value to add to the collection. 6634 6635 @compatibility(eager) 6636 Collections are only supported in eager when variables are created inside 6637 an EagerVariableStore (e.g. as part of a layer or template). 6638 @end_compatibility 6639 """ 6640 get_default_graph().add_to_collection(name, value) 6641 6642 6643@tf_export(v1=["add_to_collections"]) 6644def add_to_collections(names, value): 6645 """Wrapper for `Graph.add_to_collections()` using the default graph. 6646 6647 See `tf.Graph.add_to_collections` 6648 for more details. 6649 6650 Args: 6651 names: The key for the collections. The `GraphKeys` class contains many 6652 standard names for collections. 6653 value: The value to add to the collections. 6654 6655 @compatibility(eager) 6656 Collections are only supported in eager when variables are created inside 6657 an EagerVariableStore (e.g. as part of a layer or template). 6658 @end_compatibility 6659 """ 6660 get_default_graph().add_to_collections(names, value) 6661 6662 6663@tf_export(v1=["get_collection_ref"]) 6664def get_collection_ref(key): 6665 """Wrapper for `Graph.get_collection_ref()` using the default graph. 6666 6667 See `tf.Graph.get_collection_ref` 6668 for more details. 6669 6670 Args: 6671 key: The key for the collection. For example, the `GraphKeys` class contains 6672 many standard names for collections. 6673 6674 Returns: 6675 The list of values in the collection with the given `name`, or an empty 6676 list if no value has been added to that collection. Note that this returns 6677 the collection list itself, which can be modified in place to change the 6678 collection. 6679 6680 @compatibility(eager) 6681 Collections are not supported when eager execution is enabled. 6682 @end_compatibility 6683 """ 6684 return get_default_graph().get_collection_ref(key) 6685 6686 6687@tf_export(v1=["get_collection"]) 6688def get_collection(key, scope=None): 6689 """Wrapper for `Graph.get_collection()` using the default graph. 6690 6691 See `tf.Graph.get_collection` 6692 for more details. 6693 6694 Args: 6695 key: The key for the collection. For example, the `GraphKeys` class contains 6696 many standard names for collections. 6697 scope: (Optional.) If supplied, the resulting list is filtered to include 6698 only items whose `name` attribute matches using `re.match`. Items without 6699 a `name` attribute are never returned if a scope is supplied and the 6700 choice or `re.match` means that a `scope` without special tokens filters 6701 by prefix. 6702 6703 Returns: 6704 The list of values in the collection with the given `name`, or 6705 an empty list if no value has been added to that collection. The 6706 list contains the values in the order under which they were 6707 collected. 6708 6709 @compatibility(eager) 6710 Collections are not supported when eager execution is enabled. 6711 @end_compatibility 6712 """ 6713 return get_default_graph().get_collection(key, scope) 6714 6715 6716def get_all_collection_keys(): 6717 """Returns a list of collections used in the default graph.""" 6718 return get_default_graph().get_all_collection_keys() 6719 6720 6721def name_scope(name, default_name=None, values=None, skip_on_eager=True): 6722 """Internal-only entry point for `name_scope*`. 6723 6724 Internal ops do not use the public API and instead rely on 6725 `ops.name_scope` regardless of the execution mode. This function 6726 dispatches to the correct `name_scope*` implementation based on 6727 the arguments provided and the current mode. Specifically, 6728 6729 * if `values` contains a graph tensor `Graph.name_scope` is used; 6730 * `name_scope_v1` is used in graph mode; 6731 * `name_scope_v2` -- in eager mode. 6732 6733 Args: 6734 name: The name argument that is passed to the op function. 6735 default_name: The default name to use if the `name` argument is `None`. 6736 values: The list of `Tensor` arguments that are passed to the op function. 6737 skip_on_eager: Indicates to return NullContextmanager if executing eagerly. 6738 By default this is True since naming tensors and operations in eager mode 6739 have little use and cause unnecessary performance overhead. However, it is 6740 important to preserve variable names since they are often useful for 6741 debugging and saved models. 6742 6743 Returns: 6744 `name_scope*` context manager. 6745 """ 6746 if not context.executing_eagerly(): 6747 return internal_name_scope_v1(name, default_name, values) 6748 6749 if skip_on_eager: 6750 return NullContextmanager() 6751 6752 name = default_name if name is None else name 6753 if values: 6754 # The presence of a graph tensor in `values` overrides the context. 6755 # TODO(slebedev): this is Keras-specific and should be removed. 6756 # pylint: disable=unidiomatic-typecheck 6757 graph_value = next((value for value in values if type(value) == Tensor), 6758 None) 6759 # pylint: enable=unidiomatic-typecheck 6760 if graph_value is not None: 6761 return graph_value.graph.name_scope(name) 6762 6763 return name_scope_v2(name or "") 6764 6765 6766class internal_name_scope_v1(object): # pylint: disable=invalid-name 6767 """Graph-only version of `name_scope_v1`.""" 6768 6769 @property 6770 def name(self): 6771 return self._name 6772 6773 def __init__(self, name, default_name=None, values=None): 6774 """Initialize the context manager. 6775 6776 Args: 6777 name: The name argument that is passed to the op function. 6778 default_name: The default name to use if the `name` argument is `None`. 6779 values: The list of `Tensor` arguments that are passed to the op function. 6780 6781 Raises: 6782 TypeError: if `default_name` is passed in but not a string. 6783 """ 6784 if not (default_name is None or isinstance(default_name, str)): 6785 raise TypeError( 6786 "`default_name` type (%s) is not a string type. You likely meant to " 6787 "pass this into the `values` kwarg." % type(default_name)) 6788 self._name = default_name if name is None else name 6789 self._default_name = default_name 6790 self._values = values 6791 6792 def __enter__(self): 6793 """Start the scope block. 6794 6795 Returns: 6796 The scope name. 6797 6798 Raises: 6799 ValueError: if neither `name` nor `default_name` is provided 6800 but `values` are. 6801 """ 6802 if self._name is None and self._values is not None: 6803 # We only raise an error if values is not None (provided) because 6804 # currently tf.name_scope(None) (values=None then) is sometimes used as 6805 # an idiom to reset to top scope. 6806 raise ValueError( 6807 "At least one of name (%s) and default_name (%s) must be provided." 6808 % (self._name, self._default_name)) 6809 6810 g = get_default_graph() 6811 if self._values and not g.building_function: 6812 # Specialize based on the knowledge that `_get_graph_from_inputs()` 6813 # ignores `inputs` when building a function. 6814 g_from_inputs = _get_graph_from_inputs(self._values) 6815 if g_from_inputs is not g: 6816 g = g_from_inputs 6817 self._g_manager = g.as_default() 6818 self._g_manager.__enter__() 6819 else: 6820 self._g_manager = None 6821 else: 6822 self._g_manager = None 6823 6824 try: 6825 self._name_scope = g.name_scope(self._name) 6826 return self._name_scope.__enter__() 6827 except: 6828 if self._g_manager is not None: 6829 self._g_manager.__exit__(*sys.exc_info()) 6830 raise 6831 6832 def __exit__(self, *exc_info): 6833 self._name_scope.__exit__(*exc_info) 6834 if self._g_manager is not None: 6835 self._g_manager.__exit__(*exc_info) 6836 6837 6838# Named like a function for backwards compatibility with the 6839# @tf_contextlib.contextmanager version, which was switched to a class to avoid 6840# some object creation overhead. 6841@tf_export(v1=["name_scope"]) 6842class name_scope_v1(object): # pylint: disable=invalid-name 6843 """A context manager for use when defining a Python op. 6844 6845 This context manager validates that the given `values` are from the 6846 same graph, makes that graph the default graph, and pushes a 6847 name scope in that graph (see 6848 `tf.Graph.name_scope` 6849 for more details on that). 6850 6851 For example, to define a new Python op called `my_op`: 6852 6853 ```python 6854 def my_op(a, b, c, name=None): 6855 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 6856 a = tf.convert_to_tensor(a, name="a") 6857 b = tf.convert_to_tensor(b, name="b") 6858 c = tf.convert_to_tensor(c, name="c") 6859 # Define some computation that uses `a`, `b`, and `c`. 6860 return foo_op(..., name=scope) 6861 ``` 6862 """ 6863 6864 __slots__ = ["_name", "_name_scope"] 6865 6866 @property 6867 def name(self): 6868 return self._name 6869 6870 def __init__(self, name, default_name=None, values=None): 6871 """Initialize the context manager. 6872 6873 Args: 6874 name: The name argument that is passed to the op function. 6875 default_name: The default name to use if the `name` argument is `None`. 6876 values: The list of `Tensor` arguments that are passed to the op function. 6877 6878 Raises: 6879 TypeError: if `default_name` is passed in but not a string. 6880 """ 6881 self._name_scope = name_scope( 6882 name, default_name, values, skip_on_eager=False) 6883 self._name = default_name if name is None else name 6884 6885 def __enter__(self): 6886 return self._name_scope.__enter__() 6887 6888 def __exit__(self, *exc_info): 6889 return self._name_scope.__exit__(*exc_info) 6890 6891 6892@tf_export("get_current_name_scope", v1=[]) 6893def get_current_name_scope(): 6894 """Returns current full name scope specified by `tf.name_scope(...)`s. 6895 6896 For example, 6897 ```python 6898 with tf.name_scope("outer"): 6899 tf.get_current_name_scope() # "outer" 6900 6901 with tf.name_scope("inner"): 6902 tf.get_current_name_scope() # "outer/inner" 6903 ``` 6904 6905 In other words, `tf.get_current_name_scope()` returns the op name prefix that 6906 will be prepended to, if an op is created at that place. 6907 6908 Note that `@tf.function` resets the name scope stack as shown below. 6909 6910 ``` 6911 with tf.name_scope("outer"): 6912 6913 @tf.function 6914 def foo(x): 6915 with tf.name_scope("inner"): 6916 return tf.add(x * x) # Op name is "inner/Add", not "outer/inner/Add" 6917 ``` 6918 """ 6919 6920 ctx = context.context() 6921 if ctx.executing_eagerly(): 6922 return ctx.scope_name.rstrip("/") 6923 else: 6924 return get_default_graph().get_name_scope() 6925 6926 6927@tf_export("name_scope", v1=[]) 6928class name_scope_v2(object): 6929 """A context manager for use when defining a Python op. 6930 6931 This context manager pushes a name scope, which will make the name of all 6932 operations added within it have a prefix. 6933 6934 For example, to define a new Python op called `my_op`: 6935 6936 ```python 6937 def my_op(a, b, c, name=None): 6938 with tf.name_scope("MyOp") as scope: 6939 a = tf.convert_to_tensor(a, name="a") 6940 b = tf.convert_to_tensor(b, name="b") 6941 c = tf.convert_to_tensor(c, name="c") 6942 # Define some computation that uses `a`, `b`, and `c`. 6943 return foo_op(..., name=scope) 6944 ``` 6945 6946 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 6947 and `MyOp/c`. 6948 6949 Inside a `tf.function`, if the scope name already exists, the name will be 6950 made unique by appending `_n`. For example, calling `my_op` the second time 6951 will generate `MyOp_1/a`, etc. 6952 """ 6953 6954 __slots__ = ["_name", "_exit_fns"] 6955 6956 def __init__(self, name): 6957 """Initialize the context manager. 6958 6959 Args: 6960 name: The prefix to use on all names created within the name scope. 6961 6962 Raises: 6963 ValueError: If name is not a string. 6964 """ 6965 if not isinstance(name, str): 6966 raise ValueError("name for name_scope must be a string.") 6967 self._name = name 6968 self._exit_fns = [] 6969 6970 @property 6971 def name(self): 6972 return self._name 6973 6974 def __enter__(self): 6975 """Start the scope block. 6976 6977 Returns: 6978 The scope name. 6979 """ 6980 ctx = context.context() 6981 if ctx.executing_eagerly(): 6982 # Names are not auto-incremented in eager mode. 6983 # A trailing slash breaks out of nested name scopes, indicating a 6984 # fully specified scope name, for compatibility with Graph.name_scope. 6985 # This also prevents auto-incrementing. 6986 old_name = ctx.scope_name 6987 name = self._name 6988 if not name: 6989 scope_name = "" 6990 elif name[-1] == "/": 6991 scope_name = name 6992 elif old_name: 6993 scope_name = old_name + name + "/" 6994 else: 6995 scope_name = name + "/" 6996 ctx.scope_name = scope_name 6997 6998 def _restore_name_scope(*_): 6999 ctx.scope_name = old_name 7000 7001 self._exit_fns.append(_restore_name_scope) 7002 else: 7003 scope = get_default_graph().name_scope(self._name) 7004 scope_name = scope.__enter__() 7005 self._exit_fns.append(scope.__exit__) 7006 return scope_name 7007 7008 def __exit__(self, type_arg, value_arg, traceback_arg): 7009 self._exit_fns.pop()(type_arg, value_arg, traceback_arg) 7010 return False # False values do not suppress exceptions 7011 7012 def __getstate__(self): 7013 return self._name, self._exit_fns 7014 7015 def __setstate__(self, state): 7016 self._name = state[0] 7017 self._exit_fns = state[1] 7018 7019 7020def strip_name_scope(name, export_scope): 7021 """Removes name scope from a name. 7022 7023 Args: 7024 name: A `string` name. 7025 export_scope: Optional `string`. Name scope to remove. 7026 7027 Returns: 7028 Name with name scope removed, or the original name if export_scope 7029 is None. 7030 """ 7031 if export_scope: 7032 if export_scope[-1] == "/": 7033 export_scope = export_scope[:-1] 7034 7035 try: 7036 # Strips export_scope/, export_scope///, 7037 # ^export_scope/, loc:@export_scope/. 7038 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 7039 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 7040 except TypeError as e: 7041 # If the name is not of a type we can process, simply return it. 7042 logging.warning(e) 7043 return name 7044 else: 7045 return name 7046 7047 7048def prepend_name_scope(name, import_scope): 7049 """Prepends name scope to a name. 7050 7051 Args: 7052 name: A `string` name. 7053 import_scope: Optional `string`. Name scope to add. 7054 7055 Returns: 7056 Name with name scope added, or the original name if import_scope 7057 is None. 7058 """ 7059 if import_scope: 7060 if import_scope[-1] == "/": 7061 import_scope = import_scope[:-1] 7062 7063 try: 7064 str_to_replace = r"([\^]|loc:@|^)(.*)" 7065 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 7066 compat.as_str(name)) 7067 except TypeError as e: 7068 # If the name is not of a type we can process, simply return it. 7069 logging.warning(e) 7070 return name 7071 else: 7072 return name 7073 7074 7075# pylint: disable=g-doc-return-or-yield 7076# pylint: disable=not-context-manager 7077@tf_export(v1=["op_scope"]) 7078@tf_contextlib.contextmanager 7079def op_scope(values, name, default_name=None): 7080 """DEPRECATED. Same as name_scope above, just different argument order.""" 7081 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 7082 " use tf.name_scope(name, default_name, values)") 7083 with name_scope(name, default_name=default_name, values=values) as scope: 7084 yield scope 7085 7086 7087_proto_function_registry = registry.Registry("proto functions") 7088 7089 7090def register_proto_function(collection_name, 7091 proto_type=None, 7092 to_proto=None, 7093 from_proto=None): 7094 """Registers `to_proto` and `from_proto` functions for collection_name. 7095 7096 `to_proto` function converts a Python object to the corresponding protocol 7097 buffer, and returns the protocol buffer. 7098 7099 `from_proto` function converts protocol buffer into a Python object, and 7100 returns the object.. 7101 7102 Args: 7103 collection_name: Name of the collection. 7104 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 7105 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 7106 to_proto: Function that implements Python object to protobuf conversion. 7107 from_proto: Function that implements protobuf to Python object conversion. 7108 """ 7109 if to_proto and not callable(to_proto): 7110 raise TypeError("to_proto must be callable.") 7111 if from_proto and not callable(from_proto): 7112 raise TypeError("from_proto must be callable.") 7113 7114 _proto_function_registry.register((proto_type, to_proto, from_proto), 7115 collection_name) 7116 7117 7118def get_collection_proto_type(collection_name): 7119 """Returns the proto_type for collection_name.""" 7120 try: 7121 return _proto_function_registry.lookup(collection_name)[0] 7122 except LookupError: 7123 return None 7124 7125 7126def get_to_proto_function(collection_name): 7127 """Returns the to_proto function for collection_name.""" 7128 try: 7129 return _proto_function_registry.lookup(collection_name)[1] 7130 except LookupError: 7131 return None 7132 7133 7134def get_from_proto_function(collection_name): 7135 """Returns the from_proto function for collection_name.""" 7136 try: 7137 return _proto_function_registry.lookup(collection_name)[2] 7138 except LookupError: 7139 return None 7140 7141 7142def _op_to_colocate_with(v, graph): 7143 """Operation object corresponding to v to use for colocation constraints.""" 7144 if v is None: 7145 return None, None 7146 if isinstance(v, Operation): 7147 return v, None 7148 7149 # We always want to colocate with the reference op. 7150 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 7151 # 7152 # What this should be is: 7153 # if isinstance(v, ResourceVariable): 7154 # return v.handle.op, v 7155 # However, that would require a circular import dependency. 7156 # As of October 2018, there were attempts underway to remove 7157 # colocation constraints altogether. Assuming that will 7158 # happen soon, perhaps this hack to work around the circular 7159 # import dependency is acceptable. 7160 if hasattr(v, "handle") and isinstance(v.handle, Tensor): 7161 device_only_candidate = lambda: None 7162 device_only_candidate.device = v.device 7163 device_only_candidate.name = v.name 7164 if graph.building_function: 7165 return graph.capture(v.handle).op, device_only_candidate 7166 else: 7167 return v.handle.op, device_only_candidate 7168 return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None 7169 7170 7171def _is_keras_symbolic_tensor(x): 7172 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" 7173 7174 7175# These symbols were originally defined in this module; import them for 7176# backwards compatibility until all references have been updated to access 7177# them from the indexed_slices.py module. 7178IndexedSlices = indexed_slices.IndexedSlices 7179IndexedSlicesValue = indexed_slices.IndexedSlicesValue 7180convert_to_tensor_or_indexed_slices = \ 7181 indexed_slices.convert_to_tensor_or_indexed_slices 7182convert_n_to_tensor_or_indexed_slices = \ 7183 indexed_slices.convert_n_to_tensor_or_indexed_slices 7184internal_convert_to_tensor_or_indexed_slices = \ 7185 indexed_slices.internal_convert_to_tensor_or_indexed_slices 7186internal_convert_n_to_tensor_or_indexed_slices = \ 7187 indexed_slices.internal_convert_n_to_tensor_or_indexed_slices 7188register_tensor_conversion_function = \ 7189 tensor_conversion_registry.register_tensor_conversion_function 7190 7191 7192# Helper functions for op wrapper modules generated by `python_op_gen`. 7193 7194 7195def to_raw_op(f): 7196 """Make a given op wrapper function `f` raw. 7197 7198 Raw op wrappers can only be called with keyword arguments. 7199 7200 Args: 7201 f: An op wrapper function to make raw. 7202 7203 Returns: 7204 Raw `f`. 7205 """ 7206 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail 7207 # due to double-registration. 7208 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, 7209 f.__closure__) 7210 return kwarg_only(f) 7211 7212 7213def raise_from_not_ok_status(e, name): 7214 e.message += (" name: " + name if name is not None else "") 7215 raise core._status_to_exception(e) from None # pylint: disable=protected-access 7216 7217 7218def add_exit_callback_to_default_func_graph(fn): 7219 """Add a callback to run when the default function graph goes out of scope. 7220 7221 Usage: 7222 7223 ```python 7224 @tf.function 7225 def fn(x, v): 7226 expensive = expensive_object(v) 7227 add_exit_callback_to_default_func_graph(lambda: expensive.release()) 7228 return g(x, expensive) 7229 7230 fn(x=tf.constant(...), v=...) 7231 # `expensive` has been released. 7232 ``` 7233 7234 Args: 7235 fn: A callable that takes no arguments and whose output is ignored. 7236 To be executed when exiting func graph scope. 7237 7238 Raises: 7239 RuntimeError: If executed when the current default graph is not a FuncGraph, 7240 or not currently executing in function creation mode (e.g., if inside 7241 an init_scope). 7242 """ 7243 default_graph = get_default_graph() 7244 if not default_graph._building_function: # pylint: disable=protected-access 7245 raise RuntimeError( 7246 "Cannot add scope exit callbacks when not building a function. " 7247 "Default graph: {}".format(default_graph)) 7248 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access 7249 7250 7251def _reconstruct_sequence_inputs(op_def, inputs, attrs): 7252 """Regroups a flat list of input tensors into scalar and sequence inputs. 7253 7254 Args: 7255 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 7256 inputs: a list of input `Tensor`s to the op. 7257 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 7258 how long each sequence is) 7259 7260 Returns: 7261 A list of `Tensor`s (corresponding to scalar inputs) and lists of 7262 `Tensor`s (corresponding to sequence inputs). 7263 """ 7264 grouped_inputs = [] 7265 i = 0 7266 for input_arg in op_def.input_arg: 7267 if input_arg.number_attr: 7268 input_len = attrs[input_arg.number_attr].i 7269 is_sequence = True 7270 elif input_arg.type_list_attr: 7271 input_len = len(attrs[input_arg.type_list_attr].list.type) 7272 is_sequence = True 7273 else: 7274 input_len = 1 7275 is_sequence = False 7276 7277 if is_sequence: 7278 grouped_inputs.append(inputs[i:i + input_len]) 7279 else: 7280 grouped_inputs.append(inputs[i]) 7281 i += input_len 7282 7283 assert i == len(inputs) 7284 return grouped_inputs 7285 7286 7287_numpy_style_type_promotion = False 7288 7289 7290def enable_numpy_style_type_promotion(): 7291 """If called, follows NumPy's rules for type promotion. 7292 7293 Used for enabling NumPy behavior on methods for TF NumPy. 7294 """ 7295 global _numpy_style_type_promotion 7296 _numpy_style_type_promotion = True 7297 7298 7299_numpy_style_slicing = False 7300 7301 7302def enable_numpy_style_slicing(): 7303 """If called, follows NumPy's rules for slicing Tensors. 7304 7305 Used for enabling NumPy behavior on slicing for TF NumPy. 7306 """ 7307 global _numpy_style_slicing 7308 _numpy_style_slicing = True 7309 7310 7311class _TensorIterator(object): 7312 """Iterates over the leading dim of a Tensor. Performs no error checks.""" 7313 7314 __slots__ = ["_tensor", "_index", "_limit"] 7315 7316 def __init__(self, tensor, dim0): 7317 self._tensor = tensor 7318 self._index = 0 7319 self._limit = dim0 7320 7321 def __iter__(self): 7322 return self 7323 7324 def __next__(self): 7325 if self._index == self._limit: 7326 raise StopIteration 7327 result = self._tensor[self._index] 7328 self._index += 1 7329 return result 7330 7331 next = __next__ # python2.x compatibility. 7332 7333 7334def set_int_list_attr(op, attr_name, ints): 7335 """TF internal method used to set a list(int) attribute in the node_def.""" 7336 ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) 7337 op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access 7338 7339 7340def _get_enclosing_context(graph): 7341 # pylint: disable=protected-access 7342 if graph is None: 7343 return None 7344 7345 if graph._control_flow_context is not None: 7346 return graph._control_flow_context 7347 7348 if graph.building_function and hasattr(graph, "outer_graph"): 7349 return _get_enclosing_context(graph.outer_graph) 7350 7351 7352def get_resource_handle_data(graph_op): 7353 assert type(graph_op) == Tensor # pylint: disable=unidiomatic-typecheck 7354 7355 with graph_op.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 7356 handle_data = pywrap_tf_session.GetHandleShapeAndType( 7357 c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access 7358 7359 return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( 7360 compat.as_bytes(handle_data)) 7361 7362 7363def _copy_handle_data_to_arg_def(tensor, arg_def): 7364 handle_data = get_resource_handle_data(tensor) 7365 if handle_data.shape_and_type: 7366 shape_and_type = handle_data.shape_and_type[0] 7367 proto = arg_def.handle_data.add() 7368 proto.dtype = shape_and_type.dtype 7369 proto.shape.CopyFrom(handle_data.shape_and_type[0].shape) 7370