1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions used to construct graphs.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import re 24import sys 25import threading 26import types 27from absl import app 28 29import numpy as np 30import six 31from six.moves import map # pylint: disable=redefined-builtin 32from six.moves import xrange # pylint: disable=redefined-builtin 33 34from tensorflow.core.framework import attr_value_pb2 35from tensorflow.core.framework import function_pb2 36from tensorflow.core.framework import graph_pb2 37from tensorflow.core.framework import node_def_pb2 38from tensorflow.core.framework import op_def_pb2 39from tensorflow.core.framework import versions_pb2 40from tensorflow.core.protobuf import config_pb2 41# pywrap_tensorflow must be imported first to avoid protobuf issues. 42# (b/143110113) 43# pylint: disable=invalid-import-order,g-bad-import-order,unused-import 44from tensorflow.python import pywrap_tensorflow 45from tensorflow.python import pywrap_tfe 46# pylint: enable=invalid-import-order,g-bad-import-order,unused-import 47from tensorflow.python import tf2 48from tensorflow.python.client import pywrap_tf_session 49from tensorflow.python.eager import context 50from tensorflow.python.eager import core 51from tensorflow.python.eager import monitoring 52from tensorflow.python.eager import tape 53from tensorflow.python.framework import c_api_util 54from tensorflow.python.framework import composite_tensor 55from tensorflow.python.framework import cpp_shape_inference_pb2 56from tensorflow.python.framework import device as pydev 57from tensorflow.python.framework import dtypes 58from tensorflow.python.framework import errors 59from tensorflow.python.framework import indexed_slices 60from tensorflow.python.framework import registry 61from tensorflow.python.framework import tensor_conversion_registry 62from tensorflow.python.framework import tensor_shape 63from tensorflow.python.framework import traceable_stack 64from tensorflow.python.framework import versions 65from tensorflow.python.ops import control_flow_util 66from tensorflow.python.platform import tf_logging as logging 67from tensorflow.python.profiler import trace 68from tensorflow.python.types import core as core_tf_types 69from tensorflow.python.types import internal 70from tensorflow.python.util import compat 71from tensorflow.python.util import decorator_utils 72from tensorflow.python.util import deprecation 73from tensorflow.python.util import dispatch 74from tensorflow.python.util import function_utils 75from tensorflow.python.util import lock_util 76from tensorflow.python.util import memory 77from tensorflow.python.util import object_identity 78from tensorflow.python.util import tf_contextlib 79from tensorflow.python.util import tf_stack 80from tensorflow.python.util import traceback_utils 81from tensorflow.python.util.compat import collections_abc 82from tensorflow.python.util.deprecation import deprecated_args 83from tensorflow.python.util.lazy_loader import LazyLoader 84from tensorflow.python.util.tf_export import kwarg_only 85from tensorflow.python.util.tf_export import tf_export 86 87ag_ctx = LazyLoader( 88 "ag_ctx", globals(), 89 "tensorflow.python.autograph.core.ag_ctx") 90 91 92# Temporary global switches determining if we should enable the work-in-progress 93# calls to the C API. These will be removed once all functionality is supported. 94_USE_C_API = True 95_USE_C_SHAPES = True 96 97_api_usage_gauge = monitoring.BoolGauge( 98 "/tensorflow/api/ops_eager_execution", 99 "Whether ops.enable_eager_execution() is called.") 100 101_tensor_equality_api_usage_gauge = monitoring.BoolGauge( 102 "/tensorflow/api/enable_tensor_equality", 103 "Whether ops.enable_tensor_equality() is called.") 104 105_control_flow_api_gauge = monitoring.BoolGauge( 106 "/tensorflow/api/enable_control_flow_v2", 107 "Whether enable_control_flow_v2() is called.") 108 109_tf_function_api_guage = monitoring.BoolGauge( 110 "/tensorflow/api/tf_function", 111 "Whether tf.function() is used.") 112 113# pylint: disable=protected-access 114_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE 115# pylint: enable=protected-access 116 117 118def tensor_id(tensor): 119 """Returns a unique identifier for this Tensor.""" 120 return tensor._id # pylint: disable=protected-access 121 122 123class _UserDeviceSpec(object): 124 """Store user-specified device and provide computation of merged device.""" 125 126 def __init__(self, device_name_or_function): 127 self._device_name_or_function = device_name_or_function 128 self.display_name = str(self._device_name_or_function) 129 self.function = device_name_or_function 130 self.raw_string = None 131 132 if isinstance(device_name_or_function, pydev.MergeDevice): 133 self.is_null_merge = device_name_or_function.is_null_merge 134 135 elif callable(device_name_or_function): 136 self.is_null_merge = False 137 dev_func = self._device_name_or_function 138 func_name = function_utils.get_func_name(dev_func) 139 func_code = function_utils.get_func_code(dev_func) 140 if func_code: 141 fname = func_code.co_filename 142 lineno = func_code.co_firstlineno 143 else: 144 fname = "unknown" 145 lineno = -1 146 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 147 148 elif device_name_or_function is None: 149 # NOTE(taylorrobie): This MUST be False. None signals a break in the 150 # device stack, so `is_null_merge` must be False for such a case to 151 # allow callers to safely skip over null merges without missing a None. 152 self.is_null_merge = False 153 154 else: 155 self.raw_string = device_name_or_function 156 self.function = pydev.merge_device(device_name_or_function) 157 self.is_null_merge = self.function.is_null_merge 158 159 # We perform this check in __init__ because it is of non-trivial cost, 160 # and self.string_merge is typically called many times. 161 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) 162 163 def string_merge(self, node_def): 164 if self.fast_string_merge: 165 return self.function.shortcut_string_merge(node_def) 166 167 return compat.as_str(_device_string(self.function(node_def))) 168 169 170class NullContextmanager(object): 171 172 def __init__(self, *args, **kwargs): 173 pass 174 175 def __enter__(self): 176 pass 177 178 def __exit__(self, type_arg, value_arg, traceback_arg): 179 return False # False values do not suppress exceptions 180 181 182def _override_helper(clazz_object, operator, func): 183 """Overrides (string) operator on Tensors to call func. 184 185 Args: 186 clazz_object: the class to override for; either Tensor or SparseTensor. 187 operator: the string name of the operator to override. 188 func: the function that replaces the overridden operator. 189 190 Raises: 191 ValueError: If operator is not allowed to be overwritten. 192 """ 193 if operator not in Tensor.OVERLOADABLE_OPERATORS: 194 raise ValueError("Overriding %s is disallowed" % operator) 195 setattr(clazz_object, operator, func) 196 197 198def _as_graph_element(obj): 199 """Convert `obj` to a graph element if possible, otherwise return `None`. 200 201 Args: 202 obj: Object to convert. 203 204 Returns: 205 The result of `obj._as_graph_element()` if that method is available; 206 otherwise `None`. 207 """ 208 conv_fn = getattr(obj, "_as_graph_element", None) 209 if conv_fn and callable(conv_fn): 210 return conv_fn() 211 return None 212 213 214# Deprecated - do not use. 215# This API to avoid breaking estimator and tensorflow-mesh which depend on this 216# internal API. The stub should be safe to use after TF 2.3 is released. 217def is_dense_tensor_like(t): 218 return isinstance(t, core_tf_types.Tensor) 219 220 221def uid(): 222 """A unique (within this program execution) integer.""" 223 return pywrap_tfe.TFE_Py_UID() 224 225 226def numpy_text(tensor, is_repr=False): 227 """Human readable representation of a tensor's numpy value.""" 228 if tensor.dtype.is_numpy_compatible: 229 # pylint: disable=protected-access 230 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 231 # pylint: enable=protected-access 232 else: 233 text = "<unprintable>" 234 if "\n" in text: 235 text = "\n" + text 236 return text 237 238 239@tf_export(v1=["enable_tensor_equality"]) 240def enable_tensor_equality(): 241 """Compare Tensors with element-wise comparison and thus be unhashable. 242 243 Comparing tensors with element-wise allows comparisons such as 244 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are 245 unhashable. Thus tensors can no longer be directly used in sets or as a key in 246 a dictionary. 247 """ 248 logging.vlog(1, "Enabling tensor equality") 249 _tensor_equality_api_usage_gauge.get_cell().set(True) 250 Tensor._USE_EQUALITY = True # pylint: disable=protected-access 251 252 253@tf_export(v1=["disable_tensor_equality"]) 254def disable_tensor_equality(): 255 """Compare Tensors by their id and be hashable. 256 257 This is a legacy behaviour of TensorFlow and is highly discouraged. 258 """ 259 logging.vlog(1, "Disabling tensor equality") 260 _tensor_equality_api_usage_gauge.get_cell().set(False) 261 Tensor._USE_EQUALITY = False # pylint: disable=protected-access 262 263 264# TODO(mdan): This object should subclass Symbol, not just Tensor. 265@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"]) 266class Tensor(internal.NativeObject, core_tf_types.Tensor): 267 """A `tf.Tensor` represents a multidimensional array of elements. 268 269 All elements are of a single known data type. 270 271 When writing a TensorFlow program, the main object that is 272 manipulated and passed around is the `tf.Tensor`. 273 274 A `tf.Tensor` has the following properties: 275 276 * a single data type (float32, int32, or string, for example) 277 * a shape 278 279 TensorFlow supports eager execution and graph execution. In eager 280 execution, operations are evaluated immediately. In graph 281 execution, a computational graph is constructed for later 282 evaluation. 283 284 TensorFlow defaults to eager execution. In the example below, the 285 matrix multiplication results are calculated immediately. 286 287 >>> # Compute some values using a Tensor 288 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 289 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 290 >>> e = tf.matmul(c, d) 291 >>> print(e) 292 tf.Tensor( 293 [[1. 3.] 294 [3. 7.]], shape=(2, 2), dtype=float32) 295 296 Note that during eager execution, you may discover your `Tensors` are actually 297 of type `EagerTensor`. This is an internal detail, but it does give you 298 access to a useful function, `numpy`: 299 300 >>> type(e) 301 <class '...ops.EagerTensor'> 302 >>> print(e.numpy()) 303 [[1. 3.] 304 [3. 7.]] 305 306 In TensorFlow, `tf.function`s are a common way to define graph execution. 307 308 A Tensor's shape (that is, the rank of the Tensor and the size of 309 each dimension) may not always be fully known. In `tf.function` 310 definitions, the shape may only be partially known. 311 312 Most operations produce tensors of fully-known shapes if the shapes of their 313 inputs are also fully known, but in some cases it's only possible to find the 314 shape of a tensor at execution time. 315 316 A number of specialized tensors are available: see `tf.Variable`, 317 `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and 318 `tf.RaggedTensor`. 319 320 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor). 321 322 """ 323 324 # List of Python operators that we allow to override. 325 OVERLOADABLE_OPERATORS = { 326 # Binary. 327 "__add__", 328 "__radd__", 329 "__sub__", 330 "__rsub__", 331 "__mul__", 332 "__rmul__", 333 "__div__", 334 "__rdiv__", 335 "__truediv__", 336 "__rtruediv__", 337 "__floordiv__", 338 "__rfloordiv__", 339 "__mod__", 340 "__rmod__", 341 "__lt__", 342 "__le__", 343 "__gt__", 344 "__ge__", 345 "__ne__", 346 "__eq__", 347 "__and__", 348 "__rand__", 349 "__or__", 350 "__ror__", 351 "__xor__", 352 "__rxor__", 353 "__getitem__", 354 "__pow__", 355 "__rpow__", 356 # Unary. 357 "__invert__", 358 "__neg__", 359 "__abs__", 360 "__matmul__", 361 "__rmatmul__" 362 } 363 364 # Whether to allow hashing or numpy-style equality 365 _USE_EQUALITY = tf2.enabled() 366 367 def __init__(self, op, value_index, dtype): 368 """Creates a new `Tensor`. 369 370 Args: 371 op: An `Operation`. `Operation` that computes this tensor. 372 value_index: An `int`. Index of the operation's endpoint that produces 373 this tensor. 374 dtype: A `DType`. Type of elements stored in this tensor. 375 376 Raises: 377 TypeError: If the op is not an `Operation`. 378 """ 379 if not isinstance(op, Operation): 380 raise TypeError("op needs to be an Operation: %s" % (op,)) 381 self._op = op 382 self._value_index = value_index 383 self._dtype = dtypes.as_dtype(dtype) 384 # This will be set by self._as_tf_output(). 385 self._tf_output = None 386 # This will be set by self.shape(). 387 self._shape_val = None 388 # List of operations that use this Tensor as input. We maintain this list 389 # to easily navigate a computation graph. 390 self._consumers = [] 391 self._id = uid() 392 self._name = None 393 394 def __getattr__(self, name): 395 if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size", 396 "tolist", "data"}: 397 # TODO(wangpeng): Export the enable_numpy_behavior knob 398 raise AttributeError(""" 399 '{}' object has no attribute '{}'. 400 If you are looking for numpy-related methods, please run the following: 401 from tensorflow.python.ops.numpy_ops import np_config 402 np_config.enable_numpy_behavior()""".format(type(self).__name__, name)) 403 self.__getattribute__(name) 404 405 @staticmethod 406 def _create_with_tf_output(op, value_index, dtype, tf_output): 407 ret = Tensor(op, value_index, dtype) 408 ret._tf_output = tf_output 409 return ret 410 411 @property 412 def op(self): 413 """The `Operation` that produces this tensor as an output.""" 414 return self._op 415 416 @property 417 def dtype(self): 418 """The `DType` of elements in this tensor.""" 419 return self._dtype 420 421 @property 422 def graph(self): 423 """The `Graph` that contains this tensor.""" 424 return self._op.graph 425 426 @property 427 def name(self): 428 """The string name of this tensor.""" 429 if self._name is None: 430 if not self._op.name: 431 raise ValueError("Operation was not named: %s" % self._op) 432 self._name = "%s:%d" % (self._op.name, self._value_index) 433 return self._name 434 435 @property 436 def device(self): 437 """The name of the device on which this tensor will be produced, or None.""" 438 return self._op.device 439 440 @property 441 def shape(self): 442 """Returns a `tf.TensorShape` that represents the shape of this tensor. 443 444 >>> t = tf.constant([1,2,3,4,5]) 445 >>> t.shape 446 TensorShape([5]) 447 448 `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`. 449 450 In a `tf.function` or when building a model using 451 `tf.keras.Input`, they return the build-time shape of the 452 tensor, which may be partially unknown. 453 454 A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor 455 containing the shape, calculated at runtime. 456 457 See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples. 458 """ 459 if self._shape_val is None: 460 self._shape_val = self._c_api_shape() 461 return self._shape_val 462 463 def _c_api_shape(self): 464 """Returns the TensorShape of this tensor according to the C API.""" 465 c_graph = self._op._graph._c_graph # pylint: disable=protected-access 466 shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper( 467 c_graph, self._as_tf_output()) 468 if unknown_shape: 469 return tensor_shape.unknown_shape() 470 else: 471 shape_vec = [None if d == -1 else d for d in shape_vec] 472 return tensor_shape.TensorShape(shape_vec) 473 474 @property 475 def _shape(self): 476 logging.warning("Tensor._shape is private, use Tensor.shape " 477 "instead. Tensor._shape will eventually be removed.") 478 return self.shape 479 480 @_shape.setter 481 def _shape(self, value): 482 raise ValueError( 483 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 484 485 def _disallow_when_autograph_disabled(self, task): 486 raise errors.OperatorNotAllowedInGraphError( 487 "{} is not allowed: AutoGraph is disabled in this function." 488 " Try decorating it directly with @tf.function.".format(task)) 489 490 def _disallow_when_autograph_enabled(self, task): 491 raise errors.OperatorNotAllowedInGraphError( 492 "{} is not allowed: AutoGraph did convert this function. This might" 493 " indicate you are trying to use an unsupported feature.".format(task)) 494 495 def _disallow_in_graph_mode(self, task): 496 raise errors.OperatorNotAllowedInGraphError( 497 "{} is not allowed in Graph execution. Use Eager execution or decorate" 498 " this function with @tf.function.".format(task)) 499 500 def _disallow_bool_casting(self): 501 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 502 self._disallow_when_autograph_disabled( 503 "using a `tf.Tensor` as a Python `bool`") 504 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 505 self._disallow_when_autograph_enabled( 506 "using a `tf.Tensor` as a Python `bool`") 507 else: 508 # Default: V1-style Graph execution. 509 self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`") 510 511 def _disallow_iteration(self): 512 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 513 self._disallow_when_autograph_disabled("iterating over `tf.Tensor`") 514 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 515 self._disallow_when_autograph_enabled("iterating over `tf.Tensor`") 516 else: 517 # Default: V1-style Graph execution. 518 self._disallow_in_graph_mode("iterating over `tf.Tensor`") 519 520 def __iter__(self): 521 if not context.executing_eagerly(): 522 self._disallow_iteration() 523 524 shape = self._shape_tuple() 525 if shape is None: 526 raise TypeError("Cannot iterate over a tensor with unknown shape.") 527 if not shape: 528 raise TypeError("Cannot iterate over a scalar tensor.") 529 if shape[0] is None: 530 raise TypeError( 531 "Cannot iterate over a tensor with unknown first dimension.") 532 return _TensorIterator(self, shape[0]) 533 534 def _shape_as_list(self): 535 if self.shape.ndims is not None: 536 return [dim.value for dim in self.shape.dims] 537 else: 538 return None 539 540 def _shape_tuple(self): 541 shape = self._shape_as_list() 542 if shape is None: 543 return None 544 return tuple(shape) 545 546 def _rank(self): 547 """Integer rank of this Tensor, if known, else None. 548 549 Returns: 550 Integer rank or None 551 """ 552 return self.shape.ndims 553 554 def get_shape(self): 555 """Returns a `tf.TensorShape` that represents the shape of this tensor. 556 557 In eager execution the shape is always fully-known. 558 559 >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 560 >>> print(a.shape) 561 (2, 3) 562 563 `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`. 564 565 566 When executing in a `tf.function` or building a model using 567 `tf.keras.Input`, `Tensor.shape` may return a partial shape (including 568 `None` for unknown dimensions). See `tf.TensorShape` for more details. 569 570 >>> inputs = tf.keras.Input(shape = [10]) 571 >>> # Unknown batch size 572 >>> print(inputs.shape) 573 (None, 10) 574 575 The shape is computed using shape inference functions that are 576 registered for each `tf.Operation`. 577 578 The returned `tf.TensorShape` is determined at *build* time, without 579 executing the underlying kernel. It is not a `tf.Tensor`. If you need a 580 shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or 581 use the `tf.shape(tensor)` function, which returns the tensor's shape at 582 *execution* time. 583 584 This is useful for debugging and providing early errors. For 585 example, when tracing a `tf.function`, no ops are being executed, shapes 586 may be unknown (See the [Concrete Functions 587 Guide](https://www.tensorflow.org/guide/concrete_function) for details). 588 589 >>> @tf.function 590 ... def my_matmul(a, b): 591 ... result = a@b 592 ... # the `print` executes during tracing. 593 ... print("Result shape: ", result.shape) 594 ... return result 595 596 The shape inference functions propagate shapes to the extent possible: 597 598 >>> f = my_matmul.get_concrete_function( 599 ... tf.TensorSpec([None,3]), 600 ... tf.TensorSpec([3,5])) 601 Result shape: (None, 5) 602 603 Tracing may fail if a shape missmatch can be detected: 604 605 >>> cf = my_matmul.get_concrete_function( 606 ... tf.TensorSpec([None,3]), 607 ... tf.TensorSpec([4,5])) 608 Traceback (most recent call last): 609 ... 610 ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op: 611 'MatMul') with input shapes: [?,3], [4,5]. 612 613 In some cases, the inferred shape may have unknown dimensions. If 614 the caller has additional information about the values of these 615 dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment 616 the inferred shape. 617 618 >>> @tf.function 619 ... def my_fun(a): 620 ... a = tf.ensure_shape(a, [5, 5]) 621 ... # the `print` executes during tracing. 622 ... print("Result shape: ", a.shape) 623 ... return a 624 625 >>> cf = my_fun.get_concrete_function( 626 ... tf.TensorSpec([None, None])) 627 Result shape: (5, 5) 628 629 Returns: 630 A `tf.TensorShape` representing the shape of this tensor. 631 632 """ 633 return self.shape 634 635 def set_shape(self, shape): 636 """Updates the shape of this tensor. 637 638 Note: It is recommended to use `tf.ensure_shape` instead of 639 `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for 640 programming errors and can create guarantees for compiler 641 optimization. 642 643 With eager execution this operates as a shape assertion. 644 Here the shapes match: 645 646 >>> t = tf.constant([[1,2,3]]) 647 >>> t.set_shape([1, 3]) 648 649 Passing a `None` in the new shape allows any value for that axis: 650 651 >>> t.set_shape([1,None]) 652 653 An error is raised if an incompatible shape is passed. 654 655 >>> t.set_shape([1,5]) 656 Traceback (most recent call last): 657 ... 658 ValueError: Tensor's shape (1, 3) is not compatible with supplied 659 shape [1, 5] 660 661 When executing in a `tf.function`, or building a model using 662 `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with 663 the current shape of this tensor, and set the tensor's shape to the 664 merged value (see `tf.TensorShape.merge_with` for details): 665 666 >>> t = tf.keras.Input(shape=[None, None, 3]) 667 >>> print(t.shape) 668 (None, None, None, 3) 669 670 Dimensions set to `None` are not updated: 671 672 >>> t.set_shape([None, 224, 224, None]) 673 >>> print(t.shape) 674 (None, 224, 224, 3) 675 676 The main use case for this is to provide additional shape information 677 that cannot be inferred from the graph alone. 678 679 For example if you know all the images in a dataset have shape [28,28,3] you 680 can set it with `tf.set_shape`: 681 682 >>> @tf.function 683 ... def load_image(filename): 684 ... raw = tf.io.read_file(filename) 685 ... image = tf.image.decode_png(raw, channels=3) 686 ... # the `print` executes during tracing. 687 ... print("Initial shape: ", image.shape) 688 ... image.set_shape([28, 28, 3]) 689 ... print("Final shape: ", image.shape) 690 ... return image 691 692 Trace the function, see the [Concrete Functions 693 Guide](https://www.tensorflow.org/guide/concrete_function) for details. 694 695 >>> cf = load_image.get_concrete_function( 696 ... tf.TensorSpec([], dtype=tf.string)) 697 Initial shape: (None, None, 3) 698 Final shape: (28, 28, 3) 699 700 Similarly the `tf.io.parse_tensor` function could return a tensor with 701 any shape, even the `tf.rank` is unknown. If you know that all your 702 serialized tensors will be 2d, set it with `set_shape`: 703 704 >>> @tf.function 705 ... def my_parse(string_tensor): 706 ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32) 707 ... # the `print` executes during tracing. 708 ... print("Initial shape: ", result.shape) 709 ... result.set_shape([None, None]) 710 ... print("Final shape: ", result.shape) 711 ... return result 712 713 Trace the function 714 715 >>> concrete_parse = my_parse.get_concrete_function( 716 ... tf.TensorSpec([], dtype=tf.string)) 717 Initial shape: <unknown> 718 Final shape: (None, None) 719 720 Make sure it works: 721 722 >>> t = tf.ones([5,3], dtype=tf.float32) 723 >>> serialized = tf.io.serialize_tensor(t) 724 >>> print(serialized.dtype) 725 <dtype: 'string'> 726 >>> print(serialized.shape) 727 () 728 >>> t2 = concrete_parse(serialized) 729 >>> print(t2.shape) 730 (5, 3) 731 732 Caution: `set_shape` ensures that the applied shape is compatible with 733 the existing shape, but it does not check at runtime. Setting 734 incorrect shapes can result in inconsistencies between the 735 statically-known graph and the runtime value of tensors. For runtime 736 validation of the shape, use `tf.ensure_shape` instead. It also modifies 737 the `shape` of the tensor. 738 739 >>> # Serialize a rank-3 tensor 740 >>> t = tf.ones([5,5,5], dtype=tf.float32) 741 >>> serialized = tf.io.serialize_tensor(t) 742 >>> # The function still runs, even though it `set_shape([None,None])` 743 >>> t2 = concrete_parse(serialized) 744 >>> print(t2.shape) 745 (5, 5, 5) 746 747 Args: 748 shape: A `TensorShape` representing the shape of this tensor, a 749 `TensorShapeProto`, a list, a tuple, or None. 750 751 Raises: 752 ValueError: If `shape` is not compatible with the current shape of 753 this tensor. 754 """ 755 # Reset cached shape. 756 self._shape_val = None 757 758 # We want set_shape to be reflected in the C API graph for when we run it. 759 if not isinstance(shape, tensor_shape.TensorShape): 760 shape = tensor_shape.TensorShape(shape) 761 dim_list = [] 762 if shape.dims is None: 763 unknown_shape = True 764 else: 765 unknown_shape = False 766 for dim in shape.dims: 767 if dim.value is None: 768 dim_list.append(-1) 769 else: 770 dim_list.append(dim.value) 771 try: 772 pywrap_tf_session.TF_GraphSetTensorShape_wrapper( 773 self._op._graph._c_graph, # pylint: disable=protected-access 774 self._as_tf_output(), 775 dim_list, 776 unknown_shape) 777 except errors.InvalidArgumentError as e: 778 # Convert to ValueError for backwards compatibility. 779 raise ValueError(e.message) 780 781 @property 782 def value_index(self): 783 """The index of this tensor in the outputs of its `Operation`.""" 784 return self._value_index 785 786 def consumers(self): 787 """Returns a list of `Operation`s that consume this tensor. 788 789 Returns: 790 A list of `Operation`s. 791 """ 792 consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper( 793 self._as_tf_output()) 794 # pylint: disable=protected-access 795 return [ 796 self.graph._get_operation_by_name_unsafe(name) 797 for name in consumer_names 798 ] 799 # pylint: enable=protected-access 800 801 def _as_node_def_input(self): 802 """Return a value to use for the NodeDef "input" attribute. 803 804 The returned string can be used in a NodeDef "input" attribute 805 to indicate that the NodeDef uses this Tensor as input. 806 807 Raises: 808 ValueError: if this Tensor's Operation does not have a name. 809 810 Returns: 811 a string. 812 """ 813 if not self._op.name: 814 raise ValueError("Operation was not named: %s" % self._op) 815 if self._value_index == 0: 816 return self._op.name 817 else: 818 return "%s:%d" % (self._op.name, self._value_index) 819 820 def _as_tf_output(self): 821 # pylint: disable=protected-access 822 # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object 823 # also guarantees that a dictionary of tf_output objects will retain a 824 # deterministic (yet unsorted) order which prevents memory blowup in the 825 # cache of executor(s) stored for every session. 826 if self._tf_output is None: 827 self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index) 828 return self._tf_output 829 # pylint: enable=protected-access 830 831 def __str__(self): 832 return "Tensor(\"%s\"%s%s%s)" % ( 833 self.name, 834 (", shape=%s" % 835 self.get_shape()) if self.get_shape().ndims is not None else "", 836 (", dtype=%s" % self._dtype.name) if self._dtype else "", 837 (", device=%s" % self.device) if self.device else "") 838 839 def __repr__(self): 840 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 841 self._dtype.name) 842 843 def __hash__(self): 844 g = getattr(self, "graph", None) 845 if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and 846 (g is None or g.building_function)): 847 raise TypeError("Tensor is unhashable. " 848 "Instead, use tensor.ref() as the key.") 849 else: 850 return id(self) 851 852 def __copy__(self): 853 # TODO(b/77597810): get rid of Tensor copies. 854 cls = self.__class__ 855 result = cls.__new__(cls) 856 result.__dict__.update(self.__dict__) 857 return result 858 859 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 860 # operators to run when the left operand is an ndarray, because it 861 # accords the Tensor class higher priority than an ndarray, or a 862 # numpy matrix. 863 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 864 # mechanism, which allows more control over how Tensors interact 865 # with ndarrays. 866 __array_priority__ = 100 867 868 def __array__(self): 869 raise NotImplementedError( 870 "Cannot convert a symbolic Tensor ({}) to a numpy array." 871 " This error may indicate that you're trying to pass a Tensor to" 872 " a NumPy call, which is not supported".format(self.name)) 873 874 def __len__(self): 875 raise TypeError("len is not well defined for symbolic Tensors. ({}) " 876 "Please call `x.shape` rather than `len(x)` for " 877 "shape information.".format(self.name)) 878 879 # TODO(mdan): This convoluted machinery is hard to maintain. Clean up. 880 @staticmethod 881 def _override_operator(operator, func): 882 _override_helper(Tensor, operator, func) 883 884 def __bool__(self): 885 """Dummy method to prevent a tensor from being used as a Python `bool`. 886 887 This overload raises a `TypeError` when the user inadvertently 888 treats a `Tensor` as a boolean (most commonly in an `if` or `while` 889 statement), in code that was not converted by AutoGraph. For example: 890 891 ```python 892 if tf.constant(True): # Will raise. 893 # ... 894 895 if tf.constant(5) < tf.constant(7): # Will raise. 896 # ... 897 ``` 898 899 Raises: 900 `TypeError`. 901 """ 902 self._disallow_bool_casting() 903 904 def __nonzero__(self): 905 """Dummy method to prevent a tensor from being used as a Python `bool`. 906 907 This is the Python 2.x counterpart to `__bool__()` above. 908 909 Raises: 910 `TypeError`. 911 """ 912 self._disallow_bool_casting() 913 914 def eval(self, feed_dict=None, session=None): 915 """Evaluates this tensor in a `Session`. 916 917 Note: If you are not using `compat.v1` libraries, you should not need this, 918 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`) 919 you do not need to call `eval`. 920 921 Calling this method will execute all preceding operations that 922 produce the inputs needed for the operation that produces this 923 tensor. 924 925 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 926 launched in a session, and either a default session must be 927 available, or `session` must be specified explicitly. 928 929 Args: 930 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 931 `tf.Session.run` for a description of the valid feed values. 932 session: (Optional.) The `Session` to be used to evaluate this tensor. If 933 none, the default session will be used. 934 935 Returns: 936 A numpy array corresponding to the value of this tensor. 937 """ 938 return _eval_using_default_session(self, feed_dict, self.graph, session) 939 940 @deprecation.deprecated(None, "Use ref() instead.") 941 def experimental_ref(self): 942 return self.ref() 943 944 def ref(self): 945 # tf.Variable also has the same ref() API. If you update the 946 # documentation here, please update tf.Variable.ref() as well. 947 """Returns a hashable reference object to this Tensor. 948 949 The primary use case for this API is to put tensors in a set/dictionary. 950 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer 951 available starting Tensorflow 2.0. 952 953 The following will raise an exception starting 2.0 954 955 >>> x = tf.constant(5) 956 >>> y = tf.constant(10) 957 >>> z = tf.constant(10) 958 >>> tensor_set = {x, y, z} 959 Traceback (most recent call last): 960 ... 961 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 962 >>> tensor_dict = {x: 'five', y: 'ten'} 963 Traceback (most recent call last): 964 ... 965 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 966 967 Instead, we can use `tensor.ref()`. 968 969 >>> tensor_set = {x.ref(), y.ref(), z.ref()} 970 >>> x.ref() in tensor_set 971 True 972 >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 973 >>> tensor_dict[y.ref()] 974 'ten' 975 976 Also, the reference object provides `.deref()` function that returns the 977 original Tensor. 978 979 >>> x = tf.constant(5) 980 >>> x.ref().deref() 981 <tf.Tensor: shape=(), dtype=int32, numpy=5> 982 """ 983 return object_identity.Reference(self) 984 985 986# TODO(agarwal): consider getting rid of this. 987# TODO(mdan): This object should not subclass ops.Tensor. 988class _EagerTensorBase(Tensor): 989 """Base class for EagerTensor.""" 990 991 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and 992 # only work for scalars; values are cast as per numpy. 993 def __complex__(self): 994 return complex(self._numpy()) 995 996 def __int__(self): 997 return int(self._numpy()) 998 999 def __long__(self): 1000 return long(self._numpy()) 1001 1002 def __float__(self): 1003 return float(self._numpy()) 1004 1005 def __index__(self): 1006 return self._numpy().__index__() 1007 1008 def __bool__(self): 1009 return bool(self._numpy()) 1010 1011 __nonzero__ = __bool__ 1012 1013 def __format__(self, format_spec): 1014 if self._prefer_custom_summarizer(): 1015 return self._summarize_value().__format__(format_spec) 1016 elif self.dtype.is_numpy_compatible: 1017 # Not numpy_text here, otherwise the __format__ behaves differently. 1018 return self._numpy().__format__(format_spec) 1019 else: 1020 return "<unprintable>".__format__(format_spec) 1021 1022 def __reduce__(self): 1023 return convert_to_tensor, (self._numpy(),) 1024 1025 def __copy__(self): 1026 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1027 return self 1028 1029 def __deepcopy__(self, memo): 1030 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1031 del memo 1032 return self 1033 1034 def __str__(self): 1035 if self._prefer_custom_summarizer(): 1036 value_text = self._summarize_value() 1037 else: 1038 value_text = numpy_text(self) 1039 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (value_text, self.shape, 1040 self.dtype.name) 1041 1042 def __repr__(self): 1043 if self._prefer_custom_summarizer(): 1044 value_text = "value=" + self._summarize_value() 1045 else: 1046 value_text = "numpy=" + numpy_text(self, is_repr=True) 1047 return "<tf.Tensor: shape=%s, dtype=%s, %s>" % (self.shape, self.dtype.name, 1048 value_text) 1049 1050 def __len__(self): 1051 """Returns the length of the first dimension in the Tensor.""" 1052 if not self.shape.ndims: 1053 raise TypeError("Scalar tensor has no `len()`") 1054 # pylint: disable=protected-access 1055 try: 1056 return self._shape_tuple()[0] 1057 except core._NotOkStatusException as e: 1058 raise core._status_to_exception(e.code, e.message) from None 1059 1060 def __array__(self): 1061 return self._numpy() 1062 1063 def _numpy_internal(self): 1064 raise NotImplementedError() 1065 1066 def _numpy(self): 1067 try: 1068 return self._numpy_internal() 1069 except core._NotOkStatusException as e: # pylint: disable=protected-access 1070 raise core._status_to_exception(e.code, e.message) from None # pylint: disable=protected-access 1071 1072 @property 1073 def dtype(self): 1074 # Note: using the intern table directly here as this is 1075 # performance-sensitive in some models. 1076 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 1077 1078 def numpy(self): 1079 """Copy of the contents of this Tensor into a NumPy array or scalar. 1080 1081 Unlike NumPy arrays, Tensors are immutable, so this method has to copy 1082 the contents to ensure safety. Use `memoryview` to get a readonly 1083 view of the contents without doing a copy: 1084 1085 >>> t = tf.constant([42]) 1086 >>> np.array(memoryview(t)) 1087 array([42], dtype=int32) 1088 1089 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor 1090 is on GPU, it will have to be transferred to CPU first in order for 1091 `memoryview` to work. 1092 1093 Returns: 1094 A NumPy array of the same shape and dtype or a NumPy scalar, if this 1095 Tensor has rank 0. 1096 1097 Raises: 1098 ValueError: If the dtype of this Tensor does not have a compatible 1099 NumPy dtype. 1100 """ 1101 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors. 1102 maybe_arr = self._numpy() # pylint: disable=protected-access 1103 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr 1104 1105 @property 1106 def backing_device(self): 1107 """Returns the name of the device holding this tensor's memory. 1108 1109 `.backing_device` is usually the same as `.device`, which returns 1110 the device on which the kernel of the operation that produced this tensor 1111 ran. However, some operations can produce tensors on a different device 1112 (e.g., an operation that executes on the GPU but produces output tensors 1113 in host memory). 1114 """ 1115 raise NotImplementedError() 1116 1117 def _datatype_enum(self): 1118 raise NotImplementedError() 1119 1120 def _shape_tuple(self): 1121 """The shape of this Tensor, as a tuple. 1122 1123 This is more performant than tuple(shape().as_list()) as it avoids 1124 two list and one object creation. Marked private for now as from an API 1125 perspective, it would be better to have a single performant way of 1126 getting a shape rather than exposing shape() and shape_tuple() 1127 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 1128 but ideally one would work things out and remove the need for this method. 1129 1130 Returns: 1131 tuple with the shape. 1132 """ 1133 raise NotImplementedError() 1134 1135 def _rank(self): 1136 """Integer rank of this Tensor. 1137 1138 Unlike regular Tensors, the rank is always known for EagerTensors. 1139 1140 This is more performant than len(self._shape_tuple()) 1141 1142 Returns: 1143 Integer rank 1144 """ 1145 raise NotImplementedError() 1146 1147 def _num_elements(self): 1148 """Number of elements of this Tensor. 1149 1150 Unlike regular Tensors, the number of elements is always known for 1151 EagerTensors. 1152 1153 This is more performant than tensor.shape.num_elements 1154 1155 Returns: 1156 Long - num elements in the tensor 1157 """ 1158 raise NotImplementedError() 1159 1160 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name 1161 raise NotImplementedError() 1162 1163 @staticmethod 1164 def _override_operator(name, func): 1165 setattr(_EagerTensorBase, name, func) 1166 1167 def _copy_nograd(self, ctx=None, device_name=None): 1168 """Copies tensor to dest device, but doesn't record the operation.""" 1169 # Creates a new tensor on the dest device. 1170 if ctx is None: 1171 ctx = context.context() 1172 if device_name is None: 1173 device_name = ctx.device_name 1174 # pylint: disable=protected-access 1175 try: 1176 ctx.ensure_initialized() 1177 new_tensor = self._copy_to_device(device_name) 1178 except core._NotOkStatusException as e: 1179 raise core._status_to_exception(e.code, e.message) from None 1180 return new_tensor 1181 1182 def _copy(self, ctx=None, device_name=None): 1183 """Copies tensor to dest device.""" 1184 new_tensor = self._copy_nograd(ctx, device_name) 1185 # Record the copy on tape and define backprop copy as well. 1186 if context.executing_eagerly(): 1187 self_device = self.device 1188 1189 def grad_fun(dresult): 1190 return [ 1191 dresult._copy(device_name=self_device) 1192 if hasattr(dresult, "_copy") else dresult 1193 ] 1194 1195 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 1196 return new_tensor 1197 # pylint: enable=protected-access 1198 1199 @property 1200 def shape(self): 1201 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 1202 # pylint: disable=protected-access 1203 try: 1204 # `_tensor_shape` is declared and defined in the definition of 1205 # `EagerTensor`, in C. 1206 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 1207 except core._NotOkStatusException as e: 1208 raise core._status_to_exception(e.code, e.message) from None 1209 1210 return self._tensor_shape 1211 1212 def get_shape(self): 1213 """Alias of Tensor.shape.""" 1214 return self.shape 1215 1216 def _shape_as_list(self): 1217 """The shape of the tensor as a list.""" 1218 return list(self._shape_tuple()) 1219 1220 @property 1221 def ndim(self): 1222 """Returns the number of Tensor dimensions.""" 1223 return self.shape.ndims 1224 1225 @deprecation.deprecated(None, "Use tf.identity instead.") 1226 def cpu(self): 1227 """A copy of this Tensor with contents backed by host memory.""" 1228 return self._copy(context.context(), "CPU:0") 1229 1230 @deprecation.deprecated(None, "Use tf.identity instead.") 1231 def gpu(self, gpu_index=0): 1232 """A copy of this Tensor with contents backed by memory on the GPU. 1233 1234 Args: 1235 gpu_index: Identifies which GPU to place the contents on the returned 1236 Tensor in. 1237 1238 Returns: 1239 A GPU-memory backed Tensor object initialized with the same contents 1240 as this Tensor. 1241 """ 1242 return self._copy(context.context(), "GPU:" + str(gpu_index)) 1243 1244 def set_shape(self, shape): 1245 if not self.shape.is_compatible_with(shape): 1246 raise ValueError( 1247 "Tensor's shape %s is not compatible with supplied shape %s" % 1248 (self.shape, shape)) 1249 1250 # Methods not supported / implemented for Eager Tensors. 1251 @property 1252 def op(self): 1253 raise AttributeError( 1254 "Tensor.op is meaningless when eager execution is enabled.") 1255 1256 @property 1257 def graph(self): 1258 raise AttributeError( 1259 "Tensor.graph is meaningless when eager execution is enabled.") 1260 1261 @property 1262 def name(self): 1263 raise AttributeError( 1264 "Tensor.name is meaningless when eager execution is enabled.") 1265 1266 @property 1267 def value_index(self): 1268 raise AttributeError( 1269 "Tensor.value_index is meaningless when eager execution is enabled.") 1270 1271 def consumers(self): 1272 raise NotImplementedError( 1273 "Tensor.consumers is meaningless when eager execution is enabled.") 1274 1275 def _add_consumer(self, consumer): 1276 raise NotImplementedError( 1277 "_add_consumer not supported when eager execution is enabled.") 1278 1279 def _as_node_def_input(self): 1280 raise NotImplementedError( 1281 "_as_node_def_input not supported when eager execution is enabled.") 1282 1283 def _as_tf_output(self): 1284 raise NotImplementedError( 1285 "_as_tf_output not supported when eager execution is enabled.") 1286 1287 def eval(self, feed_dict=None, session=None): 1288 raise NotImplementedError( 1289 "eval is not supported when eager execution is enabled, " 1290 "is .numpy() what you're looking for?") 1291 1292 1293# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 1294# registers it with the current module. 1295# It is exposed as an __internal__ api for now (b/171081052), though we 1296# expect it to be eventually covered by tf Tensor types and typing. 1297EagerTensor = tf_export("__internal__.EagerTensor", v1=[])( 1298 pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)) 1299 1300 1301@tf_export(v1=["convert_to_tensor"]) 1302@dispatch.add_dispatch_support 1303def convert_to_tensor_v1_with_dispatch( 1304 value, 1305 dtype=None, 1306 name=None, 1307 preferred_dtype=None, 1308 dtype_hint=None): 1309 """Converts the given `value` to a `Tensor`. 1310 1311 This function converts Python objects of various types to `Tensor` 1312 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1313 and Python scalars. For example: 1314 1315 ```python 1316 import numpy as np 1317 1318 def my_func(arg): 1319 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1320 return tf.matmul(arg, arg) + arg 1321 1322 # The following calls are equivalent. 1323 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1324 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1325 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1326 ``` 1327 1328 This function can be useful when composing a new operation in Python 1329 (such as `my_func` in the example above). All standard Python op 1330 constructors apply this function to each of their Tensor-valued 1331 inputs, which allows those ops to accept numpy arrays, Python lists, 1332 and scalars in addition to `Tensor` objects. 1333 1334 Note: This function diverges from default Numpy behavior for `float` and 1335 `string` types when `None` is present in a Python list or scalar. Rather 1336 than silently converting `None` values, an error will be thrown. 1337 1338 Args: 1339 value: An object whose type has a registered `Tensor` conversion function. 1340 dtype: Optional element type for the returned tensor. If missing, the type 1341 is inferred from the type of `value`. 1342 name: Optional name to use if a new `Tensor` is created. 1343 preferred_dtype: Optional element type for the returned tensor, used when 1344 dtype is None. In some cases, a caller may not have a dtype in mind when 1345 converting to a tensor, so preferred_dtype can be used as a soft 1346 preference. If the conversion to `preferred_dtype` is not possible, this 1347 argument has no effect. 1348 dtype_hint: same meaning as preferred_dtype, and overrides it. 1349 1350 Returns: 1351 A `Tensor` based on `value`. 1352 1353 Raises: 1354 TypeError: If no conversion function is registered for `value` to `dtype`. 1355 RuntimeError: If a registered conversion function returns an invalid value. 1356 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1357 """ 1358 return convert_to_tensor_v1(value, dtype=dtype, name=name, 1359 preferred_dtype=preferred_dtype, 1360 dtype_hint=dtype_hint) 1361 1362 1363def convert_to_tensor_v1(value, 1364 dtype=None, 1365 name=None, 1366 preferred_dtype=None, 1367 dtype_hint=None): 1368 """Converts the given `value` to a `Tensor` (with the TF1 API).""" 1369 preferred_dtype = deprecation.deprecated_argument_lookup( 1370 "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype) 1371 return convert_to_tensor_v2(value, dtype, preferred_dtype, name) 1372 1373 1374@tf_export("convert_to_tensor", v1=[]) 1375@dispatch.add_dispatch_support 1376def convert_to_tensor_v2_with_dispatch( 1377 value, dtype=None, dtype_hint=None, name=None): 1378 """Converts the given `value` to a `Tensor`. 1379 1380 This function converts Python objects of various types to `Tensor` 1381 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1382 and Python scalars. 1383 1384 For example: 1385 1386 >>> import numpy as np 1387 >>> def my_func(arg): 1388 ... arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1389 ... return arg 1390 1391 >>> # The following calls are equivalent. 1392 ... 1393 >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1394 >>> print(value_1) 1395 tf.Tensor( 1396 [[1. 2.] 1397 [3. 4.]], shape=(2, 2), dtype=float32) 1398 >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1399 >>> print(value_2) 1400 tf.Tensor( 1401 [[1. 2.] 1402 [3. 4.]], shape=(2, 2), dtype=float32) 1403 >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1404 >>> print(value_3) 1405 tf.Tensor( 1406 [[1. 2.] 1407 [3. 4.]], shape=(2, 2), dtype=float32) 1408 1409 This function can be useful when composing a new operation in Python 1410 (such as `my_func` in the example above). All standard Python op 1411 constructors apply this function to each of their Tensor-valued 1412 inputs, which allows those ops to accept numpy arrays, Python lists, 1413 and scalars in addition to `Tensor` objects. 1414 1415 Note: This function diverges from default Numpy behavior for `float` and 1416 `string` types when `None` is present in a Python list or scalar. Rather 1417 than silently converting `None` values, an error will be thrown. 1418 1419 Args: 1420 value: An object whose type has a registered `Tensor` conversion function. 1421 dtype: Optional element type for the returned tensor. If missing, the type 1422 is inferred from the type of `value`. 1423 dtype_hint: Optional element type for the returned tensor, used when dtype 1424 is None. In some cases, a caller may not have a dtype in mind when 1425 converting to a tensor, so dtype_hint can be used as a soft preference. 1426 If the conversion to `dtype_hint` is not possible, this argument has no 1427 effect. 1428 name: Optional name to use if a new `Tensor` is created. 1429 1430 Returns: 1431 A `Tensor` based on `value`. 1432 1433 Raises: 1434 TypeError: If no conversion function is registered for `value` to `dtype`. 1435 RuntimeError: If a registered conversion function returns an invalid value. 1436 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1437 """ 1438 return convert_to_tensor_v2( 1439 value, dtype=dtype, dtype_hint=dtype_hint, name=name) 1440 1441 1442def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): 1443 """Converts the given `value` to a `Tensor`.""" 1444 return convert_to_tensor( 1445 value=value, 1446 dtype=dtype, 1447 name=name, 1448 preferred_dtype=dtype_hint, 1449 as_ref=False) 1450 1451 1452def _error_prefix(name): 1453 return "" if name is None else "%s: " % name 1454 1455 1456def pack_eager_tensors(tensors, ctx=None): 1457 """Pack multiple `EagerTensor`s of the same dtype and shape. 1458 1459 Args: 1460 tensors: a list of EagerTensors to pack. 1461 ctx: context.context(). 1462 1463 Returns: 1464 A packed EagerTensor. 1465 """ 1466 if not isinstance(tensors, list): 1467 raise TypeError("tensors must be a list or a tuple: %s" % tensors) 1468 1469 if not tensors: 1470 raise ValueError("Empty tensors is unexpected for packing.") 1471 1472 dtype = tensors[0].dtype 1473 shape = tensors[0].shape 1474 handle_data = tensors[0]._handle_data # pylint: disable=protected-access 1475 is_resource = dtype == dtypes.resource 1476 for i in range(len(tensors)): 1477 t = tensors[i] 1478 if not isinstance(t, EagerTensor): 1479 raise TypeError("tensors must be a list of EagerTensors: %s" % t) 1480 1481 if t.dtype != dtype: 1482 raise ValueError( 1483 "All tensors being packed should have the same dtype %s, " 1484 "but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype)) 1485 if t.shape != shape: 1486 raise ValueError( 1487 "All tensors being packed should have the same shape %s, " 1488 "but the %d-th tensor is of shape %s" % (shape, i, t.shape)) 1489 # pylint: disable=protected-access 1490 if is_resource and t._handle_data != handle_data: 1491 raise ValueError( 1492 "All tensors being packed should have the same handle data %s, " 1493 "but the %d-th tensor is of handle data %s" % 1494 (handle_data, i, t._handle_data)) 1495 # pylint: enable=protected-access 1496 1497 if ctx is None: 1498 ctx = context.context() 1499 1500 # Propogate handle data for resource variables 1501 packed_tensor = ctx.pack_eager_tensors(tensors) 1502 if handle_data is not None: 1503 packed_tensor._handle_data = handle_data # pylint: disable=protected-access 1504 1505 def grad_fun(_): 1506 raise ValueError( 1507 "Gradients through pack_eager_tensors are not supported yet.") 1508 1509 tape.record_operation("pack_eager_tensors", [packed_tensor], tensors, 1510 grad_fun) 1511 1512 return packed_tensor 1513 1514 1515@trace.trace_wrapper("convert_to_tensor") 1516def convert_to_tensor(value, 1517 dtype=None, 1518 name=None, 1519 as_ref=False, 1520 preferred_dtype=None, 1521 dtype_hint=None, 1522 ctx=None, 1523 accepted_result_types=(Tensor,)): 1524 """Implementation of the public convert_to_tensor.""" 1525 # TODO(b/142518781): Fix all call-sites and remove redundant arg 1526 preferred_dtype = preferred_dtype or dtype_hint 1527 if isinstance(value, EagerTensor): 1528 if ctx is None: 1529 ctx = context.context() 1530 if not ctx.executing_eagerly(): 1531 graph = get_default_graph() 1532 if not graph.building_function: 1533 raise RuntimeError("Attempting to capture an EagerTensor without " 1534 "building a function.") 1535 return graph.capture(value, name=name) 1536 1537 if dtype is not None: 1538 dtype = dtypes.as_dtype(dtype) 1539 if isinstance(value, Tensor): 1540 if dtype is not None and not dtype.is_compatible_with(value.dtype): 1541 raise ValueError( 1542 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1543 (dtype.name, value.dtype.name, value)) 1544 return value 1545 1546 if preferred_dtype is not None: 1547 preferred_dtype = dtypes.as_dtype(preferred_dtype) 1548 1549 # See below for the reason why it's `type(value)` and not just `value`. 1550 # https://docs.python.org/3.8/reference/datamodel.html#special-lookup 1551 overload = getattr(type(value), "__tf_tensor__", None) 1552 if overload is not None: 1553 return overload(value, dtype, name) # pylint: disable=not-callable 1554 1555 for base_type, conversion_func in tensor_conversion_registry.get(type(value)): 1556 # If dtype is None but preferred_dtype is not None, we try to 1557 # cast to preferred_dtype first. 1558 ret = None 1559 if dtype is None and preferred_dtype is not None: 1560 try: 1561 ret = conversion_func( 1562 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1563 except (TypeError, ValueError): 1564 # Could not coerce the conversion to use the preferred dtype. 1565 pass 1566 else: 1567 if (ret is not NotImplemented and 1568 ret.dtype.base_dtype != preferred_dtype.base_dtype): 1569 raise TypeError("convert_to_tensor did not convert to " 1570 "the preferred dtype: %s vs %s " % 1571 (ret.dtype.base_dtype, preferred_dtype.base_dtype)) 1572 1573 if ret is None: 1574 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1575 1576 if ret is NotImplemented: 1577 continue 1578 1579 if not isinstance(ret, accepted_result_types): 1580 raise RuntimeError( 1581 "%sConversion function %r for type %s returned non-Tensor: %r" % 1582 (_error_prefix(name), conversion_func, base_type, ret)) 1583 if dtype and not dtype.is_compatible_with(ret.dtype): 1584 raise RuntimeError( 1585 "%sConversion function %r for type %s returned incompatible " 1586 "dtype: requested = %s, actual = %s" % 1587 (_error_prefix(name), conversion_func, base_type, dtype.name, 1588 ret.dtype.name)) 1589 return ret 1590 raise TypeError("%sCannot convert %r with type %s to Tensor: " 1591 "no conversion function registered." % 1592 (_error_prefix(name), value, type(value))) 1593 1594 1595internal_convert_to_tensor = convert_to_tensor 1596 1597 1598def internal_convert_n_to_tensor(values, 1599 dtype=None, 1600 name=None, 1601 as_ref=False, 1602 preferred_dtype=None, 1603 ctx=None): 1604 """Converts `values` to a list of `Tensor` objects. 1605 1606 Args: 1607 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1608 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1609 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1610 which case element `i` will be given the name `name + '_' + i`. 1611 as_ref: True if the caller wants the results as ref tensors. 1612 preferred_dtype: Optional element type for the returned tensors, used when 1613 dtype is None. In some cases, a caller may not have a dtype in mind when 1614 converting to a tensor, so preferred_dtype can be used as a soft 1615 preference. If the conversion to `preferred_dtype` is not possible, this 1616 argument has no effect. 1617 ctx: The value of context.context(). 1618 1619 Returns: 1620 A list of `Tensor` and/or `IndexedSlices` objects. 1621 1622 Raises: 1623 TypeError: If no conversion function is registered for an element in 1624 `values`. 1625 RuntimeError: If a registered conversion function returns an invalid 1626 value. 1627 """ 1628 if not isinstance(values, collections_abc.Sequence): 1629 raise TypeError("values must be a sequence.") 1630 ret = [] 1631 if ctx is None: 1632 ctx = context.context() 1633 for i, value in enumerate(values): 1634 n = None if name is None else "%s_%d" % (name, i) 1635 ret.append( 1636 convert_to_tensor( 1637 value, 1638 dtype=dtype, 1639 name=n, 1640 as_ref=as_ref, 1641 preferred_dtype=preferred_dtype, 1642 ctx=ctx)) 1643 return ret 1644 1645 1646def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1647 """Converts `values` to a list of `Tensor` objects. 1648 1649 Args: 1650 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1651 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1652 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1653 which case element `i` will be given the name `name + '_' + i`. 1654 preferred_dtype: Optional element type for the returned tensors, used when 1655 dtype is None. In some cases, a caller may not have a dtype in mind when 1656 converting to a tensor, so preferred_dtype can be used as a soft 1657 preference. If the conversion to `preferred_dtype` is not possible, this 1658 argument has no effect. 1659 1660 Returns: 1661 A list of `Tensor` and/or `IndexedSlices` objects. 1662 1663 Raises: 1664 TypeError: If no conversion function is registered for an element in 1665 `values`. 1666 RuntimeError: If a registered conversion function returns an invalid 1667 value. 1668 """ 1669 return internal_convert_n_to_tensor( 1670 values=values, 1671 dtype=dtype, 1672 name=name, 1673 preferred_dtype=preferred_dtype, 1674 as_ref=False) 1675 1676 1677def convert_to_tensor_or_composite(value, dtype=None, name=None): 1678 """Converts the given object to a `Tensor` or `CompositeTensor`. 1679 1680 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1681 is converted to a `Tensor` using `convert_to_tensor()`. 1682 1683 Args: 1684 value: A `CompositeTensor` or an object that can be consumed by 1685 `convert_to_tensor()`. 1686 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1687 `CompositeTensor`. 1688 name: (Optional.) A name to use if a new `Tensor` is created. 1689 1690 Returns: 1691 A `Tensor` or `CompositeTensor`, based on `value`. 1692 1693 Raises: 1694 ValueError: If `dtype` does not match the element type of `value`. 1695 """ 1696 return internal_convert_to_tensor_or_composite( 1697 value=value, dtype=dtype, name=name, as_ref=False) 1698 1699 1700def internal_convert_to_tensor_or_composite(value, 1701 dtype=None, 1702 name=None, 1703 as_ref=False): 1704 """Converts the given object to a `Tensor` or `CompositeTensor`. 1705 1706 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1707 is converted to a `Tensor` using `convert_to_tensor()`. 1708 1709 Args: 1710 value: A `CompositeTensor`, or an object that can be consumed by 1711 `convert_to_tensor()`. 1712 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1713 `CompositeTensor`. 1714 name: (Optional.) A name to use if a new `Tensor` is created. 1715 as_ref: True if the caller wants the results as ref tensors. 1716 1717 Returns: 1718 A `Tensor` or `CompositeTensor`, based on `value`. 1719 1720 Raises: 1721 ValueError: If `dtype` does not match the element type of `value`. 1722 """ 1723 if isinstance(value, composite_tensor.CompositeTensor): 1724 value_dtype = getattr(value, "dtype", None) 1725 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 1726 raise ValueError( 1727 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1728 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1729 return value 1730 else: 1731 return convert_to_tensor( 1732 value, 1733 dtype=dtype, 1734 name=name, 1735 as_ref=as_ref, 1736 accepted_result_types=(Tensor, composite_tensor.CompositeTensor)) 1737 1738 1739def internal_convert_n_to_tensor_or_composite(values, 1740 dtype=None, 1741 name=None, 1742 as_ref=False): 1743 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 1744 1745 Any `CompositeTensor` objects in `values` are returned unmodified. 1746 1747 Args: 1748 values: A list of `None`, `CompositeTensor`, or objects that can be consumed 1749 by `convert_to_tensor()`. 1750 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1751 `CompositeTensor`s. 1752 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1753 which case element `i` will be given the name `name + '_' + i`. 1754 as_ref: True if the caller wants the results as ref tensors. 1755 1756 Returns: 1757 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 1758 1759 Raises: 1760 TypeError: If no conversion function is registered for an element in 1761 `values`. 1762 RuntimeError: If a registered conversion function returns an invalid 1763 value. 1764 """ 1765 if not isinstance(values, collections_abc.Sequence): 1766 raise TypeError("values must be a sequence.") 1767 ret = [] 1768 for i, value in enumerate(values): 1769 if value is None: 1770 ret.append(value) 1771 else: 1772 n = None if name is None else "%s_%d" % (name, i) 1773 ret.append( 1774 internal_convert_to_tensor_or_composite( 1775 value, dtype=dtype, name=n, as_ref=as_ref)) 1776 return ret 1777 1778 1779def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 1780 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 1781 1782 Any `CompositeTensor` objects in `values` are returned unmodified. 1783 1784 Args: 1785 values: A list of `None`, `CompositeTensor``, or objects that can be 1786 consumed by `convert_to_tensor()`. 1787 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1788 `CompositeTensor`s. 1789 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1790 which case element `i` will be given the name `name + '_' + i`. 1791 1792 Returns: 1793 A list of `Tensor` and/or `CompositeTensor` objects. 1794 1795 Raises: 1796 TypeError: If no conversion function is registered for an element in 1797 `values`. 1798 RuntimeError: If a registered conversion function returns an invalid 1799 value. 1800 """ 1801 return internal_convert_n_to_tensor_or_composite( 1802 values=values, dtype=dtype, name=name, as_ref=False) 1803 1804 1805def _device_string(dev_spec): 1806 if pydev.is_device_spec(dev_spec): 1807 return dev_spec.to_string() 1808 else: 1809 return dev_spec 1810 1811 1812def _NodeDef(op_type, name, attrs=None): 1813 """Create a NodeDef proto. 1814 1815 Args: 1816 op_type: Value for the "op" attribute of the NodeDef proto. 1817 name: Value for the "name" attribute of the NodeDef proto. 1818 attrs: Dictionary where the key is the attribute name (a string) 1819 and the value is the respective "attr" attribute of the NodeDef proto (an 1820 AttrValue). 1821 1822 Returns: 1823 A node_def_pb2.NodeDef protocol buffer. 1824 """ 1825 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type), 1826 name=compat.as_bytes(name)) 1827 if attrs: 1828 for k, v in six.iteritems(attrs): 1829 node_def.attr[k].CopyFrom(v) 1830 return node_def 1831 1832 1833# Copied from core/framework/node_def_util.cc 1834# TODO(mrry,josh11b): Consolidate this validation in C++ code. 1835_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") 1836_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$") 1837 1838 1839@tf_export("__internal__.create_c_op", v1=[]) 1840@traceback_utils.filter_traceback 1841def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None): 1842 """Creates a TF_Operation. 1843 1844 Args: 1845 graph: a `Graph`. 1846 node_def: `node_def_pb2.NodeDef` for the operation to create. 1847 inputs: A flattened list of `Tensor`s. This function handles grouping 1848 tensors into lists as per attributes in the `node_def`. 1849 control_inputs: A list of `Operation`s to set as control dependencies. 1850 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not 1851 specified, is looked up from the `graph` using `node_def.op`. 1852 1853 Returns: 1854 A wrapped TF_Operation*. 1855 """ 1856 if op_def is None: 1857 op_def = graph._get_op_def(node_def.op) # pylint: disable=protected-access 1858 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1859 # Refactor so we don't have to do this here. 1860 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr) 1861 # pylint: disable=protected-access 1862 op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph, 1863 compat.as_str(node_def.op), 1864 compat.as_str(node_def.name)) 1865 if node_def.device: 1866 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 1867 # Add inputs 1868 for op_input in inputs: 1869 if isinstance(op_input, (list, tuple)): 1870 pywrap_tf_session.TF_AddInputList(op_desc, 1871 [t._as_tf_output() for t in op_input]) 1872 else: 1873 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output()) 1874 1875 # Add control inputs 1876 for control_input in control_inputs: 1877 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op) 1878 # pylint: enable=protected-access 1879 1880 # Add attrs 1881 for name, attr_value in node_def.attr.items(): 1882 serialized = attr_value.SerializeToString() 1883 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1884 # It might be worth creating a convenient way to re-use the same status. 1885 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name), 1886 serialized) 1887 1888 try: 1889 c_op = pywrap_tf_session.TF_FinishOperation(op_desc) 1890 except errors.InvalidArgumentError as e: 1891 # Convert to ValueError for backwards compatibility. 1892 raise ValueError(e.message) 1893 1894 return c_op 1895 1896 1897@tf_export("Operation") 1898class Operation(object): 1899 """Represents a graph node that performs computation on tensors. 1900 1901 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor` 1902 objects as input, and produces zero or more `Tensor` objects as output. 1903 Objects of type `Operation` are created by calling a Python op constructor 1904 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default` 1905 context manager. 1906 1907 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an 1908 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and 1909 produces `c` as output. 1910 1911 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be 1912 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for 1913 calling `tf.compat.v1.get_default_session().run(op)`. 1914 """ 1915 1916 def __init__(self, 1917 node_def, 1918 g, 1919 inputs=None, 1920 output_types=None, 1921 control_inputs=None, 1922 input_types=None, 1923 original_op=None, 1924 op_def=None): 1925 r"""Creates an `Operation`. 1926 1927 NOTE: This constructor validates the name of the `Operation` (passed 1928 as `node_def.name`). Valid `Operation` names match the following 1929 regular expression: 1930 1931 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 1932 1933 Args: 1934 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for 1935 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and 1936 `device`. The `input` attribute is irrelevant here as it will be 1937 computed when generating the model. 1938 g: `Graph`. The parent graph. 1939 inputs: list of `Tensor` objects. The inputs to this `Operation`. 1940 output_types: list of `DType` objects. List of the types of the `Tensors` 1941 computed by this operation. The length of this list indicates the 1942 number of output endpoints of the `Operation`. 1943 control_inputs: list of operations or tensors from which to have a control 1944 dependency. 1945 input_types: List of `DType` objects representing the types of the tensors 1946 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x 1947 in inputs]`. Operations that expect reference-typed inputs must specify 1948 these explicitly. 1949 original_op: Optional. Used to associate the new `Operation` with an 1950 existing `Operation` (for example, a replica with the op that was 1951 replicated). 1952 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type 1953 that this `Operation` represents. 1954 1955 Raises: 1956 TypeError: if control inputs are not Operations or Tensors, 1957 or if `node_def` is not a `NodeDef`, 1958 or if `g` is not a `Graph`, 1959 or if `inputs` are not tensors, 1960 or if `inputs` and `input_types` are incompatible. 1961 ValueError: if the `node_def` name is not valid. 1962 """ 1963 # For internal use only: `node_def` can be set to a TF_Operation to create 1964 # an Operation for that op. This is useful for creating Operations for ops 1965 # indirectly created by C API methods, e.g. the ops created by 1966 # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields 1967 # should be None. 1968 1969 if isinstance(node_def, node_def_pb2.NodeDef): 1970 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 1971 raise ValueError( 1972 "Cannot create a tensor proto whose content is larger than 2GB.") 1973 if not _VALID_OP_NAME_REGEX.match(node_def.name): 1974 raise ValueError("'%s' is not a valid node name" % node_def.name) 1975 c_op = None 1976 elif type(node_def).__name__ == "TF_Operation": 1977 assert inputs is None 1978 assert output_types is None 1979 assert control_inputs is None 1980 assert input_types is None 1981 assert original_op is None 1982 assert op_def is None 1983 c_op = node_def 1984 else: 1985 raise TypeError("node_def needs to be a NodeDef: %s" % (node_def,)) 1986 1987 if not isinstance(g, Graph): 1988 raise TypeError("g needs to be a Graph: %s" % (g,)) 1989 self._graph = g 1990 1991 if inputs is None: 1992 inputs = [] 1993 elif not isinstance(inputs, list): 1994 raise TypeError("inputs needs to be a list of Tensors: %s" % (inputs,)) 1995 for a in inputs: 1996 if not isinstance(a, Tensor): 1997 raise TypeError("input needs to be a Tensor: %s" % (a,)) 1998 if input_types is None: 1999 input_types = [i.dtype.base_dtype for i in inputs] 2000 else: 2001 if not all( 2002 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)): 2003 raise TypeError("In op '%s', input types (%s) are not compatible " 2004 "with expected types (%s)" % 2005 (node_def.name, [i.dtype for i in inputs], input_types)) 2006 2007 # Build the list of control inputs. 2008 control_input_ops = [] 2009 if control_inputs: 2010 for c in control_inputs: 2011 control_op = None 2012 if isinstance(c, Operation): 2013 control_op = c 2014 elif isinstance(c, (Tensor, IndexedSlices)): 2015 control_op = c.op 2016 else: 2017 raise TypeError("Control input must be an Operation, " 2018 "a Tensor, or IndexedSlices: %s" % c) 2019 control_input_ops.append(control_op) 2020 2021 # This will be set by self.inputs. 2022 self._inputs_val = None 2023 2024 # pylint: disable=protected-access 2025 self._original_op = original_op 2026 2027 # List of _UserDevSpecs holding code location of device context manager 2028 # invocations and the users original argument to them. 2029 self._device_code_locations = None 2030 # Dict mapping op name to file and line information for op colocation 2031 # context managers. 2032 self._colocation_code_locations = None 2033 self._control_flow_context = self.graph._get_control_flow_context() 2034 2035 # Gradient function for this op. There are three ways to specify gradient 2036 # function, and first available gradient gets used, in the following order. 2037 # 1. self._gradient_function 2038 # 2. Gradient name registered by "_gradient_op_type" attribute. 2039 # 3. Gradient name registered by op.type. 2040 self._gradient_function = None 2041 2042 # Initialize self._c_op. 2043 if c_op: 2044 self._c_op = c_op 2045 op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op)) 2046 name = self.name 2047 else: 2048 if op_def is None: 2049 op_def = self._graph._get_op_def(node_def.op) 2050 self._c_op = _create_c_op(self._graph, node_def, inputs, 2051 control_input_ops, op_def) 2052 name = compat.as_str(node_def.name) 2053 2054 self._traceback = tf_stack.extract_stack_for_node(self._c_op) 2055 2056 # pylint: enable=protected-access 2057 2058 self._is_stateful = op_def.is_stateful 2059 2060 # Initialize self._outputs. 2061 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2062 self._outputs = [] 2063 for i in range(num_outputs): 2064 tf_output = c_api_util.tf_output(self._c_op, i) 2065 output_type = pywrap_tf_session.TF_OperationOutputType(tf_output) 2066 tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access 2067 self._outputs.append(tensor) 2068 2069 self._id_value = self._graph._add_op(self, name) # pylint: disable=protected-access 2070 2071 if not c_op: 2072 self._control_flow_post_processing(input_tensors=inputs) 2073 2074 def _control_flow_post_processing(self, input_tensors=None): 2075 """Add this op to its control flow context. 2076 2077 This may add new ops and change this op's inputs. self.inputs must be 2078 available before calling this method. 2079 2080 Args: 2081 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs 2082 of this op, which should be equivalent to `self.inputs`. Pass this 2083 argument to avoid evaluating `self.inputs` unnecessarily. 2084 """ 2085 if input_tensors is None: 2086 input_tensors = self.inputs 2087 for input_tensor in input_tensors: 2088 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 2089 if self._control_flow_context is not None: 2090 self._control_flow_context.AddOp(self) 2091 2092 def colocation_groups(self): 2093 """Returns the list of colocation groups of the op.""" 2094 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)] 2095 try: 2096 class_attr = self.get_attr("_class") 2097 except ValueError: 2098 # This op has no explicit colocation group, so it is itself its 2099 # own root of a colocation group. 2100 return default_colocation_group 2101 2102 attr_groups = [ 2103 class_name for class_name in class_attr 2104 if class_name.startswith(b"loc:@") 2105 ] 2106 2107 # If there are no colocation groups in the explicit _class field, 2108 # return the default colocation group. 2109 return attr_groups if attr_groups else default_colocation_group 2110 2111 def values(self): 2112 """DEPRECATED: Use outputs.""" 2113 return tuple(self.outputs) 2114 2115 def _get_control_flow_context(self): 2116 """Returns the control flow context of this op. 2117 2118 Returns: 2119 A context object. 2120 """ 2121 return self._control_flow_context 2122 2123 def _set_control_flow_context(self, ctx): 2124 """Sets the current control flow context of this op. 2125 2126 Args: 2127 ctx: a context object. 2128 """ 2129 self._control_flow_context = ctx 2130 2131 @property 2132 def name(self): 2133 """The full name of this operation.""" 2134 return pywrap_tf_session.TF_OperationName(self._c_op) 2135 2136 @property 2137 def _id(self): 2138 """The unique integer id of this operation.""" 2139 return self._id_value 2140 2141 @property 2142 def device(self): 2143 """The name of the device to which this op has been assigned, if any. 2144 2145 Returns: 2146 The string name of the device to which this op has been 2147 assigned, or an empty string if it has not been assigned to a 2148 device. 2149 """ 2150 return pywrap_tf_session.TF_OperationDevice(self._c_op) 2151 2152 @property 2153 def _device_assignments(self): 2154 """Code locations for device context managers active at op creation. 2155 2156 This property will return a list of traceable_stack.TraceableObject 2157 instances where .obj is a string representing the assigned device 2158 (or information about the function that would be applied to this op 2159 to compute the desired device) and the filename and lineno members 2160 record the location of the relevant device context manager. 2161 2162 For example, suppose file_a contained these lines: 2163 2164 file_a.py: 2165 15: with tf.device('/gpu:0'): 2166 16: node_b = tf.constant(4, name='NODE_B') 2167 2168 Then a TraceableObject t_obj representing the device context manager 2169 would have these member values: 2170 2171 t_obj.obj -> '/gpu:0' 2172 t_obj.filename = 'file_a.py' 2173 t_obj.lineno = 15 2174 2175 and node_b.op._device_assignments would return the list [t_obj]. 2176 2177 Returns: 2178 [str: traceable_stack.TraceableObject, ...] as per this method's 2179 description, above. 2180 """ 2181 return self._device_code_locations or [] 2182 2183 @property 2184 def _colocation_dict(self): 2185 """Code locations for colocation context managers active at op creation. 2186 2187 This property will return a dictionary for which the keys are nodes with 2188 which this Operation is colocated, and for which the values are 2189 traceable_stack.TraceableObject instances. The TraceableObject instances 2190 record the location of the relevant colocation context manager but have the 2191 "obj" field set to None to prevent leaking private data. 2192 2193 For example, suppose file_a contained these lines: 2194 2195 file_a.py: 2196 14: node_a = tf.constant(3, name='NODE_A') 2197 15: with tf.compat.v1.colocate_with(node_a): 2198 16: node_b = tf.constant(4, name='NODE_B') 2199 2200 Then a TraceableObject t_obj representing the colocation context manager 2201 would have these member values: 2202 2203 t_obj.obj -> None 2204 t_obj.filename = 'file_a.py' 2205 t_obj.lineno = 15 2206 2207 and node_b.op._colocation_dict would return the dictionary 2208 2209 { 'NODE_A': t_obj } 2210 2211 Returns: 2212 {str: traceable_stack.TraceableObject} as per this method's description, 2213 above. 2214 """ 2215 locations_dict = self._colocation_code_locations or {} 2216 return locations_dict.copy() 2217 2218 @property 2219 def _output_types(self): 2220 """List this operation's output types. 2221 2222 Returns: 2223 List of the types of the Tensors computed by this operation. 2224 Each element in the list is an integer whose value is one of 2225 the TF_DataType enums defined in pywrap_tf_session.h 2226 The length of this list indicates the number of output endpoints 2227 of the operation. 2228 """ 2229 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2230 output_types = [ 2231 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i))) 2232 for i in xrange(num_outputs) 2233 ] 2234 2235 return output_types 2236 2237 def _tf_output(self, output_idx): 2238 """Create and return a new TF_Output for output_idx'th output of this op.""" 2239 tf_output = pywrap_tf_session.TF_Output() 2240 tf_output.oper = self._c_op 2241 tf_output.index = output_idx 2242 return tf_output 2243 2244 def _tf_input(self, input_idx): 2245 """Create and return a new TF_Input for input_idx'th input of this op.""" 2246 tf_input = pywrap_tf_session.TF_Input() 2247 tf_input.oper = self._c_op 2248 tf_input.index = input_idx 2249 return tf_input 2250 2251 def _set_device(self, device): # pylint: disable=redefined-outer-name 2252 """Set the device of this operation. 2253 2254 Args: 2255 device: string or device.. The device to set. 2256 """ 2257 self._set_device_from_string(compat.as_str(_device_string(device))) 2258 2259 def _set_device_from_string(self, device_str): 2260 """Fast path to set device if the type is known to be a string. 2261 2262 This function is called frequently enough during graph construction that 2263 there are non-trivial performance gains if the caller can guarantee that 2264 the specified device is already a string. 2265 2266 Args: 2267 device_str: A string specifying where to place this op. 2268 """ 2269 pywrap_tf_session.SetRequestedDevice( 2270 self._graph._c_graph, # pylint: disable=protected-access 2271 self._c_op, # pylint: disable=protected-access 2272 device_str) 2273 2274 def _update_input(self, index, tensor): 2275 """Update the input to this operation at the given index. 2276 2277 NOTE: This is for TF internal use only. Please don't use it. 2278 2279 Args: 2280 index: the index of the input to update. 2281 tensor: the Tensor to be used as the input at the given index. 2282 2283 Raises: 2284 TypeError: if tensor is not a Tensor, 2285 or if input tensor type is not convertible to dtype. 2286 ValueError: if the Tensor is from a different graph. 2287 """ 2288 if not isinstance(tensor, Tensor): 2289 raise TypeError("tensor must be a Tensor: %s" % tensor) 2290 _assert_same_graph(self, tensor) 2291 2292 # Reset cached inputs. 2293 self._inputs_val = None 2294 pywrap_tf_session.UpdateEdge( 2295 self._graph._c_graph, # pylint: disable=protected-access 2296 tensor._as_tf_output(), # pylint: disable=protected-access 2297 self._tf_input(index)) 2298 2299 def _add_while_inputs(self, tensors): 2300 """See AddWhileInputHack in python_api.h. 2301 2302 NOTE: This is for TF internal use only. Please don't use it. 2303 2304 Args: 2305 tensors: list of Tensors 2306 2307 Raises: 2308 TypeError: if tensor is not a Tensor, 2309 or if input tensor type is not convertible to dtype. 2310 ValueError: if the Tensor is from a different graph. 2311 """ 2312 for tensor in tensors: 2313 if not isinstance(tensor, Tensor): 2314 raise TypeError("tensor must be a Tensor: %s" % tensor) 2315 _assert_same_graph(self, tensor) 2316 2317 # Reset cached inputs. 2318 self._inputs_val = None 2319 pywrap_tf_session.AddWhileInputHack( 2320 self._graph._c_graph, # pylint: disable=protected-access 2321 tensor._as_tf_output(), # pylint: disable=protected-access 2322 self._c_op) 2323 2324 def _add_control_inputs(self, ops): 2325 """Add a list of new control inputs to this operation. 2326 2327 Args: 2328 ops: the list of Operations to add as control input. 2329 2330 Raises: 2331 TypeError: if ops is not a list of Operations. 2332 ValueError: if any op in ops is from a different graph. 2333 """ 2334 for op in ops: 2335 if not isinstance(op, Operation): 2336 raise TypeError("op must be an Operation: %s" % op) 2337 pywrap_tf_session.AddControlInput( 2338 self._graph._c_graph, # pylint: disable=protected-access 2339 self._c_op, # pylint: disable=protected-access 2340 op._c_op) # pylint: disable=protected-access 2341 2342 def _add_control_input(self, op): 2343 """Add a new control input to this operation. 2344 2345 Args: 2346 op: the Operation to add as control input. 2347 2348 Raises: 2349 TypeError: if op is not an Operation. 2350 ValueError: if op is from a different graph. 2351 """ 2352 if not isinstance(op, Operation): 2353 raise TypeError("op must be an Operation: %s" % op) 2354 pywrap_tf_session.AddControlInput( 2355 self._graph._c_graph, # pylint: disable=protected-access 2356 self._c_op, # pylint: disable=protected-access 2357 op._c_op) # pylint: disable=protected-access 2358 2359 def _remove_all_control_inputs(self): 2360 """Removes any control inputs to this operation.""" 2361 pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access 2362 2363 def _add_outputs(self, types, shapes): 2364 """Adds new Tensors to self.outputs. 2365 2366 Note: this is generally unsafe to use. This is used in certain situations in 2367 conjunction with _set_type_list_attr. 2368 2369 Args: 2370 types: list of DTypes 2371 shapes: list of TensorShapes 2372 """ 2373 assert len(types) == len(shapes) 2374 orig_num_outputs = len(self.outputs) 2375 for i in range(len(types)): 2376 t = Tensor(self, orig_num_outputs + i, types[i]) 2377 self._outputs.append(t) 2378 t.set_shape(shapes[i]) 2379 2380 def __str__(self): 2381 return str(self.node_def) 2382 2383 def __repr__(self): 2384 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 2385 2386 def __tf_tensor__(self, dtype=None, name=None): 2387 """Raises a helpful error.""" 2388 raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) 2389 2390 @property 2391 def outputs(self): 2392 """The list of `Tensor` objects representing the outputs of this op.""" 2393 return self._outputs 2394 2395 @property 2396 def inputs(self): 2397 """The sequence of `Tensor` objects representing the data inputs of this op.""" 2398 if self._inputs_val is None: 2399 # pylint: disable=protected-access 2400 self._inputs_val = tuple( 2401 map(self.graph._get_tensor_by_tf_output, 2402 pywrap_tf_session.GetOperationInputs(self._c_op))) 2403 # pylint: enable=protected-access 2404 return self._inputs_val 2405 2406 @property 2407 def _input_types(self): 2408 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op) 2409 input_types = [ 2410 dtypes.as_dtype( 2411 pywrap_tf_session.TF_OperationInputType(self._tf_input(i))) 2412 for i in xrange(num_inputs) 2413 ] 2414 return input_types 2415 2416 @property 2417 def control_inputs(self): 2418 """The `Operation` objects on which this op has a control dependency. 2419 2420 Before this op is executed, TensorFlow will ensure that the 2421 operations in `self.control_inputs` have finished executing. This 2422 mechanism can be used to run ops sequentially for performance 2423 reasons, or to ensure that the side effects of an op are observed 2424 in the correct order. 2425 2426 Returns: 2427 A list of `Operation` objects. 2428 2429 """ 2430 control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper( 2431 self._c_op) 2432 # pylint: disable=protected-access 2433 return [ 2434 self.graph._get_operation_by_name_unsafe( 2435 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2436 ] 2437 # pylint: enable=protected-access 2438 2439 @property 2440 def _control_outputs(self): 2441 """The `Operation` objects which have a control dependency on this op. 2442 2443 Before any of the ops in self._control_outputs can execute tensorflow will 2444 ensure self has finished executing. 2445 2446 Returns: 2447 A list of `Operation` objects. 2448 2449 """ 2450 control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper( 2451 self._c_op) 2452 # pylint: disable=protected-access 2453 return [ 2454 self.graph._get_operation_by_name_unsafe( 2455 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2456 ] 2457 # pylint: enable=protected-access 2458 2459 @property 2460 def type(self): 2461 """The type of the op (e.g. `"MatMul"`).""" 2462 return pywrap_tf_session.TF_OperationOpType(self._c_op) 2463 2464 @property 2465 def graph(self): 2466 """The `Graph` that contains this operation.""" 2467 return self._graph 2468 2469 @property 2470 def node_def(self): 2471 # pylint: disable=line-too-long 2472 """Returns the `NodeDef` representation of this operation. 2473 2474 Returns: 2475 A 2476 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2477 protocol buffer. 2478 """ 2479 # pylint: enable=line-too-long 2480 with c_api_util.tf_buffer() as buf: 2481 pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf) 2482 data = pywrap_tf_session.TF_GetBuffer(buf) 2483 node_def = node_def_pb2.NodeDef() 2484 node_def.ParseFromString(compat.as_bytes(data)) 2485 return node_def 2486 2487 @property 2488 def op_def(self): 2489 # pylint: disable=line-too-long 2490 """Returns the `OpDef` proto that represents the type of this op. 2491 2492 Returns: 2493 An 2494 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2495 protocol buffer. 2496 """ 2497 # pylint: enable=line-too-long 2498 return self._graph._get_op_def(self.type) 2499 2500 @property 2501 def traceback(self): 2502 """Returns the call stack from when this operation was constructed.""" 2503 return self._traceback 2504 2505 def _set_attr(self, attr_name, attr_value): 2506 """Private method used to set an attribute in the node_def.""" 2507 buf = pywrap_tf_session.TF_NewBufferFromString( 2508 compat.as_bytes(attr_value.SerializeToString())) 2509 try: 2510 self._set_attr_with_buf(attr_name, buf) 2511 finally: 2512 pywrap_tf_session.TF_DeleteBuffer(buf) 2513 2514 def _set_attr_with_buf(self, attr_name, attr_buf): 2515 """Set an attr in the node_def with a pre-allocated buffer.""" 2516 # pylint: disable=protected-access 2517 pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name, 2518 attr_buf) 2519 # pylint: enable=protected-access 2520 2521 def _set_func_attr(self, attr_name, func_name): 2522 """Private method used to set a function attribute in the node_def.""" 2523 func = attr_value_pb2.NameAttrList(name=func_name) 2524 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 2525 2526 def _set_func_list_attr(self, attr_name, func_names): 2527 """Private method used to set a list(function) attribute in the node_def.""" 2528 funcs = [attr_value_pb2.NameAttrList(name=func_name) 2529 for func_name in func_names] 2530 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) 2531 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) 2532 2533 def _set_type_list_attr(self, attr_name, types): 2534 """Private method used to set a list(type) attribute in the node_def.""" 2535 if not types: 2536 return 2537 if isinstance(types[0], dtypes.DType): 2538 types = [dt.as_datatype_enum for dt in types] 2539 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 2540 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 2541 2542 def _set_shape_list_attr(self, attr_name, shapes): 2543 """Private method used to set a list(shape) attribute in the node_def.""" 2544 shapes = [s.as_proto() for s in shapes] 2545 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 2546 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 2547 2548 def _clear_attr(self, attr_name): 2549 """Private method used to clear an attribute in the node_def.""" 2550 # pylint: disable=protected-access 2551 pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name) 2552 # pylint: enable=protected-access 2553 2554 def get_attr(self, name): 2555 """Returns the value of the attr of this op with the given `name`. 2556 2557 Args: 2558 name: The name of the attr to fetch. 2559 2560 Returns: 2561 The value of the attr, as a Python object. 2562 2563 Raises: 2564 ValueError: If this op does not have an attr with the given `name`. 2565 """ 2566 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 2567 try: 2568 with c_api_util.tf_buffer() as buf: 2569 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf) 2570 data = pywrap_tf_session.TF_GetBuffer(buf) 2571 except errors.InvalidArgumentError as e: 2572 # Convert to ValueError for backwards compatibility. 2573 raise ValueError(e.message) 2574 x = attr_value_pb2.AttrValue() 2575 x.ParseFromString(data) 2576 2577 oneof_value = x.WhichOneof("value") 2578 if oneof_value is None: 2579 return [] 2580 if oneof_value == "list": 2581 for f in fields: 2582 if getattr(x.list, f): 2583 if f == "type": 2584 return [dtypes.as_dtype(t) for t in x.list.type] 2585 else: 2586 return list(getattr(x.list, f)) 2587 return [] 2588 if oneof_value == "type": 2589 return dtypes.as_dtype(x.type) 2590 assert oneof_value in fields, "Unsupported field type in " + str(x) 2591 return getattr(x, oneof_value) 2592 2593 def _get_attr_type(self, name): 2594 """Returns the `DType` value of the attr of this op with the given `name`.""" 2595 try: 2596 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name) 2597 return _DTYPES_INTERN_TABLE[dtype_enum] 2598 except errors.InvalidArgumentError as e: 2599 # Convert to ValueError for backwards compatibility. 2600 raise ValueError(e.message) 2601 2602 def _get_attr_bool(self, name): 2603 """Returns the `bool` value of the attr of this op with the given `name`.""" 2604 try: 2605 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name) 2606 except errors.InvalidArgumentError as e: 2607 # Convert to ValueError for backwards compatibility. 2608 raise ValueError(e.message) 2609 2610 def _get_attr_int(self, name): 2611 """Returns the `int` value of the attr of this op with the given `name`.""" 2612 try: 2613 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name) 2614 except errors.InvalidArgumentError as e: 2615 # Convert to ValueError for backwards compatibility. 2616 raise ValueError(e.message) 2617 2618 def run(self, feed_dict=None, session=None): 2619 """Runs this operation in a `Session`. 2620 2621 Calling this method will execute all preceding operations that 2622 produce the inputs needed for this operation. 2623 2624 *N.B.* Before invoking `Operation.run()`, its graph must have been 2625 launched in a session, and either a default session must be 2626 available, or `session` must be specified explicitly. 2627 2628 Args: 2629 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 2630 `tf.Session.run` for a description of the valid feed values. 2631 session: (Optional.) The `Session` to be used to run to this operation. If 2632 none, the default session will be used. 2633 """ 2634 _run_using_default_session(self, feed_dict, self.graph, session) 2635 2636# TODO(b/185395742): Clean up usages of _gradient_registry 2637gradient_registry = _gradient_registry = registry.Registry("gradient") 2638 2639 2640@tf_export("RegisterGradient") 2641class RegisterGradient(object): 2642 """A decorator for registering the gradient function for an op type. 2643 2644 This decorator is only used when defining a new op type. For an op 2645 with `m` inputs and `n` outputs, the gradient function is a function 2646 that takes the original `Operation` and `n` `Tensor` objects 2647 (representing the gradients with respect to each output of the op), 2648 and returns `m` `Tensor` objects (representing the partial gradients 2649 with respect to each input of the op). 2650 2651 For example, assuming that operations of type `"Sub"` take two 2652 inputs `x` and `y`, and return a single output `x - y`, the 2653 following gradient function would be registered: 2654 2655 ```python 2656 @tf.RegisterGradient("Sub") 2657 def _sub_grad(unused_op, grad): 2658 return grad, tf.negative(grad) 2659 ``` 2660 2661 The decorator argument `op_type` is the string type of an 2662 operation. This corresponds to the `OpDef.name` field for the proto 2663 that defines the operation. 2664 """ 2665 2666 __slots__ = ["_op_type"] 2667 2668 def __init__(self, op_type): 2669 """Creates a new decorator with `op_type` as the Operation type. 2670 2671 Args: 2672 op_type: The string type of an operation. This corresponds to the 2673 `OpDef.name` field for the proto that defines the operation. 2674 2675 Raises: 2676 TypeError: If `op_type` is not string. 2677 """ 2678 if not isinstance(op_type, six.string_types): 2679 raise TypeError("op_type must be a string") 2680 self._op_type = op_type 2681 2682 def __call__(self, f): 2683 """Registers the function `f` as gradient function for `op_type`.""" 2684 gradient_registry.register(f, self._op_type) 2685 return f 2686 2687 2688@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 2689@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 2690def no_gradient(op_type): 2691 """Specifies that ops of type `op_type` is not differentiable. 2692 2693 This function should *not* be used for operations that have a 2694 well-defined gradient that is not yet implemented. 2695 2696 This function is only used when defining a new op type. It may be 2697 used for ops such as `tf.size()` that are not differentiable. For 2698 example: 2699 2700 ```python 2701 tf.no_gradient("Size") 2702 ``` 2703 2704 The gradient computed for 'op_type' will then propagate zeros. 2705 2706 For ops that have a well-defined gradient but are not yet implemented, 2707 no declaration should be made, and an error *must* be thrown if 2708 an attempt to request its gradient is made. 2709 2710 Args: 2711 op_type: The string type of an operation. This corresponds to the 2712 `OpDef.name` field for the proto that defines the operation. 2713 2714 Raises: 2715 TypeError: If `op_type` is not a string. 2716 2717 """ 2718 if not isinstance(op_type, six.string_types): 2719 raise TypeError("op_type must be a string") 2720 gradient_registry.register(None, op_type) 2721 2722 2723# Aliases for the old names, will be eventually removed. 2724NoGradient = no_gradient 2725NotDifferentiable = no_gradient 2726 2727 2728def get_gradient_function(op): 2729 """Returns the function that computes gradients for "op".""" 2730 if not op.inputs: 2731 return None 2732 2733 gradient_function = op._gradient_function # pylint: disable=protected-access 2734 if gradient_function: 2735 return gradient_function 2736 2737 try: 2738 op_type = op.get_attr("_gradient_op_type") 2739 except ValueError: 2740 op_type = op.type 2741 return gradient_registry.lookup(op_type) 2742 2743 2744def set_shape_and_handle_data_for_outputs(_): 2745 """No op. TODO(b/74620627): Remove this.""" 2746 pass 2747 2748 2749class OpStats(object): 2750 """A holder for statistics about an operator. 2751 2752 This class holds information about the resource requirements for an op, 2753 including the size of its weight parameters on-disk and how many FLOPS it 2754 requires to execute forward inference. 2755 2756 If you define a new operation, you can create a function that will return a 2757 set of information about its usage of the CPU and disk space when serialized. 2758 The function itself takes a Graph object that's been set up so you can call 2759 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2760 argument. 2761 2762 """ 2763 2764 __slots__ = ["_statistic_type", "_value"] 2765 2766 def __init__(self, statistic_type, value=None): 2767 """Sets up the initial placeholders for the statistics.""" 2768 self.statistic_type = statistic_type 2769 self.value = value 2770 2771 @property 2772 def statistic_type(self): 2773 return self._statistic_type 2774 2775 @statistic_type.setter 2776 def statistic_type(self, statistic_type): 2777 self._statistic_type = statistic_type 2778 2779 @property 2780 def value(self): 2781 return self._value 2782 2783 @value.setter 2784 def value(self, value): 2785 self._value = value 2786 2787 def __iadd__(self, other): 2788 if other.statistic_type != self.statistic_type: 2789 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2790 (self.statistic_type, other.statistic_type)) 2791 if self.value is None: 2792 self.value = other.value 2793 elif other.value is not None: 2794 self._value += other.value 2795 return self 2796 2797 2798_stats_registry = registry.Registry("statistical functions") 2799 2800 2801class RegisterStatistics(object): 2802 """A decorator for registering the statistics function for an op type. 2803 2804 This decorator can be defined for an op type so that it gives a 2805 report on the resources used by an instance of an operator, in the 2806 form of an OpStats object. 2807 2808 Well-known types of statistics include these so far: 2809 2810 - flops: When running a graph, the bulk of the computation happens doing 2811 numerical calculations like matrix multiplications. This type allows a node 2812 to return how many floating-point operations it takes to complete. The 2813 total number of FLOPs for a graph is a good guide to its expected latency. 2814 2815 You can add your own statistics just by picking a new type string, registering 2816 functions for the ops you care about, and then calling get_stats_for_node_def. 2817 2818 If a statistic for an op is registered multiple times, a KeyError will be 2819 raised. 2820 2821 Since the statistics is counted on a per-op basis. It is not suitable for 2822 model parameters (capacity), which is expected to be counted only once, even 2823 if it is shared by multiple ops. (e.g. RNN) 2824 2825 For example, you can define a new metric called doohickey for a Foo operation 2826 by placing this in your code: 2827 2828 ```python 2829 @ops.RegisterStatistics("Foo", "doohickey") 2830 def _calc_foo_bojangles(unused_graph, unused_node_def): 2831 return ops.OpStats("doohickey", 20) 2832 ``` 2833 2834 Then in client code you can retrieve the value by making this call: 2835 2836 ```python 2837 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2838 ``` 2839 2840 If the NodeDef is for an op with a registered doohickey function, you'll get 2841 back the calculated amount in doohickey.value, or None if it's not defined. 2842 2843 """ 2844 2845 __slots__ = ["_op_type", "_statistic_type"] 2846 2847 def __init__(self, op_type, statistic_type): 2848 """Saves the `op_type` as the `Operation` type.""" 2849 if not isinstance(op_type, six.string_types): 2850 raise TypeError("op_type must be a string.") 2851 if "," in op_type: 2852 raise TypeError("op_type must not contain a comma.") 2853 self._op_type = op_type 2854 if not isinstance(statistic_type, six.string_types): 2855 raise TypeError("statistic_type must be a string.") 2856 if "," in statistic_type: 2857 raise TypeError("statistic_type must not contain a comma.") 2858 self._statistic_type = statistic_type 2859 2860 def __call__(self, f): 2861 """Registers "f" as the statistics function for "op_type".""" 2862 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 2863 return f 2864 2865 2866def get_stats_for_node_def(graph, node, statistic_type): 2867 """Looks up the node's statistics function in the registry and calls it. 2868 2869 This function takes a Graph object and a NodeDef from a GraphDef, and if 2870 there's an associated statistics method, calls it and returns a result. If no 2871 function has been registered for the particular node type, it returns an empty 2872 statistics object. 2873 2874 Args: 2875 graph: A Graph object that's been set up with the node's graph. 2876 node: A NodeDef describing the operator. 2877 statistic_type: A string identifying the statistic we're interested in. 2878 2879 Returns: 2880 An OpStats object containing information about resource usage. 2881 """ 2882 2883 try: 2884 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 2885 result = stats_func(graph, node) 2886 except LookupError: 2887 result = OpStats(statistic_type) 2888 return result 2889 2890 2891def name_from_scope_name(name): 2892 """Returns the name of an op given the name of its scope. 2893 2894 Args: 2895 name: the name of the scope. 2896 2897 Returns: 2898 the name of the op (equal to scope name minus any trailing slash). 2899 """ 2900 return name[:-1] if (name and name[-1] == "/") else name 2901 2902 2903_MUTATION_LOCK_GROUP = 0 2904_SESSION_RUN_LOCK_GROUP = 1 2905 2906 2907@tf_contextlib.contextmanager 2908def resource_creator_scope(resource_type, resource_creator): 2909 with get_default_graph()._resource_creator_scope(resource_type, # pylint: disable=protected-access 2910 resource_creator): 2911 yield 2912 2913 2914@tf_export("Graph") 2915class Graph(object): 2916 """A TensorFlow computation, represented as a dataflow graph. 2917 2918 Graphs are used by `tf.function`s to represent the function's computations. 2919 Each graph contains a set of `tf.Operation` objects, which represent units of 2920 computation; and `tf.Tensor` objects, which represent the units of data that 2921 flow between operations. 2922 2923 ### Using graphs directly (deprecated) 2924 2925 A `tf.Graph` can be constructed and used directly without a `tf.function`, as 2926 was required in TensorFlow 1, but this is deprecated and it is recommended to 2927 use a `tf.function` instead. If a graph is directly used, other deprecated 2928 TensorFlow 1 classes are also required to execute the graph, such as a 2929 `tf.compat.v1.Session`. 2930 2931 A default graph can be registered with the `tf.Graph.as_default` context 2932 manager. Then, operations will be added to the graph instead of being executed 2933 eagerly. For example: 2934 2935 ```python 2936 g = tf.Graph() 2937 with g.as_default(): 2938 # Define operations and tensors in `g`. 2939 c = tf.constant(30.0) 2940 assert c.graph is g 2941 ``` 2942 2943 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph. 2944 2945 Important note: This class *is not* thread-safe for graph construction. All 2946 operations should be created from a single thread, or external 2947 synchronization must be provided. Unless otherwise specified, all methods 2948 are not thread-safe. 2949 2950 A `Graph` instance supports an arbitrary number of "collections" 2951 that are identified by name. For convenience when building a large 2952 graph, collections can store groups of related objects: for 2953 example, the `tf.Variable` uses a collection (named 2954 `tf.GraphKeys.GLOBAL_VARIABLES`) for 2955 all variables that are created during the construction of a graph. The caller 2956 may define additional collections by specifying a new name. 2957 """ 2958 2959 def __init__(self): 2960 """Creates a new, empty Graph.""" 2961 # Protects core state that can be returned via public accessors. 2962 # Thread-safety is provided on a best-effort basis to support buggy 2963 # programs, and is not guaranteed by the public `tf.Graph` API. 2964 # 2965 # NOTE(mrry): This does not protect the various stacks. A warning will 2966 # be reported if these are used from multiple threads 2967 self._lock = threading.RLock() 2968 # The group lock synchronizes Session.run calls with methods that create 2969 # and mutate ops (e.g. Graph.create_op()). This synchronization is 2970 # necessary because it's illegal to modify an operation after it's been run. 2971 # The group lock allows any number of threads to mutate ops at the same time 2972 # but if any modification is going on, all Session.run calls have to wait. 2973 # Similarly, if one or more Session.run calls are going on, all mutate ops 2974 # have to wait until all Session.run calls have finished. 2975 self._group_lock = lock_util.GroupLock(num_groups=2) 2976 self._nodes_by_id = {} # GUARDED_BY(self._lock) 2977 self._next_id_counter = 0 # GUARDED_BY(self._lock) 2978 self._nodes_by_name = {} # GUARDED_BY(self._lock) 2979 self._version = 0 # GUARDED_BY(self._lock) 2980 # Maps a name used in the graph to the next id to use for that name. 2981 self._names_in_use = {} 2982 self._stack_state_is_thread_local = False 2983 self._thread_local = threading.local() 2984 # Functions that will be applied to choose a device if none is specified. 2985 # In TF2.x or after switch_to_thread_local(), 2986 # self._thread_local._device_function_stack is used instead. 2987 self._graph_device_function_stack = traceable_stack.TraceableStack() 2988 # Default original_op applied to new ops. 2989 self._default_original_op = None 2990 # Current control flow context. It could be either CondContext or 2991 # WhileContext defined in ops/control_flow_ops.py 2992 self._control_flow_context = None 2993 # A new node will depend of the union of all of the nodes in the stack. 2994 # In TF2.x or after switch_to_thread_local(), 2995 # self._thread_local._control_dependencies_stack is used instead. 2996 self._graph_control_dependencies_stack = [] 2997 # Arbitrary collections of objects. 2998 self._collections = {} 2999 # The graph-level random seed 3000 self._seed = None 3001 # A dictionary of attributes that should be applied to all ops. 3002 self._attr_scope_map = {} 3003 # A map from op type to the kernel label that should be used. 3004 self._op_to_kernel_label_map = {} 3005 # A map from op type to an alternative op type that should be used when 3006 # computing gradients. 3007 self._gradient_override_map = {} 3008 # A map from op type to a gradient function that should be used instead. 3009 self._gradient_function_map = {} 3010 # True if the graph is considered "finalized". In that case no 3011 # new operations can be added. 3012 self._finalized = False 3013 # Functions defined in the graph 3014 self._functions = collections.OrderedDict() 3015 # Default GraphDef versions 3016 self._graph_def_versions = versions_pb2.VersionDef( 3017 producer=versions.GRAPH_DEF_VERSION, 3018 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 3019 self._building_function = False 3020 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 3021 # self._thread_local._colocation_stack is used instead. 3022 self._graph_colocation_stack = traceable_stack.TraceableStack() 3023 # Set of tensors that are dangerous to feed! 3024 self._unfeedable_tensors = object_identity.ObjectIdentitySet() 3025 # Set of operations that are dangerous to fetch! 3026 self._unfetchable_ops = set() 3027 # A map of tensor handle placeholder to tensor dtype. 3028 self._handle_feeders = {} 3029 # A map from tensor handle to its read op. 3030 self._handle_readers = {} 3031 # A map from tensor handle to its move op. 3032 self._handle_movers = {} 3033 # A map from tensor handle to its delete op. 3034 self._handle_deleters = {} 3035 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 3036 # will be shared when defining function graphs, for example, so optimizers 3037 # being called inside function definitions behave as if they were seeing the 3038 # actual outside graph). 3039 self._graph_key = "grap-key-%d/" % (uid(),) 3040 # A string with the last reduction method passed to 3041 # losses.compute_weighted_loss(), or None. This is required only for 3042 # backward compatibility with Estimator and optimizer V1 use cases. 3043 self._last_loss_reduction = None 3044 # Flag that is used to indicate whether loss has been scaled by optimizer. 3045 # If this flag has been set, then estimator uses it to scale losss back 3046 # before reporting. This is required only for backward compatibility with 3047 # Estimator and optimizer V1 use cases. 3048 self._is_loss_scaled_by_optimizer = False 3049 self._container = "" 3050 # Set to True if this graph is being built in an 3051 # AutomaticControlDependencies context. 3052 self._add_control_dependencies = False 3053 # Cache for OpDef protobufs retrieved via the C API. 3054 self._op_def_cache = {} 3055 # Cache for constant results of `broadcast_gradient_args()`. The keys are 3056 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the 3057 # values are tuples of reduction indices: (rx, ry). 3058 self._bcast_grad_args_cache = {} 3059 # Cache for constant results of `reduced_shape()`. The keys are pairs of 3060 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values 3061 # are pairs of tuples: (output_shape_kept_dims, tile_scaling). 3062 self._reduced_shape_cache = {} 3063 3064 # TODO(skyewm): fold as much of the above as possible into the C 3065 # implementation 3066 self._scoped_c_graph = c_api_util.ScopedTFGraph() 3067 # The C API requires all ops to have shape functions. Disable this 3068 # requirement (many custom ops do not have shape functions, and we don't 3069 # want to break these existing cases). 3070 pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False) 3071 if tf2.enabled(): 3072 self.switch_to_thread_local() 3073 3074 # Note: this method is private because the API of tf.Graph() is public and 3075 # frozen, and this functionality is still not ready for public visibility. 3076 @tf_contextlib.contextmanager 3077 def _variable_creator_scope(self, creator, priority=100): 3078 """Scope which defines a variable creation function. 3079 3080 Args: 3081 creator: A callable taking `next_creator` and `kwargs`. See the 3082 `tf.variable_creator_scope` docstring. 3083 priority: Creators with a higher `priority` are called first. Within the 3084 same priority, creators are called inner-to-outer. 3085 3086 Yields: 3087 `_variable_creator_scope` is a context manager with a side effect, but 3088 doesn't return a value. 3089 3090 Raises: 3091 RuntimeError: If variable creator scopes are not properly nested. 3092 """ 3093 # This step keeps a reference to the existing stack, and it also initializes 3094 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3095 old = self._variable_creator_stack 3096 new = list(old) 3097 new.append((priority, creator)) 3098 # Sorting is stable, so we'll put higher-priority creators later in the list 3099 # but otherwise maintain registration order. 3100 new.sort(key=lambda item: item[0]) 3101 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access 3102 try: 3103 yield 3104 finally: 3105 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access 3106 raise RuntimeError( 3107 "Exiting variable_creator_scope without proper nesting.") 3108 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 3109 3110 # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope. 3111 # pylint: disable=protected-access 3112 @tf_contextlib.contextmanager 3113 def _resource_creator_scope(self, resource_type, creator): 3114 """Scope which defines a resource creation function used by some resource. 3115 3116 The resource should be a subclass of CachableResource with a class method 3117 `cls._resource_type`, the output of which is what the `resource_type` 3118 argument should be. By default, `cls._resource_type` returns the class name, 3119 `cls.__name__`. Given a scope, creators being added with the same 3120 `resource_type` argument will be composed together to apply to all classes 3121 with this `_resource_type`. 3122 3123 3124 `creator` is expected to be a function with the following signature: 3125 3126 ``` 3127 def resource_creator(next_creator, *a, **kwargs) 3128 ``` 3129 3130 The creator is supposed to eventually call the next_creator to create an 3131 instance if it does want to create an instance and not call 3132 the class initialization method directly. This helps make creators 3133 composable. A creator may choose to create multiple instances, return 3134 already existing instances, or simply register that an instance was created 3135 and defer to the next creator in line. Creators can also modify keyword 3136 arguments seen by the next creators. 3137 3138 Valid keyword arguments in `kwargs` depends on the specific resource 3139 class. For StaticHashTable, this may be: 3140 * initializer: The table initializer to use. 3141 * default_value: The value to use if a key is missing in the table. 3142 * name: Optional name for the table, default to None. 3143 3144 3145 Args: 3146 resource_type: the output of the resource class's `_resource_type` method. 3147 creator: the passed creator for the resource. 3148 3149 Yields: 3150 A scope in which the creator is active 3151 3152 Raises: 3153 RuntimeError: If resource_creator_scope is existed without proper nesting. 3154 """ 3155 # This step keeps a reference to the existing stack, and it also initializes 3156 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3157 old = self._resource_creator_stack 3158 new = copy.deepcopy(old) 3159 if isinstance(resource_type, (list, tuple)): 3160 for r in resource_type: 3161 new[r].append(creator) 3162 else: 3163 new[resource_type].append(creator) 3164 self._thread_local._resource_creator_stack = new 3165 try: 3166 yield 3167 finally: 3168 if self._thread_local._resource_creator_stack is not new: 3169 raise RuntimeError( 3170 "Exiting resource_creator_scope without proper nesting.") 3171 self._thread_local._resource_creator_stack = old 3172 3173 @property 3174 def _resource_creator_stack(self): 3175 if not hasattr(self._thread_local, "_resource_creator_stack"): 3176 self._thread_local._resource_creator_stack = collections.defaultdict(list) 3177 return self._thread_local._resource_creator_stack 3178 3179 @_resource_creator_stack.setter 3180 def _resource_creator_stack(self, resource_creator_stack): 3181 self._thread_local._resource_creator_stack = resource_creator_stack 3182 # pylint: enable=protected-access 3183 3184 # Note: this method is private because the API of tf.Graph() is public and 3185 # frozen, and this functionality is still not ready for public visibility. 3186 @property 3187 def _variable_creator_stack(self): 3188 if not hasattr(self._thread_local, "_variable_creator_stack"): 3189 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 3190 3191 # This previously returned a copy of the stack instead of the stack itself, 3192 # to guard against accidental mutation. Consider, however, code that wants 3193 # to save and restore the variable creator stack: 3194 # def f(): 3195 # original_stack = graph._variable_creator_stack 3196 # graph._variable_creator_stack = new_stack 3197 # ... # Some code 3198 # graph._variable_creator_stack = original_stack 3199 # 3200 # And lets say you have some code that calls this function with some 3201 # variable_creator: 3202 # def g(): 3203 # with variable_scope.variable_creator_scope(creator): 3204 # f() 3205 # When exiting the variable creator scope, it would see a different stack 3206 # object than it expected leading to a "Exiting variable_creator_scope 3207 # without proper nesting" error. 3208 return self._thread_local._variable_creator_stack # pylint: disable=protected-access 3209 3210 @_variable_creator_stack.setter 3211 def _variable_creator_stack(self, variable_creator_stack): 3212 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 3213 3214 def _check_not_finalized(self): 3215 """Check if the graph is finalized. 3216 3217 Raises: 3218 RuntimeError: If the graph finalized. 3219 """ 3220 if self._finalized: 3221 raise RuntimeError("Graph is finalized and cannot be modified.") 3222 3223 def _add_op(self, op, op_name): 3224 """Adds 'op' to the graph and returns the unique ID for the added Operation. 3225 3226 Args: 3227 op: the Operation to add. 3228 op_name: the name of the Operation. 3229 3230 Returns: 3231 An integer that is a unique ID for the added Operation. 3232 """ 3233 self._check_not_finalized() 3234 with self._lock: 3235 self._next_id_counter += 1 3236 op_id = self._next_id_counter 3237 self._nodes_by_id[op_id] = op 3238 self._nodes_by_name[op_name] = op 3239 self._version = max(self._version, op_id) 3240 return op_id 3241 3242 @property 3243 def _c_graph(self): 3244 return self._scoped_c_graph.graph 3245 3246 @property 3247 def version(self): 3248 """Returns a version number that increases as ops are added to the graph. 3249 3250 Note that this is unrelated to the 3251 `tf.Graph.graph_def_versions`. 3252 3253 Returns: 3254 An integer version that increases as ops are added to the graph. 3255 """ 3256 if self._finalized: 3257 return self._version 3258 3259 with self._lock: 3260 return self._version 3261 3262 @property 3263 def graph_def_versions(self): 3264 # pylint: disable=line-too-long 3265 """The GraphDef version information of this graph. 3266 3267 For details on the meaning of each version, see 3268 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 3269 3270 Returns: 3271 A `VersionDef`. 3272 """ 3273 # pylint: enable=line-too-long 3274 with c_api_util.tf_buffer() as buf: 3275 pywrap_tf_session.TF_GraphVersions(self._c_graph, buf) 3276 data = pywrap_tf_session.TF_GetBuffer(buf) 3277 version_def = versions_pb2.VersionDef() 3278 version_def.ParseFromString(compat.as_bytes(data)) 3279 return version_def 3280 3281 @property 3282 def seed(self): 3283 """The graph-level random seed of this graph.""" 3284 return self._seed 3285 3286 @seed.setter 3287 def seed(self, seed): 3288 self._seed = seed 3289 3290 @property 3291 def finalized(self): 3292 """True if this graph has been finalized.""" 3293 return self._finalized 3294 3295 def finalize(self): 3296 """Finalizes this graph, making it read-only. 3297 3298 After calling `g.finalize()`, no new operations can be added to 3299 `g`. This method is used to ensure that no operations are added 3300 to a graph when it is shared between multiple threads, for example 3301 when using a `tf.compat.v1.train.QueueRunner`. 3302 """ 3303 self._finalized = True 3304 3305 def _unsafe_unfinalize(self): 3306 """Opposite of `finalize`. 3307 3308 Internal interface. 3309 3310 NOTE: Unfinalizing a graph could have negative impact on performance, 3311 especially in a multi-threaded environment. Unfinalizing a graph 3312 when it is in use by a Session may lead to undefined behavior. Ensure 3313 that all sessions using a graph are closed before calling this method. 3314 """ 3315 self._finalized = False 3316 3317 def _get_control_flow_context(self): 3318 """Returns the current control flow context. 3319 3320 Returns: 3321 A context object. 3322 """ 3323 return self._control_flow_context 3324 3325 def _set_control_flow_context(self, ctx): 3326 """Sets the current control flow context. 3327 3328 Args: 3329 ctx: a context object. 3330 """ 3331 self._control_flow_context = ctx 3332 3333 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 3334 """If this graph contains functions, copy them to `graph_def`.""" 3335 bytesize = starting_bytesize 3336 for f in self._functions.values(): 3337 bytesize += f.definition.ByteSize() 3338 if bytesize >= (1 << 31) or bytesize < 0: 3339 raise ValueError("GraphDef cannot be larger than 2GB.") 3340 graph_def.library.function.extend([f.definition]) 3341 if f.grad_func_name: 3342 grad_def = function_pb2.GradientDef() 3343 grad_def.function_name = f.name 3344 grad_def.gradient_func = f.grad_func_name 3345 graph_def.library.gradient.extend([grad_def]) 3346 3347 def _as_graph_def(self, from_version=None, add_shapes=False): 3348 # pylint: disable=line-too-long 3349 """Returns a serialized `GraphDef` representation of this graph. 3350 3351 The serialized `GraphDef` can be imported into another `Graph` 3352 (using `tf.import_graph_def`) or used with the 3353 [C++ Session API](../../../../api_docs/cc/index.md). 3354 3355 This method is thread-safe. 3356 3357 Args: 3358 from_version: Optional. If this is set, returns a `GraphDef` containing 3359 only the nodes that were added to this graph since its `version` 3360 property had the given value. 3361 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3362 the inferred shapes of each of its outputs. 3363 3364 Returns: 3365 A tuple containing a 3366 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3367 protocol buffer, and the version of the graph to which that 3368 `GraphDef` corresponds. 3369 3370 Raises: 3371 ValueError: If the `graph_def` would be too large. 3372 3373 """ 3374 # pylint: enable=line-too-long 3375 with self._lock: 3376 with c_api_util.tf_buffer() as buf: 3377 pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf) 3378 data = pywrap_tf_session.TF_GetBuffer(buf) 3379 graph = graph_pb2.GraphDef() 3380 graph.ParseFromString(compat.as_bytes(data)) 3381 # Strip the experimental library field iff it's empty. 3382 if not graph.library.function: 3383 graph.ClearField("library") 3384 3385 if add_shapes: 3386 for node in graph.node: 3387 op = self._nodes_by_name[node.name] 3388 if op.outputs: 3389 node.attr["_output_shapes"].list.shape.extend( 3390 [output.get_shape().as_proto() for output in op.outputs]) 3391 for function_def in graph.library.function: 3392 defined_function = self._functions[function_def.signature.name] 3393 try: 3394 func_graph = defined_function.graph 3395 except AttributeError: 3396 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction 3397 # does. Both rely on ops.py, so we can't really isinstance check 3398 # them. 3399 continue 3400 input_shapes = function_def.attr["_input_shapes"] 3401 try: 3402 func_graph_inputs = func_graph.inputs 3403 except AttributeError: 3404 continue 3405 # TODO(b/141471245): Fix the inconsistency when inputs of func graph 3406 # are appended during gradient computation of while/cond. 3407 assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)] 3408 # If the function_def has inputs already filled out, skip this step. 3409 if not input_shapes.list.shape: 3410 for input_tensor, arg_def in zip(func_graph_inputs, 3411 function_def.signature.input_arg): 3412 input_shapes.list.shape.add().CopyFrom( 3413 input_tensor.get_shape().as_proto()) 3414 if input_tensor.dtype == dtypes.resource: 3415 _copy_handle_data_to_arg_def(input_tensor, arg_def) 3416 3417 for output_tensor, arg_def in zip(func_graph.outputs, 3418 function_def.signature.output_arg): 3419 if output_tensor.dtype == dtypes.resource: 3420 _copy_handle_data_to_arg_def(output_tensor, arg_def) 3421 3422 for node in function_def.node_def: 3423 try: 3424 op = func_graph.get_operation_by_name(node.name) 3425 except KeyError: 3426 continue 3427 outputs = op.outputs 3428 3429 if op.type == "StatefulPartitionedCall": 3430 # Filter out any extra outputs (possibly added by function 3431 # backpropagation rewriting). 3432 num_outputs = len(node.attr["Tout"].list.type) 3433 outputs = outputs[:num_outputs] 3434 3435 node.attr["_output_shapes"].list.shape.extend( 3436 [output.get_shape().as_proto() for output in outputs]) 3437 3438 return graph, self._version 3439 3440 def as_graph_def(self, from_version=None, add_shapes=False): 3441 # pylint: disable=line-too-long 3442 """Returns a serialized `GraphDef` representation of this graph. 3443 3444 The serialized `GraphDef` can be imported into another `Graph` 3445 (using `tf.import_graph_def`) or used with the 3446 [C++ Session API](../../api_docs/cc/index.md). 3447 3448 This method is thread-safe. 3449 3450 Args: 3451 from_version: Optional. If this is set, returns a `GraphDef` containing 3452 only the nodes that were added to this graph since its `version` 3453 property had the given value. 3454 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3455 the inferred shapes of each of its outputs. 3456 3457 Returns: 3458 A 3459 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3460 protocol buffer. 3461 3462 Raises: 3463 ValueError: If the `graph_def` would be too large. 3464 """ 3465 # pylint: enable=line-too-long 3466 result, _ = self._as_graph_def(from_version, add_shapes) 3467 return result 3468 3469 def _is_function(self, name): 3470 """Tests whether 'name' is registered in this graph's function library. 3471 3472 Args: 3473 name: string op name. 3474 3475 Returns: 3476 bool indicating whether or not 'name' is registered in function library. 3477 """ 3478 return compat.as_str(name) in self._functions 3479 3480 def _get_function(self, name): 3481 """Returns the function definition for 'name'. 3482 3483 Args: 3484 name: string function name. 3485 3486 Returns: 3487 The function def proto. 3488 """ 3489 return self._functions.get(compat.as_str(name), None) 3490 3491 def _add_function(self, function): 3492 """Adds a function to the graph. 3493 3494 After the function has been added, you can call to the function by 3495 passing the function name in place of an op name to 3496 `Graph.create_op()`. 3497 3498 Args: 3499 function: A `_DefinedFunction` object. 3500 3501 Raises: 3502 ValueError: if another function is defined with the same name. 3503 """ 3504 self._check_not_finalized() 3505 3506 name = function.name 3507 # Sanity checks on gradient definition. 3508 if (function.grad_func_name is not None) and (function.python_grad_func is 3509 not None): 3510 raise ValueError("Gradient defined twice for function %s" % name) 3511 3512 # Add function to graph 3513 # pylint: disable=protected-access 3514 gradient = ( 3515 function._grad_func._c_func.func if function._grad_func else None) 3516 pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func, 3517 gradient) 3518 # pylint: enable=protected-access 3519 3520 self._functions[compat.as_str(name)] = function 3521 3522 # Need a new-enough consumer to support the functions we add to the graph. 3523 if self._graph_def_versions.min_consumer < 12: 3524 self._graph_def_versions.min_consumer = 12 3525 3526 @property 3527 def building_function(self): 3528 """Returns True iff this graph represents a function.""" 3529 return self._building_function 3530 3531 # Helper functions to create operations. 3532 @deprecated_args(None, 3533 "Shapes are always computed; don't use the compute_shapes " 3534 "as it has no effect.", "compute_shapes") 3535 @traceback_utils.filter_traceback 3536 def create_op( 3537 self, 3538 op_type, 3539 inputs, 3540 dtypes=None, # pylint: disable=redefined-outer-name 3541 input_types=None, 3542 name=None, 3543 attrs=None, 3544 op_def=None, 3545 compute_shapes=True, 3546 compute_device=True): 3547 """Creates an `Operation` in this graph. 3548 3549 This is a low-level interface for creating an `Operation`. Most 3550 programs will not call this method directly, and instead use the 3551 Python op constructors, such as `tf.constant()`, which add ops to 3552 the default graph. 3553 3554 Args: 3555 op_type: The `Operation` type to create. This corresponds to the 3556 `OpDef.name` field for the proto that defines the operation. 3557 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3558 dtypes: (Optional) A list of `DType` objects that will be the types of the 3559 tensors that the operation produces. 3560 input_types: (Optional.) A list of `DType`s that will be the types of the 3561 tensors that the operation consumes. By default, uses the base `DType` 3562 of each input in `inputs`. Operations that expect reference-typed inputs 3563 must specify `input_types` explicitly. 3564 name: (Optional.) A string name for the operation. If not specified, a 3565 name is generated based on `op_type`. 3566 attrs: (Optional.) A dictionary where the key is the attribute name (a 3567 string) and the value is the respective `attr` attribute of the 3568 `NodeDef` proto that will represent the operation (an `AttrValue` 3569 proto). 3570 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3571 the operation will have. 3572 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 3573 computed). 3574 compute_device: (Optional.) If True, device functions will be executed to 3575 compute the device property of the Operation. 3576 3577 Raises: 3578 TypeError: if any of the inputs is not a `Tensor`. 3579 ValueError: if colocation conflicts with existing device assignment. 3580 3581 Returns: 3582 An `Operation` object. 3583 """ 3584 del compute_shapes 3585 for idx, a in enumerate(inputs): 3586 if not isinstance(a, Tensor): 3587 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3588 return self._create_op_internal(op_type, inputs, dtypes, input_types, name, 3589 attrs, op_def, compute_device) 3590 3591 def _create_op_internal( 3592 self, 3593 op_type, 3594 inputs, 3595 dtypes=None, # pylint: disable=redefined-outer-name 3596 input_types=None, 3597 name=None, 3598 attrs=None, 3599 op_def=None, 3600 compute_device=True): 3601 """Creates an `Operation` in this graph. 3602 3603 Implements `Graph.create_op()` without the overhead of the deprecation 3604 wrapper. 3605 3606 Args: 3607 op_type: The `Operation` type to create. This corresponds to the 3608 `OpDef.name` field for the proto that defines the operation. 3609 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3610 dtypes: (Optional) A list of `DType` objects that will be the types of the 3611 tensors that the operation produces. 3612 input_types: (Optional.) A list of `DType`s that will be the types of the 3613 tensors that the operation consumes. By default, uses the base `DType` 3614 of each input in `inputs`. Operations that expect reference-typed inputs 3615 must specify `input_types` explicitly. 3616 name: (Optional.) A string name for the operation. If not specified, a 3617 name is generated based on `op_type`. 3618 attrs: (Optional.) A dictionary where the key is the attribute name (a 3619 string) and the value is the respective `attr` attribute of the 3620 `NodeDef` proto that will represent the operation (an `AttrValue` 3621 proto). 3622 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3623 the operation will have. 3624 compute_device: (Optional.) If True, device functions will be executed to 3625 compute the device property of the Operation. 3626 3627 Raises: 3628 ValueError: if colocation conflicts with existing device assignment. 3629 3630 Returns: 3631 An `Operation` object. 3632 """ 3633 self._check_not_finalized() 3634 if name is None: 3635 name = op_type 3636 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3637 # after removing the trailing '/'. 3638 if name and name[-1] == "/": 3639 name = name_from_scope_name(name) 3640 else: 3641 name = self.unique_name(name) 3642 3643 node_def = _NodeDef(op_type, name, attrs) 3644 3645 input_ops = set(t.op for t in inputs) 3646 control_inputs = self._control_dependencies_for_inputs(input_ops) 3647 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 3648 # Session.run call cannot occur between creating and mutating the op. 3649 with self._mutation_lock(): 3650 ret = Operation( 3651 node_def, 3652 self, 3653 inputs=inputs, 3654 output_types=dtypes, 3655 control_inputs=control_inputs, 3656 input_types=input_types, 3657 original_op=self._default_original_op, 3658 op_def=op_def) 3659 self._create_op_helper(ret, compute_device=compute_device) 3660 return ret 3661 3662 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3663 """Creates an `Operation` in this graph from the supplied TF_Operation. 3664 3665 This method is like create_op() except the new Operation is constructed 3666 using `c_op`. The returned Operation will have `c_op` as its _c_op 3667 field. This is used to create Operation objects around TF_Operations created 3668 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3669 3670 This function does not call Operation._control_flow_post_processing or 3671 Graph._control_dependencies_for_inputs (since the inputs may not be 3672 available yet). The caller is responsible for calling these methods. 3673 3674 Args: 3675 c_op: a wrapped TF_Operation 3676 compute_device: (Optional.) If True, device functions will be executed to 3677 compute the device property of the Operation. 3678 3679 Returns: 3680 An `Operation` object. 3681 """ 3682 self._check_not_finalized() 3683 ret = Operation(c_op, self) 3684 # If a name_scope was created with ret.name but no nodes were created in it, 3685 # the name will still appear in _names_in_use even though the name hasn't 3686 # been used. This is ok, just leave _names_in_use as-is in this case. 3687 # TODO(skyewm): make the C API guarantee no name conflicts. 3688 name_key = ret.name.lower() 3689 if name_key not in self._names_in_use: 3690 self._names_in_use[name_key] = 1 3691 self._create_op_helper(ret, compute_device=compute_device) 3692 return ret 3693 3694 def _create_op_helper(self, op, compute_device=True): 3695 """Common logic for creating an op in this graph.""" 3696 # Apply any additional attributes requested. Do not overwrite any existing 3697 # attributes. 3698 for key, value in self._attr_scope_map.items(): 3699 try: 3700 op.get_attr(key) 3701 except ValueError: 3702 if callable(value): 3703 value = value(op.node_def) 3704 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3705 raise TypeError( 3706 "Callable for scope map key '%s' must return either None or " 3707 "an AttrValue protocol buffer; but it returned: %s" % 3708 (key, value)) 3709 if value: 3710 op._set_attr(key, value) # pylint: disable=protected-access 3711 3712 # Apply a kernel label if one has been specified for this op type. 3713 try: 3714 kernel_label = self._op_to_kernel_label_map[op.type] 3715 op._set_attr("_kernel", # pylint: disable=protected-access 3716 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3717 except KeyError: 3718 pass 3719 3720 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access 3721 3722 # Apply the overriding op type for gradients if one has been specified for 3723 # this op type. 3724 try: 3725 mapped_op_type = self._gradient_override_map[op.type] 3726 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3727 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3728 except KeyError: 3729 pass 3730 3731 self._record_op_seen_by_control_dependencies(op) 3732 3733 if compute_device: 3734 self._apply_device_functions(op) 3735 3736 # Snapshot the colocation stack metadata before we might generate error 3737 # messages using it. Note that this snapshot depends on the actual stack 3738 # and is independent of the op's _class attribute. 3739 # pylint: disable=protected-access 3740 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 3741 # pylint: enable=protected-access 3742 3743 if self._colocation_stack: 3744 all_colocation_groups = [] 3745 is_device_set = False 3746 for colocation_op in self._colocation_stack.peek_objs(): 3747 try: 3748 all_colocation_groups.extend(colocation_op.colocation_groups()) 3749 except AttributeError: 3750 pass 3751 if colocation_op.device and not is_device_set: 3752 # pylint: disable=protected-access 3753 op._set_device(colocation_op.device) 3754 # pylint: enable=protected-access 3755 is_device_set = True 3756 3757 all_colocation_groups = sorted(set(all_colocation_groups)) 3758 # pylint: disable=protected-access 3759 op._set_attr( 3760 "_class", 3761 attr_value_pb2.AttrValue( 3762 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3763 # pylint: enable=protected-access 3764 3765 # Sets "container" attribute if 3766 # (1) self._container is not None 3767 # (2) "is_stateful" is set in OpDef 3768 # (3) "container" attribute is in OpDef 3769 # (4) "container" attribute is None 3770 if self._container and op._is_stateful: # pylint: disable=protected-access 3771 try: 3772 container_attr = op.get_attr("container") 3773 except ValueError: 3774 # "container" attribute is not in OpDef 3775 pass 3776 else: 3777 if not container_attr: 3778 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3779 s=compat.as_bytes(self._container))) 3780 3781 def _add_new_tf_operations(self, compute_devices=True): 3782 """Creates `Operations` in this graph for any new TF_Operations. 3783 3784 This is useful for when TF_Operations are indirectly created by the C API 3785 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3786 TF_FinishWhile). This ensures there are corresponding Operations for all 3787 TF_Operations in the underlying TF_Graph. 3788 3789 Args: 3790 compute_devices: (Optional.) If True, device functions will be executed to 3791 compute the device properties of each new Operation. 3792 3793 Returns: 3794 A list of the new `Operation` objects. 3795 """ 3796 self._check_not_finalized() 3797 3798 # Create all Operation objects before accessing their inputs since an op may 3799 # be created before its inputs. 3800 new_ops = [ 3801 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3802 for c_op in c_api_util.new_tf_operations(self) 3803 ] 3804 3805 # pylint: disable=protected-access 3806 for op in new_ops: 3807 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3808 op._add_control_inputs(new_control_inputs) 3809 op._control_flow_post_processing() 3810 # pylint: enable=protected-access 3811 3812 return new_ops 3813 3814 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3815 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3816 3817 This function validates that `obj` represents an element of this 3818 graph, and gives an informative error message if it is not. 3819 3820 This function is the canonical way to get/validate an object of 3821 one of the allowed types from an external argument reference in the 3822 Session API. 3823 3824 This method may be called concurrently from multiple threads. 3825 3826 Args: 3827 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can 3828 also be any object with an `_as_graph_element()` method that returns a 3829 value of one of these types. Note: `_as_graph_element` will be called 3830 inside the graph's lock and so may not modify the graph. 3831 allow_tensor: If true, `obj` may refer to a `Tensor`. 3832 allow_operation: If true, `obj` may refer to an `Operation`. 3833 3834 Returns: 3835 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3836 3837 Raises: 3838 TypeError: If `obj` is not a type we support attempting to convert 3839 to types. 3840 ValueError: If `obj` is of an appropriate type but invalid. For 3841 example, an invalid string. 3842 KeyError: If `obj` is not an object in the graph. 3843 """ 3844 if self._finalized: 3845 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3846 3847 with self._lock: 3848 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3849 3850 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3851 """See `Graph.as_graph_element()` for details.""" 3852 # The vast majority of this function is figuring 3853 # out what an API user might be doing wrong, so 3854 # that we can give helpful error messages. 3855 # 3856 # Ideally, it would be nice to split it up, but we 3857 # need context to generate nice error messages. 3858 3859 if allow_tensor and allow_operation: 3860 types_str = "Tensor or Operation" 3861 elif allow_tensor: 3862 types_str = "Tensor" 3863 elif allow_operation: 3864 types_str = "Operation" 3865 else: 3866 raise ValueError("allow_tensor and allow_operation can't both be False.") 3867 3868 temp_obj = _as_graph_element(obj) 3869 if temp_obj is not None: 3870 obj = temp_obj 3871 3872 # If obj appears to be a name... 3873 if isinstance(obj, compat.bytes_or_text_types): 3874 name = compat.as_str(obj) 3875 3876 if ":" in name and allow_tensor: 3877 # Looks like a Tensor name and can be a Tensor. 3878 try: 3879 op_name, out_n = name.split(":") 3880 out_n = int(out_n) 3881 except: 3882 raise ValueError("The name %s looks a like a Tensor name, but is " 3883 "not a valid one. Tensor names must be of the " 3884 "form \"<op_name>:<output_index>\"." % repr(name)) 3885 if op_name in self._nodes_by_name: 3886 op = self._nodes_by_name[op_name] 3887 else: 3888 raise KeyError("The name %s refers to a Tensor which does not " 3889 "exist. The operation, %s, does not exist in the " 3890 "graph." % (repr(name), repr(op_name))) 3891 try: 3892 return op.outputs[out_n] 3893 except: 3894 raise KeyError("The name %s refers to a Tensor which does not " 3895 "exist. The operation, %s, exists but only has " 3896 "%s outputs." % 3897 (repr(name), repr(op_name), len(op.outputs))) 3898 3899 elif ":" in name and not allow_tensor: 3900 # Looks like a Tensor name but can't be a Tensor. 3901 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 3902 (repr(name), types_str)) 3903 3904 elif ":" not in name and allow_operation: 3905 # Looks like an Operation name and can be an Operation. 3906 if name not in self._nodes_by_name: 3907 raise KeyError("The name %s refers to an Operation not in the " 3908 "graph." % repr(name)) 3909 return self._nodes_by_name[name] 3910 3911 elif ":" not in name and not allow_operation: 3912 # Looks like an Operation name but can't be an Operation. 3913 if name in self._nodes_by_name: 3914 # Yep, it's an Operation name 3915 err_msg = ("The name %s refers to an Operation, not a %s." % 3916 (repr(name), types_str)) 3917 else: 3918 err_msg = ("The name %s looks like an (invalid) Operation name, " 3919 "not a %s." % (repr(name), types_str)) 3920 err_msg += (" Tensor names must be of the form " 3921 "\"<op_name>:<output_index>\".") 3922 raise ValueError(err_msg) 3923 3924 elif isinstance(obj, Tensor) and allow_tensor: 3925 # Actually obj is just the object it's referring to. 3926 if obj.graph is not self: 3927 raise ValueError("Tensor %s is not an element of this graph." % obj) 3928 return obj 3929 elif isinstance(obj, Operation) and allow_operation: 3930 # Actually obj is just the object it's referring to. 3931 if obj.graph is not self: 3932 raise ValueError("Operation %s is not an element of this graph." % obj) 3933 return obj 3934 else: 3935 # We give up! 3936 raise TypeError("Can not convert a %s into a %s." % 3937 (type(obj).__name__, types_str)) 3938 3939 def get_operations(self): 3940 """Return the list of operations in the graph. 3941 3942 You can modify the operations in place, but modifications 3943 to the list such as inserts/delete have no effect on the 3944 list of operations known to the graph. 3945 3946 This method may be called concurrently from multiple threads. 3947 3948 Returns: 3949 A list of Operations. 3950 """ 3951 if self._finalized: 3952 return list(self._nodes_by_id.values()) 3953 3954 with self._lock: 3955 return list(self._nodes_by_id.values()) 3956 3957 def get_operation_by_name(self, name): 3958 """Returns the `Operation` with the given `name`. 3959 3960 This method may be called concurrently from multiple threads. 3961 3962 Args: 3963 name: The name of the `Operation` to return. 3964 3965 Returns: 3966 The `Operation` with the given `name`. 3967 3968 Raises: 3969 TypeError: If `name` is not a string. 3970 KeyError: If `name` does not correspond to an operation in this graph. 3971 """ 3972 3973 if not isinstance(name, six.string_types): 3974 raise TypeError("Operation names are strings (or similar), not %s." % 3975 type(name).__name__) 3976 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 3977 3978 def _get_operation_by_name_unsafe(self, name): 3979 """Returns the `Operation` with the given `name`. 3980 3981 This is a internal unsafe version of get_operation_by_name. It skips many 3982 checks and does not have user friendly error messages but runs considerably 3983 faster. This method may be called concurrently from multiple threads. 3984 3985 Args: 3986 name: The name of the `Operation` to return. 3987 3988 Returns: 3989 The `Operation` with the given `name`. 3990 3991 Raises: 3992 KeyError: If `name` does not correspond to an operation in this graph. 3993 """ 3994 3995 if self._finalized: 3996 return self._nodes_by_name[name] 3997 3998 with self._lock: 3999 return self._nodes_by_name[name] 4000 4001 def _get_operation_by_tf_operation(self, tf_oper): 4002 op_name = pywrap_tf_session.TF_OperationName(tf_oper) 4003 return self._get_operation_by_name_unsafe(op_name) 4004 4005 def get_tensor_by_name(self, name): 4006 """Returns the `Tensor` with the given `name`. 4007 4008 This method may be called concurrently from multiple threads. 4009 4010 Args: 4011 name: The name of the `Tensor` to return. 4012 4013 Returns: 4014 The `Tensor` with the given `name`. 4015 4016 Raises: 4017 TypeError: If `name` is not a string. 4018 KeyError: If `name` does not correspond to a tensor in this graph. 4019 """ 4020 # Names should be strings. 4021 if not isinstance(name, six.string_types): 4022 raise TypeError("Tensor names are strings (or similar), not %s." % 4023 type(name).__name__) 4024 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 4025 4026 def _get_tensor_by_tf_output(self, tf_output): 4027 """Returns the `Tensor` representing `tf_output`. 4028 4029 Note that there is only one such `Tensor`, i.e. multiple calls to this 4030 function with the same TF_Output value will always return the same `Tensor` 4031 object. 4032 4033 Args: 4034 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 4035 4036 Returns: 4037 The `Tensor` that represents `tf_output`. 4038 """ 4039 op = self._get_operation_by_tf_operation(tf_output.oper) 4040 return op.outputs[tf_output.index] 4041 4042 @property 4043 def _last_id(self): 4044 return self._next_id_counter 4045 4046 def _get_op_def(self, type): # pylint: disable=redefined-builtin 4047 """Returns the `OpDef` proto for `type`. `type` is a string.""" 4048 # NOTE: No locking is required because the lookup and insertion operations 4049 # on Python dictionaries are atomic. 4050 try: 4051 return self._op_def_cache[type] 4052 except KeyError: 4053 with c_api_util.tf_buffer() as buf: 4054 # pylint: disable=protected-access 4055 pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), 4056 buf) 4057 # pylint: enable=protected-access 4058 data = pywrap_tf_session.TF_GetBuffer(buf) 4059 op_def = op_def_pb2.OpDef() 4060 op_def.ParseFromString(compat.as_bytes(data)) 4061 self._op_def_cache[type] = op_def 4062 return op_def 4063 4064 def as_default(self): 4065 """Returns a context manager that makes this `Graph` the default graph. 4066 4067 This method should be used if you want to create multiple graphs 4068 in the same process. For convenience, a global default graph is 4069 provided, and all ops will be added to this graph if you do not 4070 create a new graph explicitly. 4071 4072 Use this method with the `with` keyword to specify that ops created within 4073 the scope of a block should be added to this graph. In this case, once 4074 the scope of the `with` is exited, the previous default graph is set again 4075 as default. There is a stack, so it's ok to have multiple nested levels 4076 of `as_default` calls. 4077 4078 The default graph is a property of the current thread. If you 4079 create a new thread, and wish to use the default graph in that 4080 thread, you must explicitly add a `with g.as_default():` in that 4081 thread's function. 4082 4083 The following code examples are equivalent: 4084 4085 ```python 4086 # 1. Using Graph.as_default(): 4087 g = tf.Graph() 4088 with g.as_default(): 4089 c = tf.constant(5.0) 4090 assert c.graph is g 4091 4092 # 2. Constructing and making default: 4093 with tf.Graph().as_default() as g: 4094 c = tf.constant(5.0) 4095 assert c.graph is g 4096 ``` 4097 4098 If eager execution is enabled ops created under this context manager will be 4099 added to the graph instead of executed eagerly. 4100 4101 Returns: 4102 A context manager for using this graph as the default graph. 4103 """ 4104 return _default_graph_stack.get_controller(self) 4105 4106 @property 4107 def collections(self): 4108 """Returns the names of the collections known to this graph.""" 4109 return list(self._collections) 4110 4111 def add_to_collection(self, name, value): 4112 """Stores `value` in the collection with the given `name`. 4113 4114 Note that collections are not sets, so it is possible to add a value to 4115 a collection several times. 4116 4117 Args: 4118 name: The key for the collection. The `GraphKeys` class contains many 4119 standard names for collections. 4120 value: The value to add to the collection. 4121 """ # pylint: disable=g-doc-exception 4122 self._check_not_finalized() 4123 with self._lock: 4124 if name not in self._collections: 4125 self._collections[name] = [value] 4126 else: 4127 self._collections[name].append(value) 4128 4129 def add_to_collections(self, names, value): 4130 """Stores `value` in the collections given by `names`. 4131 4132 Note that collections are not sets, so it is possible to add a value to 4133 a collection several times. This function makes sure that duplicates in 4134 `names` are ignored, but it will not check for pre-existing membership of 4135 `value` in any of the collections in `names`. 4136 4137 `names` can be any iterable, but if `names` is a string, it is treated as a 4138 single collection name. 4139 4140 Args: 4141 names: The keys for the collections to add to. The `GraphKeys` class 4142 contains many standard names for collections. 4143 value: The value to add to the collections. 4144 """ 4145 # Make sure names are unique, but treat strings as a single collection name 4146 names = (names,) if isinstance(names, six.string_types) else set(names) 4147 for name in names: 4148 self.add_to_collection(name, value) 4149 4150 def get_collection_ref(self, name): 4151 """Returns a list of values in the collection with the given `name`. 4152 4153 If the collection exists, this returns the list itself, which can 4154 be modified in place to change the collection. If the collection does 4155 not exist, it is created as an empty list and the list is returned. 4156 4157 This is different from `get_collection()` which always returns a copy of 4158 the collection list if it exists and never creates an empty collection. 4159 4160 Args: 4161 name: The key for the collection. For example, the `GraphKeys` class 4162 contains many standard names for collections. 4163 4164 Returns: 4165 The list of values in the collection with the given `name`, or an empty 4166 list if no value has been added to that collection. 4167 """ # pylint: disable=g-doc-exception 4168 with self._lock: 4169 coll_list = self._collections.get(name, None) 4170 if coll_list is None: 4171 coll_list = [] 4172 self._collections[name] = coll_list 4173 return coll_list 4174 4175 def get_collection(self, name, scope=None): 4176 """Returns a list of values in the collection with the given `name`. 4177 4178 This is different from `get_collection_ref()` which always returns the 4179 actual collection list if it exists in that it returns a new list each time 4180 it is called. 4181 4182 Args: 4183 name: The key for the collection. For example, the `GraphKeys` class 4184 contains many standard names for collections. 4185 scope: (Optional.) A string. If supplied, the resulting list is filtered 4186 to include only items whose `name` attribute matches `scope` using 4187 `re.match`. Items without a `name` attribute are never returned if a 4188 scope is supplied. The choice of `re.match` means that a `scope` without 4189 special tokens filters by prefix. 4190 4191 Returns: 4192 The list of values in the collection with the given `name`, or 4193 an empty list if no value has been added to that collection. The 4194 list contains the values in the order under which they were 4195 collected. 4196 """ # pylint: disable=g-doc-exception 4197 with self._lock: 4198 collection = self._collections.get(name, None) 4199 if collection is None: 4200 return [] 4201 if scope is None: 4202 return list(collection) 4203 else: 4204 c = [] 4205 regex = re.compile(scope) 4206 for item in collection: 4207 try: 4208 if regex.match(item.name): 4209 c.append(item) 4210 except AttributeError: 4211 # Collection items with no name are ignored. 4212 pass 4213 return c 4214 4215 def get_all_collection_keys(self): 4216 """Returns a list of collections used in this graph.""" 4217 with self._lock: 4218 return [x for x in self._collections if isinstance(x, six.string_types)] 4219 4220 def clear_collection(self, name): 4221 """Clears all values in a collection. 4222 4223 Args: 4224 name: The key for the collection. The `GraphKeys` class contains many 4225 standard names for collections. 4226 """ 4227 self._check_not_finalized() 4228 with self._lock: 4229 if name in self._collections: 4230 del self._collections[name] 4231 4232 @tf_contextlib.contextmanager 4233 def _original_op(self, op): 4234 """Python 'with' handler to help annotate ops with their originator. 4235 4236 An op may have an 'original_op' property that indicates the op on which 4237 it was based. For example a replica op is based on the op that was 4238 replicated and a gradient op is based on the op that was differentiated. 4239 4240 All ops created in the scope of this 'with' handler will have 4241 the given 'op' as their original op. 4242 4243 Args: 4244 op: The Operation that all ops created in this scope will have as their 4245 original op. 4246 4247 Yields: 4248 Nothing. 4249 """ 4250 old_original_op = self._default_original_op 4251 self._default_original_op = op 4252 try: 4253 yield 4254 finally: 4255 self._default_original_op = old_original_op 4256 4257 @property 4258 def _name_stack(self): 4259 # This may be called from a thread where name_stack doesn't yet exist. 4260 if not hasattr(self._thread_local, "_name_stack"): 4261 self._thread_local._name_stack = "" 4262 return self._thread_local._name_stack 4263 4264 @_name_stack.setter 4265 def _name_stack(self, name_stack): 4266 self._thread_local._name_stack = name_stack 4267 4268 # pylint: disable=g-doc-return-or-yield,line-too-long 4269 @tf_contextlib.contextmanager 4270 def name_scope(self, name): 4271 """Returns a context manager that creates hierarchical names for operations. 4272 4273 A graph maintains a stack of name scopes. A `with name_scope(...):` 4274 statement pushes a new name onto the stack for the lifetime of the context. 4275 4276 The `name` argument will be interpreted as follows: 4277 4278 * A string (not ending with '/') will create a new name scope, in which 4279 `name` is appended to the prefix of all operations created in the 4280 context. If `name` has been used before, it will be made unique by 4281 calling `self.unique_name(name)`. 4282 * A scope previously captured from a `with g.name_scope(...) as 4283 scope:` statement will be treated as an "absolute" name scope, which 4284 makes it possible to re-enter existing scopes. 4285 * A value of `None` or the empty string will reset the current name scope 4286 to the top-level (empty) name scope. 4287 4288 For example: 4289 4290 ```python 4291 with tf.Graph().as_default() as g: 4292 c = tf.constant(5.0, name="c") 4293 assert c.op.name == "c" 4294 c_1 = tf.constant(6.0, name="c") 4295 assert c_1.op.name == "c_1" 4296 4297 # Creates a scope called "nested" 4298 with g.name_scope("nested") as scope: 4299 nested_c = tf.constant(10.0, name="c") 4300 assert nested_c.op.name == "nested/c" 4301 4302 # Creates a nested scope called "inner". 4303 with g.name_scope("inner"): 4304 nested_inner_c = tf.constant(20.0, name="c") 4305 assert nested_inner_c.op.name == "nested/inner/c" 4306 4307 # Create a nested scope called "inner_1". 4308 with g.name_scope("inner"): 4309 nested_inner_1_c = tf.constant(30.0, name="c") 4310 assert nested_inner_1_c.op.name == "nested/inner_1/c" 4311 4312 # Treats `scope` as an absolute name scope, and 4313 # switches to the "nested/" scope. 4314 with g.name_scope(scope): 4315 nested_d = tf.constant(40.0, name="d") 4316 assert nested_d.op.name == "nested/d" 4317 4318 with g.name_scope(""): 4319 e = tf.constant(50.0, name="e") 4320 assert e.op.name == "e" 4321 ``` 4322 4323 The name of the scope itself can be captured by `with 4324 g.name_scope(...) as scope:`, which stores the name of the scope 4325 in the variable `scope`. This value can be used to name an 4326 operation that represents the overall result of executing the ops 4327 in a scope. For example: 4328 4329 ```python 4330 inputs = tf.constant(...) 4331 with g.name_scope('my_layer') as scope: 4332 weights = tf.Variable(..., name="weights") 4333 biases = tf.Variable(..., name="biases") 4334 affine = tf.matmul(inputs, weights) + biases 4335 output = tf.nn.relu(affine, name=scope) 4336 ``` 4337 4338 NOTE: This constructor validates the given `name`. Valid scope 4339 names match one of the following regular expressions: 4340 4341 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 4342 [A-Za-z0-9_.\\-/]* (for other scopes) 4343 4344 Args: 4345 name: A name for the scope. 4346 4347 Returns: 4348 A context manager that installs `name` as a new name scope. 4349 4350 Raises: 4351 ValueError: If `name` is not a valid scope name, according to the rules 4352 above. 4353 """ 4354 if name: 4355 if isinstance(name, compat.bytes_or_text_types): 4356 name = compat.as_str(name) 4357 4358 if self._name_stack: 4359 # Scopes created in a nested scope may have initial characters 4360 # that are illegal as the initial character of an op name 4361 # (viz. '-', '\', '/', and '_'). 4362 if not _VALID_SCOPE_NAME_REGEX.match(name): 4363 raise ValueError("'%s' is not a valid scope name" % name) 4364 else: 4365 # Scopes created in the root must match the more restrictive 4366 # op name regex, which constrains the initial character. 4367 if not _VALID_OP_NAME_REGEX.match(name): 4368 raise ValueError("'%s' is not a valid scope name" % name) 4369 old_stack = self._name_stack 4370 if not name: # Both for name=None and name="" we re-set to empty scope. 4371 new_stack = "" 4372 returned_scope = "" 4373 elif name[-1] == "/": 4374 new_stack = name_from_scope_name(name) 4375 returned_scope = name 4376 else: 4377 new_stack = self.unique_name(name) 4378 returned_scope = new_stack + "/" 4379 self._name_stack = new_stack 4380 try: 4381 yield returned_scope 4382 finally: 4383 self._name_stack = old_stack 4384 4385 # pylint: enable=g-doc-return-or-yield,line-too-long 4386 4387 def unique_name(self, name, mark_as_used=True): 4388 """Return a unique operation name for `name`. 4389 4390 Note: You rarely need to call `unique_name()` directly. Most of 4391 the time you just need to create `with g.name_scope()` blocks to 4392 generate structured names. 4393 4394 `unique_name` is used to generate structured names, separated by 4395 `"/"`, to help identify operations when debugging a graph. 4396 Operation names are displayed in error messages reported by the 4397 TensorFlow runtime, and in various visualization tools such as 4398 TensorBoard. 4399 4400 If `mark_as_used` is set to `True`, which is the default, a new 4401 unique name is created and marked as in use. If it's set to `False`, 4402 the unique name is returned without actually being marked as used. 4403 This is useful when the caller simply wants to know what the name 4404 to be created will be. 4405 4406 Args: 4407 name: The name for an operation. 4408 mark_as_used: Whether to mark this name as being used. 4409 4410 Returns: 4411 A string to be passed to `create_op()` that will be used 4412 to name the operation being created. 4413 """ 4414 if self._name_stack: 4415 name = self._name_stack + "/" + name 4416 4417 # For the sake of checking for names in use, we treat names as case 4418 # insensitive (e.g. foo = Foo). 4419 name_key = name.lower() 4420 i = self._names_in_use.get(name_key, 0) 4421 # Increment the number for "name_key". 4422 if mark_as_used: 4423 self._names_in_use[name_key] = i + 1 4424 if i > 0: 4425 base_name_key = name_key 4426 # Make sure the composed name key is not already used. 4427 while name_key in self._names_in_use: 4428 name_key = "%s_%d" % (base_name_key, i) 4429 i += 1 4430 # Mark the composed name_key as used in case someone wants 4431 # to call unique_name("name_1"). 4432 if mark_as_used: 4433 self._names_in_use[name_key] = 1 4434 4435 # Return the new name with the original capitalization of the given name. 4436 name = "%s_%d" % (name, i - 1) 4437 return name 4438 4439 def get_name_scope(self): 4440 """Returns the current name scope. 4441 4442 For example: 4443 4444 ```python 4445 with tf.name_scope('scope1'): 4446 with tf.name_scope('scope2'): 4447 print(tf.compat.v1.get_default_graph().get_name_scope()) 4448 ``` 4449 would print the string `scope1/scope2`. 4450 4451 Returns: 4452 A string representing the current name scope. 4453 """ 4454 return self._name_stack 4455 4456 @tf_contextlib.contextmanager 4457 def _colocate_with_for_gradient(self, op, gradient_uid, 4458 ignore_existing=False): 4459 with self.colocate_with(op, ignore_existing): 4460 if gradient_uid is not None: 4461 ctx = _get_enclosing_context(self) 4462 if ctx is not None: 4463 ctx.EnterGradientColocation(op, gradient_uid) 4464 try: 4465 yield 4466 finally: 4467 ctx.ExitGradientColocation(op, gradient_uid) 4468 else: 4469 yield 4470 else: 4471 yield 4472 4473 @tf_contextlib.contextmanager 4474 def colocate_with(self, op, ignore_existing=False): 4475 """Returns a context manager that specifies an op to colocate with. 4476 4477 Note: this function is not for public use, only for internal libraries. 4478 4479 For example: 4480 4481 ```python 4482 a = tf.Variable([1.0]) 4483 with g.colocate_with(a): 4484 b = tf.constant(1.0) 4485 c = tf.add(a, b) 4486 ``` 4487 4488 `b` and `c` will always be colocated with `a`, no matter where `a` 4489 is eventually placed. 4490 4491 **NOTE** Using a colocation scope resets any existing device constraints. 4492 4493 If `op` is `None` then `ignore_existing` must be `True` and the new 4494 scope resets all colocation and device constraints. 4495 4496 Args: 4497 op: The op to colocate all created ops with, or `None`. 4498 ignore_existing: If true, only applies colocation of this op within the 4499 context, rather than applying all colocation properties on the stack. 4500 If `op` is `None`, this value must be `True`. 4501 4502 Raises: 4503 ValueError: if op is None but ignore_existing is False. 4504 4505 Yields: 4506 A context manager that specifies the op with which to colocate 4507 newly created ops. 4508 """ 4509 if op is None and not ignore_existing: 4510 raise ValueError("Trying to reset colocation (op is None) but " 4511 "ignore_existing is not True") 4512 op, device_only_candidate = _op_to_colocate_with(op, self) 4513 4514 # By default, colocate_with resets the device function stack, 4515 # since colocate_with is typically used in specific internal 4516 # library functions where colocation is intended to be "stronger" 4517 # than device functions. 4518 # 4519 # In the future, a caller may specify that device_functions win 4520 # over colocation, in which case we can add support. 4521 device_fn_tmp = self._device_function_stack 4522 self._device_function_stack = traceable_stack.TraceableStack() 4523 4524 if ignore_existing: 4525 current_stack = self._colocation_stack 4526 self._colocation_stack = traceable_stack.TraceableStack() 4527 4528 if op is not None: 4529 # offset refers to the stack frame used for storing code location. 4530 # We use 4, the sum of 1 to use our caller's stack frame and 3 4531 # to jump over layers of context managers above us. 4532 if device_only_candidate is not None: 4533 self._colocation_stack.push_obj(device_only_candidate, offset=4) 4534 self._colocation_stack.push_obj(op, offset=4) 4535 elif not ignore_existing: 4536 raise ValueError("Trying to reset colocation (op is None) but " 4537 "ignore_existing is not True") 4538 try: 4539 yield 4540 finally: 4541 # Restore device function stack 4542 self._device_function_stack = device_fn_tmp 4543 if op is not None: 4544 self._colocation_stack.pop_obj() 4545 if device_only_candidate is not None: 4546 self._colocation_stack.pop_obj() 4547 4548 # Reset the colocation stack if requested. 4549 if ignore_existing: 4550 self._colocation_stack = current_stack 4551 4552 def _add_device_to_stack(self, device_name_or_function, offset=0): 4553 """Add device to stack manually, separate from a context manager.""" 4554 total_offset = 1 + offset 4555 spec = _UserDeviceSpec(device_name_or_function) 4556 self._device_function_stack.push_obj(spec, offset=total_offset) 4557 return spec 4558 4559 @tf_contextlib.contextmanager 4560 def device(self, device_name_or_function): 4561 # pylint: disable=line-too-long 4562 """Returns a context manager that specifies the default device to use. 4563 4564 The `device_name_or_function` argument may either be a device name 4565 string, a device function, or None: 4566 4567 * If it is a device name string, all operations constructed in 4568 this context will be assigned to the device with that name, unless 4569 overridden by a nested `device()` context. 4570 * If it is a function, it will be treated as a function from 4571 Operation objects to device name strings, and invoked each time 4572 a new Operation is created. The Operation will be assigned to 4573 the device with the returned name. 4574 * If it is None, all `device()` invocations from the enclosing context 4575 will be ignored. 4576 4577 For information about the valid syntax of device name strings, see 4578 the documentation in 4579 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4580 4581 For example: 4582 4583 ```python 4584 with g.device('/device:GPU:0'): 4585 # All operations constructed in this context will be placed 4586 # on GPU 0. 4587 with g.device(None): 4588 # All operations constructed in this context will have no 4589 # assigned device. 4590 4591 # Defines a function from `Operation` to device string. 4592 def matmul_on_gpu(n): 4593 if n.type == "MatMul": 4594 return "/device:GPU:0" 4595 else: 4596 return "/cpu:0" 4597 4598 with g.device(matmul_on_gpu): 4599 # All operations of type "MatMul" constructed in this context 4600 # will be placed on GPU 0; all other operations will be placed 4601 # on CPU 0. 4602 ``` 4603 4604 **N.B.** The device scope may be overridden by op wrappers or 4605 other library code. For example, a variable assignment op 4606 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4607 incompatible device scopes will be ignored. 4608 4609 Args: 4610 device_name_or_function: The device name or function to use in the 4611 context. 4612 4613 Yields: 4614 A context manager that specifies the default device to use for newly 4615 created ops. 4616 4617 Raises: 4618 RuntimeError: If device scopes are not properly nested. 4619 """ 4620 self._add_device_to_stack(device_name_or_function, offset=2) 4621 old_top_of_stack = self._device_function_stack.peek_top_obj() 4622 try: 4623 yield 4624 finally: 4625 new_top_of_stack = self._device_function_stack.peek_top_obj() 4626 if old_top_of_stack is not new_top_of_stack: 4627 raise RuntimeError("Exiting device scope without proper scope nesting.") 4628 self._device_function_stack.pop_obj() 4629 4630 def _apply_device_functions(self, op): 4631 """Applies the current device function stack to the given operation.""" 4632 # Apply any device functions in LIFO order, so that the most recently 4633 # pushed function has the first chance to apply a device to the op. 4634 # We apply here because the result can depend on the Operation's 4635 # signature, which is computed in the Operation constructor. 4636 # pylint: disable=protected-access 4637 prior_device_string = None 4638 for device_spec in self._device_function_stack.peek_objs(): 4639 if device_spec.is_null_merge: 4640 continue 4641 4642 if device_spec.function is None: 4643 break 4644 4645 device_string = device_spec.string_merge(op) 4646 4647 # Take advantage of the fact that None is a singleton and Python interns 4648 # strings, since identity checks are faster than equality checks. 4649 if device_string is not prior_device_string: 4650 op._set_device_from_string(device_string) 4651 prior_device_string = device_string 4652 op._device_code_locations = self._snapshot_device_function_stack_metadata() 4653 # pylint: enable=protected-access 4654 4655 # pylint: disable=g-doc-return-or-yield 4656 @tf_contextlib.contextmanager 4657 def container(self, container_name): 4658 """Returns a context manager that specifies the resource container to use. 4659 4660 Stateful operations, such as variables and queues, can maintain their 4661 states on devices so that they can be shared by multiple processes. 4662 A resource container is a string name under which these stateful 4663 operations are tracked. These resources can be released or cleared 4664 with `tf.Session.reset()`. 4665 4666 For example: 4667 4668 ```python 4669 with g.container('experiment0'): 4670 # All stateful Operations constructed in this context will be placed 4671 # in resource container "experiment0". 4672 v1 = tf.Variable([1.0]) 4673 v2 = tf.Variable([2.0]) 4674 with g.container("experiment1"): 4675 # All stateful Operations constructed in this context will be 4676 # placed in resource container "experiment1". 4677 v3 = tf.Variable([3.0]) 4678 q1 = tf.queue.FIFOQueue(10, tf.float32) 4679 # All stateful Operations constructed in this context will be 4680 # be created in the "experiment0". 4681 v4 = tf.Variable([4.0]) 4682 q1 = tf.queue.FIFOQueue(20, tf.float32) 4683 with g.container(""): 4684 # All stateful Operations constructed in this context will be 4685 # be placed in the default resource container. 4686 v5 = tf.Variable([5.0]) 4687 q3 = tf.queue.FIFOQueue(30, tf.float32) 4688 4689 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4690 # will become undefined (such as uninitialized). 4691 tf.Session.reset(target, ["experiment0"]) 4692 ``` 4693 4694 Args: 4695 container_name: container name string. 4696 4697 Returns: 4698 A context manager for defining resource containers for stateful ops, 4699 yields the container name. 4700 """ 4701 original_container = self._container 4702 self._container = container_name 4703 try: 4704 yield self._container 4705 finally: 4706 self._container = original_container 4707 4708 # pylint: enable=g-doc-return-or-yield 4709 4710 class _ControlDependenciesController(object): 4711 """Context manager for `control_dependencies()`.""" 4712 4713 def __init__(self, graph, control_inputs): 4714 """Create a new `_ControlDependenciesController`. 4715 4716 A `_ControlDependenciesController` is the context manager for 4717 `with tf.control_dependencies()` blocks. These normally nest, 4718 as described in the documentation for `control_dependencies()`. 4719 4720 The `control_inputs` argument list control dependencies that must be 4721 added to the current set of control dependencies. Because of 4722 uniquification the set can be empty even if the caller passed a list of 4723 ops. The special value `None` indicates that we want to start a new 4724 empty set of control dependencies instead of extending the current set. 4725 4726 In that case we also clear the current control flow context, which is an 4727 additional mechanism to add control dependencies. 4728 4729 Args: 4730 graph: The graph that this controller is managing. 4731 control_inputs: List of ops to use as control inputs in addition to the 4732 current control dependencies. None to indicate that the dependencies 4733 should be cleared. 4734 """ 4735 self._graph = graph 4736 if control_inputs is None: 4737 self._control_inputs_val = [] 4738 self._new_stack = True 4739 else: 4740 self._control_inputs_val = control_inputs 4741 self._new_stack = False 4742 self._seen_nodes = set() 4743 self._old_stack = None 4744 self._old_control_flow_context = None 4745 4746# pylint: disable=protected-access 4747 4748 def __enter__(self): 4749 if self._new_stack: 4750 # Clear the control_dependencies graph. 4751 self._old_stack = self._graph._control_dependencies_stack 4752 self._graph._control_dependencies_stack = [] 4753 # Clear the control_flow_context too. 4754 self._old_control_flow_context = self._graph._get_control_flow_context() 4755 self._graph._set_control_flow_context(None) 4756 self._graph._push_control_dependencies_controller(self) 4757 4758 def __exit__(self, unused_type, unused_value, unused_traceback): 4759 self._graph._pop_control_dependencies_controller(self) 4760 if self._new_stack: 4761 self._graph._control_dependencies_stack = self._old_stack 4762 self._graph._set_control_flow_context(self._old_control_flow_context) 4763 4764# pylint: enable=protected-access 4765 4766 @property 4767 def control_inputs(self): 4768 return self._control_inputs_val 4769 4770 def add_op(self, op): 4771 if isinstance(op, Tensor): 4772 op = op.ref() 4773 self._seen_nodes.add(op) 4774 4775 def op_in_group(self, op): 4776 if isinstance(op, Tensor): 4777 op = op.ref() 4778 return op in self._seen_nodes 4779 4780 def _push_control_dependencies_controller(self, controller): 4781 self._control_dependencies_stack.append(controller) 4782 4783 def _pop_control_dependencies_controller(self, controller): 4784 assert self._control_dependencies_stack[-1] is controller 4785 self._control_dependencies_stack.pop() 4786 4787 def _current_control_dependencies(self): 4788 ret = set() 4789 for controller in self._control_dependencies_stack: 4790 for op in controller.control_inputs: 4791 ret.add(op) 4792 return ret 4793 4794 def _control_dependencies_for_inputs(self, input_ops): 4795 """For an op that takes `input_ops` as inputs, compute control inputs. 4796 4797 The returned control dependencies should yield an execution that 4798 is equivalent to adding all control inputs in 4799 self._control_dependencies_stack to a newly created op. However, 4800 this function attempts to prune the returned control dependencies 4801 by observing that nodes created within the same `with 4802 control_dependencies(...):` block may have data dependencies that make 4803 the explicit approach redundant. 4804 4805 Args: 4806 input_ops: The data input ops for an op to be created. 4807 4808 Returns: 4809 A list of control inputs for the op to be created. 4810 """ 4811 ret = [] 4812 for controller in self._control_dependencies_stack: 4813 # If any of the input_ops already depends on the inputs from controller, 4814 # we say that the new op is dominated (by that input), and we therefore 4815 # do not need to add control dependencies for this controller's inputs. 4816 dominated = False 4817 for op in input_ops: 4818 if controller.op_in_group(op): 4819 dominated = True 4820 break 4821 if not dominated: 4822 # Don't add a control input if we already have a data dependency on i. 4823 # NOTE(mrry): We do not currently track transitive data dependencies, 4824 # so we may add redundant control inputs. 4825 ret.extend(c for c in controller.control_inputs if c not in input_ops) 4826 return ret 4827 4828 def _record_op_seen_by_control_dependencies(self, op): 4829 """Record that the given op depends on all registered control dependencies. 4830 4831 Args: 4832 op: An Operation. 4833 """ 4834 for controller in self._control_dependencies_stack: 4835 controller.add_op(op) 4836 4837 def control_dependencies(self, control_inputs): 4838 """Returns a context manager that specifies control dependencies. 4839 4840 Use with the `with` keyword to specify that all operations constructed 4841 within the context should have control dependencies on 4842 `control_inputs`. For example: 4843 4844 ```python 4845 with g.control_dependencies([a, b, c]): 4846 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4847 d = ... 4848 e = ... 4849 ``` 4850 4851 Multiple calls to `control_dependencies()` can be nested, and in 4852 that case a new `Operation` will have control dependencies on the union 4853 of `control_inputs` from all active contexts. 4854 4855 ```python 4856 with g.control_dependencies([a, b]): 4857 # Ops constructed here run after `a` and `b`. 4858 with g.control_dependencies([c, d]): 4859 # Ops constructed here run after `a`, `b`, `c`, and `d`. 4860 ``` 4861 4862 You can pass None to clear the control dependencies: 4863 4864 ```python 4865 with g.control_dependencies([a, b]): 4866 # Ops constructed here run after `a` and `b`. 4867 with g.control_dependencies(None): 4868 # Ops constructed here run normally, not waiting for either `a` or `b`. 4869 with g.control_dependencies([c, d]): 4870 # Ops constructed here run after `c` and `d`, also not waiting 4871 # for either `a` or `b`. 4872 ``` 4873 4874 *N.B.* The control dependencies context applies *only* to ops that 4875 are constructed within the context. Merely using an op or tensor 4876 in the context does not add a control dependency. The following 4877 example illustrates this point: 4878 4879 ```python 4880 # WRONG 4881 def my_func(pred, tensor): 4882 t = tf.matmul(tensor, tensor) 4883 with tf.control_dependencies([pred]): 4884 # The matmul op is created outside the context, so no control 4885 # dependency will be added. 4886 return t 4887 4888 # RIGHT 4889 def my_func(pred, tensor): 4890 with tf.control_dependencies([pred]): 4891 # The matmul op is created in the context, so a control dependency 4892 # will be added. 4893 return tf.matmul(tensor, tensor) 4894 ``` 4895 4896 Also note that though execution of ops created under this scope will trigger 4897 execution of the dependencies, the ops created under this scope might still 4898 be pruned from a normal tensorflow graph. For example, in the following 4899 snippet of code the dependencies are never executed: 4900 4901 ```python 4902 loss = model.loss() 4903 with tf.control_dependencies(dependencies): 4904 loss = loss + tf.constant(1) # note: dependencies ignored in the 4905 # backward pass 4906 return tf.gradients(loss, model.variables) 4907 ``` 4908 4909 This is because evaluating the gradient graph does not require evaluating 4910 the constant(1) op created in the forward pass. 4911 4912 Args: 4913 control_inputs: A list of `Operation` or `Tensor` objects which must be 4914 executed or computed before running the operations defined in the 4915 context. Can also be `None` to clear the control dependencies. 4916 4917 Returns: 4918 A context manager that specifies control dependencies for all 4919 operations constructed within the context. 4920 4921 Raises: 4922 TypeError: If `control_inputs` is not a list of `Operation` or 4923 `Tensor` objects. 4924 """ 4925 if control_inputs is None: 4926 return self._ControlDependenciesController(self, None) 4927 # First convert the inputs to ops, and deduplicate them. 4928 # NOTE(mrry): Other than deduplication, we do not currently track direct 4929 # or indirect dependencies between control_inputs, which may result in 4930 # redundant control inputs. 4931 control_ops = [] 4932 current = self._current_control_dependencies() 4933 for c in control_inputs: 4934 # The hasattr(handle) is designed to match ResourceVariables. This is so 4935 # control dependencies on a variable or on an unread variable don't 4936 # trigger reads. 4937 if (isinstance(c, IndexedSlices) or 4938 (hasattr(c, "_handle") and hasattr(c, "op"))): 4939 c = c.op 4940 c = self.as_graph_element(c) 4941 if isinstance(c, Tensor): 4942 c = c.op 4943 elif not isinstance(c, Operation): 4944 raise TypeError("Control input must be Operation or Tensor: %s" % c) 4945 if c not in current: 4946 control_ops.append(c) 4947 current.add(c) 4948 return self._ControlDependenciesController(self, control_ops) 4949 4950 # pylint: disable=g-doc-return-or-yield 4951 @tf_contextlib.contextmanager 4952 def _attr_scope(self, attr_map): 4953 """EXPERIMENTAL: A context manager for setting attributes on operators. 4954 4955 This context manager can be used to add additional 4956 attributes to operators within the scope of the context. 4957 4958 For example: 4959 4960 with ops.Graph().as_default() as g: 4961 f_1 = Foo() # No extra attributes 4962 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 4963 f_2 = Foo() # Additional attribute _a=False 4964 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 4965 f_3 = Foo() # Additional attribute _a=False 4966 with g._attr_scope({"_a": None}): 4967 f_4 = Foo() # No additional attributes. 4968 4969 Args: 4970 attr_map: A dictionary mapping attr name strings to AttrValue protocol 4971 buffers or None. 4972 4973 Returns: 4974 A context manager that sets the kernel label to be used for one or more 4975 ops created in that context. 4976 4977 Raises: 4978 TypeError: If attr_map is not a dictionary mapping 4979 strings to AttrValue protobufs. 4980 """ 4981 if not isinstance(attr_map, dict): 4982 raise TypeError("attr_map must be a dictionary mapping " 4983 "strings to AttrValue protocol buffers") 4984 # The saved_attrs dictionary stores any currently-set labels that 4985 # will be overridden by this context manager. 4986 saved_attrs = {} 4987 # Install the given attribute 4988 for name, attr in attr_map.items(): 4989 if not (isinstance(name, six.string_types) and 4990 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 4991 callable(attr))): 4992 raise TypeError("attr_map must be a dictionary mapping " 4993 "strings to AttrValue protocol buffers or " 4994 "callables that emit AttrValue protocol buffers") 4995 try: 4996 saved_attrs[name] = self._attr_scope_map[name] 4997 except KeyError: 4998 pass 4999 if attr is None: 5000 del self._attr_scope_map[name] 5001 else: 5002 self._attr_scope_map[name] = attr 5003 try: 5004 yield # The code within the context runs here. 5005 finally: 5006 # Remove the attributes set for this context, and restore any saved 5007 # attributes. 5008 for name, attr in attr_map.items(): 5009 try: 5010 self._attr_scope_map[name] = saved_attrs[name] 5011 except KeyError: 5012 del self._attr_scope_map[name] 5013 5014 # pylint: enable=g-doc-return-or-yield 5015 5016 # pylint: disable=g-doc-return-or-yield 5017 @tf_contextlib.contextmanager 5018 def _kernel_label_map(self, op_to_kernel_label_map): 5019 """EXPERIMENTAL: A context manager for setting kernel labels. 5020 5021 This context manager can be used to select particular 5022 implementations of kernels within the scope of the context. 5023 5024 For example: 5025 5026 with ops.Graph().as_default() as g: 5027 f_1 = Foo() # Uses the default registered kernel for the Foo op. 5028 with g.kernel_label_map({"Foo": "v_2"}): 5029 f_2 = Foo() # Uses the registered kernel with label "v_2" 5030 # for the Foo op. 5031 with g.kernel_label_map({"Foo": "v_3"}): 5032 f_3 = Foo() # Uses the registered kernel with label "v_3" 5033 # for the Foo op. 5034 with g.kernel_label_map({"Foo": ""}): 5035 f_4 = Foo() # Uses the default registered kernel 5036 # for the Foo op. 5037 5038 Args: 5039 op_to_kernel_label_map: A dictionary mapping op type strings to kernel 5040 label strings. 5041 5042 Returns: 5043 A context manager that sets the kernel label to be used for one or more 5044 ops created in that context. 5045 5046 Raises: 5047 TypeError: If op_to_kernel_label_map is not a dictionary mapping 5048 strings to strings. 5049 """ 5050 if not isinstance(op_to_kernel_label_map, dict): 5051 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 5052 "strings to strings") 5053 # The saved_labels dictionary stores any currently-set labels that 5054 # will be overridden by this context manager. 5055 saved_labels = {} 5056 # Install the given label 5057 for op_type, label in op_to_kernel_label_map.items(): 5058 if not (isinstance(op_type, six.string_types) and 5059 isinstance(label, six.string_types)): 5060 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 5061 "strings to strings") 5062 try: 5063 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 5064 except KeyError: 5065 pass 5066 self._op_to_kernel_label_map[op_type] = label 5067 try: 5068 yield # The code within the context runs here. 5069 finally: 5070 # Remove the labels set for this context, and restore any saved labels. 5071 for op_type, label in op_to_kernel_label_map.items(): 5072 try: 5073 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 5074 except KeyError: 5075 del self._op_to_kernel_label_map[op_type] 5076 5077 # pylint: enable=g-doc-return-or-yield 5078 5079 @tf_contextlib.contextmanager 5080 def _override_gradient_function(self, gradient_function_map): 5081 """Specify gradient function for the given op type.""" 5082 5083 # This is an internal API and we don't need nested context for this. 5084 # TODO(mdan): make it a proper context manager. 5085 assert not self._gradient_function_map 5086 self._gradient_function_map = gradient_function_map 5087 try: 5088 yield 5089 finally: 5090 self._gradient_function_map = {} 5091 5092 # pylint: disable=g-doc-return-or-yield 5093 @tf_contextlib.contextmanager 5094 def gradient_override_map(self, op_type_map): 5095 """EXPERIMENTAL: A context manager for overriding gradient functions. 5096 5097 This context manager can be used to override the gradient function 5098 that will be used for ops within the scope of the context. 5099 5100 For example: 5101 5102 ```python 5103 @tf.RegisterGradient("CustomSquare") 5104 def _custom_square_grad(op, grad): 5105 # ... 5106 5107 with tf.Graph().as_default() as g: 5108 c = tf.constant(5.0) 5109 s_1 = tf.square(c) # Uses the default gradient for tf.square. 5110 with g.gradient_override_map({"Square": "CustomSquare"}): 5111 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 5112 # gradient of s_2. 5113 ``` 5114 5115 Args: 5116 op_type_map: A dictionary mapping op type strings to alternative op type 5117 strings. 5118 5119 Returns: 5120 A context manager that sets the alternative op type to be used for one 5121 or more ops created in that context. 5122 5123 Raises: 5124 TypeError: If `op_type_map` is not a dictionary mapping strings to 5125 strings. 5126 """ 5127 if not isinstance(op_type_map, dict): 5128 raise TypeError("op_type_map must be a dictionary mapping " 5129 "strings to strings") 5130 # The saved_mappings dictionary stores any currently-set mappings that 5131 # will be overridden by this context manager. 5132 saved_mappings = {} 5133 # Install the given label 5134 for op_type, mapped_op_type in op_type_map.items(): 5135 if not (isinstance(op_type, six.string_types) and 5136 isinstance(mapped_op_type, six.string_types)): 5137 raise TypeError("op_type_map must be a dictionary mapping " 5138 "strings to strings") 5139 try: 5140 saved_mappings[op_type] = self._gradient_override_map[op_type] 5141 except KeyError: 5142 pass 5143 self._gradient_override_map[op_type] = mapped_op_type 5144 try: 5145 yield # The code within the context runs here. 5146 finally: 5147 # Remove the labels set for this context, and restore any saved labels. 5148 for op_type, mapped_op_type in op_type_map.items(): 5149 try: 5150 self._gradient_override_map[op_type] = saved_mappings[op_type] 5151 except KeyError: 5152 del self._gradient_override_map[op_type] 5153 5154 # pylint: enable=g-doc-return-or-yield 5155 5156 def prevent_feeding(self, tensor): 5157 """Marks the given `tensor` as unfeedable in this graph.""" 5158 self._unfeedable_tensors.add(tensor) 5159 5160 def is_feedable(self, tensor): 5161 """Returns `True` if and only if `tensor` is feedable.""" 5162 return tensor not in self._unfeedable_tensors 5163 5164 def prevent_fetching(self, op): 5165 """Marks the given `op` as unfetchable in this graph.""" 5166 self._unfetchable_ops.add(op) 5167 5168 def is_fetchable(self, tensor_or_op): 5169 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 5170 if isinstance(tensor_or_op, Tensor): 5171 return tensor_or_op.op not in self._unfetchable_ops 5172 else: 5173 return tensor_or_op not in self._unfetchable_ops 5174 5175 def switch_to_thread_local(self): 5176 """Make device, colocation and dependencies stacks thread-local. 5177 5178 Device, colocation and dependencies stacks are not thread-local be default. 5179 If multiple threads access them, then the state is shared. This means that 5180 one thread may affect the behavior of another thread. 5181 5182 After this method is called, the stacks become thread-local. If multiple 5183 threads access them, then the state is not shared. Each thread uses its own 5184 value; a thread doesn't affect other threads by mutating such a stack. 5185 5186 The initial value for every thread's stack is set to the current value 5187 of the stack when `switch_to_thread_local()` was first called. 5188 """ 5189 if not self._stack_state_is_thread_local: 5190 self._stack_state_is_thread_local = True 5191 5192 @property 5193 def _device_function_stack(self): 5194 if self._stack_state_is_thread_local: 5195 # This may be called from a thread where device_function_stack doesn't yet 5196 # exist. 5197 # pylint: disable=protected-access 5198 if not hasattr(self._thread_local, "_device_function_stack"): 5199 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 5200 self._thread_local._device_function_stack = stack_copy_for_this_thread 5201 return self._thread_local._device_function_stack 5202 # pylint: enable=protected-access 5203 else: 5204 return self._graph_device_function_stack 5205 5206 @property 5207 def _device_functions_outer_to_inner(self): 5208 user_device_specs = self._device_function_stack.peek_objs() 5209 device_functions = [spec.function for spec in user_device_specs] 5210 device_functions_outer_to_inner = list(reversed(device_functions)) 5211 return device_functions_outer_to_inner 5212 5213 def _snapshot_device_function_stack_metadata(self): 5214 """Return device function stack as a list of TraceableObjects. 5215 5216 Returns: 5217 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 5218 member is a displayable name for the user's argument to Graph.device, and 5219 the filename and lineno members point to the code location where 5220 Graph.device was called directly or indirectly by the user. 5221 """ 5222 snapshot = [] 5223 for obj in self._device_function_stack.peek_traceable_objs(): 5224 obj_copy = obj.copy_metadata() 5225 obj_copy.obj = obj.obj.display_name 5226 snapshot.append(obj_copy) 5227 return snapshot 5228 5229 @_device_function_stack.setter 5230 def _device_function_stack(self, device_function_stack): 5231 if self._stack_state_is_thread_local: 5232 # pylint: disable=protected-access 5233 self._thread_local._device_function_stack = device_function_stack 5234 # pylint: enable=protected-access 5235 else: 5236 self._graph_device_function_stack = device_function_stack 5237 5238 @property 5239 def _colocation_stack(self): 5240 """Return thread-local copy of colocation stack.""" 5241 if self._stack_state_is_thread_local: 5242 # This may be called from a thread where colocation_stack doesn't yet 5243 # exist. 5244 # pylint: disable=protected-access 5245 if not hasattr(self._thread_local, "_colocation_stack"): 5246 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 5247 self._thread_local._colocation_stack = stack_copy_for_this_thread 5248 return self._thread_local._colocation_stack 5249 # pylint: enable=protected-access 5250 else: 5251 return self._graph_colocation_stack 5252 5253 def _snapshot_colocation_stack_metadata(self): 5254 """Return colocation stack metadata as a dictionary.""" 5255 return { 5256 traceable_obj.obj.name: traceable_obj.copy_metadata() 5257 for traceable_obj in self._colocation_stack.peek_traceable_objs() 5258 } 5259 5260 @_colocation_stack.setter 5261 def _colocation_stack(self, colocation_stack): 5262 if self._stack_state_is_thread_local: 5263 # pylint: disable=protected-access 5264 self._thread_local._colocation_stack = colocation_stack 5265 # pylint: enable=protected-access 5266 else: 5267 self._graph_colocation_stack = colocation_stack 5268 5269 @property 5270 def _control_dependencies_stack(self): 5271 if self._stack_state_is_thread_local: 5272 # This may be called from a thread where control_dependencies_stack 5273 # doesn't yet exist. 5274 if not hasattr(self._thread_local, "_control_dependencies_stack"): 5275 self._thread_local._control_dependencies_stack = ( 5276 self._graph_control_dependencies_stack[:]) 5277 return self._thread_local._control_dependencies_stack 5278 else: 5279 return self._graph_control_dependencies_stack 5280 5281 @_control_dependencies_stack.setter 5282 def _control_dependencies_stack(self, control_dependencies): 5283 if self._stack_state_is_thread_local: 5284 self._thread_local._control_dependencies_stack = control_dependencies 5285 else: 5286 self._graph_control_dependencies_stack = control_dependencies 5287 5288 @property 5289 def _distribution_strategy_stack(self): 5290 """A stack to maintain distribution strategy context for each thread.""" 5291 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 5292 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 5293 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 5294 5295 @_distribution_strategy_stack.setter 5296 def _distribution_strategy_stack(self, _distribution_strategy_stack): 5297 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 5298 _distribution_strategy_stack) 5299 5300 @property 5301 def _global_distribute_strategy_scope(self): 5302 """For implementing `tf.distribute.set_strategy()`.""" 5303 if not hasattr(self._thread_local, "distribute_strategy_scope"): 5304 self._thread_local.distribute_strategy_scope = None 5305 return self._thread_local.distribute_strategy_scope 5306 5307 @_global_distribute_strategy_scope.setter 5308 def _global_distribute_strategy_scope(self, distribute_strategy_scope): 5309 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope) 5310 5311 def _mutation_lock(self): 5312 """Returns a lock to guard code that creates & mutates ops. 5313 5314 See the comment for self._group_lock for more info. 5315 """ 5316 return self._group_lock.group(_MUTATION_LOCK_GROUP) 5317 5318 def _session_run_lock(self): 5319 """Returns a lock to guard code for Session.run. 5320 5321 See the comment for self._group_lock for more info. 5322 """ 5323 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 5324 5325 5326# TODO(agarwal): currently device directives in an outer eager scope will not 5327# apply to inner graph mode code. Fix that. 5328 5329 5330@tf_export(v1=["device"]) 5331def device(device_name_or_function): 5332 """Wrapper for `Graph.device()` using the default graph. 5333 5334 See `tf.Graph.device` for more details. 5335 5336 Args: 5337 device_name_or_function: The device name or function to use in the context. 5338 5339 Returns: 5340 A context manager that specifies the default device to use for newly 5341 created ops. 5342 5343 Raises: 5344 RuntimeError: If eager execution is enabled and a function is passed in. 5345 """ 5346 if context.executing_eagerly(): 5347 if callable(device_name_or_function): 5348 raise RuntimeError( 5349 "tf.device does not support functions when eager execution " 5350 "is enabled.") 5351 return context.device(device_name_or_function) 5352 elif executing_eagerly_outside_functions(): 5353 @tf_contextlib.contextmanager 5354 def combined(device_name_or_function): 5355 with get_default_graph().device(device_name_or_function): 5356 if not callable(device_name_or_function): 5357 with context.device(device_name_or_function): 5358 yield 5359 else: 5360 yield 5361 return combined(device_name_or_function) 5362 else: 5363 return get_default_graph().device(device_name_or_function) 5364 5365 5366@tf_export("device", v1=[]) 5367def device_v2(device_name): 5368 """Specifies the device for ops created/executed in this context. 5369 5370 This function specifies the device to be used for ops created/executed in a 5371 particular context. Nested contexts will inherit and also create/execute 5372 their ops on the specified device. If a specific device is not required, 5373 consider not using this function so that a device can be automatically 5374 assigned. In general the use of this function is optional. `device_name` can 5375 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially 5376 specified, containing only a subset of the "/"-separated fields. Any fields 5377 which are specified will override device annotations from outer scopes. 5378 5379 For example: 5380 5381 ```python 5382 with tf.device('/job:foo'): 5383 # ops created here have devices with /job:foo 5384 with tf.device('/job:bar/task:0/device:gpu:2'): 5385 # ops created here have the fully specified device above 5386 with tf.device('/device:gpu:1'): 5387 # ops created here have the device '/job:foo/device:gpu:1' 5388 ``` 5389 5390 Args: 5391 device_name: The device name to use in the context. 5392 5393 Returns: 5394 A context manager that specifies the default device to use for newly 5395 created ops. 5396 5397 Raises: 5398 RuntimeError: If a function is passed in. 5399 """ 5400 if callable(device_name): 5401 raise RuntimeError("tf.device does not support functions.") 5402 return device(device_name) 5403 5404 5405@tf_export(v1=["container"]) 5406def container(container_name): 5407 """Wrapper for `Graph.container()` using the default graph. 5408 5409 Args: 5410 container_name: The container string to use in the context. 5411 5412 Returns: 5413 A context manager that specifies the default container to use for newly 5414 created stateful ops. 5415 """ 5416 return get_default_graph().container(container_name) 5417 5418 5419def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 5420 if context.executing_eagerly(): 5421 if op is not None: 5422 if not hasattr(op, "device"): 5423 op = internal_convert_to_tensor_or_indexed_slices(op) 5424 return device(op.device) 5425 else: 5426 return NullContextmanager() 5427 else: 5428 default_graph = get_default_graph() 5429 if isinstance(op, EagerTensor): 5430 if default_graph.building_function: 5431 return default_graph.device(op.device) 5432 else: 5433 raise ValueError("Encountered an Eager-defined Tensor during graph " 5434 "construction, but a function was not being built.") 5435 return default_graph._colocate_with_for_gradient( 5436 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 5437 5438 5439# Internal interface to colocate_with. colocate_with has been deprecated from 5440# public API. There are still a few internal uses of colocate_with. Add internal 5441# only API for those uses to avoid deprecation warning. 5442def colocate_with(op, ignore_existing=False): 5443 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 5444 5445 5446@deprecation.deprecated( 5447 date=None, instructions="Colocations handled automatically by placer.") 5448@tf_export(v1=["colocate_with"]) 5449def _colocate_with(op, ignore_existing=False): 5450 return colocate_with(op, ignore_existing) 5451 5452 5453@tf_export("control_dependencies") 5454def control_dependencies(control_inputs): 5455 """Wrapper for `Graph.control_dependencies()` using the default graph. 5456 5457 See `tf.Graph.control_dependencies` for more details. 5458 5459 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 5460 this method, as ops execute in the expected order thanks to automatic control 5461 dependencies.* Only use `tf.control_dependencies` when working with v1 5462 `tf.Graph` code. 5463 5464 When eager execution is enabled, any callable object in the `control_inputs` 5465 list will be called. 5466 5467 Args: 5468 control_inputs: A list of `Operation` or `Tensor` objects which must be 5469 executed or computed before running the operations defined in the context. 5470 Can also be `None` to clear the control dependencies. If eager execution 5471 is enabled, any callable object in the `control_inputs` list will be 5472 called. 5473 5474 Returns: 5475 A context manager that specifies control dependencies for all 5476 operations constructed within the context. 5477 """ 5478 if context.executing_eagerly(): 5479 if control_inputs: 5480 # Execute any pending callables. 5481 for control in control_inputs: 5482 if callable(control): 5483 control() 5484 return NullContextmanager() 5485 else: 5486 return get_default_graph().control_dependencies(control_inputs) 5487 5488 5489class _DefaultStack(threading.local): 5490 """A thread-local stack of objects for providing implicit defaults.""" 5491 5492 def __init__(self): 5493 super(_DefaultStack, self).__init__() 5494 self._enforce_nesting = True 5495 self.stack = [] 5496 5497 def get_default(self): 5498 return self.stack[-1] if self.stack else None 5499 5500 def reset(self): 5501 self.stack = [] 5502 5503 def is_cleared(self): 5504 return not self.stack 5505 5506 @property 5507 def enforce_nesting(self): 5508 return self._enforce_nesting 5509 5510 @enforce_nesting.setter 5511 def enforce_nesting(self, value): 5512 self._enforce_nesting = value 5513 5514 @tf_contextlib.contextmanager 5515 def get_controller(self, default): 5516 """A context manager for manipulating a default stack.""" 5517 self.stack.append(default) 5518 try: 5519 yield default 5520 finally: 5521 # stack may be empty if reset() was called 5522 if self.stack: 5523 if self._enforce_nesting: 5524 if self.stack[-1] is not default: 5525 raise AssertionError( 5526 "Nesting violated for default stack of %s objects" % 5527 type(default)) 5528 self.stack.pop() 5529 else: 5530 self.stack.remove(default) 5531 5532 5533_default_session_stack = _DefaultStack() # pylint: disable=protected-access 5534 5535 5536def default_session(session): 5537 """Python "with" handler for defining a default session. 5538 5539 This function provides a means of registering a session for handling 5540 Tensor.eval() and Operation.run() calls. It is primarily intended for use 5541 by session.Session, but can be used with any object that implements 5542 the Session.run() interface. 5543 5544 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 5545 invocations within the scope of a block should be executed by a particular 5546 session. 5547 5548 The default session applies to the current thread only, so it is always 5549 possible to inspect the call stack and determine the scope of a default 5550 session. If you create a new thread, and wish to use the default session 5551 in that thread, you must explicitly add a "with ops.default_session(sess):" 5552 block in that thread's function. 5553 5554 Example: 5555 The following code examples are equivalent: 5556 5557 # 1. Using the Session object directly: 5558 sess = ... 5559 c = tf.constant(5.0) 5560 sess.run(c) 5561 5562 # 2. Using default_session(): 5563 sess = ... 5564 with ops.default_session(sess): 5565 c = tf.constant(5.0) 5566 result = c.eval() 5567 5568 # 3. Overriding default_session(): 5569 sess = ... 5570 with ops.default_session(sess): 5571 c = tf.constant(5.0) 5572 with ops.default_session(...): 5573 c.eval(session=sess) 5574 5575 Args: 5576 session: The session to be installed as the default session. 5577 5578 Returns: 5579 A context manager for the default session. 5580 """ 5581 return _default_session_stack.get_controller(session) 5582 5583 5584@tf_export(v1=["get_default_session"]) 5585def get_default_session(): 5586 """Returns the default session for the current thread. 5587 5588 The returned `Session` will be the innermost session on which a 5589 `Session` or `Session.as_default()` context has been entered. 5590 5591 NOTE: The default session is a property of the current thread. If you 5592 create a new thread, and wish to use the default session in that 5593 thread, you must explicitly add a `with sess.as_default():` in that 5594 thread's function. 5595 5596 Returns: 5597 The default `Session` being used in the current thread. 5598 """ 5599 return _default_session_stack.get_default() 5600 5601 5602def _eval_using_default_session(tensors, feed_dict, graph, session=None): 5603 """Uses the default session to evaluate one or more tensors. 5604 5605 Args: 5606 tensors: A single Tensor, or a list of Tensor objects. 5607 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5608 numpy ndarrays, TensorProtos, or strings. 5609 graph: The graph in which the tensors are defined. 5610 session: (Optional) A different session to use to evaluate "tensors". 5611 5612 Returns: 5613 Either a single numpy ndarray if "tensors" is a single tensor; or a list 5614 of numpy ndarrays that each correspond to the respective element in 5615 "tensors". 5616 5617 Raises: 5618 ValueError: If no default session is available; the default session 5619 does not have "graph" as its graph; or if "session" is specified, 5620 and it does not have "graph" as its graph. 5621 """ 5622 if session is None: 5623 session = get_default_session() 5624 if session is None: 5625 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 5626 "session is registered. Use `with " 5627 "sess.as_default()` or pass an explicit session to " 5628 "`eval(session=sess)`") 5629 if session.graph is not graph: 5630 raise ValueError("Cannot use the default session to evaluate tensor: " 5631 "the tensor's graph is different from the session's " 5632 "graph. Pass an explicit session to " 5633 "`eval(session=sess)`.") 5634 else: 5635 if session.graph is not graph: 5636 raise ValueError("Cannot use the given session to evaluate tensor: " 5637 "the tensor's graph is different from the session's " 5638 "graph.") 5639 return session.run(tensors, feed_dict) 5640 5641 5642def _run_using_default_session(operation, feed_dict, graph, session=None): 5643 """Uses the default session to run "operation". 5644 5645 Args: 5646 operation: The Operation to be run. 5647 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5648 numpy ndarrays, TensorProtos, or strings. 5649 graph: The graph in which "operation" is defined. 5650 session: (Optional) A different session to use to run "operation". 5651 5652 Raises: 5653 ValueError: If no default session is available; the default session 5654 does not have "graph" as its graph; or if "session" is specified, 5655 and it does not have "graph" as its graph. 5656 """ 5657 if session is None: 5658 session = get_default_session() 5659 if session is None: 5660 raise ValueError("Cannot execute operation using `run()`: No default " 5661 "session is registered. Use `with " 5662 "sess.as_default():` or pass an explicit session to " 5663 "`run(session=sess)`") 5664 if session.graph is not graph: 5665 raise ValueError("Cannot use the default session to execute operation: " 5666 "the operation's graph is different from the " 5667 "session's graph. Pass an explicit session to " 5668 "run(session=sess).") 5669 else: 5670 if session.graph is not graph: 5671 raise ValueError("Cannot use the given session to execute operation: " 5672 "the operation's graph is different from the session's " 5673 "graph.") 5674 session.run(operation, feed_dict) 5675 5676 5677class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 5678 """A thread-local stack of objects for providing an implicit default graph.""" 5679 5680 def __init__(self): 5681 super(_DefaultGraphStack, self).__init__() 5682 self._global_default_graph = None 5683 5684 def get_default(self): 5685 """Override that returns a global default if the stack is empty.""" 5686 if self.stack: 5687 return self.stack[-1] 5688 elif self._global_default_graph: 5689 return self._global_default_graph 5690 else: 5691 self._global_default_graph = Graph() 5692 return self._global_default_graph 5693 5694 def _GetGlobalDefaultGraph(self): 5695 if self._global_default_graph is None: 5696 # TODO(mrry): Perhaps log that the default graph is being used, or set 5697 # provide some other feedback to prevent confusion when a mixture of 5698 # the global default graph and an explicit graph are combined in the 5699 # same process. 5700 self._global_default_graph = Graph() 5701 return self._global_default_graph 5702 5703 def reset(self): 5704 super(_DefaultGraphStack, self).reset() 5705 self._global_default_graph = None 5706 5707 @tf_contextlib.contextmanager 5708 def get_controller(self, default): 5709 context.context().context_switches.push(default.building_function, 5710 default.as_default, 5711 default._device_function_stack) 5712 try: 5713 with super(_DefaultGraphStack, 5714 self).get_controller(default) as g, context.graph_mode(): 5715 yield g 5716 finally: 5717 # If an exception is raised here it may be hiding a related exception in 5718 # the try-block (just above). 5719 context.context().context_switches.pop() 5720 5721 5722_default_graph_stack = _DefaultGraphStack() 5723 5724 5725# Shared helper used in init_scope and executing_eagerly_outside_functions 5726# to obtain the outermost context that is not building a function, and the 5727# innermost non empty device stack. 5728def _get_outer_context_and_inner_device_stack(): 5729 """Get the outermost context not building a function.""" 5730 default_graph = get_default_graph() 5731 outer_context = None 5732 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 5733 5734 if not _default_graph_stack.stack: 5735 # If the default graph stack is empty, then we cannot be building a 5736 # function. Install the global graph (which, in this case, is also the 5737 # default graph) as the outer context. 5738 if default_graph.building_function: 5739 raise RuntimeError("The global graph is building a function.") 5740 outer_context = default_graph.as_default 5741 else: 5742 # Find a context that is not building a function. 5743 for stack_entry in reversed(context.context().context_switches.stack): 5744 if not innermost_nonempty_device_stack: 5745 innermost_nonempty_device_stack = stack_entry.device_stack 5746 if not stack_entry.is_building_function: 5747 outer_context = stack_entry.enter_context_fn 5748 break 5749 5750 if outer_context is None: 5751 # As a last resort, obtain the global default graph; this graph doesn't 5752 # necessarily live on the graph stack (and hence it doesn't necessarily 5753 # live on the context stack), but it is stored in the graph stack's 5754 # encapsulating object. 5755 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 5756 5757 if outer_context is None: 5758 # Sanity check; this shouldn't be triggered. 5759 raise RuntimeError("All graphs are building functions, and no " 5760 "eager context was previously active.") 5761 5762 return outer_context, innermost_nonempty_device_stack 5763 5764 5765# pylint: disable=g-doc-return-or-yield,line-too-long 5766@tf_export("init_scope") 5767@tf_contextlib.contextmanager 5768def init_scope(): 5769 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 5770 5771 There is often a need to lift variable initialization ops out of control-flow 5772 scopes, function-building graphs, and gradient tapes. Entering an 5773 `init_scope` is a mechanism for satisfying these desiderata. In particular, 5774 entering an `init_scope` has three effects: 5775 5776 (1) All control dependencies are cleared the moment the scope is entered; 5777 this is equivalent to entering the context manager returned from 5778 `control_dependencies(None)`, which has the side-effect of exiting 5779 control-flow scopes like `tf.cond` and `tf.while_loop`. 5780 5781 (2) All operations that are created while the scope is active are lifted 5782 into the lowest context on the `context_stack` that is not building a 5783 graph function. Here, a context is defined as either a graph or an eager 5784 context. Every context switch, i.e., every installation of a graph as 5785 the default graph and every switch into eager mode, is logged in a 5786 thread-local stack called `context_switches`; the log entry for a 5787 context switch is popped from the stack when the context is exited. 5788 Entering an `init_scope` is equivalent to crawling up 5789 `context_switches`, finding the first context that is not building a 5790 graph function, and entering it. A caveat is that if graph mode is 5791 enabled but the default graph stack is empty, then entering an 5792 `init_scope` will simply install a fresh graph as the default one. 5793 5794 (3) The gradient tape is paused while the scope is active. 5795 5796 When eager execution is enabled, code inside an init_scope block runs with 5797 eager execution enabled even when tracing a `tf.function`. For example: 5798 5799 ```python 5800 tf.compat.v1.enable_eager_execution() 5801 5802 @tf.function 5803 def func(): 5804 # A function constructs TensorFlow graphs, 5805 # it does not execute eagerly. 5806 assert not tf.executing_eagerly() 5807 with tf.init_scope(): 5808 # Initialization runs with eager execution enabled 5809 assert tf.executing_eagerly() 5810 ``` 5811 5812 Raises: 5813 RuntimeError: if graph state is incompatible with this initialization. 5814 """ 5815 # pylint: enable=g-doc-return-or-yield,line-too-long 5816 5817 if context.executing_eagerly(): 5818 # Fastpath. 5819 with tape.stop_recording(): 5820 yield 5821 else: 5822 # Retrieve the active name scope: entering an `init_scope` preserves 5823 # the name scope of the current context. 5824 scope = get_default_graph().get_name_scope() 5825 if scope and scope[-1] != "/": 5826 # Names that end with trailing slashes are treated by `name_scope` as 5827 # absolute. 5828 scope = scope + "/" 5829 5830 outer_context, innermost_nonempty_device_stack = ( 5831 _get_outer_context_and_inner_device_stack()) 5832 5833 outer_graph = None 5834 outer_device_stack = None 5835 try: 5836 with outer_context(), name_scope( 5837 scope, skip_on_eager=False), control_dependencies( 5838 None), tape.stop_recording(): 5839 context_manager = NullContextmanager 5840 context_manager_input = None 5841 if not context.executing_eagerly(): 5842 # The device stack is preserved when lifting into a graph. Eager 5843 # execution doesn't implement device stacks and in particular it 5844 # doesn't support device functions, so in general it's not possible 5845 # to do the same when lifting into the eager context. 5846 outer_graph = get_default_graph() 5847 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 5848 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 5849 elif innermost_nonempty_device_stack is not None: 5850 for device_spec in innermost_nonempty_device_stack.peek_objs(): 5851 if device_spec.function is None: 5852 break 5853 if device_spec.raw_string: 5854 context_manager = context.device 5855 context_manager_input = device_spec.raw_string 5856 break 5857 # It is currently not possible to have a device function in V2, 5858 # but in V1 we are unable to apply device functions in eager mode. 5859 # This means that we will silently skip some of the entries on the 5860 # device stack in V1 + eager mode. 5861 5862 with context_manager(context_manager_input): 5863 yield 5864 finally: 5865 # If an exception is raised here it may be hiding a related exception in 5866 # try-block (just above). 5867 if outer_graph is not None: 5868 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 5869 5870 5871@tf_export(v1=["executing_eagerly_outside_functions"]) 5872def executing_eagerly_outside_functions(): 5873 """Returns True if executing eagerly, even if inside a graph function. 5874 5875 This function will check the outermost context for the program and see if 5876 it is in eager mode. It is useful comparing to `tf.executing_eagerly()`, 5877 which checks the current context and will return `False` within a 5878 `tf.function` body. It can be used to build library that behave differently 5879 in eager runtime and v1 session runtime (deprecated). 5880 5881 Example: 5882 5883 >>> tf.compat.v1.enable_eager_execution() 5884 >>> @tf.function 5885 ... def func(): 5886 ... # A function constructs TensorFlow graphs, it does not execute eagerly, 5887 ... # but the outer most context is still eager. 5888 ... assert not tf.executing_eagerly() 5889 ... return tf.compat.v1.executing_eagerly_outside_functions() 5890 >>> func() 5891 <tf.Tensor: shape=(), dtype=bool, numpy=True> 5892 5893 Returns: 5894 boolean, whether the outermost context is in eager mode. 5895 """ 5896 if context.executing_eagerly(): 5897 return True 5898 else: 5899 outer_context, _ = _get_outer_context_and_inner_device_stack() 5900 with outer_context(): 5901 return context.executing_eagerly() 5902 5903 5904@tf_export("inside_function", v1=[]) 5905def inside_function(): 5906 """Indicates whether the caller code is executing inside a `tf.function`. 5907 5908 Returns: 5909 Boolean, True if the caller code is executing inside a `tf.function` 5910 rather than eagerly. 5911 5912 Example: 5913 5914 >>> tf.inside_function() 5915 False 5916 >>> @tf.function 5917 ... def f(): 5918 ... print(tf.inside_function()) 5919 >>> f() 5920 True 5921 """ 5922 return get_default_graph().building_function 5923 5924 5925@tf_export(v1=["enable_eager_execution"]) 5926def enable_eager_execution(config=None, device_policy=None, 5927 execution_mode=None): 5928 """Enables eager execution for the lifetime of this program. 5929 5930 Eager execution provides an imperative interface to TensorFlow. With eager 5931 execution enabled, TensorFlow functions execute operations immediately (as 5932 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`) 5933 and 5934 return concrete values (as opposed to symbolic references to a node in a 5935 computational graph). 5936 5937 For example: 5938 5939 ```python 5940 tf.compat.v1.enable_eager_execution() 5941 5942 # After eager execution is enabled, operations are executed as they are 5943 # defined and Tensor objects hold concrete values, which can be accessed as 5944 # numpy.ndarray`s through the numpy() method. 5945 assert tf.multiply(6, 7).numpy() == 42 5946 ``` 5947 5948 Eager execution cannot be enabled after TensorFlow APIs have been used to 5949 create or execute graphs. It is typically recommended to invoke this function 5950 at program startup and not in a library (as most libraries should be usable 5951 both with and without eager execution). 5952 5953 @compatibility(TF2) 5954 This function is not necessary if you are using TF2. Eager execution is 5955 enabled by default. 5956 @end_compatibility 5957 5958 Args: 5959 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the 5960 environment in which operations are executed. Note that 5961 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via 5962 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto` 5963 are not implemented (or are irrelevant) when eager execution is enabled. 5964 device_policy: (Optional.) Policy controlling how operations requiring 5965 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 5966 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will 5967 be picked automatically. The value picked may change between TensorFlow 5968 releases. 5969 Valid values: 5970 - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the 5971 placement is not correct. 5972 - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not 5973 on the right device but logs a warning. 5974 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. 5975 Note that this may hide performance problems as there is no notification 5976 provided when operations are blocked on the tensor being copied between 5977 devices. 5978 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 5979 int32 tensors, raising errors on the other ones. 5980 execution_mode: (Optional.) Policy controlling how operations dispatched are 5981 actually executed. When set to None, an appropriate value will be picked 5982 automatically. The value picked may change between TensorFlow releases. 5983 Valid values: 5984 - tf.contrib.eager.SYNC: executes each operation synchronously. 5985 - tf.contrib.eager.ASYNC: executes each operation asynchronously. These 5986 operations may return "non-ready" handles. 5987 5988 Raises: 5989 ValueError: If eager execution is enabled after creating/executing a 5990 TensorFlow graph, or if options provided conflict with a previous call 5991 to this function. 5992 """ 5993 _api_usage_gauge.get_cell().set(True) 5994 logging.vlog(1, "Enabling eager execution") 5995 if context.default_execution_mode != context.EAGER_MODE: 5996 return enable_eager_execution_internal( 5997 config=config, 5998 device_policy=device_policy, 5999 execution_mode=execution_mode, 6000 server_def=None) 6001 6002 6003@tf_export(v1=["disable_eager_execution"]) 6004def disable_eager_execution(): 6005 """Disables eager execution. 6006 6007 This function can only be called before any Graphs, Ops, or Tensors have been 6008 created. 6009 6010 @compatibility(TF2) 6011 This function is not necessary if you are using TF2. Eager execution is 6012 enabled by default. If you want to use Graph mode please consider 6013 [tf.function](https://www.tensorflow.org/api_docs/python/tf/function). 6014 @end_compatibility 6015 """ 6016 _api_usage_gauge.get_cell().set(False) 6017 logging.vlog(1, "Disabling eager execution") 6018 context.default_execution_mode = context.GRAPH_MODE 6019 c = context.context_safe() 6020 if c is not None: 6021 c._thread_local_data.is_eager = False # pylint: disable=protected-access 6022 6023 6024def enable_eager_execution_internal(config=None, 6025 device_policy=None, 6026 execution_mode=None, 6027 server_def=None): 6028 """Enables eager execution for the lifetime of this program. 6029 6030 Most of the doc string for enable_eager_execution is relevant here as well. 6031 6032 Args: 6033 config: See enable_eager_execution doc string 6034 device_policy: See enable_eager_execution doc string 6035 execution_mode: See enable_eager_execution doc string 6036 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on 6037 remote devices. GrpcServers need to be started by creating an identical 6038 server_def to this, and setting the appropriate task_indexes, so that the 6039 servers can communicate. It will then be possible to execute operations on 6040 remote devices. 6041 6042 Raises: 6043 ValueError 6044 6045 """ 6046 if config is not None and not isinstance(config, config_pb2.ConfigProto): 6047 raise TypeError("config must be a tf.ConfigProto, but got %s" % 6048 type(config)) 6049 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 6050 context.DEVICE_PLACEMENT_WARN, 6051 context.DEVICE_PLACEMENT_SILENT, 6052 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 6053 raise ValueError( 6054 "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" 6055 ) 6056 if execution_mode not in (None, context.SYNC, context.ASYNC): 6057 raise ValueError( 6058 "execution_mode must be one of None, tf.contrib.eager.SYNC, " 6059 "tf.contrib.eager.ASYNC") 6060 if context.default_execution_mode == context.GRAPH_MODE: 6061 graph_mode_has_been_used = ( 6062 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 6063 if graph_mode_has_been_used: 6064 raise ValueError( 6065 "tf.enable_eager_execution must be called at program startup.") 6066 context.default_execution_mode = context.EAGER_MODE 6067 # pylint: disable=protected-access 6068 with context._context_lock: 6069 if context._context is None: 6070 context._set_context_locked(context.Context( 6071 config=config, 6072 device_policy=device_policy, 6073 execution_mode=execution_mode, 6074 server_def=server_def)) 6075 elif ((config is not None and config is not context._context._config) or 6076 (device_policy is not None and 6077 device_policy is not context._context._device_policy) or 6078 (execution_mode is not None and 6079 execution_mode is not context._context._execution_mode)): 6080 raise ValueError( 6081 "Trying to change the options of an active eager" 6082 " execution. Context config: %s, specified config:" 6083 " %s. Context device policy: %s, specified device" 6084 " policy: %s. Context execution mode: %s, " 6085 " specified execution mode %s." % 6086 (context._context._config, config, context._context._device_policy, 6087 device_policy, context._context._execution_mode, execution_mode)) 6088 else: 6089 # We already created everything, so update the thread local data. 6090 context._context._thread_local_data.is_eager = True 6091 6092 # Monkey patch to get rid of an unnecessary conditional since the context is 6093 # now initialized. 6094 context.context = context.context_safe 6095 6096 6097def eager_run(main=None, argv=None): 6098 """Runs the program with an optional main function and argv list. 6099 6100 The program will run with eager execution enabled. 6101 6102 Example: 6103 ```python 6104 import tensorflow as tf 6105 # Import subject to future changes: 6106 from tensorflow.contrib.eager.python import tfe 6107 6108 def main(_): 6109 u = tf.constant(6.0) 6110 v = tf.constant(7.0) 6111 print(u * v) 6112 6113 if __name__ == "__main__": 6114 tfe.run() 6115 ``` 6116 6117 Args: 6118 main: the main function to run. 6119 argv: the arguments to pass to it. 6120 """ 6121 enable_eager_execution() 6122 app.run(main, argv) 6123 6124 6125@tf_export(v1=["reset_default_graph"]) 6126def reset_default_graph(): 6127 """Clears the default graph stack and resets the global default graph. 6128 6129 NOTE: The default graph is a property of the current thread. This 6130 function applies only to the current thread. Calling this function while 6131 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will 6132 result in undefined 6133 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 6134 after calling this function will result in undefined behavior. 6135 6136 @compatibility(TF2) 6137 `reset_default_graph` does not work with either eager execution or 6138 `tf.function`, and you should not invoke it directly. To migrate code that 6139 uses Graph-related functions to TF2, rewrite the code without them. See the 6140 [migration guide](https://www.tensorflow.org/guide/migrate) for more 6141 description about the behavior and semantic changes between Tensorflow 1 and 6142 Tensorflow 2. 6143 @end_compatibility 6144 6145 Raises: 6146 AssertionError: If this function is called within a nested graph. 6147 """ 6148 if not _default_graph_stack.is_cleared(): 6149 raise AssertionError("Do not use tf.reset_default_graph() to clear " 6150 "nested graphs. If you need a cleared graph, " 6151 "exit the nesting and create a new graph.") 6152 _default_graph_stack.reset() 6153 6154 6155@tf_export(v1=["get_default_graph"]) 6156def get_default_graph(): 6157 """Returns the default graph for the current thread. 6158 6159 The returned graph will be the innermost graph on which a 6160 `Graph.as_default()` context has been entered, or a global default 6161 graph if none has been explicitly created. 6162 6163 NOTE: The default graph is a property of the current thread. If you 6164 create a new thread, and wish to use the default graph in that 6165 thread, you must explicitly add a `with g.as_default():` in that 6166 thread's function. 6167 6168 @compatibility(TF2) 6169 `get_default_graph` does not work with either eager execution or 6170 `tf.function`, and you should not invoke it directly. To migrate code that 6171 uses Graph-related functions to TF2, rewrite the code without them. See the 6172 [migration guide](https://www.tensorflow.org/guide/migrate) for more 6173 description about the behavior and semantic changes between Tensorflow 1 and 6174 Tensorflow 2. 6175 @end_compatibility 6176 6177 Returns: 6178 The default `Graph` being used in the current thread. 6179 """ 6180 return _default_graph_stack.get_default() 6181 6182 6183def has_default_graph(): 6184 """Returns True if there is a default graph.""" 6185 return len(_default_graph_stack.stack) >= 1 6186 6187 6188# Exported due to b/171079555 6189@tf_export("__internal__.get_name_scope", v1=[]) 6190def get_name_scope(): 6191 """Returns the current name scope in the default_graph. 6192 6193 For example: 6194 6195 ```python 6196 with tf.name_scope('scope1'): 6197 with tf.name_scope('scope2'): 6198 print(tf.get_name_scope()) 6199 ``` 6200 would print the string `scope1/scope2`. 6201 6202 Returns: 6203 A string representing the current name scope. 6204 """ 6205 if context.executing_eagerly(): 6206 return context.context().scope_name.rstrip("/") 6207 return get_default_graph().get_name_scope() 6208 6209 6210def _assert_same_graph(original_item, item): 6211 """Fail if the 2 items are from different graphs. 6212 6213 Args: 6214 original_item: Original item to check against. 6215 item: Item to check. 6216 6217 Raises: 6218 ValueError: if graphs do not match. 6219 """ 6220 original_graph = getattr(original_item, "graph", None) 6221 graph = getattr(item, "graph", None) 6222 if original_graph and graph and original_graph is not graph: 6223 raise ValueError( 6224 "%s must be from the same graph as %s (graphs are %s and %s)." % 6225 (item, original_item, graph, original_graph)) 6226 6227 6228def _get_graph_from_inputs(op_input_list, graph=None): 6229 """Returns the appropriate graph to use for the given inputs. 6230 6231 This library method provides a consistent algorithm for choosing the graph 6232 in which an Operation should be constructed: 6233 6234 1. If the default graph is being used to construct a function, we 6235 use the default graph. 6236 2. If the "graph" is specified explicitly, we validate that all of the inputs 6237 in "op_input_list" are compatible with that graph. 6238 3. Otherwise, we attempt to select a graph from the first Operation- 6239 or Tensor-valued input in "op_input_list", and validate that all other 6240 such inputs are in the same graph. 6241 4. If the graph was not specified and it could not be inferred from 6242 "op_input_list", we attempt to use the default graph. 6243 6244 Args: 6245 op_input_list: A list of inputs to an operation, which may include `Tensor`, 6246 `Operation`, and other objects that may be converted to a graph element. 6247 graph: (Optional) The explicit graph to use. 6248 6249 Raises: 6250 TypeError: If op_input_list is not a list or tuple, or if graph is not a 6251 Graph. 6252 ValueError: If a graph is explicitly passed and not all inputs are from it, 6253 or if the inputs are from multiple graphs, or we could not find a graph 6254 and there was no default graph. 6255 6256 Returns: 6257 The appropriate graph to use for the given inputs. 6258 6259 """ 6260 current_default_graph = get_default_graph() 6261 if current_default_graph.building_function: 6262 return current_default_graph 6263 6264 op_input_list = tuple(op_input_list) # Handle generators correctly 6265 if graph and not isinstance(graph, Graph): 6266 raise TypeError("Input graph needs to be a Graph: %s" % (graph,)) 6267 6268 # 1. We validate that all of the inputs are from the same graph. This is 6269 # either the supplied graph parameter, or the first one selected from one 6270 # the graph-element-valued inputs. In the latter case, we hold onto 6271 # that input in original_graph_element so we can provide a more 6272 # informative error if a mismatch is found. 6273 original_graph_element = None 6274 for op_input in op_input_list: 6275 # Determine if this is a valid graph_element. 6276 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 6277 # up. 6278 graph_element = None 6279 if (isinstance(op_input, (Operation, internal.NativeObject)) and 6280 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 6281 graph_element = op_input 6282 else: 6283 graph_element = _as_graph_element(op_input) 6284 6285 if graph_element is not None: 6286 if not graph: 6287 original_graph_element = graph_element 6288 graph = getattr(graph_element, "graph", None) 6289 elif original_graph_element is not None: 6290 _assert_same_graph(original_graph_element, graph_element) 6291 elif graph_element.graph is not graph: 6292 raise ValueError("%s is not from the passed-in graph." % graph_element) 6293 6294 # 2. If all else fails, we use the default graph, which is always there. 6295 return graph or current_default_graph 6296 6297 6298@tf_export(v1=["GraphKeys"]) 6299class GraphKeys(object): 6300 """Standard names to use for graph collections. 6301 6302 The standard library uses various well-known names to collect and 6303 retrieve values associated with a graph. For example, the 6304 `tf.Optimizer` subclasses default to optimizing the variables 6305 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 6306 specified, but it is also possible to pass an explicit list of 6307 variables. 6308 6309 The following standard keys are defined: 6310 6311 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 6312 across distributed environment (model variables are subset of these). See 6313 `tf.compat.v1.global_variables` 6314 for more details. 6315 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 6316 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 6317 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 6318 machine. Usually used for temporarily variables, like counters. 6319 Note: use `tf.contrib.framework.local_variable` to add to this collection. 6320 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 6321 model for inference (feed forward). Note: use 6322 `tf.contrib.framework.model_variable` to add to this collection. 6323 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 6324 be trained by an optimizer. See 6325 `tf.compat.v1.trainable_variables` 6326 for more details. 6327 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 6328 graph. See 6329 `tf.compat.v1.summary.merge_all` 6330 for more details. 6331 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 6332 produce input for a computation. See 6333 `tf.compat.v1.train.start_queue_runners` 6334 for more details. 6335 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 6336 keep moving averages. See 6337 `tf.compat.v1.moving_average_variables` 6338 for more details. 6339 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 6340 construction. 6341 6342 The following standard keys are _defined_, but their collections are **not** 6343 automatically populated as many of the others are: 6344 6345 * `WEIGHTS` 6346 * `BIASES` 6347 * `ACTIVATIONS` 6348 """ 6349 6350 # Key to collect Variable objects that are global (shared across machines). 6351 # Default collection for all variables, except local ones. 6352 GLOBAL_VARIABLES = "variables" 6353 # Key to collect local variables that are local to the machine and are not 6354 # saved/restored. 6355 LOCAL_VARIABLES = "local_variables" 6356 # Key to collect local variables which are used to accumulate interal state 6357 # to be used in tf.metrics.*. 6358 METRIC_VARIABLES = "metric_variables" 6359 # Key to collect model variables defined by layers. 6360 MODEL_VARIABLES = "model_variables" 6361 # Key to collect Variable objects that will be trained by the 6362 # optimizers. 6363 TRAINABLE_VARIABLES = "trainable_variables" 6364 # Key to collect summaries. 6365 SUMMARIES = "summaries" 6366 # Key to collect QueueRunners. 6367 QUEUE_RUNNERS = "queue_runners" 6368 # Key to collect table initializers. 6369 TABLE_INITIALIZERS = "table_initializer" 6370 # Key to collect asset filepaths. An asset represents an external resource 6371 # like a vocabulary file. 6372 ASSET_FILEPATHS = "asset_filepaths" 6373 # Key to collect Variable objects that keep moving averages. 6374 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 6375 # Key to collect regularization losses at graph construction. 6376 REGULARIZATION_LOSSES = "regularization_losses" 6377 # Key to collect concatenated sharded variables. 6378 CONCATENATED_VARIABLES = "concatenated_variables" 6379 # Key to collect savers. 6380 SAVERS = "savers" 6381 # Key to collect weights 6382 WEIGHTS = "weights" 6383 # Key to collect biases 6384 BIASES = "biases" 6385 # Key to collect activations 6386 ACTIVATIONS = "activations" 6387 # Key to collect update_ops 6388 UPDATE_OPS = "update_ops" 6389 # Key to collect losses 6390 LOSSES = "losses" 6391 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 6392 SAVEABLE_OBJECTS = "saveable_objects" 6393 # Key to collect all shared resources used by the graph which need to be 6394 # initialized once per cluster. 6395 RESOURCES = "resources" 6396 # Key to collect all shared resources used in this graph which need to be 6397 # initialized once per session. 6398 LOCAL_RESOURCES = "local_resources" 6399 # Trainable resource-style variables. 6400 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 6401 6402 # Key to indicate various ops. 6403 INIT_OP = "init_op" 6404 LOCAL_INIT_OP = "local_init_op" 6405 READY_OP = "ready_op" 6406 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 6407 SUMMARY_OP = "summary_op" 6408 GLOBAL_STEP = "global_step" 6409 6410 # Used to count the number of evaluations performed during a single evaluation 6411 # run. 6412 EVAL_STEP = "eval_step" 6413 TRAIN_OP = "train_op" 6414 6415 # Key for control flow context. 6416 COND_CONTEXT = "cond_context" 6417 WHILE_CONTEXT = "while_context" 6418 6419 # Used to store v2 summary names. 6420 _SUMMARY_COLLECTION = "_SUMMARY_V2" 6421 6422 # List of all collections that keep track of variables. 6423 _VARIABLE_COLLECTIONS = [ 6424 GLOBAL_VARIABLES, 6425 LOCAL_VARIABLES, 6426 METRIC_VARIABLES, 6427 MODEL_VARIABLES, 6428 TRAINABLE_VARIABLES, 6429 MOVING_AVERAGE_VARIABLES, 6430 CONCATENATED_VARIABLES, 6431 TRAINABLE_RESOURCE_VARIABLES, 6432 ] 6433 6434 # Key for streaming model ports. 6435 # NOTE(yuanbyu): internal and experimental. 6436 _STREAMING_MODEL_PORTS = "streaming_model_ports" 6437 6438 @decorator_utils.classproperty 6439 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 6440 def VARIABLES(cls): # pylint: disable=no-self-argument 6441 return cls.GLOBAL_VARIABLES 6442 6443 6444def dismantle_graph(graph): 6445 """Cleans up reference cycles from a `Graph`. 6446 6447 Helpful for making sure the garbage collector doesn't need to run after a 6448 temporary `Graph` is no longer needed. 6449 6450 Args: 6451 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 6452 after this function runs. 6453 """ 6454 memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access 6455 6456 # Now clean up Operation<->Graph reference cycles by clearing all of the 6457 # attributes for the Graph and its ops. 6458 graph_operations = graph.get_operations() 6459 for op in graph_operations: 6460 op.__dict__ = {} 6461 graph.__dict__ = {} 6462 6463 6464@tf_export(v1=["add_to_collection"]) 6465def add_to_collection(name, value): 6466 """Wrapper for `Graph.add_to_collection()` using the default graph. 6467 6468 See `tf.Graph.add_to_collection` 6469 for more details. 6470 6471 Args: 6472 name: The key for the collection. For example, the `GraphKeys` class 6473 contains many standard names for collections. 6474 value: The value to add to the collection. 6475 6476 @compatibility(eager) 6477 Collections are only supported in eager when variables are created inside 6478 an EagerVariableStore (e.g. as part of a layer or template). 6479 @end_compatibility 6480 """ 6481 get_default_graph().add_to_collection(name, value) 6482 6483 6484@tf_export(v1=["add_to_collections"]) 6485def add_to_collections(names, value): 6486 """Wrapper for `Graph.add_to_collections()` using the default graph. 6487 6488 See `tf.Graph.add_to_collections` 6489 for more details. 6490 6491 Args: 6492 names: The key for the collections. The `GraphKeys` class contains many 6493 standard names for collections. 6494 value: The value to add to the collections. 6495 6496 @compatibility(eager) 6497 Collections are only supported in eager when variables are created inside 6498 an EagerVariableStore (e.g. as part of a layer or template). 6499 @end_compatibility 6500 """ 6501 get_default_graph().add_to_collections(names, value) 6502 6503 6504@tf_export(v1=["get_collection_ref"]) 6505def get_collection_ref(key): 6506 """Wrapper for `Graph.get_collection_ref()` using the default graph. 6507 6508 See `tf.Graph.get_collection_ref` 6509 for more details. 6510 6511 Args: 6512 key: The key for the collection. For example, the `GraphKeys` class contains 6513 many standard names for collections. 6514 6515 Returns: 6516 The list of values in the collection with the given `name`, or an empty 6517 list if no value has been added to that collection. Note that this returns 6518 the collection list itself, which can be modified in place to change the 6519 collection. 6520 6521 @compatibility(eager) 6522 Collections are not supported when eager execution is enabled. 6523 @end_compatibility 6524 """ 6525 return get_default_graph().get_collection_ref(key) 6526 6527 6528@tf_export(v1=["get_collection"]) 6529def get_collection(key, scope=None): 6530 """Wrapper for `Graph.get_collection()` using the default graph. 6531 6532 See `tf.Graph.get_collection` 6533 for more details. 6534 6535 Args: 6536 key: The key for the collection. For example, the `GraphKeys` class contains 6537 many standard names for collections. 6538 scope: (Optional.) If supplied, the resulting list is filtered to include 6539 only items whose `name` attribute matches using `re.match`. Items without 6540 a `name` attribute are never returned if a scope is supplied and the 6541 choice or `re.match` means that a `scope` without special tokens filters 6542 by prefix. 6543 6544 Returns: 6545 The list of values in the collection with the given `name`, or 6546 an empty list if no value has been added to that collection. The 6547 list contains the values in the order under which they were 6548 collected. 6549 6550 @compatibility(eager) 6551 Collections are not supported when eager execution is enabled. 6552 @end_compatibility 6553 """ 6554 return get_default_graph().get_collection(key, scope) 6555 6556 6557def get_all_collection_keys(): 6558 """Returns a list of collections used in the default graph.""" 6559 return get_default_graph().get_all_collection_keys() 6560 6561 6562def name_scope(name, default_name=None, values=None, skip_on_eager=True): 6563 """Internal-only entry point for `name_scope*`. 6564 6565 Internal ops do not use the public API and instead rely on 6566 `ops.name_scope` regardless of the execution mode. This function 6567 dispatches to the correct `name_scope*` implementation based on 6568 the arguments provided and the current mode. Specifically, 6569 6570 * if `values` contains a graph tensor `Graph.name_scope` is used; 6571 * `name_scope_v1` is used in graph mode; 6572 * `name_scope_v2` -- in eager mode. 6573 6574 Args: 6575 name: The name argument that is passed to the op function. 6576 default_name: The default name to use if the `name` argument is `None`. 6577 values: The list of `Tensor` arguments that are passed to the op function. 6578 skip_on_eager: Indicates to return NullContextmanager if executing eagerly. 6579 By default this is True since naming tensors and operations in eager mode 6580 have little use and cause unnecessary performance overhead. However, it is 6581 important to preserve variable names since they are often useful for 6582 debugging and saved models. 6583 6584 Returns: 6585 `name_scope*` context manager. 6586 """ 6587 if not context.executing_eagerly(): 6588 return internal_name_scope_v1(name, default_name, values) 6589 6590 if skip_on_eager: 6591 return NullContextmanager() 6592 6593 name = default_name if name is None else name 6594 if values: 6595 # The presence of a graph tensor in `values` overrides the context. 6596 # TODO(slebedev): this is Keras-specific and should be removed. 6597 # pylint: disable=unidiomatic-typecheck 6598 graph_value = next((value for value in values if type(value) == Tensor), 6599 None) 6600 # pylint: enable=unidiomatic-typecheck 6601 if graph_value is not None: 6602 return graph_value.graph.name_scope(name) 6603 6604 return name_scope_v2(name or "") 6605 6606 6607class internal_name_scope_v1(object): # pylint: disable=invalid-name 6608 """Graph-only version of `name_scope_v1`.""" 6609 6610 @property 6611 def name(self): 6612 return self._name 6613 6614 def __init__(self, name, default_name=None, values=None): 6615 """Initialize the context manager. 6616 6617 Args: 6618 name: The name argument that is passed to the op function. 6619 default_name: The default name to use if the `name` argument is `None`. 6620 values: The list of `Tensor` arguments that are passed to the op function. 6621 6622 Raises: 6623 TypeError: if `default_name` is passed in but not a string. 6624 """ 6625 if not (default_name is None or isinstance(default_name, six.string_types)): 6626 raise TypeError( 6627 "`default_name` type (%s) is not a string type. You likely meant to " 6628 "pass this into the `values` kwarg." % type(default_name)) 6629 self._name = default_name if name is None else name 6630 self._default_name = default_name 6631 self._values = values 6632 6633 def __enter__(self): 6634 """Start the scope block. 6635 6636 Returns: 6637 The scope name. 6638 6639 Raises: 6640 ValueError: if neither `name` nor `default_name` is provided 6641 but `values` are. 6642 """ 6643 if self._name is None and self._values is not None: 6644 # We only raise an error if values is not None (provided) because 6645 # currently tf.name_scope(None) (values=None then) is sometimes used as 6646 # an idiom to reset to top scope. 6647 raise ValueError( 6648 "At least one of name (%s) and default_name (%s) must be provided." 6649 % (self._name, self._default_name)) 6650 6651 g = get_default_graph() 6652 if self._values and not g.building_function: 6653 # Specialize based on the knowledge that `_get_graph_from_inputs()` 6654 # ignores `inputs` when building a function. 6655 g_from_inputs = _get_graph_from_inputs(self._values) 6656 if g_from_inputs is not g: 6657 g = g_from_inputs 6658 self._g_manager = g.as_default() 6659 self._g_manager.__enter__() 6660 else: 6661 self._g_manager = None 6662 else: 6663 self._g_manager = None 6664 6665 try: 6666 self._name_scope = g.name_scope(self._name) 6667 return self._name_scope.__enter__() 6668 except: 6669 if self._g_manager is not None: 6670 self._g_manager.__exit__(*sys.exc_info()) 6671 raise 6672 6673 def __exit__(self, *exc_info): 6674 self._name_scope.__exit__(*exc_info) 6675 if self._g_manager is not None: 6676 self._g_manager.__exit__(*exc_info) 6677 6678 6679# Named like a function for backwards compatibility with the 6680# @tf_contextlib.contextmanager version, which was switched to a class to avoid 6681# some object creation overhead. 6682@tf_export(v1=["name_scope"]) 6683class name_scope_v1(object): # pylint: disable=invalid-name 6684 """A context manager for use when defining a Python op. 6685 6686 This context manager validates that the given `values` are from the 6687 same graph, makes that graph the default graph, and pushes a 6688 name scope in that graph (see 6689 `tf.Graph.name_scope` 6690 for more details on that). 6691 6692 For example, to define a new Python op called `my_op`: 6693 6694 ```python 6695 def my_op(a, b, c, name=None): 6696 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 6697 a = tf.convert_to_tensor(a, name="a") 6698 b = tf.convert_to_tensor(b, name="b") 6699 c = tf.convert_to_tensor(c, name="c") 6700 # Define some computation that uses `a`, `b`, and `c`. 6701 return foo_op(..., name=scope) 6702 ``` 6703 """ 6704 6705 __slots__ = ["_name", "_name_scope"] 6706 6707 @property 6708 def name(self): 6709 return self._name 6710 6711 def __init__(self, name, default_name=None, values=None): 6712 """Initialize the context manager. 6713 6714 Args: 6715 name: The name argument that is passed to the op function. 6716 default_name: The default name to use if the `name` argument is `None`. 6717 values: The list of `Tensor` arguments that are passed to the op function. 6718 6719 Raises: 6720 TypeError: if `default_name` is passed in but not a string. 6721 """ 6722 self._name_scope = name_scope( 6723 name, default_name, values, skip_on_eager=False) 6724 self._name = default_name if name is None else name 6725 6726 def __enter__(self): 6727 return self._name_scope.__enter__() 6728 6729 def __exit__(self, *exc_info): 6730 return self._name_scope.__exit__(*exc_info) 6731 6732 6733@tf_export("get_current_name_scope", v1=[]) 6734def get_current_name_scope(): 6735 """Returns current full name scope specified by `tf.name_scope(...)`s. 6736 6737 For example, 6738 ```python 6739 with tf.name_scope("outer"): 6740 tf.get_current_name_scope() # "outer" 6741 6742 with tf.name_scope("inner"): 6743 tf.get_current_name_scope() # "outer/inner" 6744 ``` 6745 6746 In other words, `tf.get_current_name_scope()` returns the op name prefix that 6747 will be prepended to, if an op is created at that place. 6748 6749 Note that `@tf.function` resets the name scope stack as shown below. 6750 6751 ``` 6752 with tf.name_scope("outer"): 6753 6754 @tf.function 6755 def foo(x): 6756 with tf.name_scope("inner"): 6757 return tf.add(x * x) # Op name is "inner/Add", not "outer/inner/Add" 6758 ``` 6759 """ 6760 6761 ctx = context.context() 6762 if ctx.executing_eagerly(): 6763 return ctx.scope_name.rstrip("/") 6764 else: 6765 return get_default_graph().get_name_scope() 6766 6767 6768@tf_export("name_scope", v1=[]) 6769class name_scope_v2(object): 6770 """A context manager for use when defining a Python op. 6771 6772 This context manager pushes a name scope, which will make the name of all 6773 operations added within it have a prefix. 6774 6775 For example, to define a new Python op called `my_op`: 6776 6777 ```python 6778 def my_op(a, b, c, name=None): 6779 with tf.name_scope("MyOp") as scope: 6780 a = tf.convert_to_tensor(a, name="a") 6781 b = tf.convert_to_tensor(b, name="b") 6782 c = tf.convert_to_tensor(c, name="c") 6783 # Define some computation that uses `a`, `b`, and `c`. 6784 return foo_op(..., name=scope) 6785 ``` 6786 6787 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 6788 and `MyOp/c`. 6789 6790 Inside a `tf.function`, if the scope name already exists, the name will be 6791 made unique by appending `_n`. For example, calling `my_op` the second time 6792 will generate `MyOp_1/a`, etc. 6793 """ 6794 6795 __slots__ = ["_name", "_exit_fns"] 6796 6797 def __init__(self, name): 6798 """Initialize the context manager. 6799 6800 Args: 6801 name: The prefix to use on all names created within the name scope. 6802 6803 Raises: 6804 ValueError: If name is not a string. 6805 """ 6806 if not isinstance(name, six.string_types): 6807 raise ValueError("name for name_scope must be a string.") 6808 self._name = name 6809 self._exit_fns = [] 6810 6811 @property 6812 def name(self): 6813 return self._name 6814 6815 def __enter__(self): 6816 """Start the scope block. 6817 6818 Returns: 6819 The scope name. 6820 """ 6821 ctx = context.context() 6822 if ctx.executing_eagerly(): 6823 # Names are not auto-incremented in eager mode. 6824 # A trailing slash breaks out of nested name scopes, indicating a 6825 # fully specified scope name, for compatibility with Graph.name_scope. 6826 # This also prevents auto-incrementing. 6827 old_name = ctx.scope_name 6828 name = self._name 6829 if not name: 6830 scope_name = "" 6831 elif name[-1] == "/": 6832 scope_name = name 6833 elif old_name: 6834 scope_name = old_name + name + "/" 6835 else: 6836 scope_name = name + "/" 6837 ctx.scope_name = scope_name 6838 6839 def _restore_name_scope(*_): 6840 ctx.scope_name = old_name 6841 6842 self._exit_fns.append(_restore_name_scope) 6843 else: 6844 scope = get_default_graph().name_scope(self._name) 6845 scope_name = scope.__enter__() 6846 self._exit_fns.append(scope.__exit__) 6847 return scope_name 6848 6849 def __exit__(self, type_arg, value_arg, traceback_arg): 6850 self._exit_fns.pop()(type_arg, value_arg, traceback_arg) 6851 return False # False values do not suppress exceptions 6852 6853 def __getstate__(self): 6854 return self._name, self._exit_fns 6855 6856 def __setstate__(self, state): 6857 self._name = state[0] 6858 self._exit_fns = state[1] 6859 6860 6861def strip_name_scope(name, export_scope): 6862 """Removes name scope from a name. 6863 6864 Args: 6865 name: A `string` name. 6866 export_scope: Optional `string`. Name scope to remove. 6867 6868 Returns: 6869 Name with name scope removed, or the original name if export_scope 6870 is None. 6871 """ 6872 if export_scope: 6873 if export_scope[-1] == "/": 6874 export_scope = export_scope[:-1] 6875 6876 try: 6877 # Strips export_scope/, export_scope///, 6878 # ^export_scope/, loc:@export_scope/. 6879 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 6880 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 6881 except TypeError as e: 6882 # If the name is not of a type we can process, simply return it. 6883 logging.warning(e) 6884 return name 6885 else: 6886 return name 6887 6888 6889def prepend_name_scope(name, import_scope): 6890 """Prepends name scope to a name. 6891 6892 Args: 6893 name: A `string` name. 6894 import_scope: Optional `string`. Name scope to add. 6895 6896 Returns: 6897 Name with name scope added, or the original name if import_scope 6898 is None. 6899 """ 6900 if import_scope: 6901 if import_scope[-1] == "/": 6902 import_scope = import_scope[:-1] 6903 6904 try: 6905 str_to_replace = r"([\^]|loc:@|^)(.*)" 6906 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 6907 compat.as_str(name)) 6908 except TypeError as e: 6909 # If the name is not of a type we can process, simply return it. 6910 logging.warning(e) 6911 return name 6912 else: 6913 return name 6914 6915 6916# pylint: disable=g-doc-return-or-yield 6917# pylint: disable=not-context-manager 6918@tf_export(v1=["op_scope"]) 6919@tf_contextlib.contextmanager 6920def op_scope(values, name, default_name=None): 6921 """DEPRECATED. Same as name_scope above, just different argument order.""" 6922 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 6923 " use tf.name_scope(name, default_name, values)") 6924 with name_scope(name, default_name=default_name, values=values) as scope: 6925 yield scope 6926 6927 6928_proto_function_registry = registry.Registry("proto functions") 6929 6930 6931def register_proto_function(collection_name, 6932 proto_type=None, 6933 to_proto=None, 6934 from_proto=None): 6935 """Registers `to_proto` and `from_proto` functions for collection_name. 6936 6937 `to_proto` function converts a Python object to the corresponding protocol 6938 buffer, and returns the protocol buffer. 6939 6940 `from_proto` function converts protocol buffer into a Python object, and 6941 returns the object.. 6942 6943 Args: 6944 collection_name: Name of the collection. 6945 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 6946 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 6947 to_proto: Function that implements Python object to protobuf conversion. 6948 from_proto: Function that implements protobuf to Python object conversion. 6949 """ 6950 if to_proto and not callable(to_proto): 6951 raise TypeError("to_proto must be callable.") 6952 if from_proto and not callable(from_proto): 6953 raise TypeError("from_proto must be callable.") 6954 6955 _proto_function_registry.register((proto_type, to_proto, from_proto), 6956 collection_name) 6957 6958 6959def get_collection_proto_type(collection_name): 6960 """Returns the proto_type for collection_name.""" 6961 try: 6962 return _proto_function_registry.lookup(collection_name)[0] 6963 except LookupError: 6964 return None 6965 6966 6967def get_to_proto_function(collection_name): 6968 """Returns the to_proto function for collection_name.""" 6969 try: 6970 return _proto_function_registry.lookup(collection_name)[1] 6971 except LookupError: 6972 return None 6973 6974 6975def get_from_proto_function(collection_name): 6976 """Returns the from_proto function for collection_name.""" 6977 try: 6978 return _proto_function_registry.lookup(collection_name)[2] 6979 except LookupError: 6980 return None 6981 6982 6983def _op_to_colocate_with(v, graph): 6984 """Operation object corresponding to v to use for colocation constraints.""" 6985 if v is None: 6986 return None, None 6987 if isinstance(v, Operation): 6988 return v, None 6989 6990 # We always want to colocate with the reference op. 6991 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 6992 # 6993 # What this should be is: 6994 # if isinstance(v, ResourceVariable): 6995 # return v.handle.op, v 6996 # However, that would require a circular import dependency. 6997 # As of October 2018, there were attempts underway to remove 6998 # colocation constraints altogether. Assuming that will 6999 # happen soon, perhaps this hack to work around the circular 7000 # import dependency is acceptable. 7001 if hasattr(v, "handle") and isinstance(v.handle, Tensor): 7002 device_only_candidate = lambda: None 7003 device_only_candidate.device = v.device 7004 device_only_candidate.name = v.name 7005 if graph.building_function: 7006 return graph.capture(v.handle).op, device_only_candidate 7007 else: 7008 return v.handle.op, device_only_candidate 7009 return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None 7010 7011 7012def _is_keras_symbolic_tensor(x): 7013 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" 7014 7015 7016# These symbols were originally defined in this module; import them for 7017# backwards compatibility until all references have been updated to access 7018# them from the indexed_slices.py module. 7019IndexedSlices = indexed_slices.IndexedSlices 7020IndexedSlicesValue = indexed_slices.IndexedSlicesValue 7021convert_to_tensor_or_indexed_slices = \ 7022 indexed_slices.convert_to_tensor_or_indexed_slices 7023convert_n_to_tensor_or_indexed_slices = \ 7024 indexed_slices.convert_n_to_tensor_or_indexed_slices 7025internal_convert_to_tensor_or_indexed_slices = \ 7026 indexed_slices.internal_convert_to_tensor_or_indexed_slices 7027internal_convert_n_to_tensor_or_indexed_slices = \ 7028 indexed_slices.internal_convert_n_to_tensor_or_indexed_slices 7029register_tensor_conversion_function = \ 7030 tensor_conversion_registry.register_tensor_conversion_function 7031 7032 7033# Helper functions for op wrapper modules generated by `python_op_gen`. 7034 7035 7036def to_raw_op(f): 7037 """Make a given op wrapper function `f` raw. 7038 7039 Raw op wrappers can only be called with keyword arguments. 7040 7041 Args: 7042 f: An op wrapper function to make raw. 7043 7044 Returns: 7045 Raw `f`. 7046 """ 7047 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail 7048 # due to double-registration. 7049 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, 7050 f.__closure__) 7051 return kwarg_only(f) 7052 7053 7054def raise_from_not_ok_status(e, name): 7055 message = e.message + (" name: " + name if name is not None else "") 7056 raise core._status_to_exception(e.code, message) from None # pylint: disable=protected-access 7057 7058 7059def add_exit_callback_to_default_func_graph(fn): 7060 """Add a callback to run when the default function graph goes out of scope. 7061 7062 Usage: 7063 7064 ```python 7065 @tf.function 7066 def fn(x, v): 7067 expensive = expensive_object(v) 7068 add_exit_callback_to_default_func_graph(lambda: expensive.release()) 7069 return g(x, expensive) 7070 7071 fn(x=tf.constant(...), v=...) 7072 # `expensive` has been released. 7073 ``` 7074 7075 Args: 7076 fn: A callable that takes no arguments and whose output is ignored. 7077 To be executed when exiting func graph scope. 7078 7079 Raises: 7080 RuntimeError: If executed when the current default graph is not a FuncGraph, 7081 or not currently executing in function creation mode (e.g., if inside 7082 an init_scope). 7083 """ 7084 default_graph = get_default_graph() 7085 if not default_graph._building_function: # pylint: disable=protected-access 7086 raise RuntimeError( 7087 "Cannot add scope exit callbacks when not building a function. " 7088 "Default graph: {}".format(default_graph)) 7089 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access 7090 7091 7092def _reconstruct_sequence_inputs(op_def, inputs, attrs): 7093 """Regroups a flat list of input tensors into scalar and sequence inputs. 7094 7095 Args: 7096 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 7097 inputs: a list of input `Tensor`s to the op. 7098 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 7099 how long each sequence is) 7100 7101 Returns: 7102 A list of `Tensor`s (corresponding to scalar inputs) and lists of 7103 `Tensor`s (corresponding to sequence inputs). 7104 """ 7105 grouped_inputs = [] 7106 i = 0 7107 for input_arg in op_def.input_arg: 7108 if input_arg.number_attr: 7109 input_len = attrs[input_arg.number_attr].i 7110 is_sequence = True 7111 elif input_arg.type_list_attr: 7112 input_len = len(attrs[input_arg.type_list_attr].list.type) 7113 is_sequence = True 7114 else: 7115 input_len = 1 7116 is_sequence = False 7117 7118 if is_sequence: 7119 grouped_inputs.append(inputs[i:i + input_len]) 7120 else: 7121 grouped_inputs.append(inputs[i]) 7122 i += input_len 7123 7124 assert i == len(inputs) 7125 return grouped_inputs 7126 7127 7128_numpy_style_type_promotion = False 7129 7130 7131def enable_numpy_style_type_promotion(): 7132 """If called, follows NumPy's rules for type promotion. 7133 7134 Used for enabling NumPy behavior on methods for TF NumPy. 7135 """ 7136 global _numpy_style_type_promotion 7137 _numpy_style_type_promotion = True 7138 7139 7140_numpy_style_slicing = False 7141 7142 7143def enable_numpy_style_slicing(): 7144 """If called, follows NumPy's rules for slicing Tensors. 7145 7146 Used for enabling NumPy behavior on slicing for TF NumPy. 7147 """ 7148 global _numpy_style_slicing 7149 _numpy_style_slicing = True 7150 7151 7152class _TensorIterator(object): 7153 """Iterates over the leading dim of a Tensor. Performs no error checks.""" 7154 7155 __slots__ = ["_tensor", "_index", "_limit"] 7156 7157 def __init__(self, tensor, dim0): 7158 self._tensor = tensor 7159 self._index = 0 7160 self._limit = dim0 7161 7162 def __iter__(self): 7163 return self 7164 7165 def __next__(self): 7166 if self._index == self._limit: 7167 raise StopIteration 7168 result = self._tensor[self._index] 7169 self._index += 1 7170 return result 7171 7172 next = __next__ # python2.x compatibility. 7173 7174 7175def set_int_list_attr(op, attr_name, ints): 7176 """TF internal method used to set a list(int) attribute in the node_def.""" 7177 ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) 7178 op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access 7179 7180 7181def _get_enclosing_context(graph): 7182 # pylint: disable=protected-access 7183 if graph is None: 7184 return None 7185 7186 if graph._control_flow_context is not None: 7187 return graph._control_flow_context 7188 7189 if graph.building_function and hasattr(graph, "outer_graph"): 7190 return _get_enclosing_context(graph.outer_graph) 7191 7192 7193def get_resource_handle_data(graph_op): 7194 assert type(graph_op) == Tensor # pylint: disable=unidiomatic-typecheck 7195 7196 handle_data = pywrap_tf_session.GetHandleShapeAndType( 7197 graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access 7198 7199 return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( 7200 compat.as_bytes(handle_data)) 7201 7202 7203def _copy_handle_data_to_arg_def(tensor, arg_def): 7204 handle_data = get_resource_handle_data(tensor) 7205 if handle_data.shape_and_type: 7206 shape_and_type = handle_data.shape_and_type[0] 7207 proto = arg_def.handle_data.add() 7208 proto.dtype = shape_and_type.dtype 7209 proto.shape.CopyFrom(handle_data.shape_and_type[0].shape) 7210