1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions used to construct graphs.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23import sys 24import threading 25import types 26 27import numpy as np 28import six 29from six.moves import map # pylint: disable=redefined-builtin 30from six.moves import xrange # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.framework import function_pb2 34from tensorflow.core.framework import graph_pb2 35from tensorflow.core.framework import node_def_pb2 36from tensorflow.core.framework import op_def_pb2 37from tensorflow.core.framework import versions_pb2 38from tensorflow.core.protobuf import config_pb2 39# pywrap_tensorflow must be imported first to avoid profobuf issues. 40# (b/143110113) 41# pylint: disable=invalid-import-order,g-bad-import-order,unused-import 42from tensorflow.python import pywrap_tensorflow 43from tensorflow.python import pywrap_tfe 44# pylint: enable=invalid-import-order,g-bad-import-order,unused-import 45from tensorflow.python import tf2 46from tensorflow.python.client import pywrap_tf_session 47from tensorflow.python.eager import context 48from tensorflow.python.eager import core 49from tensorflow.python.eager import monitoring 50from tensorflow.python.eager import tape 51from tensorflow.python.framework import c_api_util 52from tensorflow.python.framework import composite_tensor 53from tensorflow.python.framework import cpp_shape_inference_pb2 54from tensorflow.python.framework import device as pydev 55from tensorflow.python.framework import dtypes 56from tensorflow.python.framework import errors 57from tensorflow.python.framework import indexed_slices 58from tensorflow.python.framework import registry 59from tensorflow.python.framework import tensor_conversion_registry 60from tensorflow.python.framework import tensor_shape 61from tensorflow.python.framework import traceable_stack 62from tensorflow.python.framework import versions 63from tensorflow.python.ops import control_flow_util 64from tensorflow.python.platform import app 65from tensorflow.python.platform import tf_logging as logging 66from tensorflow.python.profiler import trace 67from tensorflow.python.types import core as core_tf_types 68from tensorflow.python.types import internal 69from tensorflow.python.util import compat 70from tensorflow.python.util import decorator_utils 71from tensorflow.python.util import deprecation 72from tensorflow.python.util import dispatch 73from tensorflow.python.util import function_utils 74from tensorflow.python.util import lock_util 75from tensorflow.python.util import memory 76from tensorflow.python.util import object_identity 77from tensorflow.python.util import tf_contextlib 78from tensorflow.python.util import tf_stack 79from tensorflow.python.util.compat import collections_abc 80from tensorflow.python.util.deprecation import deprecated_args 81from tensorflow.python.util.lazy_loader import LazyLoader 82from tensorflow.python.util.tf_export import kwarg_only 83from tensorflow.python.util.tf_export import tf_export 84 85ag_ctx = LazyLoader( 86 "ag_ctx", globals(), 87 "tensorflow.python.autograph.core.ag_ctx") 88 89 90# Temporary global switches determining if we should enable the work-in-progress 91# calls to the C API. These will be removed once all functionality is supported. 92_USE_C_API = True 93_USE_C_SHAPES = True 94 95_api_usage_gauge = monitoring.BoolGauge( 96 "/tensorflow/api/ops_eager_execution", 97 "Whether ops.enable_eager_execution() is called.") 98 99_tensor_equality_api_usage_gauge = monitoring.BoolGauge( 100 "/tensorflow/api/enable_tensor_equality", 101 "Whether ops.enable_tensor_equality() is called.") 102 103_control_flow_api_gauge = monitoring.BoolGauge( 104 "/tensorflow/api/enable_control_flow_v2", 105 "Whether enable_control_flow_v2() is called.") 106 107_tf_function_api_guage = monitoring.BoolGauge( 108 "/tensorflow/api/tf_function", 109 "Whether tf.function() is used.") 110 111# pylint: disable=protected-access 112_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE 113# pylint: enable=protected-access 114 115 116def tensor_id(tensor): 117 """Returns a unique identifier for this Tensor.""" 118 return tensor._id # pylint: disable=protected-access 119 120 121class _UserDeviceSpec(object): 122 """Store user-specified device and provide computation of merged device.""" 123 124 def __init__(self, device_name_or_function): 125 self._device_name_or_function = device_name_or_function 126 self.display_name = str(self._device_name_or_function) 127 self.function = device_name_or_function 128 self.raw_string = None 129 130 if isinstance(device_name_or_function, pydev.MergeDevice): 131 self.is_null_merge = device_name_or_function.is_null_merge 132 133 elif callable(device_name_or_function): 134 self.is_null_merge = False 135 dev_func = self._device_name_or_function 136 func_name = function_utils.get_func_name(dev_func) 137 func_code = function_utils.get_func_code(dev_func) 138 if func_code: 139 fname = func_code.co_filename 140 lineno = func_code.co_firstlineno 141 else: 142 fname = "unknown" 143 lineno = -1 144 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 145 146 elif device_name_or_function is None: 147 # NOTE(taylorrobie): This MUST be False. None signals a break in the 148 # device stack, so `is_null_merge` must be False for such a case to 149 # allow callers to safely skip over null merges without missing a None. 150 self.is_null_merge = False 151 152 else: 153 self.raw_string = device_name_or_function 154 self.function = pydev.merge_device(device_name_or_function) 155 self.is_null_merge = self.function.is_null_merge 156 157 # We perform this check in __init__ because it is of non-trivial cost, 158 # and self.string_merge is typically called many times. 159 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) 160 161 def string_merge(self, node_def): 162 if self.fast_string_merge: 163 return self.function.shortcut_string_merge(node_def) 164 165 return compat.as_str(_device_string(self.function(node_def))) 166 167 168class NullContextmanager(object): 169 170 def __init__(self, *args, **kwargs): 171 pass 172 173 def __enter__(self): 174 pass 175 176 def __exit__(self, type_arg, value_arg, traceback_arg): 177 return False # False values do not suppress exceptions 178 179 180def _override_helper(clazz_object, operator, func): 181 """Overrides (string) operator on Tensors to call func. 182 183 Args: 184 clazz_object: the class to override for; either Tensor or SparseTensor. 185 operator: the string name of the operator to override. 186 func: the function that replaces the overridden operator. 187 188 Raises: 189 ValueError: If operator is not allowed to be overwritten. 190 """ 191 if operator not in Tensor.OVERLOADABLE_OPERATORS: 192 raise ValueError("Overriding %s is disallowed" % operator) 193 setattr(clazz_object, operator, func) 194 195 196def _as_graph_element(obj): 197 """Convert `obj` to a graph element if possible, otherwise return `None`. 198 199 Args: 200 obj: Object to convert. 201 202 Returns: 203 The result of `obj._as_graph_element()` if that method is available; 204 otherwise `None`. 205 """ 206 conv_fn = getattr(obj, "_as_graph_element", None) 207 if conv_fn and callable(conv_fn): 208 return conv_fn() 209 return None 210 211 212# Deprecated - do not use. 213# This API to avoid breaking estimator and tensorflow-mesh which depend on this 214# internal API. The stub should be safe to use after TF 2.3 is released. 215def is_dense_tensor_like(t): 216 return isinstance(t, core_tf_types.Tensor) 217 218 219def uid(): 220 """A unique (within this program execution) integer.""" 221 return pywrap_tfe.TFE_Py_UID() 222 223 224def numpy_text(tensor, is_repr=False): 225 """Human readable representation of a tensor's numpy value.""" 226 if tensor.dtype.is_numpy_compatible: 227 # pylint: disable=protected-access 228 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 229 # pylint: enable=protected-access 230 else: 231 text = "<unprintable>" 232 if "\n" in text: 233 text = "\n" + text 234 return text 235 236@tf_export(v1=["enable_tensor_equality"]) 237def enable_tensor_equality(): 238 """Compare Tensors with element-wise comparison and thus be unhashable. 239 240 Comparing tensors with element-wise allows comparisons such as 241 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are 242 unhashable. Thus tensors can no longer be directly used in sets or as a key in 243 a dictionary. 244 """ 245 _tensor_equality_api_usage_gauge.get_cell().set(True) 246 Tensor._USE_EQUALITY = True # pylint: disable=protected-access 247 248 249@tf_export(v1=["disable_tensor_equality"]) 250def disable_tensor_equality(): 251 """Compare Tensors by their id and be hashable. 252 253 This is a legacy behaviour of TensorFlow and is highly discouraged. 254 """ 255 _tensor_equality_api_usage_gauge.get_cell().set(False) 256 Tensor._USE_EQUALITY = False # pylint: disable=protected-access 257 258 259# TODO(mdan): This object should subclass Symbol, not just Tensor. 260@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"]) 261class Tensor(internal.NativeObject, core_tf_types.Tensor): 262 """A tensor is a multidimensional array of elements represented by a 263 264 `tf.Tensor` object. All elements are of a single known data type. 265 266 When writing a TensorFlow program, the main object that is 267 manipulated and passed around is the `tf.Tensor`. 268 269 A `tf.Tensor` has the following properties: 270 271 * a single data type (float32, int32, or string, for example) 272 * a shape 273 274 TensorFlow supports eager execution and graph execution. In eager 275 execution, operations are evaluated immediately. In graph 276 execution, a computational graph is constructed for later 277 evaluation. 278 279 TensorFlow defaults to eager execution. In the example below, the 280 matrix multiplication results are calculated immediately. 281 282 >>> # Compute some values using a Tensor 283 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 284 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 285 >>> e = tf.matmul(c, d) 286 >>> print(e) 287 tf.Tensor( 288 [[1. 3.] 289 [3. 7.]], shape=(2, 2), dtype=float32) 290 291 Note that during eager execution, you may discover your `Tensors` are actually 292 of type `EagerTensor`. This is an internal detail, but it does give you 293 access to a useful function, `numpy`: 294 295 >>> type(e) 296 <class '...ops.EagerTensor'> 297 >>> print(e.numpy()) 298 [[1. 3.] 299 [3. 7.]] 300 301 In TensorFlow, `tf.function`s are a common way to define graph execution. 302 303 A Tensor's shape (that is, the rank of the Tensor and the size of 304 each dimension) may not always be fully known. In `tf.function` 305 definitions, the shape may only be partially known. 306 307 Most operations produce tensors of fully-known shapes if the shapes of their 308 inputs are also fully known, but in some cases it's only possible to find the 309 shape of a tensor at execution time. 310 311 A number of specialized tensors are available: see `tf.Variable`, 312 `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and 313 `tf.RaggedTensor`. 314 315 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor). 316 317 """ 318 319 # List of Python operators that we allow to override. 320 OVERLOADABLE_OPERATORS = { 321 # Binary. 322 "__add__", 323 "__radd__", 324 "__sub__", 325 "__rsub__", 326 "__mul__", 327 "__rmul__", 328 "__div__", 329 "__rdiv__", 330 "__truediv__", 331 "__rtruediv__", 332 "__floordiv__", 333 "__rfloordiv__", 334 "__mod__", 335 "__rmod__", 336 "__lt__", 337 "__le__", 338 "__gt__", 339 "__ge__", 340 "__ne__", 341 "__eq__", 342 "__and__", 343 "__rand__", 344 "__or__", 345 "__ror__", 346 "__xor__", 347 "__rxor__", 348 "__getitem__", 349 "__pow__", 350 "__rpow__", 351 # Unary. 352 "__invert__", 353 "__neg__", 354 "__abs__", 355 "__matmul__", 356 "__rmatmul__" 357 } 358 359 # Whether to allow hashing or numpy-style equality 360 _USE_EQUALITY = tf2.enabled() 361 362 def __init__(self, op, value_index, dtype): 363 """Creates a new `Tensor`. 364 365 Args: 366 op: An `Operation`. `Operation` that computes this tensor. 367 value_index: An `int`. Index of the operation's endpoint that produces 368 this tensor. 369 dtype: A `DType`. Type of elements stored in this tensor. 370 371 Raises: 372 TypeError: If the op is not an `Operation`. 373 """ 374 if not isinstance(op, Operation): 375 raise TypeError("op needs to be an Operation: %s" % (op,)) 376 self._op = op 377 self._value_index = value_index 378 self._dtype = dtypes.as_dtype(dtype) 379 # This will be set by self._as_tf_output(). 380 self._tf_output = None 381 # This will be set by self.shape(). 382 self._shape_val = None 383 # List of operations that use this Tensor as input. We maintain this list 384 # to easily navigate a computation graph. 385 self._consumers = [] 386 self._id = uid() 387 self._name = None 388 389 def __getattr__(self, name): 390 if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size", 391 "tolist", "data"}: 392 # TODO(wangpeng): Export the enable_numpy_behavior knob 393 raise AttributeError(""" 394 If you are looking for numpy-related methods, please run the following: 395 import tensorflow.python.ops.numpy_ops.np_config 396 np_config.enable_numpy_behavior()""") 397 self.__getattribute__(name) 398 399 @staticmethod 400 def _create_with_tf_output(op, value_index, dtype, tf_output): 401 ret = Tensor(op, value_index, dtype) 402 ret._tf_output = tf_output 403 return ret 404 405 @property 406 def op(self): 407 """The `Operation` that produces this tensor as an output.""" 408 return self._op 409 410 @property 411 def dtype(self): 412 """The `DType` of elements in this tensor.""" 413 return self._dtype 414 415 @property 416 def graph(self): 417 """The `Graph` that contains this tensor.""" 418 return self._op.graph 419 420 @property 421 def name(self): 422 """The string name of this tensor.""" 423 if self._name is None: 424 if not self._op.name: 425 raise ValueError("Operation was not named: %s" % self._op) 426 self._name = "%s:%d" % (self._op.name, self._value_index) 427 return self._name 428 429 @property 430 def device(self): 431 """The name of the device on which this tensor will be produced, or None.""" 432 return self._op.device 433 434 @property 435 def shape(self): 436 """Returns a `tf.TensorShape` that represents the shape of this tensor. 437 438 >>> t = tf.constant([1,2,3,4,5]) 439 >>> t.shape 440 TensorShape([5]) 441 442 `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`. 443 444 In a `tf.function` or when building a model using 445 `tf.keras.Input`, they return the build-time shape of the 446 tensor, which may be partially unknown. 447 448 A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor 449 containing the shape, calculated at runtime. 450 451 See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples. 452 """ 453 if self._shape_val is None: 454 self._shape_val = self._c_api_shape() 455 return self._shape_val 456 457 def _c_api_shape(self): 458 """Returns the TensorShape of this tensor according to the C API.""" 459 c_graph = self._op._graph._c_graph # pylint: disable=protected-access 460 shape_vec, unknown_shape = pywrap_tf_session.TF_GraphGetTensorShapeHelper( 461 c_graph, self._as_tf_output()) 462 if unknown_shape: 463 return tensor_shape.unknown_shape() 464 else: 465 shape_vec = [None if d == -1 else d for d in shape_vec] 466 return tensor_shape.TensorShape(shape_vec) 467 468 @property 469 def _shape(self): 470 logging.warning("Tensor._shape is private, use Tensor.shape " 471 "instead. Tensor._shape will eventually be removed.") 472 return self.shape 473 474 @_shape.setter 475 def _shape(self, value): 476 raise ValueError( 477 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 478 479 def _disallow_when_autograph_disabled(self, task): 480 raise errors.OperatorNotAllowedInGraphError( 481 "{} is not allowed: AutoGraph is disabled in this function." 482 " Try decorating it directly with @tf.function.".format(task)) 483 484 def _disallow_when_autograph_enabled(self, task): 485 raise errors.OperatorNotAllowedInGraphError( 486 "{} is not allowed: AutoGraph did convert this function. This might" 487 " indicate you are trying to use an unsupported feature.".format(task)) 488 489 def _disallow_in_graph_mode(self, task): 490 raise errors.OperatorNotAllowedInGraphError( 491 "{} is not allowed in Graph execution. Use Eager execution or decorate" 492 " this function with @tf.function.".format(task)) 493 494 def _disallow_bool_casting(self): 495 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 496 self._disallow_when_autograph_disabled( 497 "using a `tf.Tensor` as a Python `bool`") 498 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 499 self._disallow_when_autograph_enabled( 500 "using a `tf.Tensor` as a Python `bool`") 501 else: 502 # Default: V1-style Graph execution. 503 self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`") 504 505 def _disallow_iteration(self): 506 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 507 self._disallow_when_autograph_disabled("iterating over `tf.Tensor`") 508 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 509 self._disallow_when_autograph_enabled("iterating over `tf.Tensor`") 510 else: 511 # Default: V1-style Graph execution. 512 self._disallow_in_graph_mode("iterating over `tf.Tensor`") 513 514 def __iter__(self): 515 if not context.executing_eagerly(): 516 self._disallow_iteration() 517 518 shape = self._shape_tuple() 519 if shape is None: 520 raise TypeError("Cannot iterate over a tensor with unknown shape.") 521 if not shape: 522 raise TypeError("Cannot iterate over a scalar tensor.") 523 if shape[0] is None: 524 raise TypeError( 525 "Cannot iterate over a tensor with unknown first dimension.") 526 return _TensorIterator(self, shape[0]) 527 528 def _shape_as_list(self): 529 if self.shape.ndims is not None: 530 return [dim.value for dim in self.shape.dims] 531 else: 532 return None 533 534 def _shape_tuple(self): 535 shape = self._shape_as_list() 536 if shape is None: 537 return None 538 return tuple(shape) 539 540 def _rank(self): 541 """Integer rank of this Tensor, if known, else None. 542 543 Returns: 544 Integer rank or None 545 """ 546 return self.shape.ndims 547 548 def get_shape(self): 549 """Returns a `tf.TensorShape` that represents the shape of this tensor. 550 551 In eager execution the shape is always fully-known. 552 553 >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 554 >>> print(a.shape) 555 (2, 3) 556 557 `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`. 558 559 560 When executing in a `tf.function` or building a model using 561 `tf.keras.Input`, `Tensor.shape` may return a partial shape (including 562 `None` for unknown dimensions). See `tf.TensorShape` for more details. 563 564 >>> inputs = tf.keras.Input(shape = [10]) 565 >>> # Unknown batch size 566 >>> print(inputs.shape) 567 (None, 10) 568 569 The shape is computed using shape inference functions that are 570 registered for each `tf.Operation`. 571 572 The returned `tf.TensorShape` is determined at *build* time, without 573 executing the underlying kernel. It is not a `tf.Tensor`. If you need a 574 shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or 575 use the `tf.shape(tensor)` function, which returns the tensor's shape at 576 *execution* time. 577 578 This is useful for debugging and providing early errors. For 579 example, when tracing a `tf.function`, no ops are being executed, shapes 580 may be unknown (See the [Concrete Functions 581 Guide](https://www.tensorflow.org/guide/concrete_function) for details). 582 583 >>> @tf.function 584 ... def my_matmul(a, b): 585 ... result = a@b 586 ... # the `print` executes during tracing. 587 ... print("Result shape: ", result.shape) 588 ... return result 589 590 The shape inference functions propagate shapes to the extent possible: 591 592 >>> f = my_matmul.get_concrete_function( 593 ... tf.TensorSpec([None,3]), 594 ... tf.TensorSpec([3,5])) 595 Result shape: (None, 5) 596 597 Tracing may fail if a shape missmatch can be detected: 598 599 >>> cf = my_matmul.get_concrete_function( 600 ... tf.TensorSpec([None,3]), 601 ... tf.TensorSpec([4,5])) 602 Traceback (most recent call last): 603 ... 604 ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op: 605 'MatMul') with input shapes: [?,3], [4,5]. 606 607 In some cases, the inferred shape may have unknown dimensions. If 608 the caller has additional information about the values of these 609 dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment 610 the inferred shape. 611 612 >>> @tf.function 613 ... def my_fun(a): 614 ... a = tf.ensure_shape(a, [5, 5]) 615 ... # the `print` executes during tracing. 616 ... print("Result shape: ", a.shape) 617 ... return a 618 619 >>> cf = my_fun.get_concrete_function( 620 ... tf.TensorSpec([None, None])) 621 Result shape: (5, 5) 622 623 Returns: 624 A `tf.TensorShape` representing the shape of this tensor. 625 626 """ 627 return self.shape 628 629 def set_shape(self, shape): 630 """Updates the shape of this tensor. 631 632 Note: It is recommended to use `tf.ensure_shape` instead of 633 `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for 634 programming errors and can create guarantees for compiler 635 optimization. 636 637 With eager execution this operates as a shape assertion. 638 Here the shapes match: 639 640 >>> t = tf.constant([[1,2,3]]) 641 >>> t.set_shape([1, 3]) 642 643 Passing a `None` in the new shape allows any value for that axis: 644 645 >>> t.set_shape([1,None]) 646 647 An error is raised if an incompatible shape is passed. 648 649 >>> t.set_shape([1,5]) 650 Traceback (most recent call last): 651 ... 652 ValueError: Tensor's shape (1, 3) is not compatible with supplied 653 shape [1, 5] 654 655 When executing in a `tf.function`, or building a model using 656 `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with 657 the current shape of this tensor, and set the tensor's shape to the 658 merged value (see `tf.TensorShape.merge_with` for details): 659 660 >>> t = tf.keras.Input(shape=[None, None, 3]) 661 >>> print(t.shape) 662 (None, None, None, 3) 663 664 Dimensions set to `None` are not updated: 665 666 >>> t.set_shape([None, 224, 224, None]) 667 >>> print(t.shape) 668 (None, 224, 224, 3) 669 670 The main use case for this is to provide additional shape information 671 that cannot be inferred from the graph alone. 672 673 For example if you know all the images in a dataset have shape [28,28,3] you 674 can set it with `tf.set_shape`: 675 676 >>> @tf.function 677 ... def load_image(filename): 678 ... raw = tf.io.read_file(filename) 679 ... image = tf.image.decode_png(raw, channels=3) 680 ... # the `print` executes during tracing. 681 ... print("Initial shape: ", image.shape) 682 ... image.set_shape([28, 28, 3]) 683 ... print("Final shape: ", image.shape) 684 ... return image 685 686 Trace the function, see the [Concrete Functions 687 Guide](https://www.tensorflow.org/guide/concrete_function) for details. 688 689 >>> cf = load_image.get_concrete_function( 690 ... tf.TensorSpec([], dtype=tf.string)) 691 Initial shape: (None, None, 3) 692 Final shape: (28, 28, 3) 693 694 Similarly the `tf.io.parse_tensor` function could return a tensor with 695 any shape, even the `tf.rank` is unknown. If you know that all your 696 serialized tensors will be 2d, set it with `set_shape`: 697 698 >>> @tf.function 699 ... def my_parse(string_tensor): 700 ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32) 701 ... # the `print` executes during tracing. 702 ... print("Initial shape: ", result.shape) 703 ... result.set_shape([None, None]) 704 ... print("Final shape: ", result.shape) 705 ... return result 706 707 Trace the function 708 709 >>> concrete_parse = my_parse.get_concrete_function( 710 ... tf.TensorSpec([], dtype=tf.string)) 711 Initial shape: <unknown> 712 Final shape: (None, None) 713 714 Make sure it works: 715 716 >>> t = tf.ones([5,3], dtype=tf.float32) 717 >>> serialized = tf.io.serialize_tensor(t) 718 >>> print(serialized.dtype) 719 <dtype: 'string'> 720 >>> print(serialized.shape) 721 () 722 >>> t2 = concrete_parse(serialized) 723 >>> print(t2.shape) 724 (5, 3) 725 726 Caution: `set_shape` ensures that the applied shape is compatible with 727 the existing shape, but it does not check at runtime. Setting 728 incorrect shapes can result in inconsistencies between the 729 statically-known graph and the runtime value of tensors. For runtime 730 validation of the shape, use `tf.ensure_shape` instead. It also modifies 731 the `shape` of the tensor. 732 733 >>> # Serialize a rank-3 tensor 734 >>> t = tf.ones([5,5,5], dtype=tf.float32) 735 >>> serialized = tf.io.serialize_tensor(t) 736 >>> # The function still runs, even though it `set_shape([None,None])` 737 >>> t2 = concrete_parse(serialized) 738 >>> print(t2.shape) 739 (5, 5, 5) 740 741 Args: 742 shape: A `TensorShape` representing the shape of this tensor, a 743 `TensorShapeProto`, a list, a tuple, or None. 744 745 Raises: 746 ValueError: If `shape` is not compatible with the current shape of 747 this tensor. 748 """ 749 # Reset cached shape. 750 self._shape_val = None 751 752 # We want set_shape to be reflected in the C API graph for when we run it. 753 if not isinstance(shape, tensor_shape.TensorShape): 754 shape = tensor_shape.TensorShape(shape) 755 dim_list = [] 756 if shape.dims is None: 757 unknown_shape = True 758 else: 759 unknown_shape = False 760 for dim in shape.dims: 761 if dim.value is None: 762 dim_list.append(-1) 763 else: 764 dim_list.append(dim.value) 765 try: 766 pywrap_tf_session.TF_GraphSetTensorShape_wrapper( 767 self._op._graph._c_graph, # pylint: disable=protected-access 768 self._as_tf_output(), 769 dim_list, 770 unknown_shape) 771 except errors.InvalidArgumentError as e: 772 # Convert to ValueError for backwards compatibility. 773 raise ValueError(str(e)) 774 775 @property 776 def value_index(self): 777 """The index of this tensor in the outputs of its `Operation`.""" 778 return self._value_index 779 780 def consumers(self): 781 """Returns a list of `Operation`s that consume this tensor. 782 783 Returns: 784 A list of `Operation`s. 785 """ 786 consumer_names = pywrap_tf_session.TF_OperationOutputConsumers_wrapper( 787 self._as_tf_output()) 788 # pylint: disable=protected-access 789 return [ 790 self.graph._get_operation_by_name_unsafe(name) 791 for name in consumer_names 792 ] 793 # pylint: enable=protected-access 794 795 def _as_node_def_input(self): 796 """Return a value to use for the NodeDef "input" attribute. 797 798 The returned string can be used in a NodeDef "input" attribute 799 to indicate that the NodeDef uses this Tensor as input. 800 801 Raises: 802 ValueError: if this Tensor's Operation does not have a name. 803 804 Returns: 805 a string. 806 """ 807 if not self._op.name: 808 raise ValueError("Operation was not named: %s" % self._op) 809 if self._value_index == 0: 810 return self._op.name 811 else: 812 return "%s:%d" % (self._op.name, self._value_index) 813 814 def _as_tf_output(self): 815 # pylint: disable=protected-access 816 # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object 817 # also guarantees that a dictionary of tf_output objects will retain a 818 # deterministic (yet unsorted) order which prevents memory blowup in the 819 # cache of executor(s) stored for every session. 820 if self._tf_output is None: 821 self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index) 822 return self._tf_output 823 # pylint: enable=protected-access 824 825 def __str__(self): 826 return "Tensor(\"%s\"%s%s%s)" % ( 827 self.name, 828 (", shape=%s" % 829 self.get_shape()) if self.get_shape().ndims is not None else "", 830 (", dtype=%s" % self._dtype.name) if self._dtype else "", 831 (", device=%s" % self.device) if self.device else "") 832 833 def __repr__(self): 834 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 835 self._dtype.name) 836 837 def __hash__(self): 838 g = getattr(self, "graph", None) 839 if (Tensor._USE_EQUALITY and executing_eagerly_outside_functions() and 840 (g is None or g.building_function)): 841 raise TypeError("Tensor is unhashable. " 842 "Instead, use tensor.ref() as the key.") 843 else: 844 return id(self) 845 846 def __copy__(self): 847 # TODO(b/77597810): get rid of Tensor copies. 848 cls = self.__class__ 849 result = cls.__new__(cls) 850 result.__dict__.update(self.__dict__) 851 return result 852 853 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 854 # operators to run when the left operand is an ndarray, because it 855 # accords the Tensor class higher priority than an ndarray, or a 856 # numpy matrix. 857 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 858 # mechanism, which allows more control over how Tensors interact 859 # with ndarrays. 860 __array_priority__ = 100 861 862 def __array__(self): 863 raise NotImplementedError( 864 "Cannot convert a symbolic Tensor ({}) to a numpy array." 865 " This error may indicate that you're trying to pass a Tensor to" 866 " a NumPy call, which is not supported".format(self.name)) 867 868 def __len__(self): 869 raise TypeError("len is not well defined for symbolic Tensors. ({}) " 870 "Please call `x.shape` rather than `len(x)` for " 871 "shape information.".format(self.name)) 872 873 # TODO(mdan): This convoluted machinery is hard to maintain. Clean up. 874 @staticmethod 875 def _override_operator(operator, func): 876 _override_helper(Tensor, operator, func) 877 878 def __bool__(self): 879 """Dummy method to prevent a tensor from being used as a Python `bool`. 880 881 This overload raises a `TypeError` when the user inadvertently 882 treats a `Tensor` as a boolean (most commonly in an `if` or `while` 883 statement), in code that was not converted by AutoGraph. For example: 884 885 ```python 886 if tf.constant(True): # Will raise. 887 # ... 888 889 if tf.constant(5) < tf.constant(7): # Will raise. 890 # ... 891 ``` 892 893 Raises: 894 `TypeError`. 895 """ 896 self._disallow_bool_casting() 897 898 def __nonzero__(self): 899 """Dummy method to prevent a tensor from being used as a Python `bool`. 900 901 This is the Python 2.x counterpart to `__bool__()` above. 902 903 Raises: 904 `TypeError`. 905 """ 906 self._disallow_bool_casting() 907 908 def eval(self, feed_dict=None, session=None): 909 """Evaluates this tensor in a `Session`. 910 911 Note: If you are not using `compat.v1` libraries, you should not need this, 912 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`) 913 you do not need to call `eval`. 914 915 Calling this method will execute all preceding operations that 916 produce the inputs needed for the operation that produces this 917 tensor. 918 919 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 920 launched in a session, and either a default session must be 921 available, or `session` must be specified explicitly. 922 923 Args: 924 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 925 `tf.Session.run` for a description of the valid feed values. 926 session: (Optional.) The `Session` to be used to evaluate this tensor. If 927 none, the default session will be used. 928 929 Returns: 930 A numpy array corresponding to the value of this tensor. 931 """ 932 return _eval_using_default_session(self, feed_dict, self.graph, session) 933 934 @deprecation.deprecated(None, "Use ref() instead.") 935 def experimental_ref(self): 936 return self.ref() 937 938 def ref(self): 939 # tf.Variable also has the same ref() API. If you update the 940 # documentation here, please update tf.Variable.ref() as well. 941 """Returns a hashable reference object to this Tensor. 942 943 The primary use case for this API is to put tensors in a set/dictionary. 944 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer 945 available starting Tensorflow 2.0. 946 947 The following will raise an exception starting 2.0 948 949 >>> x = tf.constant(5) 950 >>> y = tf.constant(10) 951 >>> z = tf.constant(10) 952 >>> tensor_set = {x, y, z} 953 Traceback (most recent call last): 954 ... 955 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 956 >>> tensor_dict = {x: 'five', y: 'ten'} 957 Traceback (most recent call last): 958 ... 959 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 960 961 Instead, we can use `tensor.ref()`. 962 963 >>> tensor_set = {x.ref(), y.ref(), z.ref()} 964 >>> x.ref() in tensor_set 965 True 966 >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 967 >>> tensor_dict[y.ref()] 968 'ten' 969 970 Also, the reference object provides `.deref()` function that returns the 971 original Tensor. 972 973 >>> x = tf.constant(5) 974 >>> x.ref().deref() 975 <tf.Tensor: shape=(), dtype=int32, numpy=5> 976 """ 977 return object_identity.Reference(self) 978 979 980# TODO(agarwal): consider getting rid of this. 981# TODO(mdan): This object should not subclass ops.Tensor. 982class _EagerTensorBase(Tensor): 983 """Base class for EagerTensor.""" 984 985 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and 986 # only work for scalars; values are cast as per numpy. 987 def __complex__(self): 988 return complex(self._numpy()) 989 990 def __int__(self): 991 return int(self._numpy()) 992 993 def __long__(self): 994 return long(self._numpy()) 995 996 def __float__(self): 997 return float(self._numpy()) 998 999 def __index__(self): 1000 return self._numpy().__index__() 1001 1002 def __bool__(self): 1003 return bool(self._numpy()) 1004 1005 __nonzero__ = __bool__ 1006 1007 def __format__(self, format_spec): 1008 return self._numpy().__format__(format_spec) 1009 1010 def __reduce__(self): 1011 return convert_to_tensor, (self._numpy(),) 1012 1013 def __copy__(self): 1014 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1015 return self 1016 1017 def __deepcopy__(self, memo): 1018 # Eager Tensors are immutable so it's safe to return themselves as a copy. 1019 del memo 1020 return self 1021 1022 def __str__(self): 1023 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), self.shape, 1024 self.dtype.name) 1025 1026 def __repr__(self): 1027 return "<tf.Tensor: shape=%s, dtype=%s, numpy=%s>" % ( 1028 self.shape, self.dtype.name, numpy_text(self, is_repr=True)) 1029 1030 def __len__(self): 1031 """Returns the length of the first dimension in the Tensor.""" 1032 if not self.shape.ndims: 1033 raise TypeError("Scalar tensor has no `len()`") 1034 # pylint: disable=protected-access 1035 try: 1036 return self._shape_tuple()[0] 1037 except core._NotOkStatusException as e: 1038 six.raise_from(core._status_to_exception(e.code, e.message), None) 1039 1040 def __array__(self): 1041 return self._numpy() 1042 1043 def _numpy_internal(self): 1044 raise NotImplementedError() 1045 1046 def _numpy(self): 1047 try: 1048 return self._numpy_internal() 1049 except core._NotOkStatusException as e: # pylint: disable=protected-access 1050 six.raise_from(core._status_to_exception(e.code, e.message), None) # pylint: disable=protected-access 1051 1052 @property 1053 def dtype(self): 1054 # Note: using the intern table directly here as this is 1055 # performance-sensitive in some models. 1056 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 1057 1058 def numpy(self): 1059 """Copy of the contents of this Tensor into a NumPy array or scalar. 1060 1061 Unlike NumPy arrays, Tensors are immutable, so this method has to copy 1062 the contents to ensure safety. Use `memoryview` to get a readonly 1063 view of the contents without doing a copy: 1064 1065 >>> t = tf.constant([42]) 1066 >>> np.array(memoryview(t)) 1067 array([42], dtype=int32) 1068 1069 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor 1070 is on GPU, it will have to be transferred to CPU first in order for 1071 `memoryview` to work. 1072 1073 Returns: 1074 A NumPy array of the same shape and dtype or a NumPy scalar, if this 1075 Tensor has rank 0. 1076 1077 Raises: 1078 ValueError: If the dtype of this Tensor does not have a compatible 1079 NumPy dtype. 1080 """ 1081 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors. 1082 maybe_arr = self._numpy() # pylint: disable=protected-access 1083 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr 1084 1085 @property 1086 def backing_device(self): 1087 """Returns the name of the device holding this tensor's memory. 1088 1089 `.backing_device` is usually the same as `.device`, which returns 1090 the device on which the kernel of the operation that produced this tensor 1091 ran. However, some operations can produce tensors on a different device 1092 (e.g., an operation that executes on the GPU but produces output tensors 1093 in host memory). 1094 """ 1095 raise NotImplementedError() 1096 1097 def _datatype_enum(self): 1098 raise NotImplementedError() 1099 1100 def _shape_tuple(self): 1101 """The shape of this Tensor, as a tuple. 1102 1103 This is more performant than tuple(shape().as_list()) as it avoids 1104 two list and one object creation. Marked private for now as from an API 1105 perspective, it would be better to have a single performant way of 1106 getting a shape rather than exposing shape() and shape_tuple() 1107 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 1108 but ideally one would work things out and remove the need for this method. 1109 1110 Returns: 1111 tuple with the shape. 1112 """ 1113 raise NotImplementedError() 1114 1115 def _rank(self): 1116 """Integer rank of this Tensor. 1117 1118 Unlike regular Tensors, the rank is always known for EagerTensors. 1119 1120 This is more performant than len(self._shape_tuple()) 1121 1122 Returns: 1123 Integer rank 1124 """ 1125 raise NotImplementedError() 1126 1127 def _num_elements(self): 1128 """Number of elements of this Tensor. 1129 1130 Unlike regular Tensors, the number of elements is always known for 1131 EagerTensors. 1132 1133 This is more performant than tensor.shape.num_elements 1134 1135 Returns: 1136 Long - num elements in the tensor 1137 """ 1138 raise NotImplementedError() 1139 1140 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name 1141 raise NotImplementedError() 1142 1143 @staticmethod 1144 def _override_operator(name, func): 1145 setattr(_EagerTensorBase, name, func) 1146 1147 def _copy_nograd(self, ctx=None, device_name=None): 1148 """Copies tensor to dest device, but doesn't record the operation.""" 1149 # Creates a new tensor on the dest device. 1150 if ctx is None: 1151 ctx = context.context() 1152 if device_name is None: 1153 device_name = ctx.device_name 1154 # pylint: disable=protected-access 1155 try: 1156 ctx.ensure_initialized() 1157 new_tensor = self._copy_to_device(device_name) 1158 except core._NotOkStatusException as e: 1159 six.raise_from(core._status_to_exception(e.code, e.message), None) 1160 return new_tensor 1161 1162 def _copy(self, ctx=None, device_name=None): 1163 """Copies tensor to dest device.""" 1164 new_tensor = self._copy_nograd(ctx, device_name) 1165 # Record the copy on tape and define backprop copy as well. 1166 if context.executing_eagerly(): 1167 self_device = self.device 1168 1169 def grad_fun(dresult): 1170 return [ 1171 dresult._copy(device_name=self_device) 1172 if hasattr(dresult, "_copy") else dresult 1173 ] 1174 1175 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 1176 return new_tensor 1177 # pylint: enable=protected-access 1178 1179 @property 1180 def shape(self): 1181 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 1182 # pylint: disable=protected-access 1183 try: 1184 # `_tensor_shape` is declared and defined in the definition of 1185 # `EagerTensor`, in C. 1186 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 1187 except core._NotOkStatusException as e: 1188 six.raise_from(core._status_to_exception(e.code, e.message), None) 1189 1190 return self._tensor_shape 1191 1192 def get_shape(self): 1193 """Alias of Tensor.shape.""" 1194 return self.shape 1195 1196 def _shape_as_list(self): 1197 """The shape of the tensor as a list.""" 1198 return list(self._shape_tuple()) 1199 1200 @property 1201 def ndim(self): 1202 """Returns the number of Tensor dimensions.""" 1203 return self.shape.ndims 1204 1205 @deprecation.deprecated(None, "Use tf.identity instead.") 1206 def cpu(self): 1207 """A copy of this Tensor with contents backed by host memory.""" 1208 return self._copy(context.context(), "CPU:0") 1209 1210 @deprecation.deprecated(None, "Use tf.identity instead.") 1211 def gpu(self, gpu_index=0): 1212 """A copy of this Tensor with contents backed by memory on the GPU. 1213 1214 Args: 1215 gpu_index: Identifies which GPU to place the contents on the returned 1216 Tensor in. 1217 1218 Returns: 1219 A GPU-memory backed Tensor object initialized with the same contents 1220 as this Tensor. 1221 """ 1222 return self._copy(context.context(), "GPU:" + str(gpu_index)) 1223 1224 def set_shape(self, shape): 1225 if not self.shape.is_compatible_with(shape): 1226 raise ValueError( 1227 "Tensor's shape %s is not compatible with supplied shape %s" % 1228 (self.shape, shape)) 1229 1230 # Methods not supported / implemented for Eager Tensors. 1231 @property 1232 def op(self): 1233 raise AttributeError( 1234 "Tensor.op is meaningless when eager execution is enabled.") 1235 1236 @property 1237 def graph(self): 1238 raise AttributeError( 1239 "Tensor.graph is meaningless when eager execution is enabled.") 1240 1241 @property 1242 def name(self): 1243 raise AttributeError( 1244 "Tensor.name is meaningless when eager execution is enabled.") 1245 1246 @property 1247 def value_index(self): 1248 raise AttributeError( 1249 "Tensor.value_index is meaningless when eager execution is enabled.") 1250 1251 def consumers(self): 1252 raise NotImplementedError( 1253 "Tensor.consumers is meaningless when eager execution is enabled.") 1254 1255 def _add_consumer(self, consumer): 1256 raise NotImplementedError( 1257 "_add_consumer not supported when eager execution is enabled.") 1258 1259 def _as_node_def_input(self): 1260 raise NotImplementedError( 1261 "_as_node_def_input not supported when eager execution is enabled.") 1262 1263 def _as_tf_output(self): 1264 raise NotImplementedError( 1265 "_as_tf_output not supported when eager execution is enabled.") 1266 1267 def eval(self, feed_dict=None, session=None): 1268 raise NotImplementedError( 1269 "eval is not supported when eager execution is enabled, " 1270 "is .numpy() what you're looking for?") 1271 1272 1273# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 1274# registers it with the current module. 1275# It is exposed as an __internal__ api for now (b/171081052), though we 1276# expect it to be eventually covered by tf Tensor types and typing. 1277EagerTensor = tf_export("__internal__.EagerTensor", v1=[])( 1278 pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)) 1279 1280 1281@tf_export(v1=["convert_to_tensor"]) 1282@dispatch.add_dispatch_support 1283def convert_to_tensor_v1_with_dispatch( 1284 value, 1285 dtype=None, 1286 name=None, 1287 preferred_dtype=None, 1288 dtype_hint=None): 1289 """Converts the given `value` to a `Tensor`. 1290 1291 This function converts Python objects of various types to `Tensor` 1292 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1293 and Python scalars. For example: 1294 1295 ```python 1296 import numpy as np 1297 1298 def my_func(arg): 1299 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1300 return tf.matmul(arg, arg) + arg 1301 1302 # The following calls are equivalent. 1303 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1304 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1305 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1306 ``` 1307 1308 This function can be useful when composing a new operation in Python 1309 (such as `my_func` in the example above). All standard Python op 1310 constructors apply this function to each of their Tensor-valued 1311 inputs, which allows those ops to accept numpy arrays, Python lists, 1312 and scalars in addition to `Tensor` objects. 1313 1314 Note: This function diverges from default Numpy behavior for `float` and 1315 `string` types when `None` is present in a Python list or scalar. Rather 1316 than silently converting `None` values, an error will be thrown. 1317 1318 Args: 1319 value: An object whose type has a registered `Tensor` conversion function. 1320 dtype: Optional element type for the returned tensor. If missing, the type 1321 is inferred from the type of `value`. 1322 name: Optional name to use if a new `Tensor` is created. 1323 preferred_dtype: Optional element type for the returned tensor, used when 1324 dtype is None. In some cases, a caller may not have a dtype in mind when 1325 converting to a tensor, so preferred_dtype can be used as a soft 1326 preference. If the conversion to `preferred_dtype` is not possible, this 1327 argument has no effect. 1328 dtype_hint: same meaning as preferred_dtype, and overrides it. 1329 1330 Returns: 1331 A `Tensor` based on `value`. 1332 1333 Raises: 1334 TypeError: If no conversion function is registered for `value` to `dtype`. 1335 RuntimeError: If a registered conversion function returns an invalid value. 1336 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1337 """ 1338 return convert_to_tensor_v1(value, dtype=dtype, name=name, 1339 preferred_dtype=preferred_dtype, 1340 dtype_hint=dtype_hint) 1341 1342 1343def convert_to_tensor_v1(value, 1344 dtype=None, 1345 name=None, 1346 preferred_dtype=None, 1347 dtype_hint=None): 1348 """Converts the given `value` to a `Tensor` (with the TF1 API).""" 1349 preferred_dtype = deprecation.deprecated_argument_lookup( 1350 "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype) 1351 return convert_to_tensor_v2(value, dtype, preferred_dtype, name) 1352 1353 1354@tf_export("convert_to_tensor", v1=[]) 1355@dispatch.add_dispatch_support 1356def convert_to_tensor_v2_with_dispatch( 1357 value, dtype=None, dtype_hint=None, name=None): 1358 """Converts the given `value` to a `Tensor`. 1359 1360 This function converts Python objects of various types to `Tensor` 1361 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1362 and Python scalars. 1363 1364 For example: 1365 1366 >>> import numpy as np 1367 >>> def my_func(arg): 1368 ... arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1369 ... return arg 1370 1371 >>> # The following calls are equivalent. 1372 ... 1373 >>> value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1374 >>> print(value_1) 1375 tf.Tensor( 1376 [[1. 2.] 1377 [3. 4.]], shape=(2, 2), dtype=float32) 1378 >>> value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1379 >>> print(value_2) 1380 tf.Tensor( 1381 [[1. 2.] 1382 [3. 4.]], shape=(2, 2), dtype=float32) 1383 >>> value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1384 >>> print(value_3) 1385 tf.Tensor( 1386 [[1. 2.] 1387 [3. 4.]], shape=(2, 2), dtype=float32) 1388 1389 This function can be useful when composing a new operation in Python 1390 (such as `my_func` in the example above). All standard Python op 1391 constructors apply this function to each of their Tensor-valued 1392 inputs, which allows those ops to accept numpy arrays, Python lists, 1393 and scalars in addition to `Tensor` objects. 1394 1395 Note: This function diverges from default Numpy behavior for `float` and 1396 `string` types when `None` is present in a Python list or scalar. Rather 1397 than silently converting `None` values, an error will be thrown. 1398 1399 Args: 1400 value: An object whose type has a registered `Tensor` conversion function. 1401 dtype: Optional element type for the returned tensor. If missing, the type 1402 is inferred from the type of `value`. 1403 dtype_hint: Optional element type for the returned tensor, used when dtype 1404 is None. In some cases, a caller may not have a dtype in mind when 1405 converting to a tensor, so dtype_hint can be used as a soft preference. 1406 If the conversion to `dtype_hint` is not possible, this argument has no 1407 effect. 1408 name: Optional name to use if a new `Tensor` is created. 1409 1410 Returns: 1411 A `Tensor` based on `value`. 1412 1413 Raises: 1414 TypeError: If no conversion function is registered for `value` to `dtype`. 1415 RuntimeError: If a registered conversion function returns an invalid value. 1416 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1417 """ 1418 return convert_to_tensor_v2( 1419 value, dtype=dtype, dtype_hint=dtype_hint, name=name) 1420 1421 1422def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): 1423 """Converts the given `value` to a `Tensor`.""" 1424 return convert_to_tensor( 1425 value=value, 1426 dtype=dtype, 1427 name=name, 1428 preferred_dtype=dtype_hint, 1429 as_ref=False) 1430 1431 1432def _error_prefix(name): 1433 return "" if name is None else "%s: " % name 1434 1435 1436def pack_eager_tensors(tensors, ctx=None): 1437 """Pack multiple `EagerTensor`s of the same dtype and shape. 1438 1439 Args: 1440 tensors: a list of EagerTensors to pack. 1441 ctx: context.context(). 1442 1443 Returns: 1444 A packed EagerTensor. 1445 """ 1446 if not isinstance(tensors, list): 1447 raise TypeError("tensors must be a list or a tuple: %s" % tensors) 1448 1449 if not tensors: 1450 raise ValueError("Empty tensors is unexpected for packing.") 1451 1452 dtype = tensors[0].dtype 1453 shape = tensors[0].shape 1454 handle_data = tensors[0]._handle_data # pylint: disable=protected-access 1455 is_resource = dtype == dtypes.resource 1456 for i in range(len(tensors)): 1457 t = tensors[i] 1458 if not isinstance(t, EagerTensor): 1459 raise TypeError("tensors must be a list of EagerTensors: %s" % t) 1460 1461 if t.dtype != dtype: 1462 raise ValueError( 1463 "All tensors being packed should have the same dtype %s, " 1464 "but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype)) 1465 if t.shape != shape: 1466 raise ValueError( 1467 "All tensors being packed should have the same shape %s, " 1468 "but the %d-th tensor is of shape %s" % (shape, i, t.shape)) 1469 # pylint: disable=protected-access 1470 if is_resource and t._handle_data != handle_data: 1471 raise ValueError( 1472 "All tensors being packed should have the same handle data %s, " 1473 "but the %d-th tensor is of handle data %s" % 1474 (handle_data, i, t._handle_data)) 1475 # pylint: enable=protected-access 1476 1477 if ctx is None: 1478 ctx = context.context() 1479 1480 # Propogate handle data for resource variables 1481 packed_tensor = ctx.pack_eager_tensors(tensors) 1482 if handle_data is not None: 1483 packed_tensor._handle_data = handle_data # pylint: disable=protected-access 1484 1485 def grad_fun(_): 1486 raise ValueError( 1487 "Gradients through pack_eager_tensors are not supported yet.") 1488 1489 tape.record_operation("pack_eager_tensors", [packed_tensor], tensors, 1490 grad_fun) 1491 1492 return packed_tensor 1493 1494 1495@trace.trace_wrapper("convert_to_tensor") 1496def convert_to_tensor(value, 1497 dtype=None, 1498 name=None, 1499 as_ref=False, 1500 preferred_dtype=None, 1501 dtype_hint=None, 1502 ctx=None, 1503 accepted_result_types=(Tensor,)): 1504 """Implementation of the public convert_to_tensor.""" 1505 # TODO(b/142518781): Fix all call-sites and remove redundant arg 1506 preferred_dtype = preferred_dtype or dtype_hint 1507 if isinstance(value, EagerTensor): 1508 if ctx is None: 1509 ctx = context.context() 1510 if not ctx.executing_eagerly(): 1511 graph = get_default_graph() 1512 if not graph.building_function: 1513 raise RuntimeError("Attempting to capture an EagerTensor without " 1514 "building a function.") 1515 return graph.capture(value, name=name) 1516 1517 if dtype is not None: 1518 dtype = dtypes.as_dtype(dtype) 1519 if isinstance(value, Tensor): 1520 if dtype is not None and not dtype.is_compatible_with(value.dtype): 1521 raise ValueError( 1522 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1523 (dtype.name, value.dtype.name, value)) 1524 return value 1525 1526 if preferred_dtype is not None: 1527 preferred_dtype = dtypes.as_dtype(preferred_dtype) 1528 1529 # See below for the reason why it's `type(value)` and not just `value`. 1530 # https://docs.python.org/3.8/reference/datamodel.html#special-lookup 1531 overload = getattr(type(value), "__tf_tensor__", None) 1532 if overload is not None: 1533 return overload(value, dtype, name) 1534 1535 for base_type, conversion_func in tensor_conversion_registry.get(type(value)): 1536 # If dtype is None but preferred_dtype is not None, we try to 1537 # cast to preferred_dtype first. 1538 ret = None 1539 if dtype is None and preferred_dtype is not None: 1540 try: 1541 ret = conversion_func( 1542 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1543 except (TypeError, ValueError): 1544 # Could not coerce the conversion to use the preferred dtype. 1545 pass 1546 else: 1547 if (ret is not NotImplemented and 1548 ret.dtype.base_dtype != preferred_dtype.base_dtype): 1549 raise TypeError("convert_to_tensor did not convert to " 1550 "the preferred dtype: %s vs %s " % 1551 (ret.dtype.base_dtype, preferred_dtype.base_dtype)) 1552 1553 if ret is None: 1554 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1555 1556 if ret is NotImplemented: 1557 continue 1558 1559 if not isinstance(ret, accepted_result_types): 1560 raise RuntimeError( 1561 "%sConversion function %r for type %s returned non-Tensor: %r" % 1562 (_error_prefix(name), conversion_func, base_type, ret)) 1563 if dtype and not dtype.is_compatible_with(ret.dtype): 1564 raise RuntimeError( 1565 "%sConversion function %r for type %s returned incompatible " 1566 "dtype: requested = %s, actual = %s" % 1567 (_error_prefix(name), conversion_func, base_type, dtype.name, 1568 ret.dtype.name)) 1569 return ret 1570 raise TypeError("%sCannot convert %r with type %s to Tensor: " 1571 "no conversion function registered." % 1572 (_error_prefix(name), value, type(value))) 1573 1574 1575internal_convert_to_tensor = convert_to_tensor 1576 1577 1578def internal_convert_n_to_tensor(values, 1579 dtype=None, 1580 name=None, 1581 as_ref=False, 1582 preferred_dtype=None, 1583 ctx=None): 1584 """Converts `values` to a list of `Tensor` objects. 1585 1586 Args: 1587 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1588 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1589 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1590 which case element `i` will be given the name `name + '_' + i`. 1591 as_ref: True if the caller wants the results as ref tensors. 1592 preferred_dtype: Optional element type for the returned tensors, used when 1593 dtype is None. In some cases, a caller may not have a dtype in mind when 1594 converting to a tensor, so preferred_dtype can be used as a soft 1595 preference. If the conversion to `preferred_dtype` is not possible, this 1596 argument has no effect. 1597 ctx: The value of context.context(). 1598 1599 Returns: 1600 A list of `Tensor` and/or `IndexedSlices` objects. 1601 1602 Raises: 1603 TypeError: If no conversion function is registered for an element in 1604 `values`. 1605 RuntimeError: If a registered conversion function returns an invalid 1606 value. 1607 """ 1608 if not isinstance(values, collections_abc.Sequence): 1609 raise TypeError("values must be a sequence.") 1610 ret = [] 1611 if ctx is None: 1612 ctx = context.context() 1613 for i, value in enumerate(values): 1614 n = None if name is None else "%s_%d" % (name, i) 1615 ret.append( 1616 convert_to_tensor( 1617 value, 1618 dtype=dtype, 1619 name=n, 1620 as_ref=as_ref, 1621 preferred_dtype=preferred_dtype, 1622 ctx=ctx)) 1623 return ret 1624 1625 1626def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1627 """Converts `values` to a list of `Tensor` objects. 1628 1629 Args: 1630 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1631 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1632 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1633 which case element `i` will be given the name `name + '_' + i`. 1634 preferred_dtype: Optional element type for the returned tensors, used when 1635 dtype is None. In some cases, a caller may not have a dtype in mind when 1636 converting to a tensor, so preferred_dtype can be used as a soft 1637 preference. If the conversion to `preferred_dtype` is not possible, this 1638 argument has no effect. 1639 1640 Returns: 1641 A list of `Tensor` and/or `IndexedSlices` objects. 1642 1643 Raises: 1644 TypeError: If no conversion function is registered for an element in 1645 `values`. 1646 RuntimeError: If a registered conversion function returns an invalid 1647 value. 1648 """ 1649 return internal_convert_n_to_tensor( 1650 values=values, 1651 dtype=dtype, 1652 name=name, 1653 preferred_dtype=preferred_dtype, 1654 as_ref=False) 1655 1656 1657def convert_to_tensor_or_composite(value, dtype=None, name=None): 1658 """Converts the given object to a `Tensor` or `CompositeTensor`. 1659 1660 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1661 is converted to a `Tensor` using `convert_to_tensor()`. 1662 1663 Args: 1664 value: A `CompositeTensor` or an object that can be consumed by 1665 `convert_to_tensor()`. 1666 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1667 `CompositeTensor`. 1668 name: (Optional.) A name to use if a new `Tensor` is created. 1669 1670 Returns: 1671 A `Tensor` or `CompositeTensor`, based on `value`. 1672 1673 Raises: 1674 ValueError: If `dtype` does not match the element type of `value`. 1675 """ 1676 return internal_convert_to_tensor_or_composite( 1677 value=value, dtype=dtype, name=name, as_ref=False) 1678 1679 1680def internal_convert_to_tensor_or_composite(value, 1681 dtype=None, 1682 name=None, 1683 as_ref=False): 1684 """Converts the given object to a `Tensor` or `CompositeTensor`. 1685 1686 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1687 is converted to a `Tensor` using `convert_to_tensor()`. 1688 1689 Args: 1690 value: A `CompositeTensor`, or an object that can be consumed by 1691 `convert_to_tensor()`. 1692 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1693 `CompositeTensor`. 1694 name: (Optional.) A name to use if a new `Tensor` is created. 1695 as_ref: True if the caller wants the results as ref tensors. 1696 1697 Returns: 1698 A `Tensor` or `CompositeTensor`, based on `value`. 1699 1700 Raises: 1701 ValueError: If `dtype` does not match the element type of `value`. 1702 """ 1703 if isinstance(value, composite_tensor.CompositeTensor): 1704 value_dtype = getattr(value, "dtype", None) 1705 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 1706 raise ValueError( 1707 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1708 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1709 return value 1710 else: 1711 return convert_to_tensor( 1712 value, 1713 dtype=dtype, 1714 name=name, 1715 as_ref=as_ref, 1716 accepted_result_types=(Tensor, composite_tensor.CompositeTensor)) 1717 1718 1719def internal_convert_n_to_tensor_or_composite(values, 1720 dtype=None, 1721 name=None, 1722 as_ref=False): 1723 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 1724 1725 Any `CompositeTensor` objects in `values` are returned unmodified. 1726 1727 Args: 1728 values: A list of `None`, `CompositeTensor`, or objects that can be consumed 1729 by `convert_to_tensor()`. 1730 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1731 `CompositeTensor`s. 1732 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1733 which case element `i` will be given the name `name + '_' + i`. 1734 as_ref: True if the caller wants the results as ref tensors. 1735 1736 Returns: 1737 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 1738 1739 Raises: 1740 TypeError: If no conversion function is registered for an element in 1741 `values`. 1742 RuntimeError: If a registered conversion function returns an invalid 1743 value. 1744 """ 1745 if not isinstance(values, collections_abc.Sequence): 1746 raise TypeError("values must be a sequence.") 1747 ret = [] 1748 for i, value in enumerate(values): 1749 if value is None: 1750 ret.append(value) 1751 else: 1752 n = None if name is None else "%s_%d" % (name, i) 1753 ret.append( 1754 internal_convert_to_tensor_or_composite( 1755 value, dtype=dtype, name=n, as_ref=as_ref)) 1756 return ret 1757 1758 1759def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 1760 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 1761 1762 Any `CompositeTensor` objects in `values` are returned unmodified. 1763 1764 Args: 1765 values: A list of `None`, `CompositeTensor``, or objects that can be 1766 consumed by `convert_to_tensor()`. 1767 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1768 `CompositeTensor`s. 1769 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 1770 which case element `i` will be given the name `name + '_' + i`. 1771 1772 Returns: 1773 A list of `Tensor` and/or `CompositeTensor` objects. 1774 1775 Raises: 1776 TypeError: If no conversion function is registered for an element in 1777 `values`. 1778 RuntimeError: If a registered conversion function returns an invalid 1779 value. 1780 """ 1781 return internal_convert_n_to_tensor_or_composite( 1782 values=values, dtype=dtype, name=name, as_ref=False) 1783 1784 1785def _device_string(dev_spec): 1786 if pydev.is_device_spec(dev_spec): 1787 return dev_spec.to_string() 1788 else: 1789 return dev_spec 1790 1791 1792def _NodeDef(op_type, name, attrs=None): 1793 """Create a NodeDef proto. 1794 1795 Args: 1796 op_type: Value for the "op" attribute of the NodeDef proto. 1797 name: Value for the "name" attribute of the NodeDef proto. 1798 attrs: Dictionary where the key is the attribute name (a string) 1799 and the value is the respective "attr" attribute of the NodeDef proto (an 1800 AttrValue). 1801 1802 Returns: 1803 A node_def_pb2.NodeDef protocol buffer. 1804 """ 1805 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type), 1806 name=compat.as_bytes(name)) 1807 if attrs: 1808 for k, v in six.iteritems(attrs): 1809 node_def.attr[k].CopyFrom(v) 1810 return node_def 1811 1812 1813# Copied from core/framework/node_def_util.cc 1814# TODO(mrry,josh11b): Consolidate this validation in C++ code. 1815_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") 1816_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$") 1817 1818 1819@tf_export("__internal__.create_c_op", v1=[]) 1820def _create_c_op(graph, node_def, inputs, control_inputs, op_def=None): 1821 """Creates a TF_Operation. 1822 1823 Args: 1824 graph: a `Graph`. 1825 node_def: `node_def_pb2.NodeDef` for the operation to create. 1826 inputs: A flattened list of `Tensor`s. This function handles grouping 1827 tensors into lists as per attributes in the `node_def`. 1828 control_inputs: A list of `Operation`s to set as control dependencies. 1829 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not 1830 specified, is looked up from the `graph` using `node_def.op`. 1831 1832 Returns: 1833 A wrapped TF_Operation*. 1834 """ 1835 if op_def is None: 1836 op_def = graph._get_op_def(node_def.op) # pylint: disable=protected-access 1837 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1838 # Refactor so we don't have to do this here. 1839 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr) 1840 # pylint: disable=protected-access 1841 op_desc = pywrap_tf_session.TF_NewOperation(graph._c_graph, 1842 compat.as_str(node_def.op), 1843 compat.as_str(node_def.name)) 1844 if node_def.device: 1845 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 1846 # Add inputs 1847 for op_input in inputs: 1848 if isinstance(op_input, (list, tuple)): 1849 pywrap_tf_session.TF_AddInputList(op_desc, 1850 [t._as_tf_output() for t in op_input]) 1851 else: 1852 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output()) 1853 1854 # Add control inputs 1855 for control_input in control_inputs: 1856 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op) 1857 # pylint: enable=protected-access 1858 1859 # Add attrs 1860 for name, attr_value in node_def.attr.items(): 1861 serialized = attr_value.SerializeToString() 1862 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1863 # It might be worth creating a convenient way to re-use the same status. 1864 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name), 1865 serialized) 1866 1867 try: 1868 c_op = pywrap_tf_session.TF_FinishOperation(op_desc) 1869 except errors.InvalidArgumentError as e: 1870 # Convert to ValueError for backwards compatibility. 1871 raise ValueError(str(e)) 1872 1873 return c_op 1874 1875 1876@tf_export("Operation") 1877class Operation(object): 1878 """Represents a graph node that performs computation on tensors. 1879 1880 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor` 1881 objects as input, and produces zero or more `Tensor` objects as output. 1882 Objects of type `Operation` are created by calling a Python op constructor 1883 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default` 1884 context manager. 1885 1886 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an 1887 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and 1888 produces `c` as output. 1889 1890 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be 1891 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for 1892 calling `tf.compat.v1.get_default_session().run(op)`. 1893 """ 1894 1895 def __init__(self, 1896 node_def, 1897 g, 1898 inputs=None, 1899 output_types=None, 1900 control_inputs=None, 1901 input_types=None, 1902 original_op=None, 1903 op_def=None): 1904 r"""Creates an `Operation`. 1905 1906 NOTE: This constructor validates the name of the `Operation` (passed 1907 as `node_def.name`). Valid `Operation` names match the following 1908 regular expression: 1909 1910 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 1911 1912 Args: 1913 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for 1914 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and 1915 `device`. The `input` attribute is irrelevant here as it will be 1916 computed when generating the model. 1917 g: `Graph`. The parent graph. 1918 inputs: list of `Tensor` objects. The inputs to this `Operation`. 1919 output_types: list of `DType` objects. List of the types of the `Tensors` 1920 computed by this operation. The length of this list indicates the 1921 number of output endpoints of the `Operation`. 1922 control_inputs: list of operations or tensors from which to have a control 1923 dependency. 1924 input_types: List of `DType` objects representing the types of the tensors 1925 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x 1926 in inputs]`. Operations that expect reference-typed inputs must specify 1927 these explicitly. 1928 original_op: Optional. Used to associate the new `Operation` with an 1929 existing `Operation` (for example, a replica with the op that was 1930 replicated). 1931 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type 1932 that this `Operation` represents. 1933 1934 Raises: 1935 TypeError: if control inputs are not Operations or Tensors, 1936 or if `node_def` is not a `NodeDef`, 1937 or if `g` is not a `Graph`, 1938 or if `inputs` are not tensors, 1939 or if `inputs` and `input_types` are incompatible. 1940 ValueError: if the `node_def` name is not valid. 1941 """ 1942 # For internal use only: `node_def` can be set to a TF_Operation to create 1943 # an Operation for that op. This is useful for creating Operations for ops 1944 # indirectly created by C API methods, e.g. the ops created by 1945 # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields 1946 # should be None. 1947 1948 if isinstance(node_def, node_def_pb2.NodeDef): 1949 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 1950 raise ValueError( 1951 "Cannot create a tensor proto whose content is larger than 2GB.") 1952 if not _VALID_OP_NAME_REGEX.match(node_def.name): 1953 raise ValueError("'%s' is not a valid node name" % node_def.name) 1954 c_op = None 1955 elif type(node_def).__name__ == "TF_Operation": 1956 assert inputs is None 1957 assert output_types is None 1958 assert control_inputs is None 1959 assert input_types is None 1960 assert original_op is None 1961 assert op_def is None 1962 c_op = node_def 1963 else: 1964 raise TypeError("node_def needs to be a NodeDef: %s" % (node_def,)) 1965 1966 if not isinstance(g, Graph): 1967 raise TypeError("g needs to be a Graph: %s" % (g,)) 1968 self._graph = g 1969 1970 if inputs is None: 1971 inputs = [] 1972 elif not isinstance(inputs, list): 1973 raise TypeError("inputs needs to be a list of Tensors: %s" % (inputs,)) 1974 for a in inputs: 1975 if not isinstance(a, Tensor): 1976 raise TypeError("input needs to be a Tensor: %s" % (a,)) 1977 if input_types is None: 1978 input_types = [i.dtype.base_dtype for i in inputs] 1979 else: 1980 if not all( 1981 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)): 1982 raise TypeError("In op '%s', input types (%s) are not compatible " 1983 "with expected types (%s)" % 1984 (node_def.name, [i.dtype for i in inputs], input_types)) 1985 1986 # Build the list of control inputs. 1987 control_input_ops = [] 1988 if control_inputs: 1989 for c in control_inputs: 1990 control_op = None 1991 if isinstance(c, Operation): 1992 control_op = c 1993 elif isinstance(c, (Tensor, IndexedSlices)): 1994 control_op = c.op 1995 else: 1996 raise TypeError("Control input must be an Operation, " 1997 "a Tensor, or IndexedSlices: %s" % c) 1998 control_input_ops.append(control_op) 1999 2000 # This will be set by self.inputs. 2001 self._inputs_val = None 2002 2003 # pylint: disable=protected-access 2004 self._original_op = original_op 2005 2006 # List of _UserDevSpecs holding code location of device context manager 2007 # invocations and the users original argument to them. 2008 self._device_code_locations = None 2009 # Dict mapping op name to file and line information for op colocation 2010 # context managers. 2011 self._colocation_code_locations = None 2012 self._control_flow_context = self.graph._get_control_flow_context() 2013 2014 # Gradient function for this op. There are three ways to specify gradient 2015 # function, and first available gradient gets used, in the following order. 2016 # 1. self._gradient_function 2017 # 2. Gradient name registered by "_gradient_op_type" attribute. 2018 # 3. Gradient name registered by op.type. 2019 self._gradient_function = None 2020 2021 # Initialize self._c_op. 2022 if c_op: 2023 self._c_op = c_op 2024 op_def = g._get_op_def(pywrap_tf_session.TF_OperationOpType(c_op)) 2025 name = self.name 2026 else: 2027 if op_def is None: 2028 op_def = self._graph._get_op_def(node_def.op) 2029 self._c_op = _create_c_op(self._graph, node_def, inputs, 2030 control_input_ops, op_def) 2031 name = compat.as_str(node_def.name) 2032 2033 self._traceback = tf_stack.extract_stack_for_node(self._c_op) 2034 2035 # pylint: enable=protected-access 2036 2037 self._is_stateful = op_def.is_stateful 2038 2039 # Initialize self._outputs. 2040 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2041 self._outputs = [] 2042 for i in range(num_outputs): 2043 tf_output = c_api_util.tf_output(self._c_op, i) 2044 output_type = pywrap_tf_session.TF_OperationOutputType(tf_output) 2045 tensor = Tensor._create_with_tf_output(self, i, output_type, tf_output) # pylint: disable=protected-access 2046 self._outputs.append(tensor) 2047 2048 self._id_value = self._graph._add_op(self, name) # pylint: disable=protected-access 2049 2050 if not c_op: 2051 self._control_flow_post_processing(input_tensors=inputs) 2052 2053 def _control_flow_post_processing(self, input_tensors=None): 2054 """Add this op to its control flow context. 2055 2056 This may add new ops and change this op's inputs. self.inputs must be 2057 available before calling this method. 2058 2059 Args: 2060 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs 2061 of this op, which should be equivalent to `self.inputs`. Pass this 2062 argument to avoid evaluating `self.inputs` unnecessarily. 2063 """ 2064 if input_tensors is None: 2065 input_tensors = self.inputs 2066 for input_tensor in input_tensors: 2067 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 2068 if self._control_flow_context is not None: 2069 self._control_flow_context.AddOp(self) 2070 2071 def colocation_groups(self): 2072 """Returns the list of colocation groups of the op.""" 2073 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)] 2074 try: 2075 class_attr = self.get_attr("_class") 2076 except ValueError: 2077 # This op has no explicit colocation group, so it is itself its 2078 # own root of a colocation group. 2079 return default_colocation_group 2080 2081 attr_groups = [ 2082 class_name for class_name in class_attr 2083 if class_name.startswith(b"loc:@") 2084 ] 2085 2086 # If there are no colocation groups in the explicit _class field, 2087 # return the default colocation group. 2088 return attr_groups if attr_groups else default_colocation_group 2089 2090 def values(self): 2091 """DEPRECATED: Use outputs.""" 2092 return tuple(self.outputs) 2093 2094 def _get_control_flow_context(self): 2095 """Returns the control flow context of this op. 2096 2097 Returns: 2098 A context object. 2099 """ 2100 return self._control_flow_context 2101 2102 def _set_control_flow_context(self, ctx): 2103 """Sets the current control flow context of this op. 2104 2105 Args: 2106 ctx: a context object. 2107 """ 2108 self._control_flow_context = ctx 2109 2110 @property 2111 def name(self): 2112 """The full name of this operation.""" 2113 return pywrap_tf_session.TF_OperationName(self._c_op) 2114 2115 @property 2116 def _id(self): 2117 """The unique integer id of this operation.""" 2118 return self._id_value 2119 2120 @property 2121 def device(self): 2122 """The name of the device to which this op has been assigned, if any. 2123 2124 Returns: 2125 The string name of the device to which this op has been 2126 assigned, or an empty string if it has not been assigned to a 2127 device. 2128 """ 2129 return pywrap_tf_session.TF_OperationDevice(self._c_op) 2130 2131 @property 2132 def _device_assignments(self): 2133 """Code locations for device context managers active at op creation. 2134 2135 This property will return a list of traceable_stack.TraceableObject 2136 instances where .obj is a string representing the assigned device 2137 (or information about the function that would be applied to this op 2138 to compute the desired device) and the filename and lineno members 2139 record the location of the relevant device context manager. 2140 2141 For example, suppose file_a contained these lines: 2142 2143 file_a.py: 2144 15: with tf.device('/gpu:0'): 2145 16: node_b = tf.constant(4, name='NODE_B') 2146 2147 Then a TraceableObject t_obj representing the device context manager 2148 would have these member values: 2149 2150 t_obj.obj -> '/gpu:0' 2151 t_obj.filename = 'file_a.py' 2152 t_obj.lineno = 15 2153 2154 and node_b.op._device_assignments would return the list [t_obj]. 2155 2156 Returns: 2157 [str: traceable_stack.TraceableObject, ...] as per this method's 2158 description, above. 2159 """ 2160 return self._device_code_locations or [] 2161 2162 @property 2163 def _colocation_dict(self): 2164 """Code locations for colocation context managers active at op creation. 2165 2166 This property will return a dictionary for which the keys are nodes with 2167 which this Operation is colocated, and for which the values are 2168 traceable_stack.TraceableObject instances. The TraceableObject instances 2169 record the location of the relevant colocation context manager but have the 2170 "obj" field set to None to prevent leaking private data. 2171 2172 For example, suppose file_a contained these lines: 2173 2174 file_a.py: 2175 14: node_a = tf.constant(3, name='NODE_A') 2176 15: with tf.compat.v1.colocate_with(node_a): 2177 16: node_b = tf.constant(4, name='NODE_B') 2178 2179 Then a TraceableObject t_obj representing the colocation context manager 2180 would have these member values: 2181 2182 t_obj.obj -> None 2183 t_obj.filename = 'file_a.py' 2184 t_obj.lineno = 15 2185 2186 and node_b.op._colocation_dict would return the dictionary 2187 2188 { 'NODE_A': t_obj } 2189 2190 Returns: 2191 {str: traceable_stack.TraceableObject} as per this method's description, 2192 above. 2193 """ 2194 locations_dict = self._colocation_code_locations or {} 2195 return locations_dict.copy() 2196 2197 @property 2198 def _output_types(self): 2199 """List this operation's output types. 2200 2201 Returns: 2202 List of the types of the Tensors computed by this operation. 2203 Each element in the list is an integer whose value is one of 2204 the TF_DataType enums defined in pywrap_tf_session.h 2205 The length of this list indicates the number of output endpoints 2206 of the operation. 2207 """ 2208 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 2209 output_types = [ 2210 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i))) 2211 for i in xrange(num_outputs) 2212 ] 2213 2214 return output_types 2215 2216 def _tf_output(self, output_idx): 2217 """Create and return a new TF_Output for output_idx'th output of this op.""" 2218 tf_output = pywrap_tf_session.TF_Output() 2219 tf_output.oper = self._c_op 2220 tf_output.index = output_idx 2221 return tf_output 2222 2223 def _tf_input(self, input_idx): 2224 """Create and return a new TF_Input for input_idx'th input of this op.""" 2225 tf_input = pywrap_tf_session.TF_Input() 2226 tf_input.oper = self._c_op 2227 tf_input.index = input_idx 2228 return tf_input 2229 2230 def _set_device(self, device): # pylint: disable=redefined-outer-name 2231 """Set the device of this operation. 2232 2233 Args: 2234 device: string or device.. The device to set. 2235 """ 2236 self._set_device_from_string(compat.as_str(_device_string(device))) 2237 2238 def _set_device_from_string(self, device_str): 2239 """Fast path to set device if the type is known to be a string. 2240 2241 This function is called frequently enough during graph construction that 2242 there are non-trivial performance gains if the caller can guarantee that 2243 the specified device is already a string. 2244 2245 Args: 2246 device_str: A string specifying where to place this op. 2247 """ 2248 pywrap_tf_session.SetRequestedDevice( 2249 self._graph._c_graph, # pylint: disable=protected-access 2250 self._c_op, # pylint: disable=protected-access 2251 device_str) 2252 2253 def _update_input(self, index, tensor): 2254 """Update the input to this operation at the given index. 2255 2256 NOTE: This is for TF internal use only. Please don't use it. 2257 2258 Args: 2259 index: the index of the input to update. 2260 tensor: the Tensor to be used as the input at the given index. 2261 2262 Raises: 2263 TypeError: if tensor is not a Tensor, 2264 or if input tensor type is not convertible to dtype. 2265 ValueError: if the Tensor is from a different graph. 2266 """ 2267 if not isinstance(tensor, Tensor): 2268 raise TypeError("tensor must be a Tensor: %s" % tensor) 2269 _assert_same_graph(self, tensor) 2270 2271 # Reset cached inputs. 2272 self._inputs_val = None 2273 pywrap_tf_session.UpdateEdge( 2274 self._graph._c_graph, # pylint: disable=protected-access 2275 tensor._as_tf_output(), # pylint: disable=protected-access 2276 self._tf_input(index)) 2277 2278 def _add_while_inputs(self, tensors): 2279 """See AddWhileInputHack in python_api.h. 2280 2281 NOTE: This is for TF internal use only. Please don't use it. 2282 2283 Args: 2284 tensors: list of Tensors 2285 2286 Raises: 2287 TypeError: if tensor is not a Tensor, 2288 or if input tensor type is not convertible to dtype. 2289 ValueError: if the Tensor is from a different graph. 2290 """ 2291 for tensor in tensors: 2292 if not isinstance(tensor, Tensor): 2293 raise TypeError("tensor must be a Tensor: %s" % tensor) 2294 _assert_same_graph(self, tensor) 2295 2296 # Reset cached inputs. 2297 self._inputs_val = None 2298 pywrap_tf_session.AddWhileInputHack( 2299 self._graph._c_graph, # pylint: disable=protected-access 2300 tensor._as_tf_output(), # pylint: disable=protected-access 2301 self._c_op) 2302 2303 def _add_control_inputs(self, ops): 2304 """Add a list of new control inputs to this operation. 2305 2306 Args: 2307 ops: the list of Operations to add as control input. 2308 2309 Raises: 2310 TypeError: if ops is not a list of Operations. 2311 ValueError: if any op in ops is from a different graph. 2312 """ 2313 for op in ops: 2314 if not isinstance(op, Operation): 2315 raise TypeError("op must be an Operation: %s" % op) 2316 pywrap_tf_session.AddControlInput( 2317 self._graph._c_graph, # pylint: disable=protected-access 2318 self._c_op, # pylint: disable=protected-access 2319 op._c_op) # pylint: disable=protected-access 2320 2321 def _add_control_input(self, op): 2322 """Add a new control input to this operation. 2323 2324 Args: 2325 op: the Operation to add as control input. 2326 2327 Raises: 2328 TypeError: if op is not an Operation. 2329 ValueError: if op is from a different graph. 2330 """ 2331 if not isinstance(op, Operation): 2332 raise TypeError("op must be an Operation: %s" % op) 2333 pywrap_tf_session.AddControlInput( 2334 self._graph._c_graph, # pylint: disable=protected-access 2335 self._c_op, # pylint: disable=protected-access 2336 op._c_op) # pylint: disable=protected-access 2337 2338 def _remove_all_control_inputs(self): 2339 """Removes any control inputs to this operation.""" 2340 pywrap_tf_session.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access 2341 2342 def _add_outputs(self, types, shapes): 2343 """Adds new Tensors to self.outputs. 2344 2345 Note: this is generally unsafe to use. This is used in certain situations in 2346 conjunction with _set_type_list_attr. 2347 2348 Args: 2349 types: list of DTypes 2350 shapes: list of TensorShapes 2351 """ 2352 assert len(types) == len(shapes) 2353 orig_num_outputs = len(self.outputs) 2354 for i in range(len(types)): 2355 t = Tensor(self, orig_num_outputs + i, types[i]) 2356 self._outputs.append(t) 2357 t.set_shape(shapes[i]) 2358 2359 def __str__(self): 2360 return str(self.node_def) 2361 2362 def __repr__(self): 2363 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 2364 2365 def __tf_tensor__(self, dtype=None, name=None): 2366 """Raises a helpful error.""" 2367 raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) 2368 2369 @property 2370 def outputs(self): 2371 """The list of `Tensor` objects representing the outputs of this op.""" 2372 return self._outputs 2373 2374 @property 2375 def inputs(self): 2376 """The sequence of `Tensor` objects representing the data inputs of this op.""" 2377 if self._inputs_val is None: 2378 # pylint: disable=protected-access 2379 self._inputs_val = tuple( 2380 map(self.graph._get_tensor_by_tf_output, 2381 pywrap_tf_session.GetOperationInputs(self._c_op))) 2382 # pylint: enable=protected-access 2383 return self._inputs_val 2384 2385 @property 2386 def _input_types(self): 2387 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op) 2388 input_types = [ 2389 dtypes.as_dtype( 2390 pywrap_tf_session.TF_OperationInputType(self._tf_input(i))) 2391 for i in xrange(num_inputs) 2392 ] 2393 return input_types 2394 2395 @property 2396 def control_inputs(self): 2397 """The `Operation` objects on which this op has a control dependency. 2398 2399 Before this op is executed, TensorFlow will ensure that the 2400 operations in `self.control_inputs` have finished executing. This 2401 mechanism can be used to run ops sequentially for performance 2402 reasons, or to ensure that the side effects of an op are observed 2403 in the correct order. 2404 2405 Returns: 2406 A list of `Operation` objects. 2407 2408 """ 2409 control_c_ops = pywrap_tf_session.TF_OperationGetControlInputs_wrapper( 2410 self._c_op) 2411 # pylint: disable=protected-access 2412 return [ 2413 self.graph._get_operation_by_name_unsafe( 2414 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2415 ] 2416 # pylint: enable=protected-access 2417 2418 @property 2419 def _control_outputs(self): 2420 """The `Operation` objects which have a control dependency on this op. 2421 2422 Before any of the ops in self._control_outputs can execute tensorflow will 2423 ensure self has finished executing. 2424 2425 Returns: 2426 A list of `Operation` objects. 2427 2428 """ 2429 control_c_ops = pywrap_tf_session.TF_OperationGetControlOutputs_wrapper( 2430 self._c_op) 2431 # pylint: disable=protected-access 2432 return [ 2433 self.graph._get_operation_by_name_unsafe( 2434 pywrap_tf_session.TF_OperationName(c_op)) for c_op in control_c_ops 2435 ] 2436 # pylint: enable=protected-access 2437 2438 @property 2439 def type(self): 2440 """The type of the op (e.g. `"MatMul"`).""" 2441 return pywrap_tf_session.TF_OperationOpType(self._c_op) 2442 2443 @property 2444 def graph(self): 2445 """The `Graph` that contains this operation.""" 2446 return self._graph 2447 2448 @property 2449 def node_def(self): 2450 # pylint: disable=line-too-long 2451 """Returns the `NodeDef` representation of this operation. 2452 2453 Returns: 2454 A 2455 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2456 protocol buffer. 2457 """ 2458 # pylint: enable=line-too-long 2459 with c_api_util.tf_buffer() as buf: 2460 pywrap_tf_session.TF_OperationToNodeDef(self._c_op, buf) 2461 data = pywrap_tf_session.TF_GetBuffer(buf) 2462 node_def = node_def_pb2.NodeDef() 2463 node_def.ParseFromString(compat.as_bytes(data)) 2464 return node_def 2465 2466 @property 2467 def op_def(self): 2468 # pylint: disable=line-too-long 2469 """Returns the `OpDef` proto that represents the type of this op. 2470 2471 Returns: 2472 An 2473 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2474 protocol buffer. 2475 """ 2476 # pylint: enable=line-too-long 2477 return self._graph._get_op_def(self.type) 2478 2479 @property 2480 def traceback(self): 2481 """Returns the call stack from when this operation was constructed.""" 2482 return self._traceback 2483 2484 def _set_attr(self, attr_name, attr_value): 2485 """Private method used to set an attribute in the node_def.""" 2486 buf = pywrap_tf_session.TF_NewBufferFromString( 2487 compat.as_bytes(attr_value.SerializeToString())) 2488 try: 2489 self._set_attr_with_buf(attr_name, buf) 2490 finally: 2491 pywrap_tf_session.TF_DeleteBuffer(buf) 2492 2493 def _set_attr_with_buf(self, attr_name, attr_buf): 2494 """Set an attr in the node_def with a pre-allocated buffer.""" 2495 # pylint: disable=protected-access 2496 pywrap_tf_session.SetAttr(self._graph._c_graph, self._c_op, attr_name, 2497 attr_buf) 2498 # pylint: enable=protected-access 2499 2500 def _set_func_attr(self, attr_name, func_name): 2501 """Private method used to set a function attribute in the node_def.""" 2502 func = attr_value_pb2.NameAttrList(name=func_name) 2503 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 2504 2505 def _set_func_list_attr(self, attr_name, func_names): 2506 """Private method used to set a list(function) attribute in the node_def.""" 2507 funcs = [attr_value_pb2.NameAttrList(name=func_name) 2508 for func_name in func_names] 2509 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) 2510 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) 2511 2512 def _set_type_list_attr(self, attr_name, types): 2513 """Private method used to set a list(type) attribute in the node_def.""" 2514 if not types: 2515 return 2516 if isinstance(types[0], dtypes.DType): 2517 types = [dt.as_datatype_enum for dt in types] 2518 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 2519 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 2520 2521 def _set_shape_list_attr(self, attr_name, shapes): 2522 """Private method used to set a list(shape) attribute in the node_def.""" 2523 shapes = [s.as_proto() for s in shapes] 2524 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 2525 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 2526 2527 def _clear_attr(self, attr_name): 2528 """Private method used to clear an attribute in the node_def.""" 2529 # pylint: disable=protected-access 2530 pywrap_tf_session.ClearAttr(self._graph._c_graph, self._c_op, attr_name) 2531 # pylint: enable=protected-access 2532 2533 def get_attr(self, name): 2534 """Returns the value of the attr of this op with the given `name`. 2535 2536 Args: 2537 name: The name of the attr to fetch. 2538 2539 Returns: 2540 The value of the attr, as a Python object. 2541 2542 Raises: 2543 ValueError: If this op does not have an attr with the given `name`. 2544 """ 2545 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 2546 try: 2547 with c_api_util.tf_buffer() as buf: 2548 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf) 2549 data = pywrap_tf_session.TF_GetBuffer(buf) 2550 except errors.InvalidArgumentError as e: 2551 # Convert to ValueError for backwards compatibility. 2552 raise ValueError(str(e)) 2553 x = attr_value_pb2.AttrValue() 2554 x.ParseFromString(data) 2555 2556 oneof_value = x.WhichOneof("value") 2557 if oneof_value is None: 2558 return [] 2559 if oneof_value == "list": 2560 for f in fields: 2561 if getattr(x.list, f): 2562 if f == "type": 2563 return [dtypes.as_dtype(t) for t in x.list.type] 2564 else: 2565 return list(getattr(x.list, f)) 2566 return [] 2567 if oneof_value == "type": 2568 return dtypes.as_dtype(x.type) 2569 assert oneof_value in fields, "Unsupported field type in " + str(x) 2570 return getattr(x, oneof_value) 2571 2572 def _get_attr_type(self, name): 2573 """Returns the `DType` value of the attr of this op with the given `name`.""" 2574 try: 2575 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name) 2576 return _DTYPES_INTERN_TABLE[dtype_enum] 2577 except errors.InvalidArgumentError as e: 2578 # Convert to ValueError for backwards compatibility. 2579 raise ValueError(str(e)) 2580 2581 def _get_attr_bool(self, name): 2582 """Returns the `bool` value of the attr of this op with the given `name`.""" 2583 try: 2584 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name) 2585 except errors.InvalidArgumentError as e: 2586 # Convert to ValueError for backwards compatibility. 2587 raise ValueError(str(e)) 2588 2589 def _get_attr_int(self, name): 2590 """Returns the `int` value of the attr of this op with the given `name`.""" 2591 try: 2592 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name) 2593 except errors.InvalidArgumentError as e: 2594 # Convert to ValueError for backwards compatibility. 2595 raise ValueError(str(e)) 2596 2597 def run(self, feed_dict=None, session=None): 2598 """Runs this operation in a `Session`. 2599 2600 Calling this method will execute all preceding operations that 2601 produce the inputs needed for this operation. 2602 2603 *N.B.* Before invoking `Operation.run()`, its graph must have been 2604 launched in a session, and either a default session must be 2605 available, or `session` must be specified explicitly. 2606 2607 Args: 2608 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 2609 `tf.Session.run` for a description of the valid feed values. 2610 session: (Optional.) The `Session` to be used to run to this operation. If 2611 none, the default session will be used. 2612 """ 2613 _run_using_default_session(self, feed_dict, self.graph, session) 2614 2615_gradient_registry = registry.Registry("gradient") 2616 2617 2618@tf_export("RegisterGradient") 2619class RegisterGradient(object): 2620 """A decorator for registering the gradient function for an op type. 2621 2622 This decorator is only used when defining a new op type. For an op 2623 with `m` inputs and `n` outputs, the gradient function is a function 2624 that takes the original `Operation` and `n` `Tensor` objects 2625 (representing the gradients with respect to each output of the op), 2626 and returns `m` `Tensor` objects (representing the partial gradients 2627 with respect to each input of the op). 2628 2629 For example, assuming that operations of type `"Sub"` take two 2630 inputs `x` and `y`, and return a single output `x - y`, the 2631 following gradient function would be registered: 2632 2633 ```python 2634 @tf.RegisterGradient("Sub") 2635 def _sub_grad(unused_op, grad): 2636 return grad, tf.negative(grad) 2637 ``` 2638 2639 The decorator argument `op_type` is the string type of an 2640 operation. This corresponds to the `OpDef.name` field for the proto 2641 that defines the operation. 2642 """ 2643 2644 __slots__ = ["_op_type"] 2645 2646 def __init__(self, op_type): 2647 """Creates a new decorator with `op_type` as the Operation type. 2648 2649 Args: 2650 op_type: The string type of an operation. This corresponds to the 2651 `OpDef.name` field for the proto that defines the operation. 2652 2653 Raises: 2654 TypeError: If `op_type` is not string. 2655 """ 2656 if not isinstance(op_type, six.string_types): 2657 raise TypeError("op_type must be a string") 2658 self._op_type = op_type 2659 2660 def __call__(self, f): 2661 """Registers the function `f` as gradient function for `op_type`.""" 2662 _gradient_registry.register(f, self._op_type) 2663 return f 2664 2665 2666@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 2667@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 2668def no_gradient(op_type): 2669 """Specifies that ops of type `op_type` is not differentiable. 2670 2671 This function should *not* be used for operations that have a 2672 well-defined gradient that is not yet implemented. 2673 2674 This function is only used when defining a new op type. It may be 2675 used for ops such as `tf.size()` that are not differentiable. For 2676 example: 2677 2678 ```python 2679 tf.no_gradient("Size") 2680 ``` 2681 2682 The gradient computed for 'op_type' will then propagate zeros. 2683 2684 For ops that have a well-defined gradient but are not yet implemented, 2685 no declaration should be made, and an error *must* be thrown if 2686 an attempt to request its gradient is made. 2687 2688 Args: 2689 op_type: The string type of an operation. This corresponds to the 2690 `OpDef.name` field for the proto that defines the operation. 2691 2692 Raises: 2693 TypeError: If `op_type` is not a string. 2694 2695 """ 2696 if not isinstance(op_type, six.string_types): 2697 raise TypeError("op_type must be a string") 2698 _gradient_registry.register(None, op_type) 2699 2700 2701# Aliases for the old names, will be eventually removed. 2702NoGradient = no_gradient 2703NotDifferentiable = no_gradient 2704 2705 2706def get_gradient_function(op): 2707 """Returns the function that computes gradients for "op".""" 2708 if not op.inputs: 2709 return None 2710 2711 gradient_function = op._gradient_function # pylint: disable=protected-access 2712 if gradient_function: 2713 return gradient_function 2714 2715 try: 2716 op_type = op.get_attr("_gradient_op_type") 2717 except ValueError: 2718 op_type = op.type 2719 return _gradient_registry.lookup(op_type) 2720 2721 2722def set_shape_and_handle_data_for_outputs(_): 2723 """No op. TODO(b/74620627): Remove this.""" 2724 pass 2725 2726 2727class OpStats(object): 2728 """A holder for statistics about an operator. 2729 2730 This class holds information about the resource requirements for an op, 2731 including the size of its weight parameters on-disk and how many FLOPS it 2732 requires to execute forward inference. 2733 2734 If you define a new operation, you can create a function that will return a 2735 set of information about its usage of the CPU and disk space when serialized. 2736 The function itself takes a Graph object that's been set up so you can call 2737 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2738 argument. 2739 2740 """ 2741 2742 __slots__ = ["_statistic_type", "_value"] 2743 2744 def __init__(self, statistic_type, value=None): 2745 """Sets up the initial placeholders for the statistics.""" 2746 self.statistic_type = statistic_type 2747 self.value = value 2748 2749 @property 2750 def statistic_type(self): 2751 return self._statistic_type 2752 2753 @statistic_type.setter 2754 def statistic_type(self, statistic_type): 2755 self._statistic_type = statistic_type 2756 2757 @property 2758 def value(self): 2759 return self._value 2760 2761 @value.setter 2762 def value(self, value): 2763 self._value = value 2764 2765 def __iadd__(self, other): 2766 if other.statistic_type != self.statistic_type: 2767 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2768 (self.statistic_type, other.statistic_type)) 2769 if self.value is None: 2770 self.value = other.value 2771 elif other.value is not None: 2772 self._value += other.value 2773 return self 2774 2775 2776_stats_registry = registry.Registry("statistical functions") 2777 2778 2779class RegisterStatistics(object): 2780 """A decorator for registering the statistics function for an op type. 2781 2782 This decorator can be defined for an op type so that it gives a 2783 report on the resources used by an instance of an operator, in the 2784 form of an OpStats object. 2785 2786 Well-known types of statistics include these so far: 2787 2788 - flops: When running a graph, the bulk of the computation happens doing 2789 numerical calculations like matrix multiplications. This type allows a node 2790 to return how many floating-point operations it takes to complete. The 2791 total number of FLOPs for a graph is a good guide to its expected latency. 2792 2793 You can add your own statistics just by picking a new type string, registering 2794 functions for the ops you care about, and then calling get_stats_for_node_def. 2795 2796 If a statistic for an op is registered multiple times, a KeyError will be 2797 raised. 2798 2799 Since the statistics is counted on a per-op basis. It is not suitable for 2800 model parameters (capacity), which is expected to be counted only once, even 2801 if it is shared by multiple ops. (e.g. RNN) 2802 2803 For example, you can define a new metric called doohickey for a Foo operation 2804 by placing this in your code: 2805 2806 ```python 2807 @ops.RegisterStatistics("Foo", "doohickey") 2808 def _calc_foo_bojangles(unused_graph, unused_node_def): 2809 return ops.OpStats("doohickey", 20) 2810 ``` 2811 2812 Then in client code you can retrieve the value by making this call: 2813 2814 ```python 2815 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2816 ``` 2817 2818 If the NodeDef is for an op with a registered doohickey function, you'll get 2819 back the calculated amount in doohickey.value, or None if it's not defined. 2820 2821 """ 2822 2823 __slots__ = ["_op_type", "_statistic_type"] 2824 2825 def __init__(self, op_type, statistic_type): 2826 """Saves the `op_type` as the `Operation` type.""" 2827 if not isinstance(op_type, six.string_types): 2828 raise TypeError("op_type must be a string.") 2829 if "," in op_type: 2830 raise TypeError("op_type must not contain a comma.") 2831 self._op_type = op_type 2832 if not isinstance(statistic_type, six.string_types): 2833 raise TypeError("statistic_type must be a string.") 2834 if "," in statistic_type: 2835 raise TypeError("statistic_type must not contain a comma.") 2836 self._statistic_type = statistic_type 2837 2838 def __call__(self, f): 2839 """Registers "f" as the statistics function for "op_type".""" 2840 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 2841 return f 2842 2843 2844def get_stats_for_node_def(graph, node, statistic_type): 2845 """Looks up the node's statistics function in the registry and calls it. 2846 2847 This function takes a Graph object and a NodeDef from a GraphDef, and if 2848 there's an associated statistics method, calls it and returns a result. If no 2849 function has been registered for the particular node type, it returns an empty 2850 statistics object. 2851 2852 Args: 2853 graph: A Graph object that's been set up with the node's graph. 2854 node: A NodeDef describing the operator. 2855 statistic_type: A string identifying the statistic we're interested in. 2856 2857 Returns: 2858 An OpStats object containing information about resource usage. 2859 """ 2860 2861 try: 2862 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 2863 result = stats_func(graph, node) 2864 except LookupError: 2865 result = OpStats(statistic_type) 2866 return result 2867 2868 2869def name_from_scope_name(name): 2870 """Returns the name of an op given the name of its scope. 2871 2872 Args: 2873 name: the name of the scope. 2874 2875 Returns: 2876 the name of the op (equal to scope name minus any trailing slash). 2877 """ 2878 return name[:-1] if (name and name[-1] == "/") else name 2879 2880 2881_MUTATION_LOCK_GROUP = 0 2882_SESSION_RUN_LOCK_GROUP = 1 2883 2884 2885@tf_export("Graph") 2886class Graph(object): 2887 """A TensorFlow computation, represented as a dataflow graph. 2888 2889 Graphs are used by `tf.function`s to represent the function's computations. 2890 Each graph contains a set of `tf.Operation` objects, which represent units of 2891 computation; and `tf.Tensor` objects, which represent the units of data that 2892 flow between operations. 2893 2894 ### Using graphs directly (deprecated) 2895 2896 A `tf.Graph` can be constructed and used directly without a `tf.function`, as 2897 was required in TensorFlow 1, but this is deprecated and it is recommended to 2898 use a `tf.function` instead. If a graph is directly used, other deprecated 2899 TensorFlow 1 classes are also required to execute the graph, such as a 2900 `tf.compat.v1.Session`. 2901 2902 A default graph can be registered with the `tf.Graph.as_default` context 2903 manager. Then, operations will be added to the graph instead of being executed 2904 eagerly. For example: 2905 2906 ```python 2907 g = tf.Graph() 2908 with g.as_default(): 2909 # Define operations and tensors in `g`. 2910 c = tf.constant(30.0) 2911 assert c.graph is g 2912 ``` 2913 2914 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph. 2915 2916 Important note: This class *is not* thread-safe for graph construction. All 2917 operations should be created from a single thread, or external 2918 synchronization must be provided. Unless otherwise specified, all methods 2919 are not thread-safe. 2920 2921 A `Graph` instance supports an arbitrary number of "collections" 2922 that are identified by name. For convenience when building a large 2923 graph, collections can store groups of related objects: for 2924 example, the `tf.Variable` uses a collection (named 2925 `tf.GraphKeys.GLOBAL_VARIABLES`) for 2926 all variables that are created during the construction of a graph. The caller 2927 may define additional collections by specifying a new name. 2928 """ 2929 2930 def __init__(self): 2931 """Creates a new, empty Graph.""" 2932 # Protects core state that can be returned via public accessors. 2933 # Thread-safety is provided on a best-effort basis to support buggy 2934 # programs, and is not guaranteed by the public `tf.Graph` API. 2935 # 2936 # NOTE(mrry): This does not protect the various stacks. A warning will 2937 # be reported if these are used from multiple threads 2938 self._lock = threading.RLock() 2939 # The group lock synchronizes Session.run calls with methods that create 2940 # and mutate ops (e.g. Graph.create_op()). This synchronization is 2941 # necessary because it's illegal to modify an operation after it's been run. 2942 # The group lock allows any number of threads to mutate ops at the same time 2943 # but if any modification is going on, all Session.run calls have to wait. 2944 # Similarly, if one or more Session.run calls are going on, all mutate ops 2945 # have to wait until all Session.run calls have finished. 2946 self._group_lock = lock_util.GroupLock(num_groups=2) 2947 self._nodes_by_id = {} # GUARDED_BY(self._lock) 2948 self._next_id_counter = 0 # GUARDED_BY(self._lock) 2949 self._nodes_by_name = {} # GUARDED_BY(self._lock) 2950 self._version = 0 # GUARDED_BY(self._lock) 2951 # Maps a name used in the graph to the next id to use for that name. 2952 self._names_in_use = {} 2953 self._stack_state_is_thread_local = False 2954 self._thread_local = threading.local() 2955 # Functions that will be applied to choose a device if none is specified. 2956 # In TF2.x or after switch_to_thread_local(), 2957 # self._thread_local._device_function_stack is used instead. 2958 self._graph_device_function_stack = traceable_stack.TraceableStack() 2959 # Default original_op applied to new ops. 2960 self._default_original_op = None 2961 # Current control flow context. It could be either CondContext or 2962 # WhileContext defined in ops/control_flow_ops.py 2963 self._control_flow_context = None 2964 # A new node will depend of the union of all of the nodes in the stack. 2965 # In TF2.x or after switch_to_thread_local(), 2966 # self._thread_local._control_dependencies_stack is used instead. 2967 self._graph_control_dependencies_stack = [] 2968 # Arbitrary collections of objects. 2969 self._collections = {} 2970 # The graph-level random seed 2971 self._seed = None 2972 # A dictionary of attributes that should be applied to all ops. 2973 self._attr_scope_map = {} 2974 # A map from op type to the kernel label that should be used. 2975 self._op_to_kernel_label_map = {} 2976 # A map from op type to an alternative op type that should be used when 2977 # computing gradients. 2978 self._gradient_override_map = {} 2979 # A map from op type to a gradient function that should be used instead. 2980 self._gradient_function_map = {} 2981 # True if the graph is considered "finalized". In that case no 2982 # new operations can be added. 2983 self._finalized = False 2984 # Functions defined in the graph 2985 self._functions = collections.OrderedDict() 2986 # Default GraphDef versions 2987 self._graph_def_versions = versions_pb2.VersionDef( 2988 producer=versions.GRAPH_DEF_VERSION, 2989 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 2990 self._building_function = False 2991 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 2992 # self._thread_local._colocation_stack is used instead. 2993 self._graph_colocation_stack = traceable_stack.TraceableStack() 2994 # Set of tensors that are dangerous to feed! 2995 self._unfeedable_tensors = object_identity.ObjectIdentitySet() 2996 # Set of operations that are dangerous to fetch! 2997 self._unfetchable_ops = set() 2998 # A map of tensor handle placeholder to tensor dtype. 2999 self._handle_feeders = {} 3000 # A map from tensor handle to its read op. 3001 self._handle_readers = {} 3002 # A map from tensor handle to its move op. 3003 self._handle_movers = {} 3004 # A map from tensor handle to its delete op. 3005 self._handle_deleters = {} 3006 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 3007 # will be shared when defining function graphs, for example, so optimizers 3008 # being called inside function definitions behave as if they were seeing the 3009 # actual outside graph). 3010 self._graph_key = "grap-key-%d/" % (uid(),) 3011 # A string with the last reduction method passed to 3012 # losses.compute_weighted_loss(), or None. This is required only for 3013 # backward compatibility with Estimator and optimizer V1 use cases. 3014 self._last_loss_reduction = None 3015 # Flag that is used to indicate whether loss has been scaled by optimizer. 3016 # If this flag has been set, then estimator uses it to scale losss back 3017 # before reporting. This is required only for backward compatibility with 3018 # Estimator and optimizer V1 use cases. 3019 self._is_loss_scaled_by_optimizer = False 3020 self._container = "" 3021 # Set to True if this graph is being built in an 3022 # AutomaticControlDependencies context. 3023 self._add_control_dependencies = False 3024 # Cache for OpDef protobufs retrieved via the C API. 3025 self._op_def_cache = {} 3026 # Cache for constant results of `broadcast_gradient_args()`. The keys are 3027 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the 3028 # values are tuples of reduction indices: (rx, ry). 3029 self._bcast_grad_args_cache = {} 3030 # Cache for constant results of `reduced_shape()`. The keys are pairs of 3031 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values 3032 # are pairs of tuples: (output_shape_kept_dims, tile_scaling). 3033 self._reduced_shape_cache = {} 3034 3035 # TODO(skyewm): fold as much of the above as possible into the C 3036 # implementation 3037 self._scoped_c_graph = c_api_util.ScopedTFGraph() 3038 # The C API requires all ops to have shape functions. Disable this 3039 # requirement (many custom ops do not have shape functions, and we don't 3040 # want to break these existing cases). 3041 pywrap_tf_session.SetRequireShapeInferenceFns(self._c_graph, False) 3042 if tf2.enabled(): 3043 self.switch_to_thread_local() 3044 3045 # Note: this method is private because the API of tf.Graph() is public and 3046 # frozen, and this functionality is still not ready for public visibility. 3047 @tf_contextlib.contextmanager 3048 def _variable_creator_scope(self, creator, priority=100): 3049 """Scope which defines a variable creation function. 3050 3051 Args: 3052 creator: A callable taking `next_creator` and `kwargs`. See the 3053 `tf.variable_creator_scope` docstring. 3054 priority: Creators with a higher `priority` are called first. Within the 3055 same priority, creators are called inner-to-outer. 3056 3057 Yields: 3058 `_variable_creator_scope` is a context manager with a side effect, but 3059 doesn't return a value. 3060 3061 Raises: 3062 RuntimeError: If variable creator scopes are not properly nested. 3063 """ 3064 # This step keeps a reference to the existing stack, and it also initializes 3065 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3066 old = self._variable_creator_stack 3067 new = list(old) 3068 new.append((priority, creator)) 3069 # Sorting is stable, so we'll put higher-priority creators later in the list 3070 # but otherwise maintain registration order. 3071 new.sort(key=lambda item: item[0]) 3072 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access 3073 try: 3074 yield 3075 finally: 3076 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access 3077 raise RuntimeError( 3078 "Exiting variable_creator_scope without proper nesting.") 3079 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 3080 3081 # Note: this method is private because the API of tf.Graph() is public and 3082 # frozen, and this functionality is still not ready for public visibility. 3083 @property 3084 def _variable_creator_stack(self): 3085 if not hasattr(self._thread_local, "_variable_creator_stack"): 3086 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 3087 3088 # This previously returned a copy of the stack instead of the stack itself, 3089 # to guard against accidental mutation. Consider, however, code that wants 3090 # to save and restore the variable creator stack: 3091 # def f(): 3092 # original_stack = graph._variable_creator_stack 3093 # graph._variable_creator_stack = new_stack 3094 # ... # Some code 3095 # graph._variable_creator_stack = original_stack 3096 # 3097 # And lets say you have some code that calls this function with some 3098 # variable_creator: 3099 # def g(): 3100 # with variable_scope.variable_creator_scope(creator): 3101 # f() 3102 # When exiting the variable creator scope, it would see a different stack 3103 # object than it expected leading to a "Exiting variable_creator_scope 3104 # without proper nesting" error. 3105 return self._thread_local._variable_creator_stack # pylint: disable=protected-access 3106 3107 @_variable_creator_stack.setter 3108 def _variable_creator_stack(self, variable_creator_stack): 3109 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 3110 3111 def _check_not_finalized(self): 3112 """Check if the graph is finalized. 3113 3114 Raises: 3115 RuntimeError: If the graph finalized. 3116 """ 3117 if self._finalized: 3118 raise RuntimeError("Graph is finalized and cannot be modified.") 3119 3120 def _add_op(self, op, op_name): 3121 """Adds 'op' to the graph and returns the unique ID for the added Operation. 3122 3123 Args: 3124 op: the Operation to add. 3125 op_name: the name of the Operation. 3126 3127 Returns: 3128 An integer that is a unique ID for the added Operation. 3129 """ 3130 self._check_not_finalized() 3131 with self._lock: 3132 self._next_id_counter += 1 3133 op_id = self._next_id_counter 3134 self._nodes_by_id[op_id] = op 3135 self._nodes_by_name[op_name] = op 3136 self._version = max(self._version, op_id) 3137 return op_id 3138 3139 @property 3140 def _c_graph(self): 3141 if self._scoped_c_graph: 3142 return self._scoped_c_graph.graph 3143 return None 3144 3145 @property 3146 def version(self): 3147 """Returns a version number that increases as ops are added to the graph. 3148 3149 Note that this is unrelated to the 3150 `tf.Graph.graph_def_versions`. 3151 3152 Returns: 3153 An integer version that increases as ops are added to the graph. 3154 """ 3155 if self._finalized: 3156 return self._version 3157 3158 with self._lock: 3159 return self._version 3160 3161 @property 3162 def graph_def_versions(self): 3163 # pylint: disable=line-too-long 3164 """The GraphDef version information of this graph. 3165 3166 For details on the meaning of each version, see 3167 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 3168 3169 Returns: 3170 A `VersionDef`. 3171 """ 3172 # pylint: enable=line-too-long 3173 with c_api_util.tf_buffer() as buf: 3174 pywrap_tf_session.TF_GraphVersions(self._c_graph, buf) 3175 data = pywrap_tf_session.TF_GetBuffer(buf) 3176 version_def = versions_pb2.VersionDef() 3177 version_def.ParseFromString(compat.as_bytes(data)) 3178 return version_def 3179 3180 @property 3181 def seed(self): 3182 """The graph-level random seed of this graph.""" 3183 return self._seed 3184 3185 @seed.setter 3186 def seed(self, seed): 3187 self._seed = seed 3188 3189 @property 3190 def finalized(self): 3191 """True if this graph has been finalized.""" 3192 return self._finalized 3193 3194 def finalize(self): 3195 """Finalizes this graph, making it read-only. 3196 3197 After calling `g.finalize()`, no new operations can be added to 3198 `g`. This method is used to ensure that no operations are added 3199 to a graph when it is shared between multiple threads, for example 3200 when using a `tf.compat.v1.train.QueueRunner`. 3201 """ 3202 self._finalized = True 3203 3204 def _unsafe_unfinalize(self): 3205 """Opposite of `finalize`. 3206 3207 Internal interface. 3208 3209 NOTE: Unfinalizing a graph could have negative impact on performance, 3210 especially in a multi-threaded environment. Unfinalizing a graph 3211 when it is in use by a Session may lead to undefined behavior. Ensure 3212 that all sessions using a graph are closed before calling this method. 3213 """ 3214 self._finalized = False 3215 3216 def _get_control_flow_context(self): 3217 """Returns the current control flow context. 3218 3219 Returns: 3220 A context object. 3221 """ 3222 return self._control_flow_context 3223 3224 def _set_control_flow_context(self, ctx): 3225 """Sets the current control flow context. 3226 3227 Args: 3228 ctx: a context object. 3229 """ 3230 self._control_flow_context = ctx 3231 3232 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 3233 """If this graph contains functions, copy them to `graph_def`.""" 3234 bytesize = starting_bytesize 3235 for f in self._functions.values(): 3236 bytesize += f.definition.ByteSize() 3237 if bytesize >= (1 << 31) or bytesize < 0: 3238 raise ValueError("GraphDef cannot be larger than 2GB.") 3239 graph_def.library.function.extend([f.definition]) 3240 if f.grad_func_name: 3241 grad_def = function_pb2.GradientDef() 3242 grad_def.function_name = f.name 3243 grad_def.gradient_func = f.grad_func_name 3244 graph_def.library.gradient.extend([grad_def]) 3245 3246 def _as_graph_def(self, from_version=None, add_shapes=False): 3247 # pylint: disable=line-too-long 3248 """Returns a serialized `GraphDef` representation of this graph. 3249 3250 The serialized `GraphDef` can be imported into another `Graph` 3251 (using `tf.import_graph_def`) or used with the 3252 [C++ Session API](../../../../api_docs/cc/index.md). 3253 3254 This method is thread-safe. 3255 3256 Args: 3257 from_version: Optional. If this is set, returns a `GraphDef` containing 3258 only the nodes that were added to this graph since its `version` 3259 property had the given value. 3260 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3261 the inferred shapes of each of its outputs. 3262 3263 Returns: 3264 A tuple containing a 3265 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3266 protocol buffer, and the version of the graph to which that 3267 `GraphDef` corresponds. 3268 3269 Raises: 3270 ValueError: If the `graph_def` would be too large. 3271 3272 """ 3273 # pylint: enable=line-too-long 3274 with self._lock: 3275 with c_api_util.tf_buffer() as buf: 3276 pywrap_tf_session.TF_GraphToGraphDef(self._c_graph, buf) 3277 data = pywrap_tf_session.TF_GetBuffer(buf) 3278 graph = graph_pb2.GraphDef() 3279 graph.ParseFromString(compat.as_bytes(data)) 3280 # Strip the experimental library field iff it's empty. 3281 if not graph.library.function: 3282 graph.ClearField("library") 3283 3284 if add_shapes: 3285 for node in graph.node: 3286 op = self._nodes_by_name[node.name] 3287 if op.outputs: 3288 node.attr["_output_shapes"].list.shape.extend( 3289 [output.get_shape().as_proto() for output in op.outputs]) 3290 for function_def in graph.library.function: 3291 defined_function = self._functions[function_def.signature.name] 3292 try: 3293 func_graph = defined_function.graph 3294 except AttributeError: 3295 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction 3296 # does. Both rely on ops.py, so we can't really isinstance check 3297 # them. 3298 continue 3299 input_shapes = function_def.attr["_input_shapes"] 3300 try: 3301 func_graph_inputs = func_graph.inputs 3302 except AttributeError: 3303 continue 3304 # TODO(b/141471245): Fix the inconsistency when inputs of func graph 3305 # are appended during gradient computation of while/cond. 3306 for input_tensor, arg_def in zip(func_graph_inputs, 3307 function_def.signature.input_arg): 3308 input_shapes.list.shape.add().CopyFrom( 3309 input_tensor.get_shape().as_proto()) 3310 if input_tensor.dtype == dtypes.resource: 3311 _copy_handle_data_to_arg_def(input_tensor, arg_def) 3312 3313 for output_tensor, arg_def in zip(func_graph.outputs, 3314 function_def.signature.output_arg): 3315 if output_tensor.dtype == dtypes.resource: 3316 _copy_handle_data_to_arg_def(output_tensor, arg_def) 3317 3318 for node in function_def.node_def: 3319 try: 3320 op = func_graph.get_operation_by_name(node.name) 3321 except KeyError: 3322 continue 3323 outputs = op.outputs 3324 3325 if op.type == "StatefulPartitionedCall": 3326 # Filter out any extra outputs (possibly added by function 3327 # backpropagation rewriting). 3328 num_outputs = len(node.attr["Tout"].list.type) 3329 outputs = outputs[:num_outputs] 3330 3331 node.attr["_output_shapes"].list.shape.extend( 3332 [output.get_shape().as_proto() for output in outputs]) 3333 3334 return graph, self._version 3335 3336 def as_graph_def(self, from_version=None, add_shapes=False): 3337 # pylint: disable=line-too-long 3338 """Returns a serialized `GraphDef` representation of this graph. 3339 3340 The serialized `GraphDef` can be imported into another `Graph` 3341 (using `tf.import_graph_def`) or used with the 3342 [C++ Session API](../../api_docs/cc/index.md). 3343 3344 This method is thread-safe. 3345 3346 Args: 3347 from_version: Optional. If this is set, returns a `GraphDef` containing 3348 only the nodes that were added to this graph since its `version` 3349 property had the given value. 3350 add_shapes: If true, adds an "_output_shapes" list attr to each node with 3351 the inferred shapes of each of its outputs. 3352 3353 Returns: 3354 A 3355 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3356 protocol buffer. 3357 3358 Raises: 3359 ValueError: If the `graph_def` would be too large. 3360 """ 3361 # pylint: enable=line-too-long 3362 result, _ = self._as_graph_def(from_version, add_shapes) 3363 return result 3364 3365 def _is_function(self, name): 3366 """Tests whether 'name' is registered in this graph's function library. 3367 3368 Args: 3369 name: string op name. 3370 3371 Returns: 3372 bool indicating whether or not 'name' is registered in function library. 3373 """ 3374 return compat.as_str(name) in self._functions 3375 3376 def _get_function(self, name): 3377 """Returns the function definition for 'name'. 3378 3379 Args: 3380 name: string function name. 3381 3382 Returns: 3383 The function def proto. 3384 """ 3385 return self._functions.get(compat.as_str(name), None) 3386 3387 def _add_function(self, function): 3388 """Adds a function to the graph. 3389 3390 After the function has been added, you can call to the function by 3391 passing the function name in place of an op name to 3392 `Graph.create_op()`. 3393 3394 Args: 3395 function: A `_DefinedFunction` object. 3396 3397 Raises: 3398 ValueError: if another function is defined with the same name. 3399 """ 3400 self._check_not_finalized() 3401 3402 name = function.name 3403 # Sanity checks on gradient definition. 3404 if (function.grad_func_name is not None) and (function.python_grad_func is 3405 not None): 3406 raise ValueError("Gradient defined twice for function %s" % name) 3407 3408 # Add function to graph 3409 # pylint: disable=protected-access 3410 gradient = ( 3411 function._grad_func._c_func.func if function._grad_func else None) 3412 pywrap_tf_session.TF_GraphCopyFunction(self._c_graph, function._c_func.func, 3413 gradient) 3414 # pylint: enable=protected-access 3415 3416 self._functions[compat.as_str(name)] = function 3417 3418 # Need a new-enough consumer to support the functions we add to the graph. 3419 if self._graph_def_versions.min_consumer < 12: 3420 self._graph_def_versions.min_consumer = 12 3421 3422 @property 3423 def building_function(self): 3424 """Returns True iff this graph represents a function.""" 3425 return self._building_function 3426 3427 # Helper functions to create operations. 3428 @deprecated_args(None, 3429 "Shapes are always computed; don't use the compute_shapes " 3430 "as it has no effect.", "compute_shapes") 3431 def create_op( 3432 self, 3433 op_type, 3434 inputs, 3435 dtypes=None, # pylint: disable=redefined-outer-name 3436 input_types=None, 3437 name=None, 3438 attrs=None, 3439 op_def=None, 3440 compute_shapes=True, 3441 compute_device=True): 3442 """Creates an `Operation` in this graph. 3443 3444 This is a low-level interface for creating an `Operation`. Most 3445 programs will not call this method directly, and instead use the 3446 Python op constructors, such as `tf.constant()`, which add ops to 3447 the default graph. 3448 3449 Args: 3450 op_type: The `Operation` type to create. This corresponds to the 3451 `OpDef.name` field for the proto that defines the operation. 3452 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3453 dtypes: (Optional) A list of `DType` objects that will be the types of the 3454 tensors that the operation produces. 3455 input_types: (Optional.) A list of `DType`s that will be the types of the 3456 tensors that the operation consumes. By default, uses the base `DType` 3457 of each input in `inputs`. Operations that expect reference-typed inputs 3458 must specify `input_types` explicitly. 3459 name: (Optional.) A string name for the operation. If not specified, a 3460 name is generated based on `op_type`. 3461 attrs: (Optional.) A dictionary where the key is the attribute name (a 3462 string) and the value is the respective `attr` attribute of the 3463 `NodeDef` proto that will represent the operation (an `AttrValue` 3464 proto). 3465 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3466 the operation will have. 3467 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 3468 computed). 3469 compute_device: (Optional.) If True, device functions will be executed to 3470 compute the device property of the Operation. 3471 3472 Raises: 3473 TypeError: if any of the inputs is not a `Tensor`. 3474 ValueError: if colocation conflicts with existing device assignment. 3475 3476 Returns: 3477 An `Operation` object. 3478 """ 3479 del compute_shapes 3480 for idx, a in enumerate(inputs): 3481 if not isinstance(a, Tensor): 3482 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3483 return self._create_op_internal(op_type, inputs, dtypes, input_types, name, 3484 attrs, op_def, compute_device) 3485 3486 def _create_op_internal( 3487 self, 3488 op_type, 3489 inputs, 3490 dtypes=None, # pylint: disable=redefined-outer-name 3491 input_types=None, 3492 name=None, 3493 attrs=None, 3494 op_def=None, 3495 compute_device=True): 3496 """Creates an `Operation` in this graph. 3497 3498 Implements `Graph.create_op()` without the overhead of the deprecation 3499 wrapper. 3500 3501 Args: 3502 op_type: The `Operation` type to create. This corresponds to the 3503 `OpDef.name` field for the proto that defines the operation. 3504 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3505 dtypes: (Optional) A list of `DType` objects that will be the types of the 3506 tensors that the operation produces. 3507 input_types: (Optional.) A list of `DType`s that will be the types of the 3508 tensors that the operation consumes. By default, uses the base `DType` 3509 of each input in `inputs`. Operations that expect reference-typed inputs 3510 must specify `input_types` explicitly. 3511 name: (Optional.) A string name for the operation. If not specified, a 3512 name is generated based on `op_type`. 3513 attrs: (Optional.) A dictionary where the key is the attribute name (a 3514 string) and the value is the respective `attr` attribute of the 3515 `NodeDef` proto that will represent the operation (an `AttrValue` 3516 proto). 3517 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3518 the operation will have. 3519 compute_device: (Optional.) If True, device functions will be executed to 3520 compute the device property of the Operation. 3521 3522 Raises: 3523 ValueError: if colocation conflicts with existing device assignment. 3524 3525 Returns: 3526 An `Operation` object. 3527 """ 3528 self._check_not_finalized() 3529 if name is None: 3530 name = op_type 3531 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3532 # after removing the trailing '/'. 3533 if name and name[-1] == "/": 3534 name = name_from_scope_name(name) 3535 else: 3536 name = self.unique_name(name) 3537 3538 node_def = _NodeDef(op_type, name, attrs) 3539 3540 input_ops = set(t.op for t in inputs) 3541 control_inputs = self._control_dependencies_for_inputs(input_ops) 3542 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 3543 # Session.run call cannot occur between creating and mutating the op. 3544 with self._mutation_lock(): 3545 ret = Operation( 3546 node_def, 3547 self, 3548 inputs=inputs, 3549 output_types=dtypes, 3550 control_inputs=control_inputs, 3551 input_types=input_types, 3552 original_op=self._default_original_op, 3553 op_def=op_def) 3554 self._create_op_helper(ret, compute_device=compute_device) 3555 return ret 3556 3557 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3558 """Creates an `Operation` in this graph from the supplied TF_Operation. 3559 3560 This method is like create_op() except the new Operation is constructed 3561 using `c_op`. The returned Operation will have `c_op` as its _c_op 3562 field. This is used to create Operation objects around TF_Operations created 3563 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3564 3565 This function does not call Operation._control_flow_post_processing or 3566 Graph._control_dependencies_for_inputs (since the inputs may not be 3567 available yet). The caller is responsible for calling these methods. 3568 3569 Args: 3570 c_op: a wrapped TF_Operation 3571 compute_device: (Optional.) If True, device functions will be executed to 3572 compute the device property of the Operation. 3573 3574 Returns: 3575 An `Operation` object. 3576 """ 3577 self._check_not_finalized() 3578 ret = Operation(c_op, self) 3579 # If a name_scope was created with ret.name but no nodes were created in it, 3580 # the name will still appear in _names_in_use even though the name hasn't 3581 # been used. This is ok, just leave _names_in_use as-is in this case. 3582 # TODO(skyewm): make the C API guarantee no name conflicts. 3583 name_key = ret.name.lower() 3584 if name_key not in self._names_in_use: 3585 self._names_in_use[name_key] = 1 3586 self._create_op_helper(ret, compute_device=compute_device) 3587 return ret 3588 3589 def _create_op_helper(self, op, compute_device=True): 3590 """Common logic for creating an op in this graph.""" 3591 # Apply any additional attributes requested. Do not overwrite any existing 3592 # attributes. 3593 for key, value in self._attr_scope_map.items(): 3594 try: 3595 op.get_attr(key) 3596 except ValueError: 3597 if callable(value): 3598 value = value(op.node_def) 3599 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3600 raise TypeError( 3601 "Callable for scope map key '%s' must return either None or " 3602 "an AttrValue protocol buffer; but it returned: %s" % 3603 (key, value)) 3604 if value: 3605 op._set_attr(key, value) # pylint: disable=protected-access 3606 3607 # Apply a kernel label if one has been specified for this op type. 3608 try: 3609 kernel_label = self._op_to_kernel_label_map[op.type] 3610 op._set_attr("_kernel", # pylint: disable=protected-access 3611 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3612 except KeyError: 3613 pass 3614 3615 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access 3616 3617 # Apply the overriding op type for gradients if one has been specified for 3618 # this op type. 3619 try: 3620 mapped_op_type = self._gradient_override_map[op.type] 3621 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3622 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3623 except KeyError: 3624 pass 3625 3626 self._record_op_seen_by_control_dependencies(op) 3627 3628 if compute_device: 3629 self._apply_device_functions(op) 3630 3631 # Snapshot the colocation stack metadata before we might generate error 3632 # messages using it. Note that this snapshot depends on the actual stack 3633 # and is independent of the op's _class attribute. 3634 # pylint: disable=protected-access 3635 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 3636 # pylint: enable=protected-access 3637 3638 if self._colocation_stack: 3639 all_colocation_groups = [] 3640 is_device_set = False 3641 for colocation_op in self._colocation_stack.peek_objs(): 3642 try: 3643 all_colocation_groups.extend(colocation_op.colocation_groups()) 3644 except AttributeError: 3645 pass 3646 if colocation_op.device and not is_device_set: 3647 # pylint: disable=protected-access 3648 op._set_device(colocation_op.device) 3649 # pylint: enable=protected-access 3650 is_device_set = True 3651 3652 all_colocation_groups = sorted(set(all_colocation_groups)) 3653 # pylint: disable=protected-access 3654 op._set_attr( 3655 "_class", 3656 attr_value_pb2.AttrValue( 3657 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3658 # pylint: enable=protected-access 3659 3660 # Sets "container" attribute if 3661 # (1) self._container is not None 3662 # (2) "is_stateful" is set in OpDef 3663 # (3) "container" attribute is in OpDef 3664 # (4) "container" attribute is None 3665 if self._container and op._is_stateful: # pylint: disable=protected-access 3666 try: 3667 container_attr = op.get_attr("container") 3668 except ValueError: 3669 # "container" attribute is not in OpDef 3670 pass 3671 else: 3672 if not container_attr: 3673 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3674 s=compat.as_bytes(self._container))) 3675 3676 def _add_new_tf_operations(self, compute_devices=True): 3677 """Creates `Operations` in this graph for any new TF_Operations. 3678 3679 This is useful for when TF_Operations are indirectly created by the C API 3680 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3681 TF_FinishWhile). This ensures there are corresponding Operations for all 3682 TF_Operations in the underlying TF_Graph. 3683 3684 Args: 3685 compute_devices: (Optional.) If True, device functions will be executed to 3686 compute the device properties of each new Operation. 3687 3688 Returns: 3689 A list of the new `Operation` objects. 3690 """ 3691 self._check_not_finalized() 3692 3693 # Create all Operation objects before accessing their inputs since an op may 3694 # be created before its inputs. 3695 new_ops = [ 3696 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3697 for c_op in c_api_util.new_tf_operations(self) 3698 ] 3699 3700 # pylint: disable=protected-access 3701 for op in new_ops: 3702 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3703 op._add_control_inputs(new_control_inputs) 3704 op._control_flow_post_processing() 3705 # pylint: enable=protected-access 3706 3707 return new_ops 3708 3709 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3710 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3711 3712 This function validates that `obj` represents an element of this 3713 graph, and gives an informative error message if it is not. 3714 3715 This function is the canonical way to get/validate an object of 3716 one of the allowed types from an external argument reference in the 3717 Session API. 3718 3719 This method may be called concurrently from multiple threads. 3720 3721 Args: 3722 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can 3723 also be any object with an `_as_graph_element()` method that returns a 3724 value of one of these types. Note: `_as_graph_element` will be called 3725 inside the graph's lock and so may not modify the graph. 3726 allow_tensor: If true, `obj` may refer to a `Tensor`. 3727 allow_operation: If true, `obj` may refer to an `Operation`. 3728 3729 Returns: 3730 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3731 3732 Raises: 3733 TypeError: If `obj` is not a type we support attempting to convert 3734 to types. 3735 ValueError: If `obj` is of an appropriate type but invalid. For 3736 example, an invalid string. 3737 KeyError: If `obj` is not an object in the graph. 3738 """ 3739 if self._finalized: 3740 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3741 3742 with self._lock: 3743 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3744 3745 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3746 """See `Graph.as_graph_element()` for details.""" 3747 # The vast majority of this function is figuring 3748 # out what an API user might be doing wrong, so 3749 # that we can give helpful error messages. 3750 # 3751 # Ideally, it would be nice to split it up, but we 3752 # need context to generate nice error messages. 3753 3754 if allow_tensor and allow_operation: 3755 types_str = "Tensor or Operation" 3756 elif allow_tensor: 3757 types_str = "Tensor" 3758 elif allow_operation: 3759 types_str = "Operation" 3760 else: 3761 raise ValueError("allow_tensor and allow_operation can't both be False.") 3762 3763 temp_obj = _as_graph_element(obj) 3764 if temp_obj is not None: 3765 obj = temp_obj 3766 3767 # If obj appears to be a name... 3768 if isinstance(obj, compat.bytes_or_text_types): 3769 name = compat.as_str(obj) 3770 3771 if ":" in name and allow_tensor: 3772 # Looks like a Tensor name and can be a Tensor. 3773 try: 3774 op_name, out_n = name.split(":") 3775 out_n = int(out_n) 3776 except: 3777 raise ValueError("The name %s looks a like a Tensor name, but is " 3778 "not a valid one. Tensor names must be of the " 3779 "form \"<op_name>:<output_index>\"." % repr(name)) 3780 if op_name in self._nodes_by_name: 3781 op = self._nodes_by_name[op_name] 3782 else: 3783 raise KeyError("The name %s refers to a Tensor which does not " 3784 "exist. The operation, %s, does not exist in the " 3785 "graph." % (repr(name), repr(op_name))) 3786 try: 3787 return op.outputs[out_n] 3788 except: 3789 raise KeyError("The name %s refers to a Tensor which does not " 3790 "exist. The operation, %s, exists but only has " 3791 "%s outputs." % 3792 (repr(name), repr(op_name), len(op.outputs))) 3793 3794 elif ":" in name and not allow_tensor: 3795 # Looks like a Tensor name but can't be a Tensor. 3796 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 3797 (repr(name), types_str)) 3798 3799 elif ":" not in name and allow_operation: 3800 # Looks like an Operation name and can be an Operation. 3801 if name not in self._nodes_by_name: 3802 raise KeyError("The name %s refers to an Operation not in the " 3803 "graph." % repr(name)) 3804 return self._nodes_by_name[name] 3805 3806 elif ":" not in name and not allow_operation: 3807 # Looks like an Operation name but can't be an Operation. 3808 if name in self._nodes_by_name: 3809 # Yep, it's an Operation name 3810 err_msg = ("The name %s refers to an Operation, not a %s." % 3811 (repr(name), types_str)) 3812 else: 3813 err_msg = ("The name %s looks like an (invalid) Operation name, " 3814 "not a %s." % (repr(name), types_str)) 3815 err_msg += (" Tensor names must be of the form " 3816 "\"<op_name>:<output_index>\".") 3817 raise ValueError(err_msg) 3818 3819 elif isinstance(obj, Tensor) and allow_tensor: 3820 # Actually obj is just the object it's referring to. 3821 if obj.graph is not self: 3822 raise ValueError("Tensor %s is not an element of this graph." % obj) 3823 return obj 3824 elif isinstance(obj, Operation) and allow_operation: 3825 # Actually obj is just the object it's referring to. 3826 if obj.graph is not self: 3827 raise ValueError("Operation %s is not an element of this graph." % obj) 3828 return obj 3829 else: 3830 # We give up! 3831 raise TypeError("Can not convert a %s into a %s." % 3832 (type(obj).__name__, types_str)) 3833 3834 def get_operations(self): 3835 """Return the list of operations in the graph. 3836 3837 You can modify the operations in place, but modifications 3838 to the list such as inserts/delete have no effect on the 3839 list of operations known to the graph. 3840 3841 This method may be called concurrently from multiple threads. 3842 3843 Returns: 3844 A list of Operations. 3845 """ 3846 if self._finalized: 3847 return list(self._nodes_by_id.values()) 3848 3849 with self._lock: 3850 return list(self._nodes_by_id.values()) 3851 3852 def get_operation_by_name(self, name): 3853 """Returns the `Operation` with the given `name`. 3854 3855 This method may be called concurrently from multiple threads. 3856 3857 Args: 3858 name: The name of the `Operation` to return. 3859 3860 Returns: 3861 The `Operation` with the given `name`. 3862 3863 Raises: 3864 TypeError: If `name` is not a string. 3865 KeyError: If `name` does not correspond to an operation in this graph. 3866 """ 3867 3868 if not isinstance(name, six.string_types): 3869 raise TypeError("Operation names are strings (or similar), not %s." % 3870 type(name).__name__) 3871 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 3872 3873 def _get_operation_by_name_unsafe(self, name): 3874 """Returns the `Operation` with the given `name`. 3875 3876 This is a internal unsafe version of get_operation_by_name. It skips many 3877 checks and does not have user friendly error messages but runs considerably 3878 faster. This method may be called concurrently from multiple threads. 3879 3880 Args: 3881 name: The name of the `Operation` to return. 3882 3883 Returns: 3884 The `Operation` with the given `name`. 3885 3886 Raises: 3887 KeyError: If `name` does not correspond to an operation in this graph. 3888 """ 3889 3890 if self._finalized: 3891 return self._nodes_by_name[name] 3892 3893 with self._lock: 3894 return self._nodes_by_name[name] 3895 3896 def _get_operation_by_tf_operation(self, tf_oper): 3897 op_name = pywrap_tf_session.TF_OperationName(tf_oper) 3898 return self._get_operation_by_name_unsafe(op_name) 3899 3900 def get_tensor_by_name(self, name): 3901 """Returns the `Tensor` with the given `name`. 3902 3903 This method may be called concurrently from multiple threads. 3904 3905 Args: 3906 name: The name of the `Tensor` to return. 3907 3908 Returns: 3909 The `Tensor` with the given `name`. 3910 3911 Raises: 3912 TypeError: If `name` is not a string. 3913 KeyError: If `name` does not correspond to a tensor in this graph. 3914 """ 3915 # Names should be strings. 3916 if not isinstance(name, six.string_types): 3917 raise TypeError("Tensor names are strings (or similar), not %s." % 3918 type(name).__name__) 3919 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 3920 3921 def _get_tensor_by_tf_output(self, tf_output): 3922 """Returns the `Tensor` representing `tf_output`. 3923 3924 Note that there is only one such `Tensor`, i.e. multiple calls to this 3925 function with the same TF_Output value will always return the same `Tensor` 3926 object. 3927 3928 Args: 3929 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 3930 3931 Returns: 3932 The `Tensor` that represents `tf_output`. 3933 """ 3934 op = self._get_operation_by_tf_operation(tf_output.oper) 3935 return op.outputs[tf_output.index] 3936 3937 @property 3938 def _last_id(self): 3939 return self._next_id_counter 3940 3941 def _get_op_def(self, type): # pylint: disable=redefined-builtin 3942 """Returns the `OpDef` proto for `type`. `type` is a string.""" 3943 # NOTE: No locking is required because the lookup and insertion operations 3944 # on Python dictionaries are atomic. 3945 try: 3946 return self._op_def_cache[type] 3947 except KeyError: 3948 with c_api_util.tf_buffer() as buf: 3949 # pylint: disable=protected-access 3950 pywrap_tf_session.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), 3951 buf) 3952 # pylint: enable=protected-access 3953 data = pywrap_tf_session.TF_GetBuffer(buf) 3954 op_def = op_def_pb2.OpDef() 3955 op_def.ParseFromString(compat.as_bytes(data)) 3956 self._op_def_cache[type] = op_def 3957 return op_def 3958 3959 def as_default(self): 3960 """Returns a context manager that makes this `Graph` the default graph. 3961 3962 This method should be used if you want to create multiple graphs 3963 in the same process. For convenience, a global default graph is 3964 provided, and all ops will be added to this graph if you do not 3965 create a new graph explicitly. 3966 3967 Use this method with the `with` keyword to specify that ops created within 3968 the scope of a block should be added to this graph. In this case, once 3969 the scope of the `with` is exited, the previous default graph is set again 3970 as default. There is a stack, so it's ok to have multiple nested levels 3971 of `as_default` calls. 3972 3973 The default graph is a property of the current thread. If you 3974 create a new thread, and wish to use the default graph in that 3975 thread, you must explicitly add a `with g.as_default():` in that 3976 thread's function. 3977 3978 The following code examples are equivalent: 3979 3980 ```python 3981 # 1. Using Graph.as_default(): 3982 g = tf.Graph() 3983 with g.as_default(): 3984 c = tf.constant(5.0) 3985 assert c.graph is g 3986 3987 # 2. Constructing and making default: 3988 with tf.Graph().as_default() as g: 3989 c = tf.constant(5.0) 3990 assert c.graph is g 3991 ``` 3992 3993 If eager execution is enabled ops created under this context manager will be 3994 added to the graph instead of executed eagerly. 3995 3996 Returns: 3997 A context manager for using this graph as the default graph. 3998 """ 3999 return _default_graph_stack.get_controller(self) 4000 4001 @property 4002 def collections(self): 4003 """Returns the names of the collections known to this graph.""" 4004 return list(self._collections) 4005 4006 def add_to_collection(self, name, value): 4007 """Stores `value` in the collection with the given `name`. 4008 4009 Note that collections are not sets, so it is possible to add a value to 4010 a collection several times. 4011 4012 Args: 4013 name: The key for the collection. The `GraphKeys` class contains many 4014 standard names for collections. 4015 value: The value to add to the collection. 4016 """ # pylint: disable=g-doc-exception 4017 self._check_not_finalized() 4018 with self._lock: 4019 if name not in self._collections: 4020 self._collections[name] = [value] 4021 else: 4022 self._collections[name].append(value) 4023 4024 def add_to_collections(self, names, value): 4025 """Stores `value` in the collections given by `names`. 4026 4027 Note that collections are not sets, so it is possible to add a value to 4028 a collection several times. This function makes sure that duplicates in 4029 `names` are ignored, but it will not check for pre-existing membership of 4030 `value` in any of the collections in `names`. 4031 4032 `names` can be any iterable, but if `names` is a string, it is treated as a 4033 single collection name. 4034 4035 Args: 4036 names: The keys for the collections to add to. The `GraphKeys` class 4037 contains many standard names for collections. 4038 value: The value to add to the collections. 4039 """ 4040 # Make sure names are unique, but treat strings as a single collection name 4041 names = (names,) if isinstance(names, six.string_types) else set(names) 4042 for name in names: 4043 self.add_to_collection(name, value) 4044 4045 def get_collection_ref(self, name): 4046 """Returns a list of values in the collection with the given `name`. 4047 4048 If the collection exists, this returns the list itself, which can 4049 be modified in place to change the collection. If the collection does 4050 not exist, it is created as an empty list and the list is returned. 4051 4052 This is different from `get_collection()` which always returns a copy of 4053 the collection list if it exists and never creates an empty collection. 4054 4055 Args: 4056 name: The key for the collection. For example, the `GraphKeys` class 4057 contains many standard names for collections. 4058 4059 Returns: 4060 The list of values in the collection with the given `name`, or an empty 4061 list if no value has been added to that collection. 4062 """ # pylint: disable=g-doc-exception 4063 with self._lock: 4064 coll_list = self._collections.get(name, None) 4065 if coll_list is None: 4066 coll_list = [] 4067 self._collections[name] = coll_list 4068 return coll_list 4069 4070 def get_collection(self, name, scope=None): 4071 """Returns a list of values in the collection with the given `name`. 4072 4073 This is different from `get_collection_ref()` which always returns the 4074 actual collection list if it exists in that it returns a new list each time 4075 it is called. 4076 4077 Args: 4078 name: The key for the collection. For example, the `GraphKeys` class 4079 contains many standard names for collections. 4080 scope: (Optional.) A string. If supplied, the resulting list is filtered 4081 to include only items whose `name` attribute matches `scope` using 4082 `re.match`. Items without a `name` attribute are never returned if a 4083 scope is supplied. The choice of `re.match` means that a `scope` without 4084 special tokens filters by prefix. 4085 4086 Returns: 4087 The list of values in the collection with the given `name`, or 4088 an empty list if no value has been added to that collection. The 4089 list contains the values in the order under which they were 4090 collected. 4091 """ # pylint: disable=g-doc-exception 4092 with self._lock: 4093 collection = self._collections.get(name, None) 4094 if collection is None: 4095 return [] 4096 if scope is None: 4097 return list(collection) 4098 else: 4099 c = [] 4100 regex = re.compile(scope) 4101 for item in collection: 4102 try: 4103 if regex.match(item.name): 4104 c.append(item) 4105 except AttributeError: 4106 # Collection items with no name are ignored. 4107 pass 4108 return c 4109 4110 def get_all_collection_keys(self): 4111 """Returns a list of collections used in this graph.""" 4112 with self._lock: 4113 return [x for x in self._collections if isinstance(x, six.string_types)] 4114 4115 def clear_collection(self, name): 4116 """Clears all values in a collection. 4117 4118 Args: 4119 name: The key for the collection. The `GraphKeys` class contains many 4120 standard names for collections. 4121 """ 4122 self._check_not_finalized() 4123 with self._lock: 4124 if name in self._collections: 4125 del self._collections[name] 4126 4127 @tf_contextlib.contextmanager 4128 def _original_op(self, op): 4129 """Python 'with' handler to help annotate ops with their originator. 4130 4131 An op may have an 'original_op' property that indicates the op on which 4132 it was based. For example a replica op is based on the op that was 4133 replicated and a gradient op is based on the op that was differentiated. 4134 4135 All ops created in the scope of this 'with' handler will have 4136 the given 'op' as their original op. 4137 4138 Args: 4139 op: The Operation that all ops created in this scope will have as their 4140 original op. 4141 4142 Yields: 4143 Nothing. 4144 """ 4145 old_original_op = self._default_original_op 4146 self._default_original_op = op 4147 try: 4148 yield 4149 finally: 4150 self._default_original_op = old_original_op 4151 4152 @property 4153 def _name_stack(self): 4154 # This may be called from a thread where name_stack doesn't yet exist. 4155 if not hasattr(self._thread_local, "_name_stack"): 4156 self._thread_local._name_stack = "" 4157 return self._thread_local._name_stack 4158 4159 @_name_stack.setter 4160 def _name_stack(self, name_stack): 4161 self._thread_local._name_stack = name_stack 4162 4163 # pylint: disable=g-doc-return-or-yield,line-too-long 4164 @tf_contextlib.contextmanager 4165 def name_scope(self, name): 4166 """Returns a context manager that creates hierarchical names for operations. 4167 4168 A graph maintains a stack of name scopes. A `with name_scope(...):` 4169 statement pushes a new name onto the stack for the lifetime of the context. 4170 4171 The `name` argument will be interpreted as follows: 4172 4173 * A string (not ending with '/') will create a new name scope, in which 4174 `name` is appended to the prefix of all operations created in the 4175 context. If `name` has been used before, it will be made unique by 4176 calling `self.unique_name(name)`. 4177 * A scope previously captured from a `with g.name_scope(...) as 4178 scope:` statement will be treated as an "absolute" name scope, which 4179 makes it possible to re-enter existing scopes. 4180 * A value of `None` or the empty string will reset the current name scope 4181 to the top-level (empty) name scope. 4182 4183 For example: 4184 4185 ```python 4186 with tf.Graph().as_default() as g: 4187 c = tf.constant(5.0, name="c") 4188 assert c.op.name == "c" 4189 c_1 = tf.constant(6.0, name="c") 4190 assert c_1.op.name == "c_1" 4191 4192 # Creates a scope called "nested" 4193 with g.name_scope("nested") as scope: 4194 nested_c = tf.constant(10.0, name="c") 4195 assert nested_c.op.name == "nested/c" 4196 4197 # Creates a nested scope called "inner". 4198 with g.name_scope("inner"): 4199 nested_inner_c = tf.constant(20.0, name="c") 4200 assert nested_inner_c.op.name == "nested/inner/c" 4201 4202 # Create a nested scope called "inner_1". 4203 with g.name_scope("inner"): 4204 nested_inner_1_c = tf.constant(30.0, name="c") 4205 assert nested_inner_1_c.op.name == "nested/inner_1/c" 4206 4207 # Treats `scope` as an absolute name scope, and 4208 # switches to the "nested/" scope. 4209 with g.name_scope(scope): 4210 nested_d = tf.constant(40.0, name="d") 4211 assert nested_d.op.name == "nested/d" 4212 4213 with g.name_scope(""): 4214 e = tf.constant(50.0, name="e") 4215 assert e.op.name == "e" 4216 ``` 4217 4218 The name of the scope itself can be captured by `with 4219 g.name_scope(...) as scope:`, which stores the name of the scope 4220 in the variable `scope`. This value can be used to name an 4221 operation that represents the overall result of executing the ops 4222 in a scope. For example: 4223 4224 ```python 4225 inputs = tf.constant(...) 4226 with g.name_scope('my_layer') as scope: 4227 weights = tf.Variable(..., name="weights") 4228 biases = tf.Variable(..., name="biases") 4229 affine = tf.matmul(inputs, weights) + biases 4230 output = tf.nn.relu(affine, name=scope) 4231 ``` 4232 4233 NOTE: This constructor validates the given `name`. Valid scope 4234 names match one of the following regular expressions: 4235 4236 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 4237 [A-Za-z0-9_.\\-/]* (for other scopes) 4238 4239 Args: 4240 name: A name for the scope. 4241 4242 Returns: 4243 A context manager that installs `name` as a new name scope. 4244 4245 Raises: 4246 ValueError: If `name` is not a valid scope name, according to the rules 4247 above. 4248 """ 4249 if name: 4250 if isinstance(name, compat.bytes_or_text_types): 4251 name = compat.as_str(name) 4252 4253 if self._name_stack: 4254 # Scopes created in a nested scope may have initial characters 4255 # that are illegal as the initial character of an op name 4256 # (viz. '-', '\', '/', and '_'). 4257 if not _VALID_SCOPE_NAME_REGEX.match(name): 4258 raise ValueError("'%s' is not a valid scope name" % name) 4259 else: 4260 # Scopes created in the root must match the more restrictive 4261 # op name regex, which constrains the initial character. 4262 if not _VALID_OP_NAME_REGEX.match(name): 4263 raise ValueError("'%s' is not a valid scope name" % name) 4264 old_stack = self._name_stack 4265 if not name: # Both for name=None and name="" we re-set to empty scope. 4266 new_stack = None 4267 elif name[-1] == "/": 4268 new_stack = name_from_scope_name(name) 4269 else: 4270 new_stack = self.unique_name(name) 4271 self._name_stack = new_stack 4272 try: 4273 yield "" if new_stack is None else new_stack + "/" 4274 finally: 4275 self._name_stack = old_stack 4276 4277 # pylint: enable=g-doc-return-or-yield,line-too-long 4278 4279 def unique_name(self, name, mark_as_used=True): 4280 """Return a unique operation name for `name`. 4281 4282 Note: You rarely need to call `unique_name()` directly. Most of 4283 the time you just need to create `with g.name_scope()` blocks to 4284 generate structured names. 4285 4286 `unique_name` is used to generate structured names, separated by 4287 `"/"`, to help identify operations when debugging a graph. 4288 Operation names are displayed in error messages reported by the 4289 TensorFlow runtime, and in various visualization tools such as 4290 TensorBoard. 4291 4292 If `mark_as_used` is set to `True`, which is the default, a new 4293 unique name is created and marked as in use. If it's set to `False`, 4294 the unique name is returned without actually being marked as used. 4295 This is useful when the caller simply wants to know what the name 4296 to be created will be. 4297 4298 Args: 4299 name: The name for an operation. 4300 mark_as_used: Whether to mark this name as being used. 4301 4302 Returns: 4303 A string to be passed to `create_op()` that will be used 4304 to name the operation being created. 4305 """ 4306 if self._name_stack: 4307 name = self._name_stack + "/" + name 4308 4309 # For the sake of checking for names in use, we treat names as case 4310 # insensitive (e.g. foo = Foo). 4311 name_key = name.lower() 4312 i = self._names_in_use.get(name_key, 0) 4313 # Increment the number for "name_key". 4314 if mark_as_used: 4315 self._names_in_use[name_key] = i + 1 4316 if i > 0: 4317 base_name_key = name_key 4318 # Make sure the composed name key is not already used. 4319 while name_key in self._names_in_use: 4320 name_key = "%s_%d" % (base_name_key, i) 4321 i += 1 4322 # Mark the composed name_key as used in case someone wants 4323 # to call unique_name("name_1"). 4324 if mark_as_used: 4325 self._names_in_use[name_key] = 1 4326 4327 # Return the new name with the original capitalization of the given name. 4328 name = "%s_%d" % (name, i - 1) 4329 return name 4330 4331 def get_name_scope(self): 4332 """Returns the current name scope. 4333 4334 For example: 4335 4336 ```python 4337 with tf.name_scope('scope1'): 4338 with tf.name_scope('scope2'): 4339 print(tf.compat.v1.get_default_graph().get_name_scope()) 4340 ``` 4341 would print the string `scope1/scope2`. 4342 4343 Returns: 4344 A string representing the current name scope. 4345 """ 4346 return self._name_stack 4347 4348 @tf_contextlib.contextmanager 4349 def _colocate_with_for_gradient(self, op, gradient_uid, 4350 ignore_existing=False): 4351 with self.colocate_with(op, ignore_existing): 4352 if gradient_uid is not None: 4353 ctx = _get_enclosing_context(self) 4354 if ctx is not None: 4355 ctx.EnterGradientColocation(op, gradient_uid) 4356 try: 4357 yield 4358 finally: 4359 ctx.ExitGradientColocation(op, gradient_uid) 4360 else: 4361 yield 4362 else: 4363 yield 4364 4365 @tf_contextlib.contextmanager 4366 def colocate_with(self, op, ignore_existing=False): 4367 """Returns a context manager that specifies an op to colocate with. 4368 4369 Note: this function is not for public use, only for internal libraries. 4370 4371 For example: 4372 4373 ```python 4374 a = tf.Variable([1.0]) 4375 with g.colocate_with(a): 4376 b = tf.constant(1.0) 4377 c = tf.add(a, b) 4378 ``` 4379 4380 `b` and `c` will always be colocated with `a`, no matter where `a` 4381 is eventually placed. 4382 4383 **NOTE** Using a colocation scope resets any existing device constraints. 4384 4385 If `op` is `None` then `ignore_existing` must be `True` and the new 4386 scope resets all colocation and device constraints. 4387 4388 Args: 4389 op: The op to colocate all created ops with, or `None`. 4390 ignore_existing: If true, only applies colocation of this op within the 4391 context, rather than applying all colocation properties on the stack. 4392 If `op` is `None`, this value must be `True`. 4393 4394 Raises: 4395 ValueError: if op is None but ignore_existing is False. 4396 4397 Yields: 4398 A context manager that specifies the op with which to colocate 4399 newly created ops. 4400 """ 4401 if op is None and not ignore_existing: 4402 raise ValueError("Trying to reset colocation (op is None) but " 4403 "ignore_existing is not True") 4404 op, device_only_candidate = _op_to_colocate_with(op, self) 4405 4406 # By default, colocate_with resets the device function stack, 4407 # since colocate_with is typically used in specific internal 4408 # library functions where colocation is intended to be "stronger" 4409 # than device functions. 4410 # 4411 # In the future, a caller may specify that device_functions win 4412 # over colocation, in which case we can add support. 4413 device_fn_tmp = self._device_function_stack 4414 self._device_function_stack = traceable_stack.TraceableStack() 4415 4416 if ignore_existing: 4417 current_stack = self._colocation_stack 4418 self._colocation_stack = traceable_stack.TraceableStack() 4419 4420 if op is not None: 4421 # offset refers to the stack frame used for storing code location. 4422 # We use 4, the sum of 1 to use our caller's stack frame and 3 4423 # to jump over layers of context managers above us. 4424 if device_only_candidate is not None: 4425 self._colocation_stack.push_obj(device_only_candidate, offset=4) 4426 self._colocation_stack.push_obj(op, offset=4) 4427 elif not ignore_existing: 4428 raise ValueError("Trying to reset colocation (op is None) but " 4429 "ignore_existing is not True") 4430 try: 4431 yield 4432 finally: 4433 # Restore device function stack 4434 self._device_function_stack = device_fn_tmp 4435 if op is not None: 4436 self._colocation_stack.pop_obj() 4437 if device_only_candidate is not None: 4438 self._colocation_stack.pop_obj() 4439 4440 # Reset the colocation stack if requested. 4441 if ignore_existing: 4442 self._colocation_stack = current_stack 4443 4444 def _add_device_to_stack(self, device_name_or_function, offset=0): 4445 """Add device to stack manually, separate from a context manager.""" 4446 total_offset = 1 + offset 4447 spec = _UserDeviceSpec(device_name_or_function) 4448 self._device_function_stack.push_obj(spec, offset=total_offset) 4449 return spec 4450 4451 @tf_contextlib.contextmanager 4452 def device(self, device_name_or_function): 4453 # pylint: disable=line-too-long 4454 """Returns a context manager that specifies the default device to use. 4455 4456 The `device_name_or_function` argument may either be a device name 4457 string, a device function, or None: 4458 4459 * If it is a device name string, all operations constructed in 4460 this context will be assigned to the device with that name, unless 4461 overridden by a nested `device()` context. 4462 * If it is a function, it will be treated as a function from 4463 Operation objects to device name strings, and invoked each time 4464 a new Operation is created. The Operation will be assigned to 4465 the device with the returned name. 4466 * If it is None, all `device()` invocations from the enclosing context 4467 will be ignored. 4468 4469 For information about the valid syntax of device name strings, see 4470 the documentation in 4471 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4472 4473 For example: 4474 4475 ```python 4476 with g.device('/device:GPU:0'): 4477 # All operations constructed in this context will be placed 4478 # on GPU 0. 4479 with g.device(None): 4480 # All operations constructed in this context will have no 4481 # assigned device. 4482 4483 # Defines a function from `Operation` to device string. 4484 def matmul_on_gpu(n): 4485 if n.type == "MatMul": 4486 return "/device:GPU:0" 4487 else: 4488 return "/cpu:0" 4489 4490 with g.device(matmul_on_gpu): 4491 # All operations of type "MatMul" constructed in this context 4492 # will be placed on GPU 0; all other operations will be placed 4493 # on CPU 0. 4494 ``` 4495 4496 **N.B.** The device scope may be overridden by op wrappers or 4497 other library code. For example, a variable assignment op 4498 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4499 incompatible device scopes will be ignored. 4500 4501 Args: 4502 device_name_or_function: The device name or function to use in the 4503 context. 4504 4505 Yields: 4506 A context manager that specifies the default device to use for newly 4507 created ops. 4508 4509 Raises: 4510 RuntimeError: If device scopes are not properly nested. 4511 """ 4512 self._add_device_to_stack(device_name_or_function, offset=2) 4513 old_top_of_stack = self._device_function_stack.peek_top_obj() 4514 try: 4515 yield 4516 finally: 4517 new_top_of_stack = self._device_function_stack.peek_top_obj() 4518 if old_top_of_stack is not new_top_of_stack: 4519 raise RuntimeError("Exiting device scope without proper scope nesting.") 4520 self._device_function_stack.pop_obj() 4521 4522 def _apply_device_functions(self, op): 4523 """Applies the current device function stack to the given operation.""" 4524 # Apply any device functions in LIFO order, so that the most recently 4525 # pushed function has the first chance to apply a device to the op. 4526 # We apply here because the result can depend on the Operation's 4527 # signature, which is computed in the Operation constructor. 4528 # pylint: disable=protected-access 4529 prior_device_string = None 4530 for device_spec in self._device_function_stack.peek_objs(): 4531 if device_spec.is_null_merge: 4532 continue 4533 4534 if device_spec.function is None: 4535 break 4536 4537 device_string = device_spec.string_merge(op) 4538 4539 # Take advantage of the fact that None is a singleton and Python interns 4540 # strings, since identity checks are faster than equality checks. 4541 if device_string is not prior_device_string: 4542 op._set_device_from_string(device_string) 4543 prior_device_string = device_string 4544 op._device_code_locations = self._snapshot_device_function_stack_metadata() 4545 # pylint: enable=protected-access 4546 4547 # pylint: disable=g-doc-return-or-yield 4548 @tf_contextlib.contextmanager 4549 def container(self, container_name): 4550 """Returns a context manager that specifies the resource container to use. 4551 4552 Stateful operations, such as variables and queues, can maintain their 4553 states on devices so that they can be shared by multiple processes. 4554 A resource container is a string name under which these stateful 4555 operations are tracked. These resources can be released or cleared 4556 with `tf.Session.reset()`. 4557 4558 For example: 4559 4560 ```python 4561 with g.container('experiment0'): 4562 # All stateful Operations constructed in this context will be placed 4563 # in resource container "experiment0". 4564 v1 = tf.Variable([1.0]) 4565 v2 = tf.Variable([2.0]) 4566 with g.container("experiment1"): 4567 # All stateful Operations constructed in this context will be 4568 # placed in resource container "experiment1". 4569 v3 = tf.Variable([3.0]) 4570 q1 = tf.queue.FIFOQueue(10, tf.float32) 4571 # All stateful Operations constructed in this context will be 4572 # be created in the "experiment0". 4573 v4 = tf.Variable([4.0]) 4574 q1 = tf.queue.FIFOQueue(20, tf.float32) 4575 with g.container(""): 4576 # All stateful Operations constructed in this context will be 4577 # be placed in the default resource container. 4578 v5 = tf.Variable([5.0]) 4579 q3 = tf.queue.FIFOQueue(30, tf.float32) 4580 4581 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4582 # will become undefined (such as uninitialized). 4583 tf.Session.reset(target, ["experiment0"]) 4584 ``` 4585 4586 Args: 4587 container_name: container name string. 4588 4589 Returns: 4590 A context manager for defining resource containers for stateful ops, 4591 yields the container name. 4592 """ 4593 original_container = self._container 4594 self._container = container_name 4595 try: 4596 yield self._container 4597 finally: 4598 self._container = original_container 4599 4600 # pylint: enable=g-doc-return-or-yield 4601 4602 class _ControlDependenciesController(object): 4603 """Context manager for `control_dependencies()`.""" 4604 4605 def __init__(self, graph, control_inputs): 4606 """Create a new `_ControlDependenciesController`. 4607 4608 A `_ControlDependenciesController` is the context manager for 4609 `with tf.control_dependencies()` blocks. These normally nest, 4610 as described in the documentation for `control_dependencies()`. 4611 4612 The `control_inputs` argument list control dependencies that must be 4613 added to the current set of control dependencies. Because of 4614 uniquification the set can be empty even if the caller passed a list of 4615 ops. The special value `None` indicates that we want to start a new 4616 empty set of control dependencies instead of extending the current set. 4617 4618 In that case we also clear the current control flow context, which is an 4619 additional mechanism to add control dependencies. 4620 4621 Args: 4622 graph: The graph that this controller is managing. 4623 control_inputs: List of ops to use as control inputs in addition to the 4624 current control dependencies. None to indicate that the dependencies 4625 should be cleared. 4626 """ 4627 self._graph = graph 4628 if control_inputs is None: 4629 self._control_inputs_val = [] 4630 self._new_stack = True 4631 else: 4632 self._control_inputs_val = control_inputs 4633 self._new_stack = False 4634 self._seen_nodes = set() 4635 self._old_stack = None 4636 self._old_control_flow_context = None 4637 4638# pylint: disable=protected-access 4639 4640 def __enter__(self): 4641 if self._new_stack: 4642 # Clear the control_dependencies graph. 4643 self._old_stack = self._graph._control_dependencies_stack 4644 self._graph._control_dependencies_stack = [] 4645 # Clear the control_flow_context too. 4646 self._old_control_flow_context = self._graph._get_control_flow_context() 4647 self._graph._set_control_flow_context(None) 4648 self._graph._push_control_dependencies_controller(self) 4649 4650 def __exit__(self, unused_type, unused_value, unused_traceback): 4651 self._graph._pop_control_dependencies_controller(self) 4652 if self._new_stack: 4653 self._graph._control_dependencies_stack = self._old_stack 4654 self._graph._set_control_flow_context(self._old_control_flow_context) 4655 4656# pylint: enable=protected-access 4657 4658 @property 4659 def control_inputs(self): 4660 return self._control_inputs_val 4661 4662 def add_op(self, op): 4663 if isinstance(op, Tensor): 4664 op = op.ref() 4665 self._seen_nodes.add(op) 4666 4667 def op_in_group(self, op): 4668 if isinstance(op, Tensor): 4669 op = op.ref() 4670 return op in self._seen_nodes 4671 4672 def _push_control_dependencies_controller(self, controller): 4673 self._control_dependencies_stack.append(controller) 4674 4675 def _pop_control_dependencies_controller(self, controller): 4676 assert self._control_dependencies_stack[-1] is controller 4677 self._control_dependencies_stack.pop() 4678 4679 def _current_control_dependencies(self): 4680 ret = set() 4681 for controller in self._control_dependencies_stack: 4682 for op in controller.control_inputs: 4683 ret.add(op) 4684 return ret 4685 4686 def _control_dependencies_for_inputs(self, input_ops): 4687 """For an op that takes `input_ops` as inputs, compute control inputs. 4688 4689 The returned control dependencies should yield an execution that 4690 is equivalent to adding all control inputs in 4691 self._control_dependencies_stack to a newly created op. However, 4692 this function attempts to prune the returned control dependencies 4693 by observing that nodes created within the same `with 4694 control_dependencies(...):` block may have data dependencies that make 4695 the explicit approach redundant. 4696 4697 Args: 4698 input_ops: The data input ops for an op to be created. 4699 4700 Returns: 4701 A list of control inputs for the op to be created. 4702 """ 4703 ret = [] 4704 for controller in self._control_dependencies_stack: 4705 # If any of the input_ops already depends on the inputs from controller, 4706 # we say that the new op is dominated (by that input), and we therefore 4707 # do not need to add control dependencies for this controller's inputs. 4708 dominated = False 4709 for op in input_ops: 4710 if controller.op_in_group(op): 4711 dominated = True 4712 break 4713 if not dominated: 4714 # Don't add a control input if we already have a data dependency on i. 4715 # NOTE(mrry): We do not currently track transitive data dependencies, 4716 # so we may add redundant control inputs. 4717 ret.extend(c for c in controller.control_inputs if c not in input_ops) 4718 return ret 4719 4720 def _record_op_seen_by_control_dependencies(self, op): 4721 """Record that the given op depends on all registered control dependencies. 4722 4723 Args: 4724 op: An Operation. 4725 """ 4726 for controller in self._control_dependencies_stack: 4727 controller.add_op(op) 4728 4729 def control_dependencies(self, control_inputs): 4730 """Returns a context manager that specifies control dependencies. 4731 4732 Use with the `with` keyword to specify that all operations constructed 4733 within the context should have control dependencies on 4734 `control_inputs`. For example: 4735 4736 ```python 4737 with g.control_dependencies([a, b, c]): 4738 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4739 d = ... 4740 e = ... 4741 ``` 4742 4743 Multiple calls to `control_dependencies()` can be nested, and in 4744 that case a new `Operation` will have control dependencies on the union 4745 of `control_inputs` from all active contexts. 4746 4747 ```python 4748 with g.control_dependencies([a, b]): 4749 # Ops constructed here run after `a` and `b`. 4750 with g.control_dependencies([c, d]): 4751 # Ops constructed here run after `a`, `b`, `c`, and `d`. 4752 ``` 4753 4754 You can pass None to clear the control dependencies: 4755 4756 ```python 4757 with g.control_dependencies([a, b]): 4758 # Ops constructed here run after `a` and `b`. 4759 with g.control_dependencies(None): 4760 # Ops constructed here run normally, not waiting for either `a` or `b`. 4761 with g.control_dependencies([c, d]): 4762 # Ops constructed here run after `c` and `d`, also not waiting 4763 # for either `a` or `b`. 4764 ``` 4765 4766 *N.B.* The control dependencies context applies *only* to ops that 4767 are constructed within the context. Merely using an op or tensor 4768 in the context does not add a control dependency. The following 4769 example illustrates this point: 4770 4771 ```python 4772 # WRONG 4773 def my_func(pred, tensor): 4774 t = tf.matmul(tensor, tensor) 4775 with tf.control_dependencies([pred]): 4776 # The matmul op is created outside the context, so no control 4777 # dependency will be added. 4778 return t 4779 4780 # RIGHT 4781 def my_func(pred, tensor): 4782 with tf.control_dependencies([pred]): 4783 # The matmul op is created in the context, so a control dependency 4784 # will be added. 4785 return tf.matmul(tensor, tensor) 4786 ``` 4787 4788 Also note that though execution of ops created under this scope will trigger 4789 execution of the dependencies, the ops created under this scope might still 4790 be pruned from a normal tensorflow graph. For example, in the following 4791 snippet of code the dependencies are never executed: 4792 4793 ```python 4794 loss = model.loss() 4795 with tf.control_dependencies(dependencies): 4796 loss = loss + tf.constant(1) # note: dependencies ignored in the 4797 # backward pass 4798 return tf.gradients(loss, model.variables) 4799 ``` 4800 4801 This is because evaluating the gradient graph does not require evaluating 4802 the constant(1) op created in the forward pass. 4803 4804 Args: 4805 control_inputs: A list of `Operation` or `Tensor` objects which must be 4806 executed or computed before running the operations defined in the 4807 context. Can also be `None` to clear the control dependencies. 4808 4809 Returns: 4810 A context manager that specifies control dependencies for all 4811 operations constructed within the context. 4812 4813 Raises: 4814 TypeError: If `control_inputs` is not a list of `Operation` or 4815 `Tensor` objects. 4816 """ 4817 if control_inputs is None: 4818 return self._ControlDependenciesController(self, None) 4819 # First convert the inputs to ops, and deduplicate them. 4820 # NOTE(mrry): Other than deduplication, we do not currently track direct 4821 # or indirect dependencies between control_inputs, which may result in 4822 # redundant control inputs. 4823 control_ops = [] 4824 current = self._current_control_dependencies() 4825 for c in control_inputs: 4826 # The hasattr(handle) is designed to match ResourceVariables. This is so 4827 # control dependencies on a variable or on an unread variable don't 4828 # trigger reads. 4829 if (isinstance(c, IndexedSlices) or 4830 (hasattr(c, "_handle") and hasattr(c, "op"))): 4831 c = c.op 4832 c = self.as_graph_element(c) 4833 if isinstance(c, Tensor): 4834 c = c.op 4835 elif not isinstance(c, Operation): 4836 raise TypeError("Control input must be Operation or Tensor: %s" % c) 4837 if c not in current: 4838 control_ops.append(c) 4839 current.add(c) 4840 return self._ControlDependenciesController(self, control_ops) 4841 4842 # pylint: disable=g-doc-return-or-yield 4843 @tf_contextlib.contextmanager 4844 def _attr_scope(self, attr_map): 4845 """EXPERIMENTAL: A context manager for setting attributes on operators. 4846 4847 This context manager can be used to add additional 4848 attributes to operators within the scope of the context. 4849 4850 For example: 4851 4852 with ops.Graph().as_default() as g: 4853 f_1 = Foo() # No extra attributes 4854 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 4855 f_2 = Foo() # Additional attribute _a=False 4856 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 4857 f_3 = Foo() # Additional attribute _a=False 4858 with g._attr_scope({"_a": None}): 4859 f_4 = Foo() # No additional attributes. 4860 4861 Args: 4862 attr_map: A dictionary mapping attr name strings to AttrValue protocol 4863 buffers or None. 4864 4865 Returns: 4866 A context manager that sets the kernel label to be used for one or more 4867 ops created in that context. 4868 4869 Raises: 4870 TypeError: If attr_map is not a dictionary mapping 4871 strings to AttrValue protobufs. 4872 """ 4873 if not isinstance(attr_map, dict): 4874 raise TypeError("attr_map must be a dictionary mapping " 4875 "strings to AttrValue protocol buffers") 4876 # The saved_attrs dictionary stores any currently-set labels that 4877 # will be overridden by this context manager. 4878 saved_attrs = {} 4879 # Install the given attribute 4880 for name, attr in attr_map.items(): 4881 if not (isinstance(name, six.string_types) and 4882 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 4883 callable(attr))): 4884 raise TypeError("attr_map must be a dictionary mapping " 4885 "strings to AttrValue protocol buffers or " 4886 "callables that emit AttrValue protocol buffers") 4887 try: 4888 saved_attrs[name] = self._attr_scope_map[name] 4889 except KeyError: 4890 pass 4891 if attr is None: 4892 del self._attr_scope_map[name] 4893 else: 4894 self._attr_scope_map[name] = attr 4895 try: 4896 yield # The code within the context runs here. 4897 finally: 4898 # Remove the attributes set for this context, and restore any saved 4899 # attributes. 4900 for name, attr in attr_map.items(): 4901 try: 4902 self._attr_scope_map[name] = saved_attrs[name] 4903 except KeyError: 4904 del self._attr_scope_map[name] 4905 4906 # pylint: enable=g-doc-return-or-yield 4907 4908 # pylint: disable=g-doc-return-or-yield 4909 @tf_contextlib.contextmanager 4910 def _kernel_label_map(self, op_to_kernel_label_map): 4911 """EXPERIMENTAL: A context manager for setting kernel labels. 4912 4913 This context manager can be used to select particular 4914 implementations of kernels within the scope of the context. 4915 4916 For example: 4917 4918 with ops.Graph().as_default() as g: 4919 f_1 = Foo() # Uses the default registered kernel for the Foo op. 4920 with g.kernel_label_map({"Foo": "v_2"}): 4921 f_2 = Foo() # Uses the registered kernel with label "v_2" 4922 # for the Foo op. 4923 with g.kernel_label_map({"Foo": "v_3"}): 4924 f_3 = Foo() # Uses the registered kernel with label "v_3" 4925 # for the Foo op. 4926 with g.kernel_label_map({"Foo": ""}): 4927 f_4 = Foo() # Uses the default registered kernel 4928 # for the Foo op. 4929 4930 Args: 4931 op_to_kernel_label_map: A dictionary mapping op type strings to kernel 4932 label strings. 4933 4934 Returns: 4935 A context manager that sets the kernel label to be used for one or more 4936 ops created in that context. 4937 4938 Raises: 4939 TypeError: If op_to_kernel_label_map is not a dictionary mapping 4940 strings to strings. 4941 """ 4942 if not isinstance(op_to_kernel_label_map, dict): 4943 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4944 "strings to strings") 4945 # The saved_labels dictionary stores any currently-set labels that 4946 # will be overridden by this context manager. 4947 saved_labels = {} 4948 # Install the given label 4949 for op_type, label in op_to_kernel_label_map.items(): 4950 if not (isinstance(op_type, six.string_types) and 4951 isinstance(label, six.string_types)): 4952 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4953 "strings to strings") 4954 try: 4955 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 4956 except KeyError: 4957 pass 4958 self._op_to_kernel_label_map[op_type] = label 4959 try: 4960 yield # The code within the context runs here. 4961 finally: 4962 # Remove the labels set for this context, and restore any saved labels. 4963 for op_type, label in op_to_kernel_label_map.items(): 4964 try: 4965 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 4966 except KeyError: 4967 del self._op_to_kernel_label_map[op_type] 4968 4969 # pylint: enable=g-doc-return-or-yield 4970 4971 @tf_contextlib.contextmanager 4972 def _override_gradient_function(self, gradient_function_map): 4973 """Specify gradient function for the given op type.""" 4974 4975 # This is an internal API and we don't need nested context for this. 4976 # TODO(mdan): make it a proper context manager. 4977 assert not self._gradient_function_map 4978 self._gradient_function_map = gradient_function_map 4979 try: 4980 yield 4981 finally: 4982 self._gradient_function_map = {} 4983 4984 # pylint: disable=g-doc-return-or-yield 4985 @tf_contextlib.contextmanager 4986 def gradient_override_map(self, op_type_map): 4987 """EXPERIMENTAL: A context manager for overriding gradient functions. 4988 4989 This context manager can be used to override the gradient function 4990 that will be used for ops within the scope of the context. 4991 4992 For example: 4993 4994 ```python 4995 @tf.RegisterGradient("CustomSquare") 4996 def _custom_square_grad(op, grad): 4997 # ... 4998 4999 with tf.Graph().as_default() as g: 5000 c = tf.constant(5.0) 5001 s_1 = tf.square(c) # Uses the default gradient for tf.square. 5002 with g.gradient_override_map({"Square": "CustomSquare"}): 5003 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 5004 # gradient of s_2. 5005 ``` 5006 5007 Args: 5008 op_type_map: A dictionary mapping op type strings to alternative op type 5009 strings. 5010 5011 Returns: 5012 A context manager that sets the alternative op type to be used for one 5013 or more ops created in that context. 5014 5015 Raises: 5016 TypeError: If `op_type_map` is not a dictionary mapping strings to 5017 strings. 5018 """ 5019 if not isinstance(op_type_map, dict): 5020 raise TypeError("op_type_map must be a dictionary mapping " 5021 "strings to strings") 5022 # The saved_mappings dictionary stores any currently-set mappings that 5023 # will be overridden by this context manager. 5024 saved_mappings = {} 5025 # Install the given label 5026 for op_type, mapped_op_type in op_type_map.items(): 5027 if not (isinstance(op_type, six.string_types) and 5028 isinstance(mapped_op_type, six.string_types)): 5029 raise TypeError("op_type_map must be a dictionary mapping " 5030 "strings to strings") 5031 try: 5032 saved_mappings[op_type] = self._gradient_override_map[op_type] 5033 except KeyError: 5034 pass 5035 self._gradient_override_map[op_type] = mapped_op_type 5036 try: 5037 yield # The code within the context runs here. 5038 finally: 5039 # Remove the labels set for this context, and restore any saved labels. 5040 for op_type, mapped_op_type in op_type_map.items(): 5041 try: 5042 self._gradient_override_map[op_type] = saved_mappings[op_type] 5043 except KeyError: 5044 del self._gradient_override_map[op_type] 5045 5046 # pylint: enable=g-doc-return-or-yield 5047 5048 def prevent_feeding(self, tensor): 5049 """Marks the given `tensor` as unfeedable in this graph.""" 5050 self._unfeedable_tensors.add(tensor) 5051 5052 def is_feedable(self, tensor): 5053 """Returns `True` if and only if `tensor` is feedable.""" 5054 return tensor not in self._unfeedable_tensors 5055 5056 def prevent_fetching(self, op): 5057 """Marks the given `op` as unfetchable in this graph.""" 5058 self._unfetchable_ops.add(op) 5059 5060 def is_fetchable(self, tensor_or_op): 5061 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 5062 if isinstance(tensor_or_op, Tensor): 5063 return tensor_or_op.op not in self._unfetchable_ops 5064 else: 5065 return tensor_or_op not in self._unfetchable_ops 5066 5067 def switch_to_thread_local(self): 5068 """Make device, colocation and dependencies stacks thread-local. 5069 5070 Device, colocation and dependencies stacks are not thread-local be default. 5071 If multiple threads access them, then the state is shared. This means that 5072 one thread may affect the behavior of another thread. 5073 5074 After this method is called, the stacks become thread-local. If multiple 5075 threads access them, then the state is not shared. Each thread uses its own 5076 value; a thread doesn't affect other threads by mutating such a stack. 5077 5078 The initial value for every thread's stack is set to the current value 5079 of the stack when `switch_to_thread_local()` was first called. 5080 """ 5081 if not self._stack_state_is_thread_local: 5082 self._stack_state_is_thread_local = True 5083 5084 @property 5085 def _device_function_stack(self): 5086 if self._stack_state_is_thread_local: 5087 # This may be called from a thread where device_function_stack doesn't yet 5088 # exist. 5089 # pylint: disable=protected-access 5090 if not hasattr(self._thread_local, "_device_function_stack"): 5091 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 5092 self._thread_local._device_function_stack = stack_copy_for_this_thread 5093 return self._thread_local._device_function_stack 5094 # pylint: enable=protected-access 5095 else: 5096 return self._graph_device_function_stack 5097 5098 @property 5099 def _device_functions_outer_to_inner(self): 5100 user_device_specs = self._device_function_stack.peek_objs() 5101 device_functions = [spec.function for spec in user_device_specs] 5102 device_functions_outer_to_inner = list(reversed(device_functions)) 5103 return device_functions_outer_to_inner 5104 5105 def _snapshot_device_function_stack_metadata(self): 5106 """Return device function stack as a list of TraceableObjects. 5107 5108 Returns: 5109 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 5110 member is a displayable name for the user's argument to Graph.device, and 5111 the filename and lineno members point to the code location where 5112 Graph.device was called directly or indirectly by the user. 5113 """ 5114 snapshot = [] 5115 for obj in self._device_function_stack.peek_traceable_objs(): 5116 obj_copy = obj.copy_metadata() 5117 obj_copy.obj = obj.obj.display_name 5118 snapshot.append(obj_copy) 5119 return snapshot 5120 5121 @_device_function_stack.setter 5122 def _device_function_stack(self, device_function_stack): 5123 if self._stack_state_is_thread_local: 5124 # pylint: disable=protected-access 5125 self._thread_local._device_function_stack = device_function_stack 5126 # pylint: enable=protected-access 5127 else: 5128 self._graph_device_function_stack = device_function_stack 5129 5130 @property 5131 def _colocation_stack(self): 5132 """Return thread-local copy of colocation stack.""" 5133 if self._stack_state_is_thread_local: 5134 # This may be called from a thread where colocation_stack doesn't yet 5135 # exist. 5136 # pylint: disable=protected-access 5137 if not hasattr(self._thread_local, "_colocation_stack"): 5138 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 5139 self._thread_local._colocation_stack = stack_copy_for_this_thread 5140 return self._thread_local._colocation_stack 5141 # pylint: enable=protected-access 5142 else: 5143 return self._graph_colocation_stack 5144 5145 def _snapshot_colocation_stack_metadata(self): 5146 """Return colocation stack metadata as a dictionary.""" 5147 return { 5148 traceable_obj.obj.name: traceable_obj.copy_metadata() 5149 for traceable_obj in self._colocation_stack.peek_traceable_objs() 5150 } 5151 5152 @_colocation_stack.setter 5153 def _colocation_stack(self, colocation_stack): 5154 if self._stack_state_is_thread_local: 5155 # pylint: disable=protected-access 5156 self._thread_local._colocation_stack = colocation_stack 5157 # pylint: enable=protected-access 5158 else: 5159 self._graph_colocation_stack = colocation_stack 5160 5161 @property 5162 def _control_dependencies_stack(self): 5163 if self._stack_state_is_thread_local: 5164 # This may be called from a thread where control_dependencies_stack 5165 # doesn't yet exist. 5166 if not hasattr(self._thread_local, "_control_dependencies_stack"): 5167 self._thread_local._control_dependencies_stack = ( 5168 self._graph_control_dependencies_stack[:]) 5169 return self._thread_local._control_dependencies_stack 5170 else: 5171 return self._graph_control_dependencies_stack 5172 5173 @_control_dependencies_stack.setter 5174 def _control_dependencies_stack(self, control_dependencies): 5175 if self._stack_state_is_thread_local: 5176 self._thread_local._control_dependencies_stack = control_dependencies 5177 else: 5178 self._graph_control_dependencies_stack = control_dependencies 5179 5180 @property 5181 def _distribution_strategy_stack(self): 5182 """A stack to maintain distribution strategy context for each thread.""" 5183 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 5184 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 5185 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 5186 5187 @_distribution_strategy_stack.setter 5188 def _distribution_strategy_stack(self, _distribution_strategy_stack): 5189 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 5190 _distribution_strategy_stack) 5191 5192 @property 5193 def _global_distribute_strategy_scope(self): 5194 """For implementing `tf.distribute.set_strategy()`.""" 5195 if not hasattr(self._thread_local, "distribute_strategy_scope"): 5196 self._thread_local.distribute_strategy_scope = None 5197 return self._thread_local.distribute_strategy_scope 5198 5199 @_global_distribute_strategy_scope.setter 5200 def _global_distribute_strategy_scope(self, distribute_strategy_scope): 5201 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope) 5202 5203 def _mutation_lock(self): 5204 """Returns a lock to guard code that creates & mutates ops. 5205 5206 See the comment for self._group_lock for more info. 5207 """ 5208 return self._group_lock.group(_MUTATION_LOCK_GROUP) 5209 5210 def _session_run_lock(self): 5211 """Returns a lock to guard code for Session.run. 5212 5213 See the comment for self._group_lock for more info. 5214 """ 5215 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 5216 5217 5218# TODO(agarwal): currently device directives in an outer eager scope will not 5219# apply to inner graph mode code. Fix that. 5220 5221 5222@tf_export(v1=["device"]) 5223def device(device_name_or_function): 5224 """Wrapper for `Graph.device()` using the default graph. 5225 5226 See `tf.Graph.device` for more details. 5227 5228 Args: 5229 device_name_or_function: The device name or function to use in the context. 5230 5231 Returns: 5232 A context manager that specifies the default device to use for newly 5233 created ops. 5234 5235 Raises: 5236 RuntimeError: If eager execution is enabled and a function is passed in. 5237 """ 5238 if context.executing_eagerly(): 5239 if callable(device_name_or_function): 5240 raise RuntimeError( 5241 "tf.device does not support functions when eager execution " 5242 "is enabled.") 5243 return context.device(device_name_or_function) 5244 elif executing_eagerly_outside_functions(): 5245 @tf_contextlib.contextmanager 5246 def combined(device_name_or_function): 5247 with get_default_graph().device(device_name_or_function): 5248 if not callable(device_name_or_function): 5249 with context.device(device_name_or_function): 5250 yield 5251 else: 5252 yield 5253 return combined(device_name_or_function) 5254 else: 5255 return get_default_graph().device(device_name_or_function) 5256 5257 5258@tf_export("device", v1=[]) 5259def device_v2(device_name): 5260 """Specifies the device for ops created/executed in this context. 5261 5262 This function specifies the device to be used for ops created/executed in a 5263 particular context. Nested contexts will inherit and also create/execute 5264 their ops on the specified device. If a specific device is not required, 5265 consider not using this function so that a device can be automatically 5266 assigned. In general the use of this function is optional. `device_name` can 5267 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially 5268 specified, containing only a subset of the "/"-separated fields. Any fields 5269 which are specified will override device annotations from outer scopes. 5270 5271 For example: 5272 5273 ```python 5274 with tf.device('/job:foo'): 5275 # ops created here have devices with /job:foo 5276 with tf.device('/job:bar/task:0/device:gpu:2'): 5277 # ops created here have the fully specified device above 5278 with tf.device('/device:gpu:1'): 5279 # ops created here have the device '/job:foo/device:gpu:1' 5280 ``` 5281 5282 Args: 5283 device_name: The device name to use in the context. 5284 5285 Returns: 5286 A context manager that specifies the default device to use for newly 5287 created ops. 5288 5289 Raises: 5290 RuntimeError: If a function is passed in. 5291 """ 5292 if callable(device_name): 5293 raise RuntimeError("tf.device does not support functions.") 5294 return device(device_name) 5295 5296 5297@tf_export(v1=["container"]) 5298def container(container_name): 5299 """Wrapper for `Graph.container()` using the default graph. 5300 5301 Args: 5302 container_name: The container string to use in the context. 5303 5304 Returns: 5305 A context manager that specifies the default container to use for newly 5306 created stateful ops. 5307 """ 5308 return get_default_graph().container(container_name) 5309 5310 5311def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 5312 if context.executing_eagerly(): 5313 if op is not None: 5314 if not hasattr(op, "device"): 5315 op = internal_convert_to_tensor_or_indexed_slices(op) 5316 return device(op.device) 5317 else: 5318 return NullContextmanager() 5319 else: 5320 default_graph = get_default_graph() 5321 if isinstance(op, EagerTensor): 5322 if default_graph.building_function: 5323 return default_graph.device(op.device) 5324 else: 5325 raise ValueError("Encountered an Eager-defined Tensor during graph " 5326 "construction, but a function was not being built.") 5327 return default_graph._colocate_with_for_gradient( 5328 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 5329 5330 5331# Internal interface to colocate_with. colocate_with has been deprecated from 5332# public API. There are still a few internal uses of colocate_with. Add internal 5333# only API for those uses to avoid deprecation warning. 5334def colocate_with(op, ignore_existing=False): 5335 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 5336 5337 5338@deprecation.deprecated( 5339 date=None, instructions="Colocations handled automatically by placer.") 5340@tf_export(v1=["colocate_with"]) 5341def _colocate_with(op, ignore_existing=False): 5342 return colocate_with(op, ignore_existing) 5343 5344 5345@tf_export("control_dependencies") 5346def control_dependencies(control_inputs): 5347 """Wrapper for `Graph.control_dependencies()` using the default graph. 5348 5349 See `tf.Graph.control_dependencies` for more details. 5350 5351 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 5352 this method, as ops execute in the expected order thanks to automatic control 5353 dependencies.* Only use `tf.control_dependencies` when working with v1 5354 `tf.Graph` code. 5355 5356 When eager execution is enabled, any callable object in the `control_inputs` 5357 list will be called. 5358 5359 Args: 5360 control_inputs: A list of `Operation` or `Tensor` objects which must be 5361 executed or computed before running the operations defined in the context. 5362 Can also be `None` to clear the control dependencies. If eager execution 5363 is enabled, any callable object in the `control_inputs` list will be 5364 called. 5365 5366 Returns: 5367 A context manager that specifies control dependencies for all 5368 operations constructed within the context. 5369 """ 5370 if context.executing_eagerly(): 5371 if control_inputs: 5372 # Execute any pending callables. 5373 for control in control_inputs: 5374 if callable(control): 5375 control() 5376 return NullContextmanager() 5377 else: 5378 return get_default_graph().control_dependencies(control_inputs) 5379 5380 5381class _DefaultStack(threading.local): 5382 """A thread-local stack of objects for providing implicit defaults.""" 5383 5384 def __init__(self): 5385 super(_DefaultStack, self).__init__() 5386 self._enforce_nesting = True 5387 self.stack = [] 5388 5389 def get_default(self): 5390 return self.stack[-1] if self.stack else None 5391 5392 def reset(self): 5393 self.stack = [] 5394 5395 def is_cleared(self): 5396 return not self.stack 5397 5398 @property 5399 def enforce_nesting(self): 5400 return self._enforce_nesting 5401 5402 @enforce_nesting.setter 5403 def enforce_nesting(self, value): 5404 self._enforce_nesting = value 5405 5406 @tf_contextlib.contextmanager 5407 def get_controller(self, default): 5408 """A context manager for manipulating a default stack.""" 5409 self.stack.append(default) 5410 try: 5411 yield default 5412 finally: 5413 # stack may be empty if reset() was called 5414 if self.stack: 5415 if self._enforce_nesting: 5416 if self.stack[-1] is not default: 5417 raise AssertionError( 5418 "Nesting violated for default stack of %s objects" % 5419 type(default)) 5420 self.stack.pop() 5421 else: 5422 self.stack.remove(default) 5423 5424 5425_default_session_stack = _DefaultStack() # pylint: disable=protected-access 5426 5427 5428def default_session(session): 5429 """Python "with" handler for defining a default session. 5430 5431 This function provides a means of registering a session for handling 5432 Tensor.eval() and Operation.run() calls. It is primarily intended for use 5433 by session.Session, but can be used with any object that implements 5434 the Session.run() interface. 5435 5436 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 5437 invocations within the scope of a block should be executed by a particular 5438 session. 5439 5440 The default session applies to the current thread only, so it is always 5441 possible to inspect the call stack and determine the scope of a default 5442 session. If you create a new thread, and wish to use the default session 5443 in that thread, you must explicitly add a "with ops.default_session(sess):" 5444 block in that thread's function. 5445 5446 Example: 5447 The following code examples are equivalent: 5448 5449 # 1. Using the Session object directly: 5450 sess = ... 5451 c = tf.constant(5.0) 5452 sess.run(c) 5453 5454 # 2. Using default_session(): 5455 sess = ... 5456 with ops.default_session(sess): 5457 c = tf.constant(5.0) 5458 result = c.eval() 5459 5460 # 3. Overriding default_session(): 5461 sess = ... 5462 with ops.default_session(sess): 5463 c = tf.constant(5.0) 5464 with ops.default_session(...): 5465 c.eval(session=sess) 5466 5467 Args: 5468 session: The session to be installed as the default session. 5469 5470 Returns: 5471 A context manager for the default session. 5472 """ 5473 return _default_session_stack.get_controller(session) 5474 5475 5476@tf_export(v1=["get_default_session"]) 5477def get_default_session(): 5478 """Returns the default session for the current thread. 5479 5480 The returned `Session` will be the innermost session on which a 5481 `Session` or `Session.as_default()` context has been entered. 5482 5483 NOTE: The default session is a property of the current thread. If you 5484 create a new thread, and wish to use the default session in that 5485 thread, you must explicitly add a `with sess.as_default():` in that 5486 thread's function. 5487 5488 Returns: 5489 The default `Session` being used in the current thread. 5490 """ 5491 return _default_session_stack.get_default() 5492 5493 5494def _eval_using_default_session(tensors, feed_dict, graph, session=None): 5495 """Uses the default session to evaluate one or more tensors. 5496 5497 Args: 5498 tensors: A single Tensor, or a list of Tensor objects. 5499 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5500 numpy ndarrays, TensorProtos, or strings. 5501 graph: The graph in which the tensors are defined. 5502 session: (Optional) A different session to use to evaluate "tensors". 5503 5504 Returns: 5505 Either a single numpy ndarray if "tensors" is a single tensor; or a list 5506 of numpy ndarrays that each correspond to the respective element in 5507 "tensors". 5508 5509 Raises: 5510 ValueError: If no default session is available; the default session 5511 does not have "graph" as its graph; or if "session" is specified, 5512 and it does not have "graph" as its graph. 5513 """ 5514 if session is None: 5515 session = get_default_session() 5516 if session is None: 5517 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 5518 "session is registered. Use `with " 5519 "sess.as_default()` or pass an explicit session to " 5520 "`eval(session=sess)`") 5521 if session.graph is not graph: 5522 raise ValueError("Cannot use the default session to evaluate tensor: " 5523 "the tensor's graph is different from the session's " 5524 "graph. Pass an explicit session to " 5525 "`eval(session=sess)`.") 5526 else: 5527 if session.graph is not graph: 5528 raise ValueError("Cannot use the given session to evaluate tensor: " 5529 "the tensor's graph is different from the session's " 5530 "graph.") 5531 return session.run(tensors, feed_dict) 5532 5533 5534def _run_using_default_session(operation, feed_dict, graph, session=None): 5535 """Uses the default session to run "operation". 5536 5537 Args: 5538 operation: The Operation to be run. 5539 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5540 numpy ndarrays, TensorProtos, or strings. 5541 graph: The graph in which "operation" is defined. 5542 session: (Optional) A different session to use to run "operation". 5543 5544 Raises: 5545 ValueError: If no default session is available; the default session 5546 does not have "graph" as its graph; or if "session" is specified, 5547 and it does not have "graph" as its graph. 5548 """ 5549 if session is None: 5550 session = get_default_session() 5551 if session is None: 5552 raise ValueError("Cannot execute operation using `run()`: No default " 5553 "session is registered. Use `with " 5554 "sess.as_default():` or pass an explicit session to " 5555 "`run(session=sess)`") 5556 if session.graph is not graph: 5557 raise ValueError("Cannot use the default session to execute operation: " 5558 "the operation's graph is different from the " 5559 "session's graph. Pass an explicit session to " 5560 "run(session=sess).") 5561 else: 5562 if session.graph is not graph: 5563 raise ValueError("Cannot use the given session to execute operation: " 5564 "the operation's graph is different from the session's " 5565 "graph.") 5566 session.run(operation, feed_dict) 5567 5568 5569class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 5570 """A thread-local stack of objects for providing an implicit default graph.""" 5571 5572 def __init__(self): 5573 super(_DefaultGraphStack, self).__init__() 5574 self._global_default_graph = None 5575 5576 def get_default(self): 5577 """Override that returns a global default if the stack is empty.""" 5578 if self.stack: 5579 return self.stack[-1] 5580 elif self._global_default_graph: 5581 return self._global_default_graph 5582 else: 5583 self._global_default_graph = Graph() 5584 return self._global_default_graph 5585 5586 def _GetGlobalDefaultGraph(self): 5587 if self._global_default_graph is None: 5588 # TODO(mrry): Perhaps log that the default graph is being used, or set 5589 # provide some other feedback to prevent confusion when a mixture of 5590 # the global default graph and an explicit graph are combined in the 5591 # same process. 5592 self._global_default_graph = Graph() 5593 return self._global_default_graph 5594 5595 def reset(self): 5596 super(_DefaultGraphStack, self).reset() 5597 self._global_default_graph = None 5598 5599 @tf_contextlib.contextmanager 5600 def get_controller(self, default): 5601 context.context().context_switches.push(default.building_function, 5602 default.as_default, 5603 default._device_function_stack) 5604 try: 5605 with super(_DefaultGraphStack, 5606 self).get_controller(default) as g, context.graph_mode(): 5607 yield g 5608 finally: 5609 # If an exception is raised here it may be hiding a related exception in 5610 # the try-block (just above). 5611 context.context().context_switches.pop() 5612 5613 5614_default_graph_stack = _DefaultGraphStack() 5615 5616 5617# Shared helper used in init_scope and executing_eagerly_outside_functions 5618# to obtain the outermost context that is not building a function, and the 5619# innermost non empty device stack. 5620def _get_outer_context_and_inner_device_stack(): 5621 """Get the outermost context not building a function.""" 5622 default_graph = get_default_graph() 5623 outer_context = None 5624 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 5625 5626 if not _default_graph_stack.stack: 5627 # If the default graph stack is empty, then we cannot be building a 5628 # function. Install the global graph (which, in this case, is also the 5629 # default graph) as the outer context. 5630 if default_graph.building_function: 5631 raise RuntimeError("The global graph is building a function.") 5632 outer_context = default_graph.as_default 5633 else: 5634 # Find a context that is not building a function. 5635 for stack_entry in reversed(context.context().context_switches.stack): 5636 if not innermost_nonempty_device_stack: 5637 innermost_nonempty_device_stack = stack_entry.device_stack 5638 if not stack_entry.is_building_function: 5639 outer_context = stack_entry.enter_context_fn 5640 break 5641 5642 if outer_context is None: 5643 # As a last resort, obtain the global default graph; this graph doesn't 5644 # necessarily live on the graph stack (and hence it doesn't necessarily 5645 # live on the context stack), but it is stored in the graph stack's 5646 # encapsulating object. 5647 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 5648 5649 if outer_context is None: 5650 # Sanity check; this shouldn't be triggered. 5651 raise RuntimeError("All graphs are building functions, and no " 5652 "eager context was previously active.") 5653 5654 return outer_context, innermost_nonempty_device_stack 5655 5656 5657# pylint: disable=g-doc-return-or-yield,line-too-long 5658@tf_export("init_scope") 5659@tf_contextlib.contextmanager 5660def init_scope(): 5661 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 5662 5663 There is often a need to lift variable initialization ops out of control-flow 5664 scopes, function-building graphs, and gradient tapes. Entering an 5665 `init_scope` is a mechanism for satisfying these desiderata. In particular, 5666 entering an `init_scope` has three effects: 5667 5668 (1) All control dependencies are cleared the moment the scope is entered; 5669 this is equivalent to entering the context manager returned from 5670 `control_dependencies(None)`, which has the side-effect of exiting 5671 control-flow scopes like `tf.cond` and `tf.while_loop`. 5672 5673 (2) All operations that are created while the scope is active are lifted 5674 into the lowest context on the `context_stack` that is not building a 5675 graph function. Here, a context is defined as either a graph or an eager 5676 context. Every context switch, i.e., every installation of a graph as 5677 the default graph and every switch into eager mode, is logged in a 5678 thread-local stack called `context_switches`; the log entry for a 5679 context switch is popped from the stack when the context is exited. 5680 Entering an `init_scope` is equivalent to crawling up 5681 `context_switches`, finding the first context that is not building a 5682 graph function, and entering it. A caveat is that if graph mode is 5683 enabled but the default graph stack is empty, then entering an 5684 `init_scope` will simply install a fresh graph as the default one. 5685 5686 (3) The gradient tape is paused while the scope is active. 5687 5688 When eager execution is enabled, code inside an init_scope block runs with 5689 eager execution enabled even when tracing a `tf.function`. For example: 5690 5691 ```python 5692 tf.compat.v1.enable_eager_execution() 5693 5694 @tf.function 5695 def func(): 5696 # A function constructs TensorFlow graphs, 5697 # it does not execute eagerly. 5698 assert not tf.executing_eagerly() 5699 with tf.init_scope(): 5700 # Initialization runs with eager execution enabled 5701 assert tf.executing_eagerly() 5702 ``` 5703 5704 Raises: 5705 RuntimeError: if graph state is incompatible with this initialization. 5706 """ 5707 # pylint: enable=g-doc-return-or-yield,line-too-long 5708 5709 if context.executing_eagerly(): 5710 # Fastpath. 5711 with tape.stop_recording(): 5712 yield 5713 else: 5714 # Retrieve the active name scope: entering an `init_scope` preserves 5715 # the name scope of the current context. 5716 scope = get_default_graph().get_name_scope() 5717 if scope and scope[-1] != "/": 5718 # Names that end with trailing slashes are treated by `name_scope` as 5719 # absolute. 5720 scope = scope + "/" 5721 5722 outer_context, innermost_nonempty_device_stack = ( 5723 _get_outer_context_and_inner_device_stack()) 5724 5725 outer_graph = None 5726 outer_device_stack = None 5727 try: 5728 with outer_context(), name_scope( 5729 scope, skip_on_eager=False), control_dependencies( 5730 None), tape.stop_recording(): 5731 context_manager = NullContextmanager 5732 context_manager_input = None 5733 if not context.executing_eagerly(): 5734 # The device stack is preserved when lifting into a graph. Eager 5735 # execution doesn't implement device stacks and in particular it 5736 # doesn't support device functions, so in general it's not possible 5737 # to do the same when lifting into the eager context. 5738 outer_graph = get_default_graph() 5739 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 5740 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 5741 elif innermost_nonempty_device_stack is not None: 5742 for device_spec in innermost_nonempty_device_stack.peek_objs(): 5743 if device_spec.function is None: 5744 break 5745 if device_spec.raw_string: 5746 context_manager = context.device 5747 context_manager_input = device_spec.raw_string 5748 break 5749 # It is currently not possible to have a device function in V2, 5750 # but in V1 we are unable to apply device functions in eager mode. 5751 # This means that we will silently skip some of the entries on the 5752 # device stack in V1 + eager mode. 5753 5754 with context_manager(context_manager_input): 5755 yield 5756 finally: 5757 # If an exception is raised here it may be hiding a related exception in 5758 # try-block (just above). 5759 if outer_graph is not None: 5760 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 5761 5762 5763@tf_export(v1=["executing_eagerly_outside_functions"]) 5764def executing_eagerly_outside_functions(): 5765 """Returns True if executing eagerly, even if inside a graph function. 5766 5767 This function will check the outermost context for the program and see if 5768 it is in eager mode. It is useful comparing to `tf.executing_eagerly()`, 5769 which checks the current context and will return `False` within a 5770 `tf.function` body. It can be used to build library that behave differently 5771 in eager runtime and v1 session runtime (deprecated). 5772 5773 Example: 5774 5775 >>> tf.compat.v1.enable_eager_execution() 5776 >>> @tf.function 5777 ... def func(): 5778 ... # A function constructs TensorFlow graphs, it does not execute eagerly, 5779 ... # but the outer most context is still eager. 5780 ... assert not tf.executing_eagerly() 5781 ... return tf.compat.v1.executing_eagerly_outside_functions() 5782 >>> func() 5783 <tf.Tensor: shape=(), dtype=bool, numpy=True> 5784 5785 Returns: 5786 boolean, whether the outermost context is in eager mode. 5787 """ 5788 if context.executing_eagerly(): 5789 return True 5790 else: 5791 outer_context, _ = _get_outer_context_and_inner_device_stack() 5792 with outer_context(): 5793 return context.executing_eagerly() 5794 5795 5796@tf_export("inside_function", v1=[]) 5797def inside_function(): 5798 """Indicates whether the caller code is executing inside a `tf.function`. 5799 5800 Returns: 5801 Boolean, True if the caller code is executing inside a `tf.function` 5802 rather than eagerly. 5803 5804 Example: 5805 5806 >>> tf.inside_function() 5807 False 5808 >>> @tf.function 5809 ... def f(): 5810 ... print(tf.inside_function()) 5811 >>> f() 5812 True 5813 """ 5814 return get_default_graph().building_function 5815 5816 5817@tf_export(v1=["enable_eager_execution"]) 5818def enable_eager_execution(config=None, device_policy=None, 5819 execution_mode=None): 5820 """Enables eager execution for the lifetime of this program. 5821 5822 Eager execution provides an imperative interface to TensorFlow. With eager 5823 execution enabled, TensorFlow functions execute operations immediately (as 5824 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`) 5825 and 5826 return concrete values (as opposed to symbolic references to a node in a 5827 computational graph). 5828 5829 For example: 5830 5831 ```python 5832 tf.compat.v1.enable_eager_execution() 5833 5834 # After eager execution is enabled, operations are executed as they are 5835 # defined and Tensor objects hold concrete values, which can be accessed as 5836 # numpy.ndarray`s through the numpy() method. 5837 assert tf.multiply(6, 7).numpy() == 42 5838 ``` 5839 5840 Eager execution cannot be enabled after TensorFlow APIs have been used to 5841 create or execute graphs. It is typically recommended to invoke this function 5842 at program startup and not in a library (as most libraries should be usable 5843 both with and without eager execution). 5844 5845 Args: 5846 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the 5847 environment in which operations are executed. Note that 5848 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via 5849 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto` 5850 are not implemented (or are irrelevant) when eager execution is enabled. 5851 device_policy: (Optional.) Policy controlling how operations requiring 5852 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 5853 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will 5854 be picked automatically. The value picked may change between TensorFlow 5855 releases. 5856 Valid values: 5857 - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the 5858 placement is not correct. 5859 - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not 5860 on the right device but logs a warning. 5861 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. 5862 Note that this may hide performance problems as there is no notification 5863 provided when operations are blocked on the tensor being copied between 5864 devices. 5865 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 5866 int32 tensors, raising errors on the other ones. 5867 execution_mode: (Optional.) Policy controlling how operations dispatched are 5868 actually executed. When set to None, an appropriate value will be picked 5869 automatically. The value picked may change between TensorFlow releases. 5870 Valid values: 5871 - tf.contrib.eager.SYNC: executes each operation synchronously. 5872 - tf.contrib.eager.ASYNC: executes each operation asynchronously. These 5873 operations may return "non-ready" handles. 5874 5875 Raises: 5876 ValueError: If eager execution is enabled after creating/executing a 5877 TensorFlow graph, or if options provided conflict with a previous call 5878 to this function. 5879 """ 5880 _api_usage_gauge.get_cell().set(True) 5881 if context.default_execution_mode != context.EAGER_MODE: 5882 return enable_eager_execution_internal( 5883 config=config, 5884 device_policy=device_policy, 5885 execution_mode=execution_mode, 5886 server_def=None) 5887 5888 5889@tf_export(v1=["disable_eager_execution"]) 5890def disable_eager_execution(): 5891 """Disables eager execution. 5892 5893 This function can only be called before any Graphs, Ops, or Tensors have been 5894 created. It can be used at the beginning of the program for complex migration 5895 projects from TensorFlow 1.x to 2.x. 5896 """ 5897 _api_usage_gauge.get_cell().set(False) 5898 context.default_execution_mode = context.GRAPH_MODE 5899 c = context.context_safe() 5900 if c is not None: 5901 c._thread_local_data.is_eager = False # pylint: disable=protected-access 5902 5903 5904def enable_eager_execution_internal(config=None, 5905 device_policy=None, 5906 execution_mode=None, 5907 server_def=None): 5908 """Enables eager execution for the lifetime of this program. 5909 5910 Most of the doc string for enable_eager_execution is relevant here as well. 5911 5912 Args: 5913 config: See enable_eager_execution doc string 5914 device_policy: See enable_eager_execution doc string 5915 execution_mode: See enable_eager_execution doc string 5916 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on 5917 remote devices. GrpcServers need to be started by creating an identical 5918 server_def to this, and setting the appropriate task_indexes, so that the 5919 servers can communicate. It will then be possible to execute operations on 5920 remote devices. 5921 5922 Raises: 5923 ValueError 5924 5925 """ 5926 if config is not None and not isinstance(config, config_pb2.ConfigProto): 5927 raise TypeError("config must be a tf.ConfigProto, but got %s" % 5928 type(config)) 5929 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 5930 context.DEVICE_PLACEMENT_WARN, 5931 context.DEVICE_PLACEMENT_SILENT, 5932 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 5933 raise ValueError( 5934 "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" 5935 ) 5936 if execution_mode not in (None, context.SYNC, context.ASYNC): 5937 raise ValueError( 5938 "execution_mode must be one of None, tf.contrib.eager.SYNC, " 5939 "tf.contrib.eager.ASYNC") 5940 if context.default_execution_mode == context.GRAPH_MODE: 5941 graph_mode_has_been_used = ( 5942 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 5943 if graph_mode_has_been_used: 5944 raise ValueError( 5945 "tf.enable_eager_execution must be called at program startup.") 5946 context.default_execution_mode = context.EAGER_MODE 5947 # pylint: disable=protected-access 5948 with context._context_lock: 5949 if context._context is None: 5950 context._set_context_locked(context.Context( 5951 config=config, 5952 device_policy=device_policy, 5953 execution_mode=execution_mode, 5954 server_def=server_def)) 5955 elif ((config is not None and config is not context._context._config) or 5956 (device_policy is not None and 5957 device_policy is not context._context._device_policy) or 5958 (execution_mode is not None and 5959 execution_mode is not context._context._execution_mode)): 5960 raise ValueError( 5961 "Trying to change the options of an active eager" 5962 " execution. Context config: %s, specified config:" 5963 " %s. Context device policy: %s, specified device" 5964 " policy: %s. Context execution mode: %s, " 5965 " specified execution mode %s." % 5966 (context._context._config, config, context._context._device_policy, 5967 device_policy, context._context._execution_mode, execution_mode)) 5968 else: 5969 # We already created everything, so update the thread local data. 5970 context._context._thread_local_data.is_eager = True 5971 5972 # Monkey patch to get rid of an unnecessary conditional since the context is 5973 # now initialized. 5974 context.context = context.context_safe 5975 5976 5977def eager_run(main=None, argv=None): 5978 """Runs the program with an optional main function and argv list. 5979 5980 The program will run with eager execution enabled. 5981 5982 Example: 5983 ```python 5984 import tensorflow as tf 5985 # Import subject to future changes: 5986 from tensorflow.contrib.eager.python import tfe 5987 5988 def main(_): 5989 u = tf.constant(6.0) 5990 v = tf.constant(7.0) 5991 print(u * v) 5992 5993 if __name__ == "__main__": 5994 tfe.run() 5995 ``` 5996 5997 Args: 5998 main: the main function to run. 5999 argv: the arguments to pass to it. 6000 """ 6001 enable_eager_execution() 6002 app.run(main, argv) 6003 6004 6005@tf_export(v1=["reset_default_graph"]) 6006def reset_default_graph(): 6007 """Clears the default graph stack and resets the global default graph. 6008 6009 NOTE: The default graph is a property of the current thread. This 6010 function applies only to the current thread. Calling this function while 6011 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will 6012 result in undefined 6013 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 6014 after calling this function will result in undefined behavior. 6015 Raises: 6016 AssertionError: If this function is called within a nested graph. 6017 """ 6018 if not _default_graph_stack.is_cleared(): 6019 raise AssertionError("Do not use tf.reset_default_graph() to clear " 6020 "nested graphs. If you need a cleared graph, " 6021 "exit the nesting and create a new graph.") 6022 _default_graph_stack.reset() 6023 6024 6025@tf_export(v1=["get_default_graph"]) 6026def get_default_graph(): 6027 """Returns the default graph for the current thread. 6028 6029 The returned graph will be the innermost graph on which a 6030 `Graph.as_default()` context has been entered, or a global default 6031 graph if none has been explicitly created. 6032 6033 NOTE: The default graph is a property of the current thread. If you 6034 create a new thread, and wish to use the default graph in that 6035 thread, you must explicitly add a `with g.as_default():` in that 6036 thread's function. 6037 6038 Returns: 6039 The default `Graph` being used in the current thread. 6040 """ 6041 return _default_graph_stack.get_default() 6042 6043 6044def has_default_graph(): 6045 """Returns True if there is a default graph.""" 6046 return len(_default_graph_stack.stack) >= 1 6047 6048 6049# Exported due to b/171079555 6050@tf_export("__internal__.get_name_scope", v1=[]) 6051def get_name_scope(): 6052 """Returns the current name scope in the default_graph. 6053 6054 For example: 6055 6056 ```python 6057 with tf.name_scope('scope1'): 6058 with tf.name_scope('scope2'): 6059 print(tf.get_name_scope()) 6060 ``` 6061 would print the string `scope1/scope2`. 6062 6063 Returns: 6064 A string representing the current name scope. 6065 """ 6066 if context.executing_eagerly(): 6067 return context.context().scope_name.rstrip("/") 6068 return get_default_graph().get_name_scope() 6069 6070 6071def _assert_same_graph(original_item, item): 6072 """Fail if the 2 items are from different graphs. 6073 6074 Args: 6075 original_item: Original item to check against. 6076 item: Item to check. 6077 6078 Raises: 6079 ValueError: if graphs do not match. 6080 """ 6081 original_graph = getattr(original_item, "graph", None) 6082 graph = getattr(item, "graph", None) 6083 if original_graph and graph and original_graph is not graph: 6084 raise ValueError( 6085 "%s must be from the same graph as %s (graphs are %s and %s)." % 6086 (item, original_item, graph, original_graph)) 6087 6088 6089def _get_graph_from_inputs(op_input_list, graph=None): 6090 """Returns the appropriate graph to use for the given inputs. 6091 6092 This library method provides a consistent algorithm for choosing the graph 6093 in which an Operation should be constructed: 6094 6095 1. If the default graph is being used to construct a function, we 6096 use the default graph. 6097 2. If the "graph" is specified explicitly, we validate that all of the inputs 6098 in "op_input_list" are compatible with that graph. 6099 3. Otherwise, we attempt to select a graph from the first Operation- 6100 or Tensor-valued input in "op_input_list", and validate that all other 6101 such inputs are in the same graph. 6102 4. If the graph was not specified and it could not be inferred from 6103 "op_input_list", we attempt to use the default graph. 6104 6105 Args: 6106 op_input_list: A list of inputs to an operation, which may include `Tensor`, 6107 `Operation`, and other objects that may be converted to a graph element. 6108 graph: (Optional) The explicit graph to use. 6109 6110 Raises: 6111 TypeError: If op_input_list is not a list or tuple, or if graph is not a 6112 Graph. 6113 ValueError: If a graph is explicitly passed and not all inputs are from it, 6114 or if the inputs are from multiple graphs, or we could not find a graph 6115 and there was no default graph. 6116 6117 Returns: 6118 The appropriate graph to use for the given inputs. 6119 6120 """ 6121 current_default_graph = get_default_graph() 6122 if current_default_graph.building_function: 6123 return current_default_graph 6124 6125 op_input_list = tuple(op_input_list) # Handle generators correctly 6126 if graph and not isinstance(graph, Graph): 6127 raise TypeError("Input graph needs to be a Graph: %s" % (graph,)) 6128 6129 # 1. We validate that all of the inputs are from the same graph. This is 6130 # either the supplied graph parameter, or the first one selected from one 6131 # the graph-element-valued inputs. In the latter case, we hold onto 6132 # that input in original_graph_element so we can provide a more 6133 # informative error if a mismatch is found. 6134 original_graph_element = None 6135 for op_input in op_input_list: 6136 # Determine if this is a valid graph_element. 6137 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 6138 # up. 6139 graph_element = None 6140 if (isinstance(op_input, (Operation, internal.NativeObject)) and 6141 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 6142 graph_element = op_input 6143 else: 6144 graph_element = _as_graph_element(op_input) 6145 6146 if graph_element is not None: 6147 if not graph: 6148 original_graph_element = graph_element 6149 graph = getattr(graph_element, "graph", None) 6150 elif original_graph_element is not None: 6151 _assert_same_graph(original_graph_element, graph_element) 6152 elif graph_element.graph is not graph: 6153 raise ValueError("%s is not from the passed-in graph." % graph_element) 6154 6155 # 2. If all else fails, we use the default graph, which is always there. 6156 return graph or current_default_graph 6157 6158 6159@tf_export(v1=["GraphKeys"]) 6160class GraphKeys(object): 6161 """Standard names to use for graph collections. 6162 6163 The standard library uses various well-known names to collect and 6164 retrieve values associated with a graph. For example, the 6165 `tf.Optimizer` subclasses default to optimizing the variables 6166 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 6167 specified, but it is also possible to pass an explicit list of 6168 variables. 6169 6170 The following standard keys are defined: 6171 6172 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 6173 across distributed environment (model variables are subset of these). See 6174 `tf.compat.v1.global_variables` 6175 for more details. 6176 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 6177 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 6178 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 6179 machine. Usually used for temporarily variables, like counters. 6180 Note: use `tf.contrib.framework.local_variable` to add to this collection. 6181 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 6182 model for inference (feed forward). Note: use 6183 `tf.contrib.framework.model_variable` to add to this collection. 6184 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 6185 be trained by an optimizer. See 6186 `tf.compat.v1.trainable_variables` 6187 for more details. 6188 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 6189 graph. See 6190 `tf.compat.v1.summary.merge_all` 6191 for more details. 6192 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 6193 produce input for a computation. See 6194 `tf.compat.v1.train.start_queue_runners` 6195 for more details. 6196 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 6197 keep moving averages. See 6198 `tf.compat.v1.moving_average_variables` 6199 for more details. 6200 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 6201 construction. 6202 6203 The following standard keys are _defined_, but their collections are **not** 6204 automatically populated as many of the others are: 6205 6206 * `WEIGHTS` 6207 * `BIASES` 6208 * `ACTIVATIONS` 6209 """ 6210 6211 # Key to collect Variable objects that are global (shared across machines). 6212 # Default collection for all variables, except local ones. 6213 GLOBAL_VARIABLES = "variables" 6214 # Key to collect local variables that are local to the machine and are not 6215 # saved/restored. 6216 LOCAL_VARIABLES = "local_variables" 6217 # Key to collect local variables which are used to accumulate interal state 6218 # to be used in tf.metrics.*. 6219 METRIC_VARIABLES = "metric_variables" 6220 # Key to collect model variables defined by layers. 6221 MODEL_VARIABLES = "model_variables" 6222 # Key to collect Variable objects that will be trained by the 6223 # optimizers. 6224 TRAINABLE_VARIABLES = "trainable_variables" 6225 # Key to collect summaries. 6226 SUMMARIES = "summaries" 6227 # Key to collect QueueRunners. 6228 QUEUE_RUNNERS = "queue_runners" 6229 # Key to collect table initializers. 6230 TABLE_INITIALIZERS = "table_initializer" 6231 # Key to collect asset filepaths. An asset represents an external resource 6232 # like a vocabulary file. 6233 ASSET_FILEPATHS = "asset_filepaths" 6234 # Key to collect Variable objects that keep moving averages. 6235 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 6236 # Key to collect regularization losses at graph construction. 6237 REGULARIZATION_LOSSES = "regularization_losses" 6238 # Key to collect concatenated sharded variables. 6239 CONCATENATED_VARIABLES = "concatenated_variables" 6240 # Key to collect savers. 6241 SAVERS = "savers" 6242 # Key to collect weights 6243 WEIGHTS = "weights" 6244 # Key to collect biases 6245 BIASES = "biases" 6246 # Key to collect activations 6247 ACTIVATIONS = "activations" 6248 # Key to collect update_ops 6249 UPDATE_OPS = "update_ops" 6250 # Key to collect losses 6251 LOSSES = "losses" 6252 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 6253 SAVEABLE_OBJECTS = "saveable_objects" 6254 # Key to collect all shared resources used by the graph which need to be 6255 # initialized once per cluster. 6256 RESOURCES = "resources" 6257 # Key to collect all shared resources used in this graph which need to be 6258 # initialized once per session. 6259 LOCAL_RESOURCES = "local_resources" 6260 # Trainable resource-style variables. 6261 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 6262 6263 # Key to indicate various ops. 6264 INIT_OP = "init_op" 6265 LOCAL_INIT_OP = "local_init_op" 6266 READY_OP = "ready_op" 6267 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 6268 SUMMARY_OP = "summary_op" 6269 GLOBAL_STEP = "global_step" 6270 6271 # Used to count the number of evaluations performed during a single evaluation 6272 # run. 6273 EVAL_STEP = "eval_step" 6274 TRAIN_OP = "train_op" 6275 6276 # Key for control flow context. 6277 COND_CONTEXT = "cond_context" 6278 WHILE_CONTEXT = "while_context" 6279 6280 # Used to store v2 summary names. 6281 _SUMMARY_COLLECTION = "_SUMMARY_V2" 6282 6283 # List of all collections that keep track of variables. 6284 _VARIABLE_COLLECTIONS = [ 6285 GLOBAL_VARIABLES, 6286 LOCAL_VARIABLES, 6287 METRIC_VARIABLES, 6288 MODEL_VARIABLES, 6289 TRAINABLE_VARIABLES, 6290 MOVING_AVERAGE_VARIABLES, 6291 CONCATENATED_VARIABLES, 6292 TRAINABLE_RESOURCE_VARIABLES, 6293 ] 6294 6295 # Key for streaming model ports. 6296 # NOTE(yuanbyu): internal and experimental. 6297 _STREAMING_MODEL_PORTS = "streaming_model_ports" 6298 6299 @decorator_utils.classproperty 6300 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 6301 def VARIABLES(cls): # pylint: disable=no-self-argument 6302 return cls.GLOBAL_VARIABLES 6303 6304 6305def dismantle_graph(graph): 6306 """Cleans up reference cycles from a `Graph`. 6307 6308 Helpful for making sure the garbage collector doesn't need to run after a 6309 temporary `Graph` is no longer needed. 6310 6311 Args: 6312 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 6313 after this function runs. 6314 """ 6315 memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access 6316 6317 # Now clean up Operation<->Graph reference cycles by clearing all of the 6318 # attributes for the Graph and its ops. 6319 graph_operations = graph.get_operations() 6320 for op in graph_operations: 6321 op.__dict__ = {} 6322 graph.__dict__ = {} 6323 6324 6325@tf_export(v1=["add_to_collection"]) 6326def add_to_collection(name, value): 6327 """Wrapper for `Graph.add_to_collection()` using the default graph. 6328 6329 See `tf.Graph.add_to_collection` 6330 for more details. 6331 6332 Args: 6333 name: The key for the collection. For example, the `GraphKeys` class 6334 contains many standard names for collections. 6335 value: The value to add to the collection. 6336 6337 @compatibility(eager) 6338 Collections are only supported in eager when variables are created inside 6339 an EagerVariableStore (e.g. as part of a layer or template). 6340 @end_compatibility 6341 """ 6342 get_default_graph().add_to_collection(name, value) 6343 6344 6345@tf_export(v1=["add_to_collections"]) 6346def add_to_collections(names, value): 6347 """Wrapper for `Graph.add_to_collections()` using the default graph. 6348 6349 See `tf.Graph.add_to_collections` 6350 for more details. 6351 6352 Args: 6353 names: The key for the collections. The `GraphKeys` class contains many 6354 standard names for collections. 6355 value: The value to add to the collections. 6356 6357 @compatibility(eager) 6358 Collections are only supported in eager when variables are created inside 6359 an EagerVariableStore (e.g. as part of a layer or template). 6360 @end_compatibility 6361 """ 6362 get_default_graph().add_to_collections(names, value) 6363 6364 6365@tf_export(v1=["get_collection_ref"]) 6366def get_collection_ref(key): 6367 """Wrapper for `Graph.get_collection_ref()` using the default graph. 6368 6369 See `tf.Graph.get_collection_ref` 6370 for more details. 6371 6372 Args: 6373 key: The key for the collection. For example, the `GraphKeys` class contains 6374 many standard names for collections. 6375 6376 Returns: 6377 The list of values in the collection with the given `name`, or an empty 6378 list if no value has been added to that collection. Note that this returns 6379 the collection list itself, which can be modified in place to change the 6380 collection. 6381 6382 @compatibility(eager) 6383 Collections are not supported when eager execution is enabled. 6384 @end_compatibility 6385 """ 6386 return get_default_graph().get_collection_ref(key) 6387 6388 6389@tf_export(v1=["get_collection"]) 6390def get_collection(key, scope=None): 6391 """Wrapper for `Graph.get_collection()` using the default graph. 6392 6393 See `tf.Graph.get_collection` 6394 for more details. 6395 6396 Args: 6397 key: The key for the collection. For example, the `GraphKeys` class contains 6398 many standard names for collections. 6399 scope: (Optional.) If supplied, the resulting list is filtered to include 6400 only items whose `name` attribute matches using `re.match`. Items without 6401 a `name` attribute are never returned if a scope is supplied and the 6402 choice or `re.match` means that a `scope` without special tokens filters 6403 by prefix. 6404 6405 Returns: 6406 The list of values in the collection with the given `name`, or 6407 an empty list if no value has been added to that collection. The 6408 list contains the values in the order under which they were 6409 collected. 6410 6411 @compatibility(eager) 6412 Collections are not supported when eager execution is enabled. 6413 @end_compatibility 6414 """ 6415 return get_default_graph().get_collection(key, scope) 6416 6417 6418def get_all_collection_keys(): 6419 """Returns a list of collections used in the default graph.""" 6420 return get_default_graph().get_all_collection_keys() 6421 6422 6423def name_scope(name, default_name=None, values=None, skip_on_eager=True): 6424 """Internal-only entry point for `name_scope*`. 6425 6426 Internal ops do not use the public API and instead rely on 6427 `ops.name_scope` regardless of the execution mode. This function 6428 dispatches to the correct `name_scope*` implementation based on 6429 the arguments provided and the current mode. Specifically, 6430 6431 * if `values` contains a graph tensor `Graph.name_scope` is used; 6432 * `name_scope_v1` is used in graph mode; 6433 * `name_scope_v2` -- in eager mode. 6434 6435 Args: 6436 name: The name argument that is passed to the op function. 6437 default_name: The default name to use if the `name` argument is `None`. 6438 values: The list of `Tensor` arguments that are passed to the op function. 6439 skip_on_eager: Indicates to return NullContextmanager if executing eagerly. 6440 By default this is True since naming tensors and operations in eager mode 6441 have little use and cause unnecessary performance overhead. However, it is 6442 important to preserve variable names since they are often useful for 6443 debugging and saved models. 6444 6445 Returns: 6446 `name_scope*` context manager. 6447 """ 6448 if not context.executing_eagerly(): 6449 return internal_name_scope_v1(name, default_name, values) 6450 6451 if skip_on_eager: 6452 return NullContextmanager() 6453 6454 name = default_name if name is None else name 6455 if values: 6456 # The presence of a graph tensor in `values` overrides the context. 6457 # TODO(slebedev): this is Keras-specific and should be removed. 6458 # pylint: disable=unidiomatic-typecheck 6459 graph_value = next((value for value in values if type(value) == Tensor), 6460 None) 6461 # pylint: enable=unidiomatic-typecheck 6462 if graph_value is not None: 6463 return graph_value.graph.name_scope(name) 6464 6465 return name_scope_v2(name or "") 6466 6467 6468class internal_name_scope_v1(object): # pylint: disable=invalid-name 6469 """Graph-only version of `name_scope_v1`.""" 6470 6471 @property 6472 def name(self): 6473 return self._name 6474 6475 def __init__(self, name, default_name=None, values=None): 6476 """Initialize the context manager. 6477 6478 Args: 6479 name: The name argument that is passed to the op function. 6480 default_name: The default name to use if the `name` argument is `None`. 6481 values: The list of `Tensor` arguments that are passed to the op function. 6482 6483 Raises: 6484 TypeError: if `default_name` is passed in but not a string. 6485 """ 6486 if not (default_name is None or isinstance(default_name, six.string_types)): 6487 raise TypeError( 6488 "`default_name` type (%s) is not a string type. You likely meant to " 6489 "pass this into the `values` kwarg." % type(default_name)) 6490 self._name = default_name if name is None else name 6491 self._default_name = default_name 6492 self._values = values 6493 6494 def __enter__(self): 6495 """Start the scope block. 6496 6497 Returns: 6498 The scope name. 6499 6500 Raises: 6501 ValueError: if neither `name` nor `default_name` is provided 6502 but `values` are. 6503 """ 6504 if self._name is None and self._values is not None: 6505 # We only raise an error if values is not None (provided) because 6506 # currently tf.name_scope(None) (values=None then) is sometimes used as 6507 # an idiom to reset to top scope. 6508 raise ValueError( 6509 "At least one of name (%s) and default_name (%s) must be provided." 6510 % (self._name, self._default_name)) 6511 6512 g = get_default_graph() 6513 if self._values and not g.building_function: 6514 # Specialize based on the knowledge that `_get_graph_from_inputs()` 6515 # ignores `inputs` when building a function. 6516 g_from_inputs = _get_graph_from_inputs(self._values) 6517 if g_from_inputs is not g: 6518 g = g_from_inputs 6519 self._g_manager = g.as_default() 6520 self._g_manager.__enter__() 6521 else: 6522 self._g_manager = None 6523 else: 6524 self._g_manager = None 6525 6526 try: 6527 self._name_scope = g.name_scope(self._name) 6528 return self._name_scope.__enter__() 6529 except: 6530 if self._g_manager is not None: 6531 self._g_manager.__exit__(*sys.exc_info()) 6532 raise 6533 6534 def __exit__(self, *exc_info): 6535 self._name_scope.__exit__(*exc_info) 6536 if self._g_manager is not None: 6537 self._g_manager.__exit__(*exc_info) 6538 6539 6540# Named like a function for backwards compatibility with the 6541# @tf_contextlib.contextmanager version, which was switched to a class to avoid 6542# some object creation overhead. 6543@tf_export(v1=["name_scope"]) 6544class name_scope_v1(object): # pylint: disable=invalid-name 6545 """A context manager for use when defining a Python op. 6546 6547 This context manager validates that the given `values` are from the 6548 same graph, makes that graph the default graph, and pushes a 6549 name scope in that graph (see 6550 `tf.Graph.name_scope` 6551 for more details on that). 6552 6553 For example, to define a new Python op called `my_op`: 6554 6555 ```python 6556 def my_op(a, b, c, name=None): 6557 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 6558 a = tf.convert_to_tensor(a, name="a") 6559 b = tf.convert_to_tensor(b, name="b") 6560 c = tf.convert_to_tensor(c, name="c") 6561 # Define some computation that uses `a`, `b`, and `c`. 6562 return foo_op(..., name=scope) 6563 ``` 6564 """ 6565 6566 __slots__ = ["_name", "_name_scope"] 6567 6568 @property 6569 def name(self): 6570 return self._name 6571 6572 def __init__(self, name, default_name=None, values=None): 6573 """Initialize the context manager. 6574 6575 Args: 6576 name: The name argument that is passed to the op function. 6577 default_name: The default name to use if the `name` argument is `None`. 6578 values: The list of `Tensor` arguments that are passed to the op function. 6579 6580 Raises: 6581 TypeError: if `default_name` is passed in but not a string. 6582 """ 6583 self._name_scope = name_scope( 6584 name, default_name, values, skip_on_eager=False) 6585 self._name = default_name if name is None else name 6586 6587 def __enter__(self): 6588 return self._name_scope.__enter__() 6589 6590 def __exit__(self, *exc_info): 6591 return self._name_scope.__exit__(*exc_info) 6592 6593 6594@tf_export("name_scope", v1=[]) 6595class name_scope_v2(object): 6596 """A context manager for use when defining a Python op. 6597 6598 This context manager pushes a name scope, which will make the name of all 6599 operations added within it have a prefix. 6600 6601 For example, to define a new Python op called `my_op`: 6602 6603 ```python 6604 def my_op(a, b, c, name=None): 6605 with tf.name_scope("MyOp") as scope: 6606 a = tf.convert_to_tensor(a, name="a") 6607 b = tf.convert_to_tensor(b, name="b") 6608 c = tf.convert_to_tensor(c, name="c") 6609 # Define some computation that uses `a`, `b`, and `c`. 6610 return foo_op(..., name=scope) 6611 ``` 6612 6613 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 6614 and `MyOp/c`. 6615 6616 Inside a `tf.function`, if the scope name already exists, the name will be 6617 made unique by appending `_n`. For example, calling `my_op` the second time 6618 will generate `MyOp_1/a`, etc. 6619 """ 6620 6621 __slots__ = ["_name", "_exit_fns"] 6622 6623 def __init__(self, name): 6624 """Initialize the context manager. 6625 6626 Args: 6627 name: The prefix to use on all names created within the name scope. 6628 6629 Raises: 6630 ValueError: If name is not a string. 6631 """ 6632 if not isinstance(name, six.string_types): 6633 raise ValueError("name for name_scope must be a string.") 6634 self._name = name 6635 self._exit_fns = [] 6636 6637 @property 6638 def name(self): 6639 return self._name 6640 6641 def __enter__(self): 6642 """Start the scope block. 6643 6644 Returns: 6645 The scope name. 6646 """ 6647 ctx = context.context() 6648 if ctx.executing_eagerly(): 6649 # Names are not auto-incremented in eager mode. 6650 # A trailing slash breaks out of nested name scopes, indicating a 6651 # fully specified scope name, for compatibility with Graph.name_scope. 6652 # This also prevents auto-incrementing. 6653 old_name = ctx.scope_name 6654 name = self._name 6655 if not name: 6656 scope_name = "" 6657 elif name[-1] == "/": 6658 scope_name = name 6659 elif old_name: 6660 scope_name = old_name + name + "/" 6661 else: 6662 scope_name = name + "/" 6663 ctx.scope_name = scope_name 6664 6665 def _restore_name_scope(*_): 6666 ctx.scope_name = old_name 6667 6668 self._exit_fns.append(_restore_name_scope) 6669 else: 6670 scope = get_default_graph().name_scope(self._name) 6671 scope_name = scope.__enter__() 6672 self._exit_fns.append(scope.__exit__) 6673 return scope_name 6674 6675 def __exit__(self, type_arg, value_arg, traceback_arg): 6676 self._exit_fns.pop()(type_arg, value_arg, traceback_arg) 6677 return False # False values do not suppress exceptions 6678 6679 def __getstate__(self): 6680 return self._name, self._exit_fns 6681 6682 def __setstate__(self, state): 6683 self._name = state[0] 6684 self._exit_fns = state[1] 6685 6686 6687def strip_name_scope(name, export_scope): 6688 """Removes name scope from a name. 6689 6690 Args: 6691 name: A `string` name. 6692 export_scope: Optional `string`. Name scope to remove. 6693 6694 Returns: 6695 Name with name scope removed, or the original name if export_scope 6696 is None. 6697 """ 6698 if export_scope: 6699 if export_scope[-1] == "/": 6700 export_scope = export_scope[:-1] 6701 6702 try: 6703 # Strips export_scope/, export_scope///, 6704 # ^export_scope/, loc:@export_scope/. 6705 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 6706 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 6707 except TypeError as e: 6708 # If the name is not of a type we can process, simply return it. 6709 logging.warning(e) 6710 return name 6711 else: 6712 return name 6713 6714 6715def prepend_name_scope(name, import_scope): 6716 """Prepends name scope to a name. 6717 6718 Args: 6719 name: A `string` name. 6720 import_scope: Optional `string`. Name scope to add. 6721 6722 Returns: 6723 Name with name scope added, or the original name if import_scope 6724 is None. 6725 """ 6726 if import_scope: 6727 if import_scope[-1] == "/": 6728 import_scope = import_scope[:-1] 6729 6730 try: 6731 str_to_replace = r"([\^]|loc:@|^)(.*)" 6732 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 6733 compat.as_str(name)) 6734 except TypeError as e: 6735 # If the name is not of a type we can process, simply return it. 6736 logging.warning(e) 6737 return name 6738 else: 6739 return name 6740 6741 6742# pylint: disable=g-doc-return-or-yield 6743# pylint: disable=not-context-manager 6744@tf_export(v1=["op_scope"]) 6745@tf_contextlib.contextmanager 6746def op_scope(values, name, default_name=None): 6747 """DEPRECATED. Same as name_scope above, just different argument order.""" 6748 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 6749 " use tf.name_scope(name, default_name, values)") 6750 with name_scope(name, default_name=default_name, values=values) as scope: 6751 yield scope 6752 6753 6754_proto_function_registry = registry.Registry("proto functions") 6755 6756 6757def register_proto_function(collection_name, 6758 proto_type=None, 6759 to_proto=None, 6760 from_proto=None): 6761 """Registers `to_proto` and `from_proto` functions for collection_name. 6762 6763 `to_proto` function converts a Python object to the corresponding protocol 6764 buffer, and returns the protocol buffer. 6765 6766 `from_proto` function converts protocol buffer into a Python object, and 6767 returns the object.. 6768 6769 Args: 6770 collection_name: Name of the collection. 6771 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 6772 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 6773 to_proto: Function that implements Python object to protobuf conversion. 6774 from_proto: Function that implements protobuf to Python object conversion. 6775 """ 6776 if to_proto and not callable(to_proto): 6777 raise TypeError("to_proto must be callable.") 6778 if from_proto and not callable(from_proto): 6779 raise TypeError("from_proto must be callable.") 6780 6781 _proto_function_registry.register((proto_type, to_proto, from_proto), 6782 collection_name) 6783 6784 6785def get_collection_proto_type(collection_name): 6786 """Returns the proto_type for collection_name.""" 6787 try: 6788 return _proto_function_registry.lookup(collection_name)[0] 6789 except LookupError: 6790 return None 6791 6792 6793def get_to_proto_function(collection_name): 6794 """Returns the to_proto function for collection_name.""" 6795 try: 6796 return _proto_function_registry.lookup(collection_name)[1] 6797 except LookupError: 6798 return None 6799 6800 6801def get_from_proto_function(collection_name): 6802 """Returns the from_proto function for collection_name.""" 6803 try: 6804 return _proto_function_registry.lookup(collection_name)[2] 6805 except LookupError: 6806 return None 6807 6808 6809def _op_to_colocate_with(v, graph): 6810 """Operation object corresponding to v to use for colocation constraints.""" 6811 if v is None: 6812 return None, None 6813 if isinstance(v, Operation): 6814 return v, None 6815 6816 # We always want to colocate with the reference op. 6817 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 6818 # 6819 # What this should be is: 6820 # if isinstance(v, ResourceVariable): 6821 # return v.handle.op, v 6822 # However, that would require a circular import dependency. 6823 # As of October 2018, there were attempts underway to remove 6824 # colocation constraints altogether. Assuming that will 6825 # happen soon, perhaps this hack to work around the circular 6826 # import dependency is acceptable. 6827 if hasattr(v, "handle") and isinstance(v.handle, Tensor): 6828 device_only_candidate = lambda: None 6829 device_only_candidate.device = v.device 6830 device_only_candidate.name = v.name 6831 if graph.building_function: 6832 return graph.capture(v.handle).op, device_only_candidate 6833 else: 6834 return v.handle.op, device_only_candidate 6835 return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op, None 6836 6837 6838def _is_keras_symbolic_tensor(x): 6839 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" 6840 6841 6842# These symbols were originally defined in this module; import them for 6843# backwards compatibility until all references have been updated to access 6844# them from the indexed_slices.py module. 6845IndexedSlices = indexed_slices.IndexedSlices 6846IndexedSlicesValue = indexed_slices.IndexedSlicesValue 6847convert_to_tensor_or_indexed_slices = \ 6848 indexed_slices.convert_to_tensor_or_indexed_slices 6849convert_n_to_tensor_or_indexed_slices = \ 6850 indexed_slices.convert_n_to_tensor_or_indexed_slices 6851internal_convert_to_tensor_or_indexed_slices = \ 6852 indexed_slices.internal_convert_to_tensor_or_indexed_slices 6853internal_convert_n_to_tensor_or_indexed_slices = \ 6854 indexed_slices.internal_convert_n_to_tensor_or_indexed_slices 6855register_tensor_conversion_function = \ 6856 tensor_conversion_registry.register_tensor_conversion_function 6857 6858 6859# Helper functions for op wrapper modules generated by `python_op_gen`. 6860 6861 6862def to_raw_op(f): 6863 """Make a given op wrapper function `f` raw. 6864 6865 Raw op wrappers can only be called with keyword arguments. 6866 6867 Args: 6868 f: An op wrapper function to make raw. 6869 6870 Returns: 6871 Raw `f`. 6872 """ 6873 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail 6874 # due to double-registration. 6875 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, 6876 f.__closure__) 6877 return kwarg_only(f) 6878 6879 6880def raise_from_not_ok_status(e, name): 6881 message = e.message + (" name: " + name if name is not None else "") 6882 # pylint: disable=protected-access 6883 six.raise_from(core._status_to_exception(e.code, message), None) 6884 # pylint: enable=protected-access 6885 6886 6887def add_exit_callback_to_default_func_graph(fn): 6888 """Add a callback to run when the default function graph goes out of scope. 6889 6890 Usage: 6891 6892 ```python 6893 @tf.function 6894 def fn(x, v): 6895 expensive = expensive_object(v) 6896 add_exit_callback_to_default_func_graph(lambda: expensive.release()) 6897 return g(x, expensive) 6898 6899 fn(x=tf.constant(...), v=...) 6900 # `expensive` has been released. 6901 ``` 6902 6903 Args: 6904 fn: A callable that takes no arguments and whose output is ignored. 6905 To be executed when exiting func graph scope. 6906 6907 Raises: 6908 RuntimeError: If executed when the current default graph is not a FuncGraph, 6909 or not currently executing in function creation mode (e.g., if inside 6910 an init_scope). 6911 """ 6912 default_graph = get_default_graph() 6913 if not default_graph._building_function: # pylint: disable=protected-access 6914 raise RuntimeError( 6915 "Cannot add scope exit callbacks when not building a function. " 6916 "Default graph: {}".format(default_graph)) 6917 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access 6918 6919 6920def _reconstruct_sequence_inputs(op_def, inputs, attrs): 6921 """Regroups a flat list of input tensors into scalar and sequence inputs. 6922 6923 Args: 6924 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 6925 inputs: a list of input `Tensor`s to the op. 6926 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 6927 how long each sequence is) 6928 6929 Returns: 6930 A list of `Tensor`s (corresponding to scalar inputs) and lists of 6931 `Tensor`s (corresponding to sequence inputs). 6932 """ 6933 grouped_inputs = [] 6934 i = 0 6935 for input_arg in op_def.input_arg: 6936 if input_arg.number_attr: 6937 input_len = attrs[input_arg.number_attr].i 6938 is_sequence = True 6939 elif input_arg.type_list_attr: 6940 input_len = len(attrs[input_arg.type_list_attr].list.type) 6941 is_sequence = True 6942 else: 6943 input_len = 1 6944 is_sequence = False 6945 6946 if is_sequence: 6947 grouped_inputs.append(inputs[i:i + input_len]) 6948 else: 6949 grouped_inputs.append(inputs[i]) 6950 i += input_len 6951 6952 assert i == len(inputs) 6953 return grouped_inputs 6954 6955 6956_numpy_style_type_promotion = False 6957 6958 6959def enable_numpy_style_type_promotion(): 6960 """If called, follows NumPy's rules for type promotion. 6961 6962 Used for enabling NumPy behavior on methods for TF NumPy. 6963 """ 6964 global _numpy_style_type_promotion 6965 _numpy_style_type_promotion = True 6966 6967 6968_numpy_style_slicing = False 6969 6970 6971def enable_numpy_style_slicing(): 6972 """If called, follows NumPy's rules for slicing Tensors. 6973 6974 Used for enabling NumPy behavior on slicing for TF NumPy. 6975 """ 6976 global _numpy_style_slicing 6977 _numpy_style_slicing = True 6978 6979 6980class _TensorIterator(object): 6981 """Iterates over the leading dim of a Tensor. Performs no error checks.""" 6982 6983 __slots__ = ["_tensor", "_index", "_limit"] 6984 6985 def __init__(self, tensor, dim0): 6986 self._tensor = tensor 6987 self._index = 0 6988 self._limit = dim0 6989 6990 def __iter__(self): 6991 return self 6992 6993 def __next__(self): 6994 if self._index == self._limit: 6995 raise StopIteration 6996 result = self._tensor[self._index] 6997 self._index += 1 6998 return result 6999 7000 next = __next__ # python2.x compatibility. 7001 7002 7003def set_int_list_attr(op, attr_name, ints): 7004 """TF internal method used to set a list(int) attribute in the node_def.""" 7005 ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) 7006 op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access 7007 7008 7009def _get_enclosing_context(graph): 7010 # pylint: disable=protected-access 7011 if graph is None: 7012 return None 7013 7014 if graph._control_flow_context is not None: 7015 return graph._control_flow_context 7016 7017 if graph.building_function and hasattr(graph, "outer_graph"): 7018 return _get_enclosing_context(graph.outer_graph) 7019 7020 7021def get_resource_handle_data(graph_op): 7022 assert type(graph_op) == Tensor # pylint: disable=unidiomatic-typecheck 7023 7024 handle_data = pywrap_tf_session.GetHandleShapeAndType( 7025 graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access 7026 7027 return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( 7028 compat.as_bytes(handle_data)) 7029 7030 7031def _copy_handle_data_to_arg_def(tensor, arg_def): 7032 handle_data = get_resource_handle_data(tensor) 7033 if handle_data.shape_and_type: 7034 shape_and_type = handle_data.shape_and_type[0] 7035 proto = arg_def.handle_data.add() 7036 proto.dtype = shape_and_type.dtype 7037 proto.shape.CopyFrom(handle_data.shape_and_type[0].shape) 7038