1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions used to construct graphs.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import re 24import sys 25import threading 26 27import numpy as np 28import six 29from six.moves import xrange # pylint: disable=redefined-builtin 30 31from tensorflow.core.framework import attr_value_pb2 32from tensorflow.core.framework import function_pb2 33from tensorflow.core.framework import graph_pb2 34from tensorflow.core.framework import node_def_pb2 35from tensorflow.core.framework import op_def_pb2 36from tensorflow.core.framework import versions_pb2 37from tensorflow.core.protobuf import config_pb2 38from tensorflow.python import pywrap_tensorflow as c_api 39from tensorflow.python import tf2 40from tensorflow.python.eager import context 41from tensorflow.python.eager import core 42from tensorflow.python.eager import tape 43from tensorflow.python.framework import c_api_util 44from tensorflow.python.framework import composite_tensor 45from tensorflow.python.framework import device as pydev 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import op_def_registry 49from tensorflow.python.framework import registry 50from tensorflow.python.framework import tensor_shape 51from tensorflow.python.framework import traceable_stack 52from tensorflow.python.framework import versions 53from tensorflow.python.ops import control_flow_util 54from tensorflow.python.platform import app 55from tensorflow.python.platform import tf_logging as logging 56from tensorflow.python.util import compat 57from tensorflow.python.util import decorator_utils 58from tensorflow.python.util import deprecation 59from tensorflow.python.util import function_utils 60from tensorflow.python.util import lock_util 61from tensorflow.python.util import memory 62from tensorflow.python.util import tf_contextlib 63from tensorflow.python.util import tf_stack 64from tensorflow.python.util.deprecation import deprecated_args 65from tensorflow.python.util.tf_export import tf_export 66 67 68# Temporary global switches determining if we should enable the work-in-progress 69# calls to the C API. These will be removed once all functionality is supported. 70_USE_C_API = True 71_USE_C_SHAPES = True 72 73 74def tensor_id(tensor): 75 """Returns a unique identifier for this Tensor.""" 76 return tensor._id # pylint: disable=protected-access 77 78 79class _UserDeviceSpec(object): 80 """Store user-specified device and provide computation of merged device.""" 81 82 def __init__(self, device_name_or_function): 83 self._device_name_or_function = device_name_or_function 84 85 self.display_name = str(self._device_name_or_function) 86 if callable(self._device_name_or_function): 87 dev_func = self._device_name_or_function 88 func_name = function_utils.get_func_name(dev_func) 89 func_code = function_utils.get_func_code(dev_func) 90 if func_code: 91 fname = func_code.co_filename 92 lineno = func_code.co_firstlineno 93 else: 94 fname = "unknown" 95 lineno = -1 96 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 97 98 self.raw_string = None 99 100 self.function = self._device_name_or_function 101 if not (self._device_name_or_function is None or 102 callable(self._device_name_or_function)): 103 self.raw_string = self._device_name_or_function 104 self.function = pydev.merge_device(self._device_name_or_function) 105 106 107class NullContextmanager(object): 108 109 def __init__(self, *args, **kwargs): 110 pass 111 112 def __enter__(self): 113 pass 114 115 def __exit__(self, type_arg, value_arg, traceback_arg): 116 return False # False values do not suppress exceptions 117 118 119def _override_helper(clazz_object, operator, func): 120 """Overrides (string) operator on Tensors to call func. 121 122 Args: 123 clazz_object: the class to override for; either Tensor or SparseTensor. 124 operator: the string name of the operator to override. 125 func: the function that replaces the overridden operator. 126 127 Raises: 128 ValueError: If operator has already been overwritten, 129 or if operator is not allowed to be overwritten. 130 """ 131 existing = getattr(clazz_object, operator, None) 132 if existing is not None: 133 # Check to see if this is a default method-wrapper or slot wrapper which 134 # will be true for the comparison operators. 135 if not isinstance(existing, type(object.__lt__)): 136 raise ValueError("operator %s cannot be overwritten again on class %s." % 137 (operator, clazz_object)) 138 if operator not in Tensor.OVERLOADABLE_OPERATORS: 139 raise ValueError("Overriding %s is disallowed" % operator) 140 setattr(clazz_object, operator, func) 141 142 143def _as_graph_element(obj): 144 """Convert `obj` to a graph element if possible, otherwise return `None`. 145 146 Args: 147 obj: Object to convert. 148 149 Returns: 150 The result of `obj._as_graph_element()` if that method is available; 151 otherwise `None`. 152 """ 153 conv_fn = getattr(obj, "_as_graph_element", None) 154 if conv_fn and callable(conv_fn): 155 return conv_fn() 156 return None 157 158 159_TENSOR_LIKE_TYPES = tuple() 160 161 162def is_dense_tensor_like(t): 163 """EXPERIMENTAL: Returns true if `t` implements the tensor interface. 164 165 See `register_dense_tensor_like_type()` for the current definition of a 166 "tensor-like type". 167 168 Args: 169 t: An object. 170 171 Returns: 172 True iff `t` is an instance of one of the registered "tensor-like" types. 173 """ 174 return isinstance(t, _TENSOR_LIKE_TYPES) 175 176 177def register_dense_tensor_like_type(tensor_type): 178 """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface. 179 180 A "tensor-like type" can represent a single dense tensor, and implements 181 the `name` and `dtype` properties. 182 183 Args: 184 tensor_type: A type implementing the tensor interface. 185 186 Raises: 187 TypeError: If `tensor_type` does not implement the tensor interface. 188 """ 189 try: 190 if not isinstance(tensor_type.name, property): 191 raise TypeError("Type %s does not define a `name` property" % 192 tensor_type.__name__) 193 except AttributeError: 194 raise TypeError("Type %s does not define a `name` property" % 195 tensor_type.__name__) 196 try: 197 if not isinstance(tensor_type.dtype, property): 198 raise TypeError("Type %s does not define a `dtype` property" % 199 tensor_type.__name__) 200 except AttributeError: 201 raise TypeError("Type %s does not define a `dtype` property" % 202 tensor_type.__name__) 203 # We expect this list to be small, so choose quadratic complexity 204 # for registration, so that we have a tuple that can be used for 205 # more efficient `isinstance` checks later. 206 global _TENSOR_LIKE_TYPES 207 _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type]) 208 209 210def uid(): 211 """A unique (within this program execution) integer.""" 212 return c_api.TFE_Py_UID() 213 214 215def numpy_text(tensor, is_repr=False): 216 """Human readable representation of a tensor's numpy value.""" 217 if tensor.dtype.is_numpy_compatible: 218 text = repr(tensor.numpy()) if is_repr else str(tensor.numpy()) 219 else: 220 text = "<unprintable>" 221 if "\n" in text: 222 text = "\n" + text 223 return text 224 225 226# NOTE(ebrevdo): Do not subclass this. If you do, I will break you on purpose. 227class _TensorLike(object): 228 """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance.""" 229 pass 230 231 232@tf_export("Tensor") 233class Tensor(_TensorLike): 234 """Represents one of the outputs of an `Operation`. 235 236 A `Tensor` is a symbolic handle to one of the outputs of an 237 `Operation`. It does not hold the values of that operation's output, 238 but instead provides a means of computing those values in a 239 TensorFlow `tf.Session`. 240 241 This class has two primary purposes: 242 243 1. A `Tensor` can be passed as an input to another `Operation`. 244 This builds a dataflow connection between operations, which 245 enables TensorFlow to execute an entire `Graph` that represents a 246 large, multi-step computation. 247 248 2. After the graph has been launched in a session, the value of the 249 `Tensor` can be computed by passing it to 250 `tf.Session.run`. 251 `t.eval()` is a shortcut for calling 252 `tf.get_default_session().run(t)`. 253 254 In the following example, `c`, `d`, and `e` are symbolic `Tensor` 255 objects, whereas `result` is a numpy array that stores a concrete 256 value: 257 258 ```python 259 # Build a dataflow graph. 260 c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 261 d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 262 e = tf.matmul(c, d) 263 264 # Construct a `Session` to execute the graph. 265 sess = tf.Session() 266 267 # Execute the graph and store the value that `e` represents in `result`. 268 result = sess.run(e) 269 ``` 270 """ 271 272 # List of Python operators that we allow to override. 273 OVERLOADABLE_OPERATORS = { 274 # Binary. 275 "__add__", 276 "__radd__", 277 "__sub__", 278 "__rsub__", 279 "__mul__", 280 "__rmul__", 281 "__div__", 282 "__rdiv__", 283 "__truediv__", 284 "__rtruediv__", 285 "__floordiv__", 286 "__rfloordiv__", 287 "__mod__", 288 "__rmod__", 289 "__lt__", 290 "__le__", 291 "__gt__", 292 "__ge__", 293 "__and__", 294 "__rand__", 295 "__or__", 296 "__ror__", 297 "__xor__", 298 "__rxor__", 299 "__getitem__", 300 "__pow__", 301 "__rpow__", 302 # Unary. 303 "__invert__", 304 "__neg__", 305 "__abs__", 306 "__matmul__", 307 "__rmatmul__" 308 } 309 310 def __init__(self, op, value_index, dtype): 311 """Creates a new `Tensor`. 312 313 Args: 314 op: An `Operation`. `Operation` that computes this tensor. 315 value_index: An `int`. Index of the operation's endpoint that produces 316 this tensor. 317 dtype: A `DType`. Type of elements stored in this tensor. 318 319 Raises: 320 TypeError: If the op is not an `Operation`. 321 """ 322 if not isinstance(op, Operation): 323 raise TypeError("op needs to be an Operation: %s" % op) 324 self._op = op 325 self._value_index = value_index 326 self._dtype = dtypes.as_dtype(dtype) 327 # This will be set by self._as_tf_output(). 328 self._tf_output = None 329 # This will be set by self.shape(). 330 self._shape_val = None 331 # List of operations that use this Tensor as input. We maintain this list 332 # to easily navigate a computation graph. 333 self._consumers = [] 334 self._id = uid() 335 336 @property 337 def op(self): 338 """The `Operation` that produces this tensor as an output.""" 339 return self._op 340 341 @property 342 def dtype(self): 343 """The `DType` of elements in this tensor.""" 344 return self._dtype 345 346 @property 347 def graph(self): 348 """The `Graph` that contains this tensor.""" 349 return self._op.graph 350 351 @property 352 def name(self): 353 """The string name of this tensor.""" 354 if not self._op.name: 355 raise ValueError("Operation was not named: %s" % self._op) 356 return "%s:%d" % (self._op.name, self._value_index) 357 358 @property 359 def device(self): 360 """The name of the device on which this tensor will be produced, or None.""" 361 return self._op.device 362 363 @property 364 def shape(self): 365 """Returns the `TensorShape` that represents the shape of this tensor. 366 367 The shape is computed using shape inference functions that are 368 registered in the Op for each `Operation`. See 369 `tf.TensorShape` 370 for more details of what a shape represents. 371 372 The inferred shape of a tensor is used to provide shape 373 information without having to launch the graph in a session. This 374 can be used for debugging, and providing early error messages. For 375 example: 376 377 ```python 378 c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 379 380 print(c.shape) 381 ==> TensorShape([Dimension(2), Dimension(3)]) 382 383 d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]) 384 385 print(d.shape) 386 ==> TensorShape([Dimension(4), Dimension(2)]) 387 388 # Raises a ValueError, because `c` and `d` do not have compatible 389 # inner dimensions. 390 e = tf.matmul(c, d) 391 392 f = tf.matmul(c, d, transpose_a=True, transpose_b=True) 393 394 print(f.shape) 395 ==> TensorShape([Dimension(3), Dimension(4)]) 396 ``` 397 398 In some cases, the inferred shape may have unknown dimensions. If 399 the caller has additional information about the values of these 400 dimensions, `Tensor.set_shape()` can be used to augment the 401 inferred shape. 402 403 Returns: 404 A `TensorShape` representing the shape of this tensor. 405 406 """ 407 if self._shape_val is None: 408 self._shape_val = self._c_api_shape() 409 return self._shape_val 410 411 def _get_input_ops_without_shapes(self, target_op): 412 """Returns ops needing shape inference to compute target_op's shape.""" 413 result = [] 414 stack = [self._op] 415 visited = set() 416 while stack: 417 op = stack.pop() 418 if op in visited: continue 419 result.append(op) 420 stack.extend(t.op for t in op.inputs if t._shape_val is None) 421 visited.add(op) 422 return result 423 424 def _c_api_shape(self): 425 """Returns the TensorShape of this tensor according to the C API.""" 426 c_graph = self._op._graph._c_graph # pylint: disable=protected-access 427 shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( 428 c_graph, self._as_tf_output()) 429 if unknown_shape: 430 return tensor_shape.unknown_shape() 431 else: 432 shape_vector = [None if d == -1 else d for d in shape_vector] 433 return tensor_shape.TensorShape(shape_vector) 434 435 @property 436 def _shape(self): 437 logging.warning("Tensor._shape is private, use Tensor.shape " 438 "instead. Tensor._shape will eventually be removed.") 439 return self.shape 440 441 @_shape.setter 442 def _shape(self, value): 443 raise ValueError( 444 "Tensor._shape cannot be assigned, use Tensor.set_shape instead.") 445 446 def __iter__(self): 447 if not context.executing_eagerly(): 448 raise TypeError( 449 "Tensor objects are only iterable when eager execution is " 450 "enabled. To iterate over this tensor use tf.map_fn.") 451 shape = self._shape_tuple() 452 if shape is None: 453 raise TypeError("Cannot iterate over a tensor with unknown shape.") 454 if not shape: 455 raise TypeError("Cannot iterate over a scalar tensor.") 456 if shape[0] is None: 457 raise TypeError( 458 "Cannot iterate over a tensor with unknown first dimension.") 459 for i in xrange(shape[0]): 460 yield self[i] 461 462 def _shape_as_list(self): 463 if self.shape.ndims is not None: 464 return [dim.value for dim in self.shape.dims] 465 else: 466 return None 467 468 def _shape_tuple(self): 469 shape = self._shape_as_list() 470 if shape is None: 471 return None 472 return tuple(shape) 473 474 def _rank(self): 475 """Integer rank of this Tensor, if known, else None. 476 477 Returns: 478 Integer rank or None 479 """ 480 return self.shape.ndims 481 482 def get_shape(self): 483 """Alias of Tensor.shape.""" 484 return self.shape 485 486 def set_shape(self, shape): 487 """Updates the shape of this tensor. 488 489 This method can be called multiple times, and will merge the given 490 `shape` with the current shape of this tensor. It can be used to 491 provide additional information about the shape of this tensor that 492 cannot be inferred from the graph alone. For example, this can be used 493 to provide additional information about the shapes of images: 494 495 ```python 496 _, image_data = tf.TFRecordReader(...).read(...) 497 image = tf.image.decode_png(image_data, channels=3) 498 499 # The height and width dimensions of `image` are data dependent, and 500 # cannot be computed without executing the op. 501 print(image.shape) 502 ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)]) 503 504 # We know that each image in this dataset is 28 x 28 pixels. 505 image.set_shape([28, 28, 3]) 506 print(image.shape) 507 ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) 508 ``` 509 510 NOTE: This shape is not enforced at runtime. Setting incorrect shapes can 511 result in inconsistencies between the statically-known graph and the runtime 512 value of tensors. For runtime validation of the shape, use `tf.ensure_shape` 513 instead. 514 515 Args: 516 shape: A `TensorShape` representing the shape of this tensor, a 517 `TensorShapeProto`, a list, a tuple, or None. 518 519 Raises: 520 ValueError: If `shape` is not compatible with the current shape of 521 this tensor. 522 """ 523 # Reset cached shape. 524 self._shape_val = None 525 526 # We want set_shape to be reflected in the C API graph for when we run it. 527 if not isinstance(shape, tensor_shape.TensorShape): 528 shape = tensor_shape.TensorShape(shape) 529 dim_list = [] 530 if shape.dims is None: 531 unknown_shape = True 532 else: 533 unknown_shape = False 534 for dim in shape.dims: 535 if dim.value is None: 536 dim_list.append(-1) 537 else: 538 dim_list.append(dim.value) 539 try: 540 c_api.TF_GraphSetTensorShape_wrapper( 541 self._op._graph._c_graph, # pylint: disable=protected-access 542 self._as_tf_output(), 543 dim_list, 544 unknown_shape) 545 except errors.InvalidArgumentError as e: 546 # Convert to ValueError for backwards compatibility. 547 raise ValueError(str(e)) 548 549 @property 550 def value_index(self): 551 """The index of this tensor in the outputs of its `Operation`.""" 552 return self._value_index 553 554 def consumers(self): 555 """Returns a list of `Operation`s that consume this tensor. 556 557 Returns: 558 A list of `Operation`s. 559 """ 560 consumer_names = c_api.TF_OperationOutputConsumers_wrapper( 561 self._as_tf_output()) 562 # pylint: disable=protected-access 563 return [ 564 self.graph._get_operation_by_name_unsafe(name) 565 for name in consumer_names 566 ] 567 # pylint: enable=protected-access 568 569 def _as_node_def_input(self): 570 """Return a value to use for the NodeDef "input" attribute. 571 572 The returned string can be used in a NodeDef "input" attribute 573 to indicate that the NodeDef uses this Tensor as input. 574 575 Raises: 576 ValueError: if this Tensor's Operation does not have a name. 577 578 Returns: 579 a string. 580 """ 581 if not self._op.name: 582 raise ValueError("Operation was not named: %s" % self._op) 583 if self._value_index == 0: 584 return self._op.name 585 else: 586 return "%s:%d" % (self._op.name, self._value_index) 587 588 def _as_tf_output(self): 589 # pylint: disable=protected-access 590 # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object 591 # also guarantees that a dictionary of tf_output objects will retain a 592 # deterministic (yet unsorted) order which prevents memory blowup in the 593 # cache of executor(s) stored for every session. 594 if self._tf_output is None: 595 self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index) 596 return self._tf_output 597 # pylint: enable=protected-access 598 599 def __str__(self): 600 return "Tensor(\"%s\"%s%s%s)" % ( 601 self.name, (", shape=%s" % self.get_shape()) 602 if self.get_shape().ndims is not None else "", 603 (", dtype=%s" % self._dtype.name) 604 if self._dtype else "", (", device=%s" % self.device) 605 if self.device else "") 606 607 def __repr__(self): 608 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 609 self._dtype.name) 610 611 def __hash__(self): 612 # Necessary to support Python's collection membership operators 613 return id(self) 614 615 def __eq__(self, other): 616 # Necessary to support Python's collection membership operators 617 return id(self) == id(other) 618 619 def __copy__(self): 620 # TODO(b/77597810): get rid of Tensor copies. 621 cls = self.__class__ 622 result = cls.__new__(cls) 623 result.__dict__.update(self.__dict__) 624 return result 625 626 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 627 # operators to run when the left operand is an ndarray, because it 628 # accords the Tensor class higher priority than an ndarray, or a 629 # numpy matrix. 630 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 631 # mechanism, which allows more control over how Tensors interact 632 # with ndarrays. 633 __array_priority__ = 100 634 635 @staticmethod 636 def _override_operator(operator, func): 637 _override_helper(Tensor, operator, func) 638 639 def __bool__(self): 640 """Dummy method to prevent a tensor from being used as a Python `bool`. 641 642 This overload raises a `TypeError` when the user inadvertently 643 treats a `Tensor` as a boolean (e.g. in an `if` statement). For 644 example: 645 646 ```python 647 if tf.constant(True): # Will raise. 648 # ... 649 650 if tf.constant(5) < tf.constant(7): # Will raise. 651 # ... 652 ``` 653 654 This disallows ambiguities between testing the Python value vs testing the 655 dynamic condition of the `Tensor`. 656 657 Raises: 658 `TypeError`. 659 """ 660 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. " 661 "Use `if t is not None:` instead of `if t:` to test if a " 662 "tensor is defined, and use TensorFlow ops such as " 663 "tf.cond to execute subgraphs conditioned on the value of " 664 "a tensor.") 665 666 def __nonzero__(self): 667 """Dummy method to prevent a tensor from being used as a Python `bool`. 668 669 This is the Python 2.x counterpart to `__bool__()` above. 670 671 Raises: 672 `TypeError`. 673 """ 674 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. " 675 "Use `if t is not None:` instead of `if t:` to test if a " 676 "tensor is defined, and use TensorFlow ops such as " 677 "tf.cond to execute subgraphs conditioned on the value of " 678 "a tensor.") 679 680 def eval(self, feed_dict=None, session=None): 681 """Evaluates this tensor in a `Session`. 682 683 Calling this method will execute all preceding operations that 684 produce the inputs needed for the operation that produces this 685 tensor. 686 687 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 688 launched in a session, and either a default session must be 689 available, or `session` must be specified explicitly. 690 691 Args: 692 feed_dict: A dictionary that maps `Tensor` objects to feed values. 693 See `tf.Session.run` for a 694 description of the valid feed values. 695 session: (Optional.) The `Session` to be used to evaluate this tensor. If 696 none, the default session will be used. 697 698 Returns: 699 A numpy array corresponding to the value of this tensor. 700 701 """ 702 return _eval_using_default_session(self, feed_dict, self.graph, session) 703 704 705# TODO(agarwal): consider getting rid of this. 706class _EagerTensorBase(Tensor): 707 """Base class for EagerTensor.""" 708 709 @property 710 def dtype(self): 711 # Note: using the intern table directly here as this is 712 # performance-sensitive in some models. 713 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 714 715 def numpy(self): 716 """Returns a numpy array or a scalar with the same contents as the Tensor. 717 718 TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying 719 buffer but instead always explicitly copy? Note that currently it may or may 720 not copy based on whether the numpy data is properly aligned or not. 721 722 Returns: 723 A numpy array or a scalar. Numpy array may share memory with the 724 Tensor object. Any changes to one may be reflected in the other. A scalar 725 value is returned when self has rank 0. 726 727 Raises: 728 ValueError: if the type of this Tensor is not representable in numpy. 729 """ 730 if self.dtype == dtypes.resource: 731 raise ValueError("Resource handles are not convertible to numpy.") 732 return self._cpu_nograd()._numpy() # pylint: disable=protected-access 733 734 # __int__, __float__ and __index__ may copy the tensor to CPU and 735 # only work for scalars; values are cast as per numpy. 736 def __int__(self): 737 return int(self.numpy()) 738 739 def __float__(self): 740 return float(self.numpy()) 741 742 def __index__(self): 743 return int(self.numpy()) 744 745 def __array__(self, dtype=None): 746 return np.array(self.numpy(), dtype=dtype) 747 748 def __format__(self, format_spec): 749 return self.numpy().__format__(format_spec) 750 751 def __reduce__(self): 752 return (convert_to_tensor, (self.numpy(),)) 753 754 def _numpy(self): 755 raise NotImplementedError() 756 757 @property 758 def backing_device(self): 759 """Returns the name of the device holding this tensor's memory. 760 761 `.backing_device` is usually the same as `.device`, which returns 762 the device on which the kernel of the operation that produced this tensor 763 ran. However, some operations can produce tensors on a different device 764 (e.g., an operation that executes on the GPU but produces output tensors 765 in host memory). 766 """ 767 raise NotImplementedError() 768 769 def __copy__(self): 770 # Eager Tensors are immutable so it's safe to return themselves as a copy. 771 return self 772 773 def __deepcopy__(self, memo): 774 # Eager Tensors are immutable so it's safe to return themselves as a copy. 775 del memo 776 return self 777 778 def _datatype_enum(self): 779 raise NotImplementedError() 780 781 def _shape_tuple(self): 782 """The shape of this Tensor, as a tuple. 783 784 This is more performant than tuple(shape().as_list()) as it avoids 785 two list and one object creation. Marked private for now as from an API 786 perspective, it would be better to have a single performant way of 787 getting a shape rather than exposing shape() and shape_tuple() 788 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 789 but ideally one would work things out and remove the need for this method. 790 791 Returns: 792 tuple with the shape. 793 """ 794 raise NotImplementedError() 795 796 def _rank(self): 797 """Integer rank of this Tensor. 798 799 Unlike regular Tensors, the rank is always known for EagerTensors. 800 801 This is more performant than len(self._shape_tuple()) 802 803 Returns: 804 Integer rank 805 """ 806 raise NotImplementedError() 807 808 def _num_elements(self): 809 """Number of elements of this Tensor. 810 811 Unlike regular Tensors, the number of elements is always known for 812 EagerTensors. 813 814 This is more performant than tensor.shape.num_elements 815 816 Returns: 817 Long - num elements in the tensor 818 """ 819 raise NotImplementedError() 820 821 def _copy_to_device(self, context, device): # pylint: disable=redefined-outer-name 822 raise NotImplementedError() 823 824 def __str__(self): 825 return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self), 826 self.shape, 827 self.dtype.name) 828 829 def __repr__(self): 830 return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % ( 831 self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True)) 832 833 @staticmethod 834 def _override_operator(name, func): 835 setattr(_EagerTensorBase, name, func) 836 837 def _copy_nograd(self, ctx=None, device_name=None): 838 """Copies tensor to dest device, but doesn't record the operation.""" 839 # pylint: disable=protected-access 840 # Creates a new tensor on the dest device. 841 if ctx is None: 842 ctx = context.context() 843 if device_name is None: 844 device_name = ctx.device_name 845 # pylint: disable=protected-access 846 try: 847 new_tensor = self._copy_to_device(context=ctx._handle, device=device_name) 848 except core._NotOkStatusException as e: 849 six.raise_from(core._status_to_exception(e.code, e.message), None) 850 return new_tensor 851 852 def _copy(self, ctx=None, device_name=None): 853 """Copies tensor to dest device.""" 854 new_tensor = self._copy_nograd(ctx, device_name) 855 # Record the copy on tape and define backprop copy as well. 856 if context.executing_eagerly(): 857 self_device = self.device 858 def grad_fun(dresult): 859 return [dresult._copy(device_name=self_device)] 860 tape.record_operation("_copy", [new_tensor], [self], grad_fun) 861 return new_tensor 862 # pylint: enable=protected-access 863 864 @property 865 def shape(self): 866 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 867 # `_tensor_shape` is declared and defined in the definition of 868 # `EagerTensor`, in C. 869 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 870 return self._tensor_shape 871 872 def get_shape(self): 873 """Alias of Tensor.shape.""" 874 return self.shape 875 876 def _shape_as_list(self): 877 """The shape of the tensor as a list.""" 878 return list(self._shape_tuple()) 879 880 @property 881 def ndim(self): 882 """Returns the number of Tensor dimensions.""" 883 return self.shape.ndims 884 885 def __len__(self): 886 """Returns the length of the first dimension in the Tensor.""" 887 if not self.shape.ndims: 888 raise TypeError("Scalar tensor has no `len()`") 889 return self._shape_tuple()[0] 890 891 def _cpu_nograd(self): 892 """A copy of this Tensor with contents backed by host memory. 893 894 The copy cannot be differentiated through. 895 896 Returns: 897 A CPU-memory backed Tensor object with the same contents as this Tensor. 898 """ 899 return self._copy_nograd(context.context(), "CPU:0") 900 901 def cpu(self): 902 """A copy of this Tensor with contents backed by host memory.""" 903 return self._copy(context.context(), "CPU:0") 904 905 def gpu(self, gpu_index=0): 906 """A copy of this Tensor with contents backed by memory on the GPU. 907 908 Arguments: 909 gpu_index: Identifies which GPU to place the contents on the returned 910 Tensor in. 911 912 Returns: 913 A GPU-memory backed Tensor object initialized with the same contents 914 as this Tensor. 915 """ 916 return self._copy(context.context(), "GPU:" + str(gpu_index)) 917 918 def __bool__(self): 919 return bool(self.numpy()) 920 921 def __nonzero__(self): 922 return self.__bool__() 923 924 def set_shape(self, shape): 925 if not self.shape.is_compatible_with(shape): 926 raise ValueError( 927 "Tensor's shape %s is not compatible with supplied shape %s" % 928 (self.shape, shape)) 929 930 # Methods not supported / implemented for Eager Tensors. 931 @property 932 def op(self): 933 raise AttributeError( 934 "Tensor.op is meaningless when eager execution is enabled.") 935 936 @property 937 def graph(self): 938 raise AttributeError( 939 "Tensor.graph is meaningless when eager execution is enabled.") 940 941 @property 942 def name(self): 943 raise AttributeError( 944 "Tensor.name is meaningless when eager execution is enabled.") 945 946 @property 947 def value_index(self): 948 raise AttributeError( 949 "Tensor.value_index is meaningless when eager execution is enabled.") 950 951 def consumers(self): 952 raise NotImplementedError( 953 "Tensor.consumers is meaningless when eager execution is enabled.") 954 955 def _add_consumer(self, consumer): 956 raise NotImplementedError( 957 "_add_consumer not supported when eager execution is enabled.") 958 959 def _as_node_def_input(self): 960 raise NotImplementedError( 961 "_as_node_def_input not supported when eager execution is enabled.") 962 963 def _as_tf_output(self): 964 raise NotImplementedError( 965 "_as_tf_output not supported when eager execution is enabled.") 966 967 def eval(self, feed_dict=None, session=None): 968 raise NotImplementedError( 969 "eval is not supported when eager execution is enabled, " 970 "is .numpy() what you're looking for?" 971 ) 972 973 974# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 975# registers it with the current module. 976EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase) 977 978 979def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False): 980 _ = name, as_ref 981 if dtype and not dtype.is_compatible_with(t.dtype): 982 raise ValueError( 983 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 984 (dtype.name, t.dtype.name, str(t))) 985 return t 986 987 988_tensor_conversion_func_registry = { 989 0: [(Tensor, _TensorTensorConversionFunction)] 990} 991_tensor_conversion_func_cache = {} 992_tensor_conversion_func_lock = threading.Lock() 993register_dense_tensor_like_type(Tensor) 994 995 996@tf_export(v1=["convert_to_tensor"]) 997def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None, 998 dtype_hint=None): 999 """Converts the given `value` to a `Tensor`. 1000 1001 This function converts Python objects of various types to `Tensor` 1002 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1003 and Python scalars. For example: 1004 1005 ```python 1006 import numpy as np 1007 1008 def my_func(arg): 1009 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1010 return tf.matmul(arg, arg) + arg 1011 1012 # The following calls are equivalent. 1013 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1014 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1015 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1016 ``` 1017 1018 This function can be useful when composing a new operation in Python 1019 (such as `my_func` in the example above). All standard Python op 1020 constructors apply this function to each of their Tensor-valued 1021 inputs, which allows those ops to accept numpy arrays, Python lists, 1022 and scalars in addition to `Tensor` objects. 1023 1024 Note: This function diverges from default Numpy behavior for `float` and 1025 `string` types when `None` is present in a Python list or scalar. Rather 1026 than silently converting `None` values, an error will be thrown. 1027 1028 Args: 1029 value: An object whose type has a registered `Tensor` conversion function. 1030 dtype: Optional element type for the returned tensor. If missing, the 1031 type is inferred from the type of `value`. 1032 name: Optional name to use if a new `Tensor` is created. 1033 preferred_dtype: Optional element type for the returned tensor, 1034 used when dtype is None. In some cases, a caller may not have a 1035 dtype in mind when converting to a tensor, so preferred_dtype 1036 can be used as a soft preference. If the conversion to 1037 `preferred_dtype` is not possible, this argument has no effect. 1038 dtype_hint: same meaning as preferred_dtype, and overrides it. 1039 1040 Returns: 1041 A `Tensor` based on `value`. 1042 1043 Raises: 1044 TypeError: If no conversion function is registered for `value` to `dtype`. 1045 RuntimeError: If a registered conversion function returns an invalid value. 1046 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1047 """ 1048 preferred_dtype = deprecation.deprecated_argument_lookup( 1049 "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype) 1050 return convert_to_tensor_v2(value, dtype, preferred_dtype, name) 1051 1052 1053@tf_export("convert_to_tensor", v1=[]) 1054def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): 1055 """Converts the given `value` to a `Tensor`. 1056 1057 This function converts Python objects of various types to `Tensor` 1058 objects. It accepts `Tensor` objects, numpy arrays, Python lists, 1059 and Python scalars. For example: 1060 1061 ```python 1062 import numpy as np 1063 1064 def my_func(arg): 1065 arg = tf.convert_to_tensor(arg, dtype=tf.float32) 1066 return tf.matmul(arg, arg) + arg 1067 1068 # The following calls are equivalent. 1069 value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 1070 value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) 1071 value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) 1072 ``` 1073 1074 This function can be useful when composing a new operation in Python 1075 (such as `my_func` in the example above). All standard Python op 1076 constructors apply this function to each of their Tensor-valued 1077 inputs, which allows those ops to accept numpy arrays, Python lists, 1078 and scalars in addition to `Tensor` objects. 1079 1080 Note: This function diverges from default Numpy behavior for `float` and 1081 `string` types when `None` is present in a Python list or scalar. Rather 1082 than silently converting `None` values, an error will be thrown. 1083 1084 Args: 1085 value: An object whose type has a registered `Tensor` conversion function. 1086 dtype: Optional element type for the returned tensor. If missing, the 1087 type is inferred from the type of `value`. 1088 dtype_hint: Optional element type for the returned tensor, 1089 used when dtype is None. In some cases, a caller may not have a 1090 dtype in mind when converting to a tensor, so dtype_hint 1091 can be used as a soft preference. If the conversion to 1092 `dtype_hint` is not possible, this argument has no effect. 1093 name: Optional name to use if a new `Tensor` is created. 1094 1095 Returns: 1096 A `Tensor` based on `value`. 1097 1098 Raises: 1099 TypeError: If no conversion function is registered for `value` to `dtype`. 1100 RuntimeError: If a registered conversion function returns an invalid value. 1101 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 1102 """ 1103 return internal_convert_to_tensor( 1104 value=value, 1105 dtype=dtype, 1106 name=name, 1107 preferred_dtype=dtype_hint, 1108 as_ref=False) 1109 1110 1111def _error_prefix(name): 1112 return "" if name is None else "%s: " % name 1113 1114 1115def internal_convert_to_tensor(value, 1116 dtype=None, 1117 name=None, 1118 as_ref=False, 1119 preferred_dtype=None, 1120 ctx=None, 1121 accept_symbolic_tensors=True): 1122 """Implementation of the public convert_to_tensor.""" 1123 if ctx is None: ctx = context.context() 1124 if isinstance(value, EagerTensor): 1125 if ctx.executing_eagerly(): 1126 if dtype is not None: 1127 dtype = dtypes.as_dtype(dtype) 1128 value = _TensorTensorConversionFunction(value, dtype=dtype) 1129 return value 1130 else: 1131 graph = get_default_graph() 1132 if not graph.building_function: 1133 raise RuntimeError("Attempting to capture an EagerTensor without " 1134 "building a function.") 1135 return graph.capture(value, name=name) 1136 elif ((not accept_symbolic_tensors) and 1137 isinstance(value, Tensor) and 1138 ctx.executing_eagerly()): 1139 # Found a symbolic tensor in an eager context. 1140 # This happens when we use the Keras functional API (i.e. calling layers 1141 # on the output of `keras.Input()`, which is symbolic) while eager 1142 # execution is enabled. 1143 if _is_keras_symbolic_tensor(value): 1144 # If the graph of the tensor isn't the Keras graph, we should still 1145 # fail, for the time being. TODO(fchollet): consider allowing 1146 # all symbolic tensors to raise this exception in this case. 1147 raise core._SymbolicException( # pylint: disable=protected-access 1148 "Using the symbolic output of a Keras layer during eager execution.") 1149 1150 if dtype is not None: 1151 dtype = dtypes.as_dtype(dtype) 1152 unwrapped_type = type(value) 1153 conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None) 1154 if conversion_func_list is None: 1155 with _tensor_conversion_func_lock: 1156 conversion_func_list = [] 1157 for _, funcs_at_priority in sorted( 1158 _tensor_conversion_func_registry.items()): 1159 for base_type, conversion_func in funcs_at_priority: 1160 if isinstance(value, base_type): 1161 conversion_func_list.append((base_type, conversion_func)) 1162 _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list 1163 1164 for base_type, conversion_func in conversion_func_list: 1165 # If dtype is None but preferred_dtype is not None, we try to 1166 # cast to preferred_dtype first. 1167 ret = None 1168 if dtype is None and preferred_dtype is not None: 1169 try: 1170 ret = conversion_func( 1171 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 1172 except (TypeError, ValueError, errors.UnimplementedError, 1173 errors.InvalidArgumentError): 1174 # Could not coerce the conversion to use the preferred dtype. 1175 ret = None 1176 1177 if ret is not None and ret is not NotImplemented: 1178 if (ret.dtype.base_dtype != 1179 dtypes.as_dtype(preferred_dtype).base_dtype): 1180 raise TypeError("convert_to_tensor did not convert to " 1181 "the preferred dtype: %s vs %s " % 1182 (ret.dtype.base_dtype, 1183 dtypes.as_dtype(preferred_dtype).base_dtype)) 1184 1185 if ret is None: 1186 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1187 1188 if ret is NotImplemented: 1189 continue 1190 1191 if not isinstance(ret, Tensor): 1192 raise RuntimeError( 1193 "%sConversion function %r for type %s returned non-Tensor: %r" % 1194 (_error_prefix(name), conversion_func, base_type, ret)) 1195 if dtype and not dtype.is_compatible_with(ret.dtype): 1196 raise RuntimeError( 1197 "%sConversion function %r for type %s returned incompatible " 1198 "dtype: requested = %s, actual = %s" % 1199 (_error_prefix(name), conversion_func, base_type, dtype.name, 1200 ret.dtype.name)) 1201 return ret 1202 raise TypeError("%sCannot convert %r with type %s to Tensor: " 1203 "no conversion function registered." % 1204 (_error_prefix(name), value, unwrapped_type)) 1205 1206 1207def internal_convert_n_to_tensor(values, 1208 dtype=None, 1209 name=None, 1210 as_ref=False, 1211 preferred_dtype=None, 1212 ctx=None): 1213 """Converts `values` to a list of `Tensor` objects. 1214 1215 Args: 1216 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1217 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1218 name: (Optional.) A name prefix to used when a new `Tensor` is 1219 created, in which case element `i` will be given the name `name 1220 + '_' + i`. 1221 as_ref: True if the caller wants the results as ref tensors. 1222 preferred_dtype: Optional element type for the returned tensors, 1223 used when dtype is None. In some cases, a caller may not have a 1224 dtype in mind when converting to a tensor, so preferred_dtype 1225 can be used as a soft preference. If the conversion to 1226 `preferred_dtype` is not possible, this argument has no effect. 1227 ctx: The value of context.context(). 1228 1229 Returns: 1230 A list of `Tensor` and/or `IndexedSlices` objects. 1231 1232 Raises: 1233 TypeError: If no conversion function is registered for an element in 1234 `values`. 1235 RuntimeError: If a registered conversion function returns an invalid 1236 value. 1237 """ 1238 if not isinstance(values, collections.Sequence): 1239 raise TypeError("values must be a sequence.") 1240 ret = [] 1241 if ctx is None: ctx = context.context() 1242 for i, value in enumerate(values): 1243 n = None if name is None else "%s_%d" % (name, i) 1244 ret.append( 1245 internal_convert_to_tensor( 1246 value, 1247 dtype=dtype, 1248 name=n, 1249 as_ref=as_ref, 1250 preferred_dtype=preferred_dtype, 1251 ctx=ctx)) 1252 return ret 1253 1254 1255def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 1256 """Converts `values` to a list of `Tensor` objects. 1257 1258 Args: 1259 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 1260 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 1261 name: (Optional.) A name prefix to used when a new `Tensor` is 1262 created, in which case element `i` will be given the name `name 1263 + '_' + i`. 1264 preferred_dtype: Optional element type for the returned tensors, 1265 used when dtype is None. In some cases, a caller may not have a 1266 dtype in mind when converting to a tensor, so preferred_dtype 1267 can be used as a soft preference. If the conversion to 1268 `preferred_dtype` is not possible, this argument has no effect. 1269 1270 Returns: 1271 A list of `Tensor` and/or `IndexedSlices` objects. 1272 1273 Raises: 1274 TypeError: If no conversion function is registered for an element in 1275 `values`. 1276 RuntimeError: If a registered conversion function returns an invalid 1277 value. 1278 """ 1279 return internal_convert_n_to_tensor( 1280 values=values, 1281 dtype=dtype, 1282 name=name, 1283 preferred_dtype=preferred_dtype, 1284 as_ref=False) 1285 1286 1287@tf_export(v1=["convert_to_tensor_or_indexed_slices"]) 1288def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): 1289 """Converts the given object to a `Tensor` or an `IndexedSlices`. 1290 1291 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 1292 unmodified. Otherwise, it is converted to a `Tensor` using 1293 `convert_to_tensor()`. 1294 1295 Args: 1296 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 1297 by `convert_to_tensor()`. 1298 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1299 `IndexedSlices`. 1300 name: (Optional.) A name to use if a new `Tensor` is created. 1301 1302 Returns: 1303 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 1304 1305 Raises: 1306 ValueError: If `dtype` does not match the element type of `value`. 1307 """ 1308 return internal_convert_to_tensor_or_indexed_slices( 1309 value=value, dtype=dtype, name=name, as_ref=False) 1310 1311 1312def internal_convert_to_tensor_or_indexed_slices(value, 1313 dtype=None, 1314 name=None, 1315 as_ref=False): 1316 """Converts the given object to a `Tensor` or an `IndexedSlices`. 1317 1318 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 1319 unmodified. Otherwise, it is converted to a `Tensor` using 1320 `convert_to_tensor()`. 1321 1322 Args: 1323 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 1324 by `convert_to_tensor()`. 1325 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1326 `IndexedSlices`. 1327 name: (Optional.) A name to use if a new `Tensor` is created. 1328 as_ref: True if the caller wants the results as ref tensors. 1329 1330 Returns: 1331 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 1332 1333 Raises: 1334 ValueError: If `dtype` does not match the element type of `value`. 1335 """ 1336 if isinstance(value, EagerTensor) and not context.executing_eagerly(): 1337 return internal_convert_to_tensor( 1338 value, dtype=dtype, name=name, as_ref=as_ref) 1339 elif isinstance(value, _TensorLike): 1340 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): 1341 raise ValueError( 1342 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1343 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1344 return value 1345 else: 1346 return internal_convert_to_tensor( 1347 value, dtype=dtype, name=name, as_ref=as_ref) 1348 1349 1350def internal_convert_n_to_tensor_or_indexed_slices(values, 1351 dtype=None, 1352 name=None, 1353 as_ref=False): 1354 """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. 1355 1356 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 1357 unmodified. 1358 1359 Args: 1360 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 1361 can be consumed by `convert_to_tensor()`. 1362 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1363 `IndexedSlices`. 1364 name: (Optional.) A name prefix to used when a new `Tensor` is 1365 created, in which case element `i` will be given the name `name 1366 + '_' + i`. 1367 as_ref: True if the caller wants the results as ref tensors. 1368 1369 Returns: 1370 A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects. 1371 1372 Raises: 1373 TypeError: If no conversion function is registered for an element in 1374 `values`. 1375 RuntimeError: If a registered conversion function returns an invalid 1376 value. 1377 """ 1378 if not isinstance(values, collections.Sequence): 1379 raise TypeError("values must be a sequence.") 1380 ret = [] 1381 for i, value in enumerate(values): 1382 if value is None: 1383 ret.append(value) 1384 else: 1385 n = None if name is None else "%s_%d" % (name, i) 1386 ret.append( 1387 internal_convert_to_tensor_or_indexed_slices( 1388 value, dtype=dtype, name=n, as_ref=as_ref)) 1389 return ret 1390 1391 1392def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): 1393 """Converts `values` to a list of `Output` or `IndexedSlices` objects. 1394 1395 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 1396 unmodified. 1397 1398 Args: 1399 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 1400 can be consumed by `convert_to_tensor()`. 1401 dtype: (Optional.) The required `DType` of the returned `Tensor` 1402 `IndexedSlices`. 1403 name: (Optional.) A name prefix to used when a new `Tensor` is 1404 created, in which case element `i` will be given the name `name 1405 + '_' + i`. 1406 1407 Returns: 1408 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 1409 1410 Raises: 1411 TypeError: If no conversion function is registered for an element in 1412 `values`. 1413 RuntimeError: If a registered conversion function returns an invalid 1414 value. 1415 """ 1416 return internal_convert_n_to_tensor_or_indexed_slices( 1417 values=values, dtype=dtype, name=name, as_ref=False) 1418 1419 1420def convert_to_tensor_or_composite(value, dtype=None, name=None): 1421 """Converts the given object to a `Tensor` or `CompositeTensor`. 1422 1423 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1424 is converted to a `Tensor` using `convert_to_tensor()`. 1425 1426 Args: 1427 value: A `CompositeTensor` or an object that can be consumed 1428 by `convert_to_tensor()`. 1429 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1430 `CompositeTensor`. 1431 name: (Optional.) A name to use if a new `Tensor` is created. 1432 1433 Returns: 1434 A `Tensor` or `CompositeTensor`, based on `value`. 1435 1436 Raises: 1437 ValueError: If `dtype` does not match the element type of `value`. 1438 """ 1439 return internal_convert_to_tensor_or_composite( 1440 value=value, dtype=dtype, name=name, as_ref=False) 1441 1442 1443def internal_convert_to_tensor_or_composite(value, 1444 dtype=None, 1445 name=None, 1446 as_ref=False): 1447 """Converts the given object to a `Tensor` or `CompositeTensor`. 1448 1449 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 1450 is converted to a `Tensor` using `convert_to_tensor()`. 1451 1452 Args: 1453 value: A `CompositeTensor`, or an object that can be consumed 1454 by `convert_to_tensor()`. 1455 dtype: (Optional.) The required `DType` of the returned `Tensor` or 1456 `CompositeTensor`. 1457 name: (Optional.) A name to use if a new `Tensor` is created. 1458 as_ref: True if the caller wants the results as ref tensors. 1459 1460 Returns: 1461 A `Tensor` or `CompositeTensor`, based on `value`. 1462 1463 Raises: 1464 ValueError: If `dtype` does not match the element type of `value`. 1465 """ 1466 if isinstance(value, composite_tensor.CompositeTensor): 1467 value_dtype = getattr(value, "dtype", None) 1468 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 1469 raise ValueError( 1470 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % 1471 (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) 1472 return value 1473 else: 1474 return internal_convert_to_tensor( 1475 value, dtype=dtype, name=name, as_ref=as_ref) 1476 1477 1478def internal_convert_n_to_tensor_or_composite(values, 1479 dtype=None, 1480 name=None, 1481 as_ref=False): 1482 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 1483 1484 Any `CompositeTensor` objects in `values` are returned unmodified. 1485 1486 Args: 1487 values: A list of `None`, `CompositeTensor`, or objects that 1488 can be consumed by `convert_to_tensor()`. 1489 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1490 `CompositeTensor`s. 1491 name: (Optional.) A name prefix to used when a new `Tensor` is 1492 created, in which case element `i` will be given the name `name 1493 + '_' + i`. 1494 as_ref: True if the caller wants the results as ref tensors. 1495 1496 Returns: 1497 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 1498 1499 Raises: 1500 TypeError: If no conversion function is registered for an element in 1501 `values`. 1502 RuntimeError: If a registered conversion function returns an invalid 1503 value. 1504 """ 1505 if not isinstance(values, collections.Sequence): 1506 raise TypeError("values must be a sequence.") 1507 ret = [] 1508 for i, value in enumerate(values): 1509 if value is None: 1510 ret.append(value) 1511 else: 1512 n = None if name is None else "%s_%d" % (name, i) 1513 ret.append( 1514 internal_convert_to_tensor_or_composite( 1515 value, dtype=dtype, name=n, as_ref=as_ref)) 1516 return ret 1517 1518 1519def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 1520 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 1521 1522 Any `CompositeTensor` objects in `values` are returned unmodified. 1523 1524 Args: 1525 values: A list of `None`, `CompositeTensor``, or objects that 1526 can be consumed by `convert_to_tensor()`. 1527 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 1528 `CompositeTensor`s. 1529 name: (Optional.) A name prefix to used when a new `Tensor` is 1530 created, in which case element `i` will be given the name `name 1531 + '_' + i`. 1532 1533 Returns: 1534 A list of `Tensor` and/or `CompositeTensor` objects. 1535 1536 Raises: 1537 TypeError: If no conversion function is registered for an element in 1538 `values`. 1539 RuntimeError: If a registered conversion function returns an invalid 1540 value. 1541 """ 1542 return internal_convert_n_to_tensor_or_composite( 1543 values=values, dtype=dtype, name=name, as_ref=False) 1544 1545 1546# TODO(josh11b): Add ctx argument to conversion_func() signature. 1547@tf_export("register_tensor_conversion_function") 1548def register_tensor_conversion_function(base_type, 1549 conversion_func, 1550 priority=100): 1551 """Registers a function for converting objects of `base_type` to `Tensor`. 1552 1553 The conversion function must have the following signature: 1554 1555 ```python 1556 def conversion_func(value, dtype=None, name=None, as_ref=False): 1557 # ... 1558 ``` 1559 1560 It must return a `Tensor` with the given `dtype` if specified. If the 1561 conversion function creates a new `Tensor`, it should use the given 1562 `name` if specified. All exceptions will be propagated to the caller. 1563 1564 The conversion function may return `NotImplemented` for some 1565 inputs. In this case, the conversion process will continue to try 1566 subsequent conversion functions. 1567 1568 If `as_ref` is true, the function must return a `Tensor` reference, 1569 such as a `Variable`. 1570 1571 NOTE: The conversion functions will execute in order of priority, 1572 followed by order of registration. To ensure that a conversion function 1573 `F` runs before another conversion function `G`, ensure that `F` is 1574 registered with a smaller priority than `G`. 1575 1576 Args: 1577 base_type: The base type or tuple of base types for all objects that 1578 `conversion_func` accepts. 1579 conversion_func: A function that converts instances of `base_type` to 1580 `Tensor`. 1581 priority: Optional integer that indicates the priority for applying this 1582 conversion function. Conversion functions with smaller priority values 1583 run earlier than conversion functions with larger priority values. 1584 Defaults to 100. 1585 1586 Raises: 1587 TypeError: If the arguments do not have the appropriate type. 1588 1589 """ 1590 global _tensor_conversion_func_cache 1591 with _tensor_conversion_func_lock: 1592 if not (isinstance(base_type, type) or 1593 (isinstance(base_type, tuple) and 1594 all(isinstance(x, type) for x in base_type))): 1595 raise TypeError("base_type must be a type or a tuple of types.") 1596 if not callable(conversion_func): 1597 raise TypeError("conversion_func must be callable.") 1598 1599 # context._context is checked so that we don't inadvertently create it. 1600 # This is because enable_eager_execution will fail when called from the main 1601 # function if the context._context is already created, and the 1602 # register_tensor_conversion_function calls happen when the module is 1603 # imported. 1604 if context._context is not None and context.executing_eagerly( 1605 ) and isinstance(base_type, six.integer_types + ( 1606 float, 1607 np.ndarray, 1608 )): 1609 # TODO(nareshmodi): consider setting a context variable which disables the 1610 # fastpath instead. 1611 raise TypeError( 1612 "Cannot register conversions for numpy arrays, python number types " 1613 "when executing eagerly.") 1614 1615 try: 1616 funcs_at_priority = _tensor_conversion_func_registry[priority] 1617 except KeyError: 1618 funcs_at_priority = [] 1619 _tensor_conversion_func_registry[priority] = funcs_at_priority 1620 funcs_at_priority.append((base_type, conversion_func)) 1621 _tensor_conversion_func_cache = {} 1622 1623 1624@tf_export("IndexedSlices") 1625class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor): 1626 """A sparse representation of a set of tensor slices at given indices. 1627 1628 This class is a simple wrapper for a pair of `Tensor` objects: 1629 1630 * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. 1631 * `indices`: A 1-D integer `Tensor` with shape `[D0]`. 1632 1633 An `IndexedSlices` is typically used to represent a subset of a larger 1634 tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. 1635 The values in `indices` are the indices in the first dimension of 1636 the slices that have been extracted from the larger tensor. 1637 1638 The dense tensor `dense` represented by an `IndexedSlices` `slices` has 1639 1640 ```python 1641 dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] 1642 ``` 1643 1644 The `IndexedSlices` class is used principally in the definition of 1645 gradients for operations that have sparse gradients 1646 (e.g. `tf.gather`). 1647 1648 Contrast this representation with 1649 `tf.SparseTensor`, 1650 which uses multi-dimensional indices and scalar values. 1651 """ 1652 1653 def __init__(self, values, indices, dense_shape=None): 1654 """Creates an `IndexedSlices`.""" 1655 _get_graph_from_inputs([values, indices, dense_shape]) 1656 self._values = values 1657 self._indices = indices 1658 self._dense_shape = dense_shape 1659 1660 @property 1661 def values(self): 1662 """A `Tensor` containing the values of the slices.""" 1663 return self._values 1664 1665 @property 1666 def indices(self): 1667 """A 1-D `Tensor` containing the indices of the slices.""" 1668 return self._indices 1669 1670 @property 1671 def dense_shape(self): 1672 """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" 1673 return self._dense_shape 1674 1675 @property 1676 def name(self): 1677 """The name of this `IndexedSlices`.""" 1678 return self.values.name 1679 1680 @property 1681 def device(self): 1682 """The name of the device on which `values` will be produced, or `None`.""" 1683 return self.values.device 1684 1685 @property 1686 def op(self): 1687 """The `Operation` that produces `values` as an output.""" 1688 return self.values.op 1689 1690 @property 1691 def dtype(self): 1692 """The `DType` of elements in this tensor.""" 1693 return self.values.dtype 1694 1695 @property 1696 def graph(self): 1697 """The `Graph` that contains the values, indices, and shape tensors.""" 1698 return self._values.graph 1699 1700 def __str__(self): 1701 return "IndexedSlices(indices=%s, values=%s%s)" % ( 1702 self._indices, self._values, (", dense_shape=%s" % self._dense_shape) 1703 if self._dense_shape is not None else "") 1704 1705 def __neg__(self): 1706 return IndexedSlices(-self.values, self.indices, self.dense_shape) 1707 1708 def _to_components(self): 1709 if self._dense_shape is None: 1710 return (self._values, self._indices) 1711 else: 1712 return (self._values, self._indices, self._dense_shape) 1713 1714 @classmethod 1715 def _from_components(cls, components): 1716 return cls(*components) 1717 1718 def _shape_invariant_to_components(self, shape=None): 1719 if shape is None: 1720 shape = self._values.shape 1721 if self._dense_shape is None: 1722 return [shape, shape[:1]] # values, indices 1723 else: 1724 # values, indices, dense_shape 1725 return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])] 1726 1727 @property 1728 def _is_graph_tensor(self): 1729 return hasattr(self._values, 'graph') 1730 1731 1732IndexedSlicesValue = collections.namedtuple( 1733 "IndexedSlicesValue", ["values", "indices", "dense_shape"]) 1734 1735 1736def _device_string(dev_spec): 1737 if isinstance(dev_spec, pydev.DeviceSpec): 1738 return dev_spec.to_string() 1739 else: 1740 return dev_spec 1741 1742 1743def _NodeDef(op_type, name, device=None, attrs=None): # pylint: disable=redefined-outer-name 1744 """Create a NodeDef proto. 1745 1746 Args: 1747 op_type: Value for the "op" attribute of the NodeDef proto. 1748 name: Value for the "name" attribute of the NodeDef proto. 1749 device: string, device, or function from NodeDef to string. 1750 Value for the "device" attribute of the NodeDef proto. 1751 attrs: Optional dictionary where the key is the attribute name (a string) 1752 and the value is the respective "attr" attribute of the NodeDef proto (an 1753 AttrValue). 1754 1755 Returns: 1756 A node_def_pb2.NodeDef protocol buffer. 1757 """ 1758 node_def = node_def_pb2.NodeDef() 1759 node_def.op = compat.as_bytes(op_type) 1760 node_def.name = compat.as_bytes(name) 1761 if attrs is not None: 1762 for k, v in six.iteritems(attrs): 1763 node_def.attr[k].CopyFrom(v) 1764 if device is not None: 1765 if callable(device): 1766 node_def.device = device(node_def) 1767 else: 1768 node_def.device = _device_string(device) 1769 return node_def 1770 1771 1772# Copied from core/framework/node_def_util.cc 1773# TODO(mrry,josh11b): Consolidate this validation in C++ code. 1774_VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$") 1775_VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$") 1776 1777 1778def _create_c_op(graph, node_def, inputs, control_inputs): 1779 """Creates a TF_Operation. 1780 1781 Args: 1782 graph: a `Graph`. 1783 node_def: `node_def_pb2.NodeDef` for the operation to create. 1784 inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of 1785 `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", 1786 "list(int64)"). The length of the list should be equal to the number of 1787 inputs specified by this operation's op def. 1788 control_inputs: A list of `Operation`s to set as control dependencies. 1789 1790 Returns: 1791 A wrapped TF_Operation*. 1792 """ 1793 # pylint: disable=protected-access 1794 op_desc = c_api.TF_NewOperation(graph._c_graph, 1795 compat.as_str(node_def.op), 1796 compat.as_str(node_def.name)) 1797 if node_def.device: 1798 c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 1799 # Add inputs 1800 for op_input in inputs: 1801 if isinstance(op_input, (list, tuple)): 1802 c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input]) 1803 else: 1804 c_api.TF_AddInput(op_desc, op_input._as_tf_output()) 1805 1806 # Add control inputs 1807 for control_input in control_inputs: 1808 c_api.TF_AddControlInput(op_desc, control_input._c_op) 1809 # pylint: enable=protected-access 1810 1811 # Add attrs 1812 for name, attr_value in node_def.attr.items(): 1813 serialized = attr_value.SerializeToString() 1814 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 1815 # It might be worth creating a convenient way to re-use the same status. 1816 c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized) 1817 1818 try: 1819 c_op = c_api.TF_FinishOperation(op_desc) 1820 except errors.InvalidArgumentError as e: 1821 # Convert to ValueError for backwards compatibility. 1822 raise ValueError(str(e)) 1823 1824 return c_op 1825 1826 1827@tf_export("Operation") 1828class Operation(object): 1829 """Represents a graph node that performs computation on tensors. 1830 1831 An `Operation` is a node in a TensorFlow `Graph` that takes zero or 1832 more `Tensor` objects as input, and produces zero or more `Tensor` 1833 objects as output. Objects of type `Operation` are created by 1834 calling a Python op constructor (such as 1835 `tf.matmul`) 1836 or `tf.Graph.create_op`. 1837 1838 For example `c = tf.matmul(a, b)` creates an `Operation` of type 1839 "MatMul" that takes tensors `a` and `b` as input, and produces `c` 1840 as output. 1841 1842 After the graph has been launched in a session, an `Operation` can 1843 be executed by passing it to 1844 `tf.Session.run`. 1845 `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. 1846 """ 1847 1848 def __init__(self, 1849 node_def, 1850 g, 1851 inputs=None, 1852 output_types=None, 1853 control_inputs=None, 1854 input_types=None, 1855 original_op=None, 1856 op_def=None): 1857 r"""Creates an `Operation`. 1858 1859 NOTE: This constructor validates the name of the `Operation` (passed 1860 as `node_def.name`). Valid `Operation` names match the following 1861 regular expression: 1862 1863 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 1864 1865 Args: 1866 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. 1867 Used for attributes of `node_def_pb2.NodeDef`, typically `name`, 1868 `op`, and `device`. The `input` attribute is irrelevant here 1869 as it will be computed when generating the model. 1870 g: `Graph`. The parent graph. 1871 inputs: list of `Tensor` objects. The inputs to this `Operation`. 1872 output_types: list of `DType` objects. List of the types of the 1873 `Tensors` computed by this operation. The length of this list indicates 1874 the number of output endpoints of the `Operation`. 1875 control_inputs: list of operations or tensors from which to have a 1876 control dependency. 1877 input_types: List of `DType` objects representing the 1878 types of the tensors accepted by the `Operation`. By default 1879 uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect 1880 reference-typed inputs must specify these explicitly. 1881 original_op: Optional. Used to associate the new `Operation` with an 1882 existing `Operation` (for example, a replica with the op that was 1883 replicated). 1884 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the 1885 op type that this `Operation` represents. 1886 1887 Raises: 1888 TypeError: if control inputs are not Operations or Tensors, 1889 or if `node_def` is not a `NodeDef`, 1890 or if `g` is not a `Graph`, 1891 or if `inputs` are not tensors, 1892 or if `inputs` and `input_types` are incompatible. 1893 ValueError: if the `node_def` name is not valid. 1894 """ 1895 # For internal use only: `node_def` can be set to a TF_Operation to create 1896 # an Operation for that op. This is useful for creating Operations for ops 1897 # indirectly created by C API methods, e.g. the ops created by 1898 # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields 1899 # should be None. 1900 1901 if isinstance(node_def, node_def_pb2.NodeDef): 1902 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 1903 raise ValueError( 1904 "Cannot create a tensor proto whose content is larger than 2GB.") 1905 if not _VALID_OP_NAME_REGEX.match(node_def.name): 1906 raise ValueError("'%s' is not a valid node name" % node_def.name) 1907 c_op = None 1908 elif type(node_def).__name__ == "SwigPyObject": 1909 assert inputs is None 1910 assert output_types is None 1911 assert control_inputs is None 1912 assert input_types is None 1913 assert original_op is None 1914 assert op_def is None 1915 c_op = node_def 1916 else: 1917 raise TypeError("node_def needs to be a NodeDef: %s" % node_def) 1918 1919 if not isinstance(g, Graph): 1920 raise TypeError("g needs to be a Graph: %s" % g) 1921 self._graph = g 1922 1923 if inputs is None: 1924 inputs = [] 1925 elif not isinstance(inputs, list): 1926 raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) 1927 for a in inputs: 1928 if not isinstance(a, Tensor): 1929 raise TypeError("input needs to be a Tensor: %s" % a) 1930 if input_types is None: 1931 input_types = [i.dtype.base_dtype for i in inputs] 1932 else: 1933 if not all( 1934 x.is_compatible_with(i.dtype) 1935 for i, x in zip(inputs, input_types)): 1936 raise TypeError("In op '%s', input types (%s) are not compatible " 1937 "with expected types (%s)" % 1938 (node_def.name, [i.dtype for i in inputs], 1939 input_types)) 1940 1941 # Build the list of control inputs. 1942 control_input_ops = [] 1943 if control_inputs: 1944 for c in control_inputs: 1945 control_op = None 1946 if isinstance(c, Operation): 1947 control_op = c 1948 elif isinstance(c, (Tensor, IndexedSlices)): 1949 control_op = c.op 1950 else: 1951 raise TypeError("Control input must be an Operation, " 1952 "a Tensor, or IndexedSlices: %s" % c) 1953 control_input_ops.append(control_op) 1954 1955 # This will be set by self.inputs. 1956 self._inputs_val = None 1957 1958 # pylint: disable=protected-access 1959 self._id_value = self._graph._next_id() 1960 self._original_op = original_op 1961 self._traceback = tf_stack.extract_stack() 1962 1963 # List of _UserDevSpecs holding code location of device context manager 1964 # invocations and the users original argument to them. 1965 self._device_code_locations = None 1966 # Dict mapping op name to file and line information for op colocation 1967 # context managers. 1968 self._colocation_code_locations = None 1969 self._control_flow_context = self.graph._get_control_flow_context() 1970 # pylint: enable=protected-access 1971 1972 # Initialize self._c_op. 1973 if c_op: 1974 self._c_op = c_op 1975 else: 1976 if op_def is None: 1977 op_def = self._graph._get_op_def(node_def.op) 1978 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 1979 # Refactor so we don't have to do this here. 1980 grouped_inputs = self._reconstruct_sequence_inputs( 1981 op_def, inputs, node_def.attr) 1982 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, 1983 control_input_ops) 1984 1985 # Initialize self._outputs. 1986 num_outputs = c_api.TF_OperationNumOutputs(self._c_op) 1987 output_types = [ 1988 c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) 1989 for i in range(num_outputs)] 1990 self._outputs = [ 1991 Tensor(self, i, output_type) 1992 for i, output_type in enumerate(output_types) 1993 ] 1994 1995 self._graph._add_op(self) # pylint: disable=protected-access 1996 1997 if not c_op: 1998 self._control_flow_post_processing() 1999 2000 def _control_flow_post_processing(self): 2001 """Add this op to its control flow context. 2002 2003 This may add new ops and change this op's inputs. self.inputs must be 2004 available before calling this method. 2005 """ 2006 for input_tensor in self.inputs: 2007 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 2008 if self._control_flow_context is not None: 2009 self._control_flow_context.AddOp(self) 2010 2011 def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): 2012 """Regroups a flat list of input tensors into scalar and sequence inputs. 2013 2014 Args: 2015 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 2016 inputs: a list of input `Tensor`s to the op. 2017 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 2018 how long each sequence is) 2019 2020 Returns: 2021 A list of `Tensor`s (corresponding to scalar inputs) and lists of 2022 `Tensor`s (corresponding to sequence inputs). 2023 """ 2024 grouped_inputs = [] 2025 i = 0 2026 for input_arg in op_def.input_arg: 2027 if input_arg.number_attr: 2028 input_len = attrs[input_arg.number_attr].i 2029 is_sequence = True 2030 elif input_arg.type_list_attr: 2031 input_len = len(attrs[input_arg.type_list_attr].list.type) 2032 is_sequence = True 2033 else: 2034 input_len = 1 2035 is_sequence = False 2036 2037 if is_sequence: 2038 grouped_inputs.append(inputs[i:i + input_len]) 2039 else: 2040 grouped_inputs.append(inputs[i]) 2041 i += input_len 2042 2043 assert i == len(inputs) 2044 return grouped_inputs 2045 2046 def colocation_groups(self): 2047 """Returns the list of colocation groups of the op.""" 2048 default_colocation_group = [ 2049 compat.as_bytes("loc:@%s" % self.name) 2050 ] 2051 try: 2052 class_attr = self.get_attr("_class") 2053 except ValueError: 2054 # This op has no explicit colocation group, so it is itself its 2055 # own root of a colocation group. 2056 return default_colocation_group 2057 2058 attr_groups = [ 2059 class_name for class_name in class_attr 2060 if class_name.startswith(b"loc:@") 2061 ] 2062 2063 # If there are no colocation groups in the explicit _class field, 2064 # return the default colocation group. 2065 return attr_groups if attr_groups else default_colocation_group 2066 2067 def values(self): 2068 """DEPRECATED: Use outputs.""" 2069 return tuple(self.outputs) 2070 2071 def _get_control_flow_context(self): 2072 """Returns the control flow context of this op. 2073 2074 Returns: 2075 A context object. 2076 """ 2077 return self._control_flow_context 2078 2079 def _set_control_flow_context(self, ctx): 2080 """Sets the current control flow context of this op. 2081 2082 Args: 2083 ctx: a context object. 2084 """ 2085 self._control_flow_context = ctx 2086 2087 @property 2088 def name(self): 2089 """The full name of this operation.""" 2090 return c_api.TF_OperationName(self._c_op) 2091 2092 @property 2093 def _id(self): 2094 """The unique integer id of this operation.""" 2095 return self._id_value 2096 2097 @property 2098 def device(self): 2099 """The name of the device to which this op has been assigned, if any. 2100 2101 Returns: 2102 The string name of the device to which this op has been 2103 assigned, or an empty string if it has not been assigned to a 2104 device. 2105 """ 2106 return c_api.TF_OperationDevice(self._c_op) 2107 2108 @property 2109 def _device_assignments(self): 2110 """Code locations for device context managers active at op creation. 2111 2112 This property will return a list of traceable_stack.TraceableObject 2113 instances where .obj is a string representing the assigned device 2114 (or information about the function that would be applied to this op 2115 to compute the desired device) and the filename and lineno members 2116 record the location of the relevant device context manager. 2117 2118 For example, suppose file_a contained these lines: 2119 2120 file_a.py: 2121 15: with tf.device('/gpu:0'): 2122 16: node_b = tf.constant(4, name='NODE_B') 2123 2124 Then a TraceableObject t_obj representing the device context manager 2125 would have these member values: 2126 2127 t_obj.obj -> '/gpu:0' 2128 t_obj.filename = 'file_a.py' 2129 t_obj.lineno = 15 2130 2131 and node_b.op._device_assignments would return the list [t_obj]. 2132 2133 Returns: 2134 [str: traceable_stack.TraceableObject, ...] as per this method's 2135 description, above. 2136 """ 2137 return self._device_code_locations or [] 2138 2139 @property 2140 def _colocation_dict(self): 2141 """Code locations for colocation context managers active at op creation. 2142 2143 This property will return a dictionary for which the keys are nodes with 2144 which this Operation is colocated, and for which the values are 2145 traceable_stack.TraceableObject instances. The TraceableObject instances 2146 record the location of the relevant colocation context manager but have the 2147 "obj" field set to None to prevent leaking private data. 2148 2149 For example, suppose file_a contained these lines: 2150 2151 file_a.py: 2152 14: node_a = tf.constant(3, name='NODE_A') 2153 15: with tf.colocate_with(node_a): 2154 16: node_b = tf.constant(4, name='NODE_B') 2155 2156 Then a TraceableObject t_obj representing the colocation context manager 2157 would have these member values: 2158 2159 t_obj.obj -> None 2160 t_obj.filename = 'file_a.py' 2161 t_obj.lineno = 15 2162 2163 and node_b.op._colocation_dict would return the dictionary 2164 2165 { 'NODE_A': t_obj } 2166 2167 Returns: 2168 {str: traceable_stack.TraceableObject} as per this method's description, 2169 above. 2170 """ 2171 locations_dict = self._colocation_code_locations or {} 2172 return locations_dict.copy() 2173 2174 @property 2175 def _output_types(self): 2176 """List this operation's output types. 2177 2178 Returns: 2179 List of the types of the Tensors computed by this operation. 2180 Each element in the list is an integer whose value is one of 2181 the TF_DataType enums defined in c_api.h 2182 The length of this list indicates the number of output endpoints 2183 of the operation. 2184 """ 2185 num_outputs = c_api.TF_OperationNumOutputs(self._c_op) 2186 output_types = [ 2187 c_api.TF_OperationOutputType(self._tf_output(i)) 2188 for i in xrange(num_outputs) 2189 ] 2190 # In all the tests we have output_types that are passed into 2191 # Operation.__init__ are a list of ints (which is illegal according 2192 # to the docstring), but input_types are instances of DType. 2193 # This extra assert is to catch if we ever use DType for output_types. 2194 if output_types: 2195 assert isinstance(output_types[0], int) 2196 return output_types 2197 2198 def _tf_output(self, output_idx): 2199 """Create and return a new TF_Output for output_idx'th output of this op.""" 2200 tf_output = c_api.TF_Output() 2201 tf_output.oper = self._c_op 2202 tf_output.index = output_idx 2203 return tf_output 2204 2205 def _tf_input(self, input_idx): 2206 """Create and return a new TF_Input for input_idx'th input of this op.""" 2207 tf_input = c_api.TF_Input() 2208 tf_input.oper = self._c_op 2209 tf_input.index = input_idx 2210 return tf_input 2211 2212 def _set_device(self, device): # pylint: disable=redefined-outer-name 2213 """Set the device of this operation. 2214 2215 Args: 2216 device: string or device.. The device to set. 2217 """ 2218 c_api.SetRequestedDevice( 2219 self._graph._c_graph, # pylint: disable=protected-access 2220 self._c_op, # pylint: disable=protected-access 2221 compat.as_str(_device_string(device))) 2222 2223 def _update_input(self, index, tensor): 2224 """Update the input to this operation at the given index. 2225 2226 NOTE: This is for TF internal use only. Please don't use it. 2227 2228 Args: 2229 index: the index of the input to update. 2230 tensor: the Tensor to be used as the input at the given index. 2231 2232 Raises: 2233 TypeError: if tensor is not a Tensor, 2234 or if input tensor type is not convertible to dtype. 2235 ValueError: if the Tensor is from a different graph. 2236 """ 2237 if not isinstance(tensor, Tensor): 2238 raise TypeError("tensor must be a Tensor: %s" % tensor) 2239 _assert_same_graph(self, tensor) 2240 2241 # Reset cached inputs. 2242 self._inputs_val = None 2243 c_api.UpdateEdge( 2244 self._graph._c_graph, # pylint: disable=protected-access 2245 tensor._as_tf_output(), # pylint: disable=protected-access 2246 self._tf_input(index)) 2247 2248 def _add_while_inputs(self, tensors): 2249 """See AddWhileInputHack in python_api.h. 2250 2251 NOTE: This is for TF internal use only. Please don't use it. 2252 2253 Args: 2254 tensors: list of Tensors 2255 2256 Raises: 2257 TypeError: if tensor is not a Tensor, 2258 or if input tensor type is not convertible to dtype. 2259 ValueError: if the Tensor is from a different graph. 2260 """ 2261 for tensor in tensors: 2262 if not isinstance(tensor, Tensor): 2263 raise TypeError("tensor must be a Tensor: %s" % tensor) 2264 _assert_same_graph(self, tensor) 2265 2266 # Reset cached inputs. 2267 self._inputs_val = None 2268 c_api.AddWhileInputHack( 2269 self._graph._c_graph, # pylint: disable=protected-access 2270 tensor._as_tf_output(), # pylint: disable=protected-access 2271 self._c_op) 2272 2273 def _add_control_inputs(self, ops): 2274 """Add a list of new control inputs to this operation. 2275 2276 Args: 2277 ops: the list of Operations to add as control input. 2278 2279 Raises: 2280 TypeError: if ops is not a list of Operations. 2281 ValueError: if any op in ops is from a different graph. 2282 """ 2283 for op in ops: 2284 if not isinstance(op, Operation): 2285 raise TypeError("op must be an Operation: %s" % op) 2286 c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access 2287 2288 def _add_control_input(self, op): 2289 """Add a new control input to this operation. 2290 2291 Args: 2292 op: the Operation to add as control input. 2293 2294 Raises: 2295 TypeError: if op is not an Operation. 2296 ValueError: if op is from a different graph. 2297 """ 2298 if not isinstance(op, Operation): 2299 raise TypeError("op must be an Operation: %s" % op) 2300 c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access 2301 2302 def _remove_all_control_inputs(self): 2303 """Removes any control inputs to this operation.""" 2304 c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access 2305 2306 def _add_outputs(self, types, shapes): 2307 """Adds new Tensors to self.outputs. 2308 2309 Note: this is generally unsafe to use. This is used in certain situations in 2310 conjunction with _set_type_list_attr. 2311 2312 Arguments: 2313 types: list of DTypes 2314 shapes: list of TensorShapes 2315 """ 2316 assert len(types) == len(shapes) 2317 orig_num_outputs = len(self.outputs) 2318 for i in range(len(types)): 2319 t = Tensor(self, orig_num_outputs + i, types[i]) 2320 self._outputs.append(t) 2321 t.set_shape(shapes[i]) 2322 2323 def __str__(self): 2324 return str(self.node_def) 2325 2326 def __repr__(self): 2327 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 2328 2329 @property 2330 def outputs(self): 2331 """The list of `Tensor` objects representing the outputs of this op.""" 2332 return self._outputs 2333 2334# pylint: disable=protected-access 2335 2336 class _InputList(object): 2337 """Immutable input list wrapper.""" 2338 2339 def __init__(self, inputs): 2340 self._inputs = inputs 2341 2342 def __iter__(self): 2343 return iter(self._inputs) 2344 2345 def __len__(self): 2346 return len(self._inputs) 2347 2348 def __bool__(self): 2349 return bool(self._inputs) 2350 2351 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 2352 __nonzero__ = __bool__ 2353 2354 def __getitem__(self, i): 2355 return self._inputs[i] 2356 2357# pylint: enable=protected-access 2358 2359 @property 2360 def inputs(self): 2361 """The list of `Tensor` objects representing the data inputs of this op.""" 2362 if self._inputs_val is None: 2363 tf_outputs = c_api.GetOperationInputs(self._c_op) 2364 # pylint: disable=protected-access 2365 retval = [ 2366 self.graph._get_tensor_by_tf_output(tf_output) 2367 for tf_output in tf_outputs 2368 ] 2369 # pylint: enable=protected-access 2370 self._inputs_val = Operation._InputList(retval) 2371 return self._inputs_val 2372 2373 @property 2374 def _inputs(self): 2375 logging.warning("Operation._inputs is private, use Operation.inputs " 2376 "instead. Operation._inputs will eventually be removed.") 2377 return self.inputs 2378 2379 @_inputs.setter 2380 def _inputs(self, value): 2381 raise ValueError("Cannot assign _inputs") 2382 2383 @property 2384 def _input_types(self): 2385 num_inputs = c_api.TF_OperationNumInputs(self._c_op) 2386 input_types = [ 2387 dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) 2388 for i in xrange(num_inputs) 2389 ] 2390 return input_types 2391 2392 @_input_types.setter 2393 def _input_types(self, value): 2394 raise ValueError("Cannot assign _input_types") 2395 2396 @property 2397 def control_inputs(self): 2398 """The `Operation` objects on which this op has a control dependency. 2399 2400 Before this op is executed, TensorFlow will ensure that the 2401 operations in `self.control_inputs` have finished executing. This 2402 mechanism can be used to run ops sequentially for performance 2403 reasons, or to ensure that the side effects of an op are observed 2404 in the correct order. 2405 2406 Returns: 2407 A list of `Operation` objects. 2408 2409 """ 2410 control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) 2411 # pylint: disable=protected-access 2412 return [ 2413 self.graph._get_operation_by_name_unsafe( 2414 c_api.TF_OperationName(c_op)) for c_op in control_c_ops 2415 ] 2416 # pylint: enable=protected-access 2417 2418 @property 2419 def _control_outputs(self): 2420 """The `Operation` objects which have a control dependency on this op. 2421 2422 Before any of the ops in self._control_outputs can execute tensorflow will 2423 ensure self has finished executing. 2424 2425 Returns: 2426 A list of `Operation` objects. 2427 2428 """ 2429 control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) 2430 # pylint: disable=protected-access 2431 return [ 2432 self.graph._get_operation_by_name_unsafe( 2433 c_api.TF_OperationName(c_op)) for c_op in control_c_ops 2434 ] 2435 # pylint: enable=protected-access 2436 2437 @property 2438 def _control_inputs(self): 2439 logging.warning("Operation._control_inputs is private, use " 2440 "Operation.control_inputs instead. " 2441 "Operation._control_inputs will eventually be removed.") 2442 return self.control_inputs 2443 2444 @_control_inputs.setter 2445 def _control_inputs(self, value): 2446 logging.warning("Operation._control_inputs is private, use " 2447 "Operation.control_inputs instead. " 2448 "Operation._control_inputs will eventually be removed.") 2449 # Copy value because it may be self._control_inputs_val (in particular if 2450 # this is called from self._control_inputs += ...), and we don't want to 2451 # clear value below. 2452 value = copy.copy(value) 2453 self._remove_all_control_inputs() 2454 self._add_control_inputs(value) 2455 2456 @property 2457 def type(self): 2458 """The type of the op (e.g. `"MatMul"`).""" 2459 return c_api.TF_OperationOpType(self._c_op) 2460 2461 @property 2462 def graph(self): 2463 """The `Graph` that contains this operation.""" 2464 return self._graph 2465 2466 @property 2467 def node_def(self): 2468 # pylint: disable=line-too-long 2469 """Returns the `NodeDef` representation of this operation. 2470 2471 Returns: 2472 A 2473 [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto) 2474 protocol buffer. 2475 """ 2476 # pylint: enable=line-too-long 2477 with c_api_util.tf_buffer() as buf: 2478 c_api.TF_OperationToNodeDef(self._c_op, buf) 2479 data = c_api.TF_GetBuffer(buf) 2480 node_def = node_def_pb2.NodeDef() 2481 node_def.ParseFromString(compat.as_bytes(data)) 2482 return node_def 2483 2484 @property 2485 def _node_def(self): 2486 logging.warning("Operation._node_def is private, use Operation.node_def " 2487 "instead. Operation._node_def will eventually be removed.") 2488 return self.node_def 2489 2490 @property 2491 def op_def(self): 2492 # pylint: disable=line-too-long 2493 """Returns the `OpDef` proto that represents the type of this op. 2494 2495 Returns: 2496 An 2497 [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto) 2498 protocol buffer. 2499 """ 2500 # pylint: enable=line-too-long 2501 return self._graph._get_op_def(self.type) 2502 2503 @property 2504 def _op_def(self): 2505 logging.warning("Operation._op_def is private, use Operation.op_def " 2506 "instead. Operation._op_def will eventually be removed.") 2507 return self.op_def 2508 2509 @property 2510 def traceback(self): 2511 """Returns the call stack from when this operation was constructed.""" 2512 return tf_stack.convert_stack(self._traceback) 2513 2514 @property 2515 def traceback_with_start_lines(self): 2516 """Same as traceback but includes start line of function definition. 2517 2518 Returns: 2519 A list of 5-tuples (filename, lineno, name, code, func_start_lineno). 2520 """ 2521 return tf_stack.convert_stack(self._traceback, 2522 include_func_start_lineno=True) 2523 2524 def _set_attr(self, attr_name, attr_value): 2525 """Private method used to set an attribute in the node_def.""" 2526 buf = c_api.TF_NewBufferFromString( 2527 compat.as_bytes(attr_value.SerializeToString())) 2528 try: 2529 # pylint: disable=protected-access 2530 c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) 2531 # pylint: enable=protected-access 2532 finally: 2533 c_api.TF_DeleteBuffer(buf) 2534 2535 def _set_func_attr(self, attr_name, func_name): 2536 """Private method used to set a function attribute in the node_def.""" 2537 func = attr_value_pb2.NameAttrList(name=func_name) 2538 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 2539 2540 def _set_type_list_attr(self, attr_name, types): 2541 """Private method used to set a function attribute in the node_def.""" 2542 if not types: return 2543 if isinstance(types[0], dtypes.DType): 2544 types = [dt.as_datatype_enum for dt in types] 2545 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 2546 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 2547 2548 def _set_shape_list_attr(self, attr_name, shapes): 2549 """Private method used to set a function attribute in the node_def.""" 2550 shapes = [s.as_proto() for s in shapes] 2551 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 2552 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 2553 2554 def _clear_attr(self, attr_name): 2555 """Private method used to clear an attribute in the node_def.""" 2556 # pylint: disable=protected-access 2557 c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name) 2558 # pylint: enable=protected-access 2559 2560 def get_attr(self, name): 2561 """Returns the value of the attr of this op with the given `name`. 2562 2563 Args: 2564 name: The name of the attr to fetch. 2565 2566 Returns: 2567 The value of the attr, as a Python object. 2568 2569 Raises: 2570 ValueError: If this op does not have an attr with the given `name`. 2571 """ 2572 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 2573 try: 2574 with c_api_util.tf_buffer() as buf: 2575 c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) 2576 data = c_api.TF_GetBuffer(buf) 2577 except errors.InvalidArgumentError as e: 2578 # Convert to ValueError for backwards compatibility. 2579 raise ValueError(str(e)) 2580 x = attr_value_pb2.AttrValue() 2581 x.ParseFromString(data) 2582 2583 oneof_value = x.WhichOneof("value") 2584 if oneof_value is None: 2585 return [] 2586 if oneof_value == "list": 2587 for f in fields: 2588 if getattr(x.list, f): 2589 if f == "type": 2590 return [dtypes.as_dtype(t) for t in x.list.type] 2591 else: 2592 return list(getattr(x.list, f)) 2593 return [] 2594 if oneof_value == "type": 2595 return dtypes.as_dtype(x.type) 2596 assert oneof_value in fields, "Unsupported field type in " + str(x) 2597 return getattr(x, oneof_value) 2598 2599 def run(self, feed_dict=None, session=None): 2600 """Runs this operation in a `Session`. 2601 2602 Calling this method will execute all preceding operations that 2603 produce the inputs needed for this operation. 2604 2605 *N.B.* Before invoking `Operation.run()`, its graph must have been 2606 launched in a session, and either a default session must be 2607 available, or `session` must be specified explicitly. 2608 2609 Args: 2610 feed_dict: A dictionary that maps `Tensor` objects to feed values. 2611 See `tf.Session.run` 2612 for a description of the valid feed values. 2613 session: (Optional.) The `Session` to be used to run to this operation. If 2614 none, the default session will be used. 2615 """ 2616 _run_using_default_session(self, feed_dict, self.graph, session) 2617 2618_gradient_registry = registry.Registry("gradient") 2619 2620 2621@tf_export("RegisterGradient") 2622class RegisterGradient(object): 2623 """A decorator for registering the gradient function for an op type. 2624 2625 This decorator is only used when defining a new op type. For an op 2626 with `m` inputs and `n` outputs, the gradient function is a function 2627 that takes the original `Operation` and `n` `Tensor` objects 2628 (representing the gradients with respect to each output of the op), 2629 and returns `m` `Tensor` objects (representing the partial gradients 2630 with respect to each input of the op). 2631 2632 For example, assuming that operations of type `"Sub"` take two 2633 inputs `x` and `y`, and return a single output `x - y`, the 2634 following gradient function would be registered: 2635 2636 ```python 2637 @tf.RegisterGradient("Sub") 2638 def _sub_grad(unused_op, grad): 2639 return grad, tf.negative(grad) 2640 ``` 2641 2642 The decorator argument `op_type` is the string type of an 2643 operation. This corresponds to the `OpDef.name` field for the proto 2644 that defines the operation. 2645 """ 2646 2647 def __init__(self, op_type): 2648 """Creates a new decorator with `op_type` as the Operation type. 2649 2650 Args: 2651 op_type: The string type of an operation. This corresponds to the 2652 `OpDef.name` field for the proto that defines the operation. 2653 """ 2654 if not isinstance(op_type, six.string_types): 2655 raise TypeError("op_type must be a string") 2656 self._op_type = op_type 2657 2658 def __call__(self, f): 2659 """Registers the function `f` as gradient function for `op_type`.""" 2660 _gradient_registry.register(f, self._op_type) 2661 return f 2662 2663 2664@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 2665@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 2666def no_gradient(op_type): 2667 """Specifies that ops of type `op_type` is not differentiable. 2668 2669 This function should *not* be used for operations that have a 2670 well-defined gradient that is not yet implemented. 2671 2672 This function is only used when defining a new op type. It may be 2673 used for ops such as `tf.size()` that are not differentiable. For 2674 example: 2675 2676 ```python 2677 tf.NotDifferentiable("Size") 2678 ``` 2679 2680 The gradient computed for 'op_type' will then propagate zeros. 2681 2682 For ops that have a well-defined gradient but are not yet implemented, 2683 no declaration should be made, and an error *must* be thrown if 2684 an attempt to request its gradient is made. 2685 2686 Args: 2687 op_type: The string type of an operation. This corresponds to the 2688 `OpDef.name` field for the proto that defines the operation. 2689 2690 Raises: 2691 TypeError: If `op_type` is not a string. 2692 2693 """ 2694 if not isinstance(op_type, six.string_types): 2695 raise TypeError("op_type must be a string") 2696 _gradient_registry.register(None, op_type) 2697 2698 2699# Aliases for the old names, will be eventually removed. 2700NoGradient = no_gradient 2701NotDifferentiable = no_gradient 2702 2703 2704def get_gradient_function(op): 2705 """Returns the function that computes gradients for "op".""" 2706 if not op.inputs: 2707 return None 2708 try: 2709 op_type = op.get_attr("_gradient_op_type") 2710 except ValueError: 2711 op_type = op.type 2712 return _gradient_registry.lookup(op_type) 2713 2714 2715_shape_registry = registry.Registry("shape functions") 2716_default_shape_function_registry = registry.Registry("default shape functions") 2717 2718# These are set to common_shapes.call_cpp_shape_fn by op generated code 2719# (generated by python_op_gen.cc). 2720# It is set outside ops.py to avoid a circular dependency. 2721_call_cpp_shape_fn = None 2722_call_cpp_shape_fn_and_require_op = None 2723 2724 2725def _set_call_cpp_shape_fn(call_cpp_shape_fn): 2726 """Sets default shape fns from passed common_shapes.call_cpp_shape_fn.""" 2727 global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op 2728 if _call_cpp_shape_fn: 2729 return # already registered 2730 2731 def call_without_requiring(op): 2732 return call_cpp_shape_fn(op, require_shape_fn=False) 2733 2734 _call_cpp_shape_fn = call_without_requiring 2735 2736 def call_with_requiring(op): 2737 return call_cpp_shape_fn(op, require_shape_fn=True) 2738 2739 _call_cpp_shape_fn_and_require_op = call_with_requiring 2740 2741 2742class RegisterShape(object): 2743 """No longer used. Was: A decorator for registering a shape function. 2744 2745 Shape functions must now be registered via the SetShapeFn on the 2746 original Op specification in C++. 2747 2748 """ 2749 2750 def __init__(self, op_type): 2751 """Saves the `op_type` as the `Operation` type.""" 2752 if not isinstance(op_type, six.string_types): 2753 raise TypeError("op_type must be a string") 2754 self._op_type = op_type 2755 2756 def __call__(self, f): 2757 """Registers "f" as the shape function for "op_type".""" 2758 if f is None: 2759 assert _call_cpp_shape_fn 2760 2761 # None is a special "weak" value that provides a default shape function, 2762 # and can be overridden by a non-None registration. 2763 try: 2764 _default_shape_function_registry.register(_call_cpp_shape_fn, 2765 self._op_type) 2766 except KeyError: 2767 # Ignore duplicate registrations of the weak value. This can 2768 # occur if the op library input to wrapper generation 2769 # inadvertently links in one or more of the standard op 2770 # libraries. 2771 pass 2772 else: 2773 _shape_registry.register(f, self._op_type) 2774 return f 2775 2776 2777def set_shape_and_handle_data_for_outputs(_): 2778 """No op. TODO(b/74620627): Remove this.""" 2779 pass 2780 2781 2782class OpStats(object): 2783 """A holder for statistics about an operator. 2784 2785 This class holds information about the resource requirements for an op, 2786 including the size of its weight parameters on-disk and how many FLOPS it 2787 requires to execute forward inference. 2788 2789 If you define a new operation, you can create a function that will return a 2790 set of information about its usage of the CPU and disk space when serialized. 2791 The function itself takes a Graph object that's been set up so you can call 2792 methods like get_tensor_by_name to help calculate the results, and a NodeDef 2793 argument. 2794 2795 """ 2796 2797 def __init__(self, statistic_type, value=None): 2798 """Sets up the initial placeholders for the statistics.""" 2799 self.statistic_type = statistic_type 2800 self.value = value 2801 2802 @property 2803 def statistic_type(self): 2804 return self._statistic_type 2805 2806 @statistic_type.setter 2807 def statistic_type(self, statistic_type): 2808 self._statistic_type = statistic_type 2809 2810 @property 2811 def value(self): 2812 return self._value 2813 2814 @value.setter 2815 def value(self, value): 2816 self._value = value 2817 2818 def __iadd__(self, other): 2819 if other.statistic_type != self.statistic_type: 2820 raise ValueError("Can't add an OpStat of type %s to one of %s." % 2821 (self.statistic_type, other.statistic_type)) 2822 if self.value is None: 2823 self.value = other.value 2824 elif other.value is not None: 2825 self._value += other.value 2826 return self 2827 2828 2829_stats_registry = registry.Registry("statistical functions") 2830 2831 2832class RegisterStatistics(object): 2833 """A decorator for registering the statistics function for an op type. 2834 2835 This decorator can be defined for an op type so that it gives a 2836 report on the resources used by an instance of an operator, in the 2837 form of an OpStats object. 2838 2839 Well-known types of statistics include these so far: 2840 2841 - flops: When running a graph, the bulk of the computation happens doing 2842 numerical calculations like matrix multiplications. This type allows a node 2843 to return how many floating-point operations it takes to complete. The 2844 total number of FLOPs for a graph is a good guide to its expected latency. 2845 2846 You can add your own statistics just by picking a new type string, registering 2847 functions for the ops you care about, and then calling get_stats_for_node_def. 2848 2849 If a statistic for an op is registered multiple times, a KeyError will be 2850 raised. 2851 2852 Since the statistics is counted on a per-op basis. It is not suitable for 2853 model parameters (capacity), which is expected to be counted only once, even 2854 if it is shared by multiple ops. (e.g. RNN) 2855 2856 For example, you can define a new metric called doohickey for a Foo operation 2857 by placing this in your code: 2858 2859 ```python 2860 @ops.RegisterStatistics("Foo", "doohickey") 2861 def _calc_foo_bojangles(unused_graph, unused_node_def): 2862 return ops.OpStats("doohickey", 20) 2863 ``` 2864 2865 Then in client code you can retrieve the value by making this call: 2866 2867 ```python 2868 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 2869 ``` 2870 2871 If the NodeDef is for an op with a registered doohickey function, you'll get 2872 back the calculated amount in doohickey.value, or None if it's not defined. 2873 2874 """ 2875 2876 def __init__(self, op_type, statistic_type): 2877 """Saves the `op_type` as the `Operation` type.""" 2878 if not isinstance(op_type, six.string_types): 2879 raise TypeError("op_type must be a string.") 2880 if "," in op_type: 2881 raise TypeError("op_type must not contain a comma.") 2882 self._op_type = op_type 2883 if not isinstance(statistic_type, six.string_types): 2884 raise TypeError("statistic_type must be a string.") 2885 if "," in statistic_type: 2886 raise TypeError("statistic_type must not contain a comma.") 2887 self._statistic_type = statistic_type 2888 2889 def __call__(self, f): 2890 """Registers "f" as the statistics function for "op_type".""" 2891 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 2892 return f 2893 2894 2895def get_stats_for_node_def(graph, node, statistic_type): 2896 """Looks up the node's statistics function in the registry and calls it. 2897 2898 This function takes a Graph object and a NodeDef from a GraphDef, and if 2899 there's an associated statistics method, calls it and returns a result. If no 2900 function has been registered for the particular node type, it returns an empty 2901 statistics object. 2902 2903 Args: 2904 graph: A Graph object that's been set up with the node's graph. 2905 node: A NodeDef describing the operator. 2906 statistic_type: A string identifying the statistic we're interested in. 2907 Returns: 2908 An OpStats object containing information about resource usage. 2909 """ 2910 2911 try: 2912 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 2913 result = stats_func(graph, node) 2914 except LookupError: 2915 result = OpStats(statistic_type) 2916 return result 2917 2918 2919def _name_from_scope_name(name): 2920 """Returns the name of an op given the name of its scope. 2921 2922 Args: 2923 name: the name of the scope. 2924 2925 Returns: 2926 the name of the op (equal to scope name minus any trailing slash). 2927 """ 2928 return name[:-1] if (name and name[-1] == "/") else name 2929 2930 2931_MUTATION_LOCK_GROUP = 0 2932_SESSION_RUN_LOCK_GROUP = 1 2933 2934@tf_export("Graph") 2935class Graph(object): 2936 """A TensorFlow computation, represented as a dataflow graph. 2937 2938 A `Graph` contains a set of 2939 `tf.Operation` objects, 2940 which represent units of computation; and 2941 `tf.Tensor` objects, which represent 2942 the units of data that flow between operations. 2943 2944 A default `Graph` is always registered, and accessible by calling 2945 `tf.get_default_graph`. 2946 To add an operation to the default graph, simply call one of the functions 2947 that defines a new `Operation`: 2948 2949 ```python 2950 c = tf.constant(4.0) 2951 assert c.graph is tf.get_default_graph() 2952 ``` 2953 2954 Another typical usage involves the 2955 `tf.Graph.as_default` 2956 context manager, which overrides the current default graph for the 2957 lifetime of the context: 2958 2959 ```python 2960 g = tf.Graph() 2961 with g.as_default(): 2962 # Define operations and tensors in `g`. 2963 c = tf.constant(30.0) 2964 assert c.graph is g 2965 ``` 2966 2967 Important note: This class *is not* thread-safe for graph construction. All 2968 operations should be created from a single thread, or external 2969 synchronization must be provided. Unless otherwise specified, all methods 2970 are not thread-safe. 2971 2972 A `Graph` instance supports an arbitrary number of "collections" 2973 that are identified by name. For convenience when building a large 2974 graph, collections can store groups of related objects: for 2975 example, the `tf.Variable` uses a collection (named 2976 `tf.GraphKeys.GLOBAL_VARIABLES`) for 2977 all variables that are created during the construction of a graph. The caller 2978 may define additional collections by specifying a new name. 2979 """ 2980 2981 def __init__(self): 2982 """Creates a new, empty Graph.""" 2983 # Protects core state that can be returned via public accessors. 2984 # Thread-safety is provided on a best-effort basis to support buggy 2985 # programs, and is not guaranteed by the public `tf.Graph` API. 2986 # 2987 # NOTE(mrry): This does not protect the various stacks. A warning will 2988 # be reported if these are used from multiple threads 2989 self._lock = threading.RLock() 2990 # The group lock synchronizes Session.run calls with methods that create 2991 # and mutate ops (e.g. Graph.create_op()). This synchronization is 2992 # necessary because it's illegal to modify an operation after it's been run. 2993 # The group lock allows any number of threads to mutate ops at the same time 2994 # but if any modification is going on, all Session.run calls have to wait. 2995 # Similarly, if one or more Session.run calls are going on, all mutate ops 2996 # have to wait until all Session.run calls have finished. 2997 self._group_lock = lock_util.GroupLock(num_groups=2) 2998 self._nodes_by_id = dict() # GUARDED_BY(self._lock) 2999 self._next_id_counter = 0 # GUARDED_BY(self._lock) 3000 self._nodes_by_name = dict() # GUARDED_BY(self._lock) 3001 self._version = 0 # GUARDED_BY(self._lock) 3002 # Maps a name used in the graph to the next id to use for that name. 3003 self._names_in_use = {} 3004 self._stack_state_is_thread_local = False 3005 self._thread_local = threading.local() 3006 # Functions that will be applied to choose a device if none is specified. 3007 # In TF2.x or after switch_to_thread_local(), 3008 # self._thread_local._device_function_stack is used instead. 3009 self._graph_device_function_stack = traceable_stack.TraceableStack() 3010 # Default original_op applied to new ops. 3011 self._default_original_op = None 3012 # Current control flow context. It could be either CondContext or 3013 # WhileContext defined in ops/control_flow_ops.py 3014 self._control_flow_context = None 3015 # A new node will depend of the union of all of the nodes in the stack. 3016 # In TF2.x or after switch_to_thread_local(), 3017 # self._thread_local._control_dependencies_stack is used instead. 3018 self._graph_control_dependencies_stack = [] 3019 # Arbitrary collections of objects. 3020 self._collections = {} 3021 # The graph-level random seed 3022 self._seed = None 3023 # A dictionary of attributes that should be applied to all ops. 3024 self._attr_scope_map = {} 3025 # A map from op type to the kernel label that should be used. 3026 self._op_to_kernel_label_map = {} 3027 # A map from op type to an alternative op type that should be used when 3028 # computing gradients. 3029 self._gradient_override_map = {} 3030 # True if the graph is considered "finalized". In that case no 3031 # new operations can be added. 3032 self._finalized = False 3033 # Functions defined in the graph 3034 self._functions = collections.OrderedDict() 3035 # Default GraphDef versions 3036 self._graph_def_versions = versions_pb2.VersionDef( 3037 producer=versions.GRAPH_DEF_VERSION, 3038 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 3039 self._building_function = False 3040 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 3041 # self._thread_local._colocation_stack is used instead. 3042 self._graph_colocation_stack = traceable_stack.TraceableStack() 3043 # Set of tensors that are dangerous to feed! 3044 self._unfeedable_tensors = set() 3045 # Set of operations that are dangerous to fetch! 3046 self._unfetchable_ops = set() 3047 # A map of tensor handle placeholder to tensor dtype. 3048 self._handle_feeders = {} 3049 # A map from tensor handle to its read op. 3050 self._handle_readers = {} 3051 # A map from tensor handle to its move op. 3052 self._handle_movers = {} 3053 # A map from tensor handle to its delete op. 3054 self._handle_deleters = {} 3055 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 3056 # will be shared when defining function graphs, for example, so optimizers 3057 # being called inside function definitions behave as if they were seeing the 3058 # actual outside graph). 3059 self._graph_key = "grap-key-%d/" % (uid(),) 3060 self._container = "" 3061 self._registered_ops = op_def_registry.get_registered_ops() 3062 # Set to True if this graph is being built in an 3063 # AutomaticControlDependencies context. 3064 self._add_control_dependencies = False 3065 3066 # TODO(skyewm): fold as much of the above as possible into the C 3067 # implementation 3068 self._scoped_c_graph = c_api_util.ScopedTFGraph() 3069 # The C API requires all ops to have shape functions. Disable this 3070 # requirement (many custom ops do not have shape functions, and we don't 3071 # want to break these existing cases). 3072 c_api.SetRequireShapeInferenceFns(self._c_graph, False) 3073 if tf2.enabled(): 3074 self.switch_to_thread_local() 3075 3076 # Note: this method is private because the API of tf.Graph() is public and 3077 # frozen, and this functionality is still not ready for public visibility. 3078 @tf_contextlib.contextmanager 3079 def _variable_creator_scope(self, creator, priority=100): 3080 """Scope which defines a variable creation function. 3081 3082 Args: 3083 creator: A callable taking `next_creator` and `kwargs`. See the 3084 `tf.variable_creator_scope` docstring. 3085 priority: Creators with a higher `priority` are called first. Within the 3086 same priority, creators are called inner-to-outer. 3087 3088 Yields: 3089 `_variable_creator_scope` is a context manager with a side effect, but 3090 doesn't return a value. 3091 """ 3092 # This step makes a copy of the existing stack, and it also initializes 3093 # self._thread_local._variable_creator_stack if it doesn't exist yet. 3094 old = list(self._variable_creator_stack) 3095 stack = self._thread_local._variable_creator_stack # pylint: disable=protected-access 3096 stack.append((priority, creator)) 3097 # Sorting is stable, so we'll put higher-priority creators later in the list 3098 # but otherwise maintain registration order. 3099 stack.sort(key=lambda item: item[0]) 3100 try: 3101 yield 3102 finally: 3103 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 3104 3105 # Note: this method is private because the API of tf.Graph() is public and 3106 # frozen, and this functionality is still not ready for public visibility. 3107 @property 3108 def _variable_creator_stack(self): 3109 if not hasattr(self._thread_local, "_variable_creator_stack"): 3110 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 3111 return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access 3112 3113 @_variable_creator_stack.setter 3114 def _variable_creator_stack(self, variable_creator_stack): 3115 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 3116 3117 def _check_not_finalized(self): 3118 """Check if the graph is finalized. 3119 3120 Raises: 3121 RuntimeError: If the graph finalized. 3122 """ 3123 if self._finalized: 3124 raise RuntimeError("Graph is finalized and cannot be modified.") 3125 3126 def _add_op(self, op): 3127 """Adds 'op' to the graph. 3128 3129 Args: 3130 op: the Operator or Tensor to add. 3131 3132 Raises: 3133 TypeError: if op is not an Operation or Tensor. 3134 ValueError: if the op.name or op._id are already used. 3135 """ 3136 self._check_not_finalized() 3137 if not isinstance(op, (Tensor, Operation)): 3138 raise TypeError("op must be a Tensor or Operation: %s" % op) 3139 with self._lock: 3140 # pylint: disable=protected-access 3141 if op._id in self._nodes_by_id: 3142 raise ValueError("cannot add an op with id %d as it already " 3143 "exists in the graph" % op._id) 3144 if op.name in self._nodes_by_name: 3145 raise ValueError("cannot add op with name %s as that name " 3146 "is already used" % op.name) 3147 self._nodes_by_id[op._id] = op 3148 self._nodes_by_name[op.name] = op 3149 self._version = max(self._version, op._id) 3150 # pylint: enable=protected-access 3151 3152 @property 3153 def _c_graph(self): 3154 if self._scoped_c_graph: 3155 return self._scoped_c_graph.graph 3156 return None 3157 3158 @property 3159 def version(self): 3160 """Returns a version number that increases as ops are added to the graph. 3161 3162 Note that this is unrelated to the 3163 `tf.Graph.graph_def_versions`. 3164 3165 Returns: 3166 An integer version that increases as ops are added to the graph. 3167 """ 3168 if self._finalized: 3169 return self._version 3170 3171 with self._lock: 3172 return self._version 3173 3174 @property 3175 def graph_def_versions(self): 3176 # pylint: disable=line-too-long 3177 """The GraphDef version information of this graph. 3178 3179 For details on the meaning of each version, see 3180 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 3181 3182 Returns: 3183 A `VersionDef`. 3184 """ 3185 # pylint: enable=line-too-long 3186 with c_api_util.tf_buffer() as buf: 3187 c_api.TF_GraphVersions(self._c_graph, buf) 3188 data = c_api.TF_GetBuffer(buf) 3189 version_def = versions_pb2.VersionDef() 3190 version_def.ParseFromString(compat.as_bytes(data)) 3191 return version_def 3192 3193 @property 3194 def seed(self): 3195 """The graph-level random seed of this graph.""" 3196 return self._seed 3197 3198 @seed.setter 3199 def seed(self, seed): 3200 self._seed = seed 3201 3202 @property 3203 def finalized(self): 3204 """True if this graph has been finalized.""" 3205 return self._finalized 3206 3207 def finalize(self): 3208 """Finalizes this graph, making it read-only. 3209 3210 After calling `g.finalize()`, no new operations can be added to 3211 `g`. This method is used to ensure that no operations are added 3212 to a graph when it is shared between multiple threads, for example 3213 when using a `tf.train.QueueRunner`. 3214 """ 3215 self._finalized = True 3216 3217 def _unsafe_unfinalize(self): 3218 """Opposite of `finalize`. Internal interface. 3219 3220 NOTE: Unfinalizing a graph could have negative impact on performance, 3221 especially in a multi-threaded environment. Unfinalizing a graph 3222 when it is in use by a Session may lead to undefined behavior. Ensure 3223 that all sessions using a graph are closed before calling this method. 3224 """ 3225 self._finalized = False 3226 3227 def _get_control_flow_context(self): 3228 """Returns the current control flow context. 3229 3230 Returns: 3231 A context object. 3232 """ 3233 return self._control_flow_context 3234 3235 def _set_control_flow_context(self, ctx): 3236 """Sets the current control flow context. 3237 3238 Args: 3239 ctx: a context object. 3240 """ 3241 self._control_flow_context = ctx 3242 3243 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 3244 """If this graph contains functions, copy them to `graph_def`.""" 3245 bytesize = starting_bytesize 3246 for f in self._functions.values(): 3247 bytesize += f.definition.ByteSize() 3248 if bytesize >= (1 << 31) or bytesize < 0: 3249 raise ValueError("GraphDef cannot be larger than 2GB.") 3250 graph_def.library.function.extend([f.definition]) 3251 if f.grad_func_name: 3252 grad_def = function_pb2.GradientDef() 3253 grad_def.function_name = f.name 3254 grad_def.gradient_func = f.grad_func_name 3255 graph_def.library.gradient.extend([grad_def]) 3256 3257 def _as_graph_def(self, from_version=None, add_shapes=False): 3258 # pylint: disable=line-too-long 3259 """Returns a serialized `GraphDef` representation of this graph. 3260 3261 The serialized `GraphDef` can be imported into another `Graph` 3262 (using `tf.import_graph_def`) or used with the 3263 [C++ Session API](../../../../api_docs/cc/index.md). 3264 3265 This method is thread-safe. 3266 3267 Args: 3268 from_version: Optional. If this is set, returns a `GraphDef` 3269 containing only the nodes that were added to this graph since 3270 its `version` property had the given value. 3271 add_shapes: If true, adds an "_output_shapes" list attr to each 3272 node with the inferred shapes of each of its outputs. 3273 3274 Returns: 3275 A tuple containing a 3276 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3277 protocol buffer, and the version of the graph to which that 3278 `GraphDef` corresponds. 3279 3280 Raises: 3281 ValueError: If the `graph_def` would be too large. 3282 3283 """ 3284 # pylint: enable=line-too-long 3285 with self._lock: 3286 with c_api_util.tf_buffer() as buf: 3287 c_api.TF_GraphToGraphDef(self._c_graph, buf) 3288 data = c_api.TF_GetBuffer(buf) 3289 graph = graph_pb2.GraphDef() 3290 graph.ParseFromString(compat.as_bytes(data)) 3291 # Strip the experimental library field iff it's empty. 3292 if not graph.library.function: 3293 graph.ClearField("library") 3294 3295 if add_shapes: 3296 for node in graph.node: 3297 op = self._nodes_by_name[node.name] 3298 if op.outputs: 3299 node.attr["_output_shapes"].list.shape.extend( 3300 [output.get_shape().as_proto() for output in op.outputs]) 3301 return graph, self._version 3302 3303 def as_graph_def(self, from_version=None, add_shapes=False): 3304 # pylint: disable=line-too-long 3305 """Returns a serialized `GraphDef` representation of this graph. 3306 3307 The serialized `GraphDef` can be imported into another `Graph` 3308 (using `tf.import_graph_def`) or used with the 3309 [C++ Session API](../../api_docs/cc/index.md). 3310 3311 This method is thread-safe. 3312 3313 Args: 3314 from_version: Optional. If this is set, returns a `GraphDef` 3315 containing only the nodes that were added to this graph since 3316 its `version` property had the given value. 3317 add_shapes: If true, adds an "_output_shapes" list attr to each 3318 node with the inferred shapes of each of its outputs. 3319 3320 Returns: 3321 A 3322 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 3323 protocol buffer. 3324 3325 Raises: 3326 ValueError: If the `graph_def` would be too large. 3327 """ 3328 # pylint: enable=line-too-long 3329 result, _ = self._as_graph_def(from_version, add_shapes) 3330 return result 3331 3332 def _is_function(self, name): 3333 """Tests whether 'name' is registered in this graph's function library. 3334 3335 Args: 3336 name: string op name. 3337 Returns: 3338 bool indicating whether or not 'name' is registered in function library. 3339 """ 3340 return compat.as_str(name) in self._functions 3341 3342 def _get_function(self, name): 3343 """Returns the function definition for 'name'. 3344 3345 Args: 3346 name: string function name. 3347 Returns: 3348 The function def proto. 3349 """ 3350 return self._functions.get(compat.as_str(name), None) 3351 3352 def _add_function(self, function): 3353 """Adds a function to the graph. 3354 3355 After the function has been added, you can call to the function by 3356 passing the function name in place of an op name to 3357 `Graph.create_op()`. 3358 3359 Args: 3360 function: A `_DefinedFunction` object. 3361 3362 3363 Raises: 3364 ValueError: if another function is defined with the same name. 3365 """ 3366 name = function.name 3367 # Sanity checks on gradient definition. 3368 if (function.grad_func_name is not None) and (function.python_grad_func is 3369 not None): 3370 raise ValueError("Gradient defined twice for function %s" % name) 3371 3372 # Add function to graph 3373 # pylint: disable=protected-access 3374 # Handle functions created without using the C API. TODO(apassos,skyewm) 3375 # remove this when all functions are generated using the C API by default 3376 # as this will be unnecessary. 3377 if not function._c_func: 3378 serialized = function.definition.SerializeToString() 3379 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 3380 function._c_func = c_api_util.ScopedTFFunction(c_func) 3381 gradient = (function._grad_func._c_func.func if function._grad_func 3382 else None) 3383 c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) 3384 # pylint: enable=protected-access 3385 3386 self._functions[compat.as_str(name)] = function 3387 3388 # Need a new-enough consumer to support the functions we add to the graph. 3389 if self._graph_def_versions.min_consumer < 12: 3390 self._graph_def_versions.min_consumer = 12 3391 3392 @property 3393 def building_function(self): 3394 """Returns True iff this graph represents a function.""" 3395 return self._building_function 3396 3397 # Helper functions to create operations. 3398 @deprecated_args(None, 3399 "Shapes are always computed; don't use the compute_shapes " 3400 "as it has no effect.", "compute_shapes") 3401 def create_op( 3402 self, 3403 op_type, 3404 inputs, 3405 dtypes=None, # pylint: disable=redefined-outer-name 3406 input_types=None, 3407 name=None, 3408 attrs=None, 3409 op_def=None, 3410 compute_shapes=True, 3411 compute_device=True): 3412 """Creates an `Operation` in this graph. 3413 3414 This is a low-level interface for creating an `Operation`. Most 3415 programs will not call this method directly, and instead use the 3416 Python op constructors, such as `tf.constant()`, which add ops to 3417 the default graph. 3418 3419 Args: 3420 op_type: The `Operation` type to create. This corresponds to the 3421 `OpDef.name` field for the proto that defines the operation. 3422 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 3423 dtypes: (Optional) A list of `DType` objects that will be the types of the 3424 tensors that the operation produces. 3425 input_types: (Optional.) A list of `DType`s that will be the types of 3426 the tensors that the operation consumes. By default, uses the base 3427 `DType` of each input in `inputs`. Operations that expect 3428 reference-typed inputs must specify `input_types` explicitly. 3429 name: (Optional.) A string name for the operation. If not specified, a 3430 name is generated based on `op_type`. 3431 attrs: (Optional.) A dictionary where the key is the attribute name (a 3432 string) and the value is the respective `attr` attribute of the 3433 `NodeDef` proto that will represent the operation (an `AttrValue` 3434 proto). 3435 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 3436 the operation will have. 3437 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 3438 computed). 3439 compute_device: (Optional.) If True, device functions will be executed 3440 to compute the device property of the Operation. 3441 3442 Raises: 3443 TypeError: if any of the inputs is not a `Tensor`. 3444 ValueError: if colocation conflicts with existing device assignment. 3445 3446 Returns: 3447 An `Operation` object. 3448 """ 3449 del compute_shapes 3450 3451 self._check_not_finalized() 3452 for idx, a in enumerate(inputs): 3453 if not isinstance(a, Tensor): 3454 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 3455 if name is None: 3456 name = op_type 3457 # If a names ends with a '/' it is a "name scope" and we use it as-is, 3458 # after removing the trailing '/'. 3459 if name and name[-1] == "/": 3460 name = _name_from_scope_name(name) 3461 else: 3462 name = self.unique_name(name) 3463 3464 node_def = _NodeDef(op_type, name, device=None, attrs=attrs) 3465 3466 input_ops = set([t.op for t in inputs]) 3467 control_inputs = self._control_dependencies_for_inputs(input_ops) 3468 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 3469 # Session.run call cannot occur between creating and mutating the op. 3470 with self._mutation_lock(): 3471 ret = Operation( 3472 node_def, 3473 self, 3474 inputs=inputs, 3475 output_types=dtypes, 3476 control_inputs=control_inputs, 3477 input_types=input_types, 3478 original_op=self._default_original_op, 3479 op_def=op_def) 3480 self._create_op_helper(ret, compute_device=compute_device) 3481 return ret 3482 3483 def _create_op_from_tf_operation(self, c_op, compute_device=True): 3484 """Creates an `Operation` in this graph from the supplied TF_Operation. 3485 3486 This method is like create_op() except the new Operation is constructed 3487 using `c_op`. The returned Operation will have `c_op` as its _c_op 3488 field. This is used to create Operation objects around TF_Operations created 3489 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 3490 3491 This function does not call Operation._control_flow_post_processing or 3492 Graph._control_dependencies_for_inputs (since the inputs may not be 3493 available yet). The caller is responsible for calling these methods. 3494 3495 Args: 3496 c_op: a wrapped TF_Operation 3497 compute_device: (Optional.) If True, device functions will be executed 3498 to compute the device property of the Operation. 3499 3500 Returns: 3501 An `Operation` object. 3502 """ 3503 self._check_not_finalized() 3504 ret = Operation(c_op, self) 3505 # If a name_scope was created with ret.name but no nodes were created in it, 3506 # the name will still appear in _names_in_use even though the name hasn't 3507 # been used. This is ok, just leave _names_in_use as-is in this case. 3508 # TODO(skyewm): make the C API guarantee no name conflicts. 3509 name_key = ret.name.lower() 3510 if name_key not in self._names_in_use: 3511 self._names_in_use[name_key] = 1 3512 self._create_op_helper(ret, compute_device=compute_device) 3513 return ret 3514 3515 def _create_op_helper(self, op, compute_device=True): 3516 """Common logic for creating an op in this graph.""" 3517 # Apply any additional attributes requested. Do not overwrite any existing 3518 # attributes. 3519 for key, value in self._attr_scope_map.items(): 3520 try: 3521 op.get_attr(key) 3522 except ValueError: 3523 if callable(value): 3524 value = value(op.node_def) 3525 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 3526 raise TypeError( 3527 "Callable for scope map key '%s' must return either None or " 3528 "an AttrValue protocol buffer; but it returned: %s" % (key, 3529 value)) 3530 if value: 3531 op._set_attr(key, value) # pylint: disable=protected-access 3532 3533 # Apply a kernel label if one has been specified for this op type. 3534 try: 3535 kernel_label = self._op_to_kernel_label_map[op.type] 3536 op._set_attr("_kernel", # pylint: disable=protected-access 3537 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 3538 except KeyError: 3539 pass 3540 3541 # Apply the overriding op type for gradients if one has been specified for 3542 # this op type. 3543 try: 3544 mapped_op_type = self._gradient_override_map[op.type] 3545 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 3546 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 3547 except KeyError: 3548 pass 3549 3550 self._record_op_seen_by_control_dependencies(op) 3551 3552 if compute_device: 3553 self._apply_device_functions(op) 3554 3555 # Snapshot the colocation stack metadata before we might generate error 3556 # messages using it. Note that this snapshot depends on the actual stack 3557 # and is independent of the op's _class attribute. 3558 # pylint: disable=protected-access 3559 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 3560 # pylint: enable=protected-access 3561 3562 if self._colocation_stack: 3563 all_colocation_groups = [] 3564 for colocation_op in self._colocation_stack.peek_objs(): 3565 all_colocation_groups.extend(colocation_op.colocation_groups()) 3566 if colocation_op.device: 3567 # pylint: disable=protected-access 3568 op._set_device(colocation_op.device) 3569 # pylint: enable=protected-access 3570 3571 all_colocation_groups = sorted(set(all_colocation_groups)) 3572 # pylint: disable=protected-access 3573 op._set_attr("_class", attr_value_pb2.AttrValue( 3574 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 3575 # pylint: enable=protected-access 3576 3577 # Sets "container" attribute if 3578 # (1) self._container is not None 3579 # (2) "is_stateful" is set in OpDef 3580 # (3) "container" attribute is in OpDef 3581 # (4) "container" attribute is None 3582 if self._container and op.op_def.is_stateful: 3583 try: 3584 container_attr = op.get_attr("container") 3585 except ValueError: 3586 # "container" attribute is not in OpDef 3587 pass 3588 else: 3589 if not container_attr: 3590 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 3591 s=compat.as_bytes(self._container))) 3592 3593 def _add_new_tf_operations(self, compute_devices=True): 3594 """Creates `Operations` in this graph for any new TF_Operations. 3595 3596 This is useful for when TF_Operations are indirectly created by the C API 3597 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 3598 TF_FinishWhile). This ensures there are corresponding Operations for all 3599 TF_Operations in the underlying TF_Graph. 3600 3601 Args: 3602 compute_devices: (Optional.) If True, device functions will be executed 3603 to compute the device properties of each new Operation. 3604 3605 Returns: 3606 A list of the new `Operation` objects. 3607 """ 3608 # Create all Operation objects before accessing their inputs since an op may 3609 # be created before its inputs. 3610 new_ops = [ 3611 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 3612 for c_op in c_api_util.new_tf_operations(self) 3613 ] 3614 3615 # pylint: disable=protected-access 3616 for op in new_ops: 3617 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 3618 op._add_control_inputs(new_control_inputs) 3619 op._control_flow_post_processing() 3620 # pylint: enable=protected-access 3621 3622 return new_ops 3623 3624 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 3625 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 3626 3627 This function validates that `obj` represents an element of this 3628 graph, and gives an informative error message if it is not. 3629 3630 This function is the canonical way to get/validate an object of 3631 one of the allowed types from an external argument reference in the 3632 Session API. 3633 3634 This method may be called concurrently from multiple threads. 3635 3636 Args: 3637 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. 3638 Can also be any object with an `_as_graph_element()` method that returns 3639 a value of one of these types. 3640 allow_tensor: If true, `obj` may refer to a `Tensor`. 3641 allow_operation: If true, `obj` may refer to an `Operation`. 3642 3643 Returns: 3644 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 3645 3646 Raises: 3647 TypeError: If `obj` is not a type we support attempting to convert 3648 to types. 3649 ValueError: If `obj` is of an appropriate type but invalid. For 3650 example, an invalid string. 3651 KeyError: If `obj` is not an object in the graph. 3652 """ 3653 if self._finalized: 3654 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3655 3656 with self._lock: 3657 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 3658 3659 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 3660 """See `Graph.as_graph_element()` for details.""" 3661 # The vast majority of this function is figuring 3662 # out what an API user might be doing wrong, so 3663 # that we can give helpful error messages. 3664 # 3665 # Ideally, it would be nice to split it up, but we 3666 # need context to generate nice error messages. 3667 3668 if allow_tensor and allow_operation: 3669 types_str = "Tensor or Operation" 3670 elif allow_tensor: 3671 types_str = "Tensor" 3672 elif allow_operation: 3673 types_str = "Operation" 3674 else: 3675 raise ValueError("allow_tensor and allow_operation can't both be False.") 3676 3677 temp_obj = _as_graph_element(obj) 3678 if temp_obj is not None: 3679 obj = temp_obj 3680 3681 # If obj appears to be a name... 3682 if isinstance(obj, compat.bytes_or_text_types): 3683 name = compat.as_str(obj) 3684 3685 if ":" in name and allow_tensor: 3686 # Looks like a Tensor name and can be a Tensor. 3687 try: 3688 op_name, out_n = name.split(":") 3689 out_n = int(out_n) 3690 except: 3691 raise ValueError("The name %s looks a like a Tensor name, but is " 3692 "not a valid one. Tensor names must be of the " 3693 "form \"<op_name>:<output_index>\"." % repr(name)) 3694 if op_name in self._nodes_by_name: 3695 op = self._nodes_by_name[op_name] 3696 else: 3697 raise KeyError("The name %s refers to a Tensor which does not " 3698 "exist. The operation, %s, does not exist in the " 3699 "graph." % (repr(name), repr(op_name))) 3700 try: 3701 return op.outputs[out_n] 3702 except: 3703 raise KeyError("The name %s refers to a Tensor which does not " 3704 "exist. The operation, %s, exists but only has " 3705 "%s outputs." % (repr(name), repr(op_name), 3706 len(op.outputs))) 3707 3708 elif ":" in name and not allow_tensor: 3709 # Looks like a Tensor name but can't be a Tensor. 3710 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 3711 (repr(name), types_str)) 3712 3713 elif ":" not in name and allow_operation: 3714 # Looks like an Operation name and can be an Operation. 3715 if name not in self._nodes_by_name: 3716 raise KeyError("The name %s refers to an Operation not in the " 3717 "graph." % repr(name)) 3718 return self._nodes_by_name[name] 3719 3720 elif ":" not in name and not allow_operation: 3721 # Looks like an Operation name but can't be an Operation. 3722 if name in self._nodes_by_name: 3723 # Yep, it's an Operation name 3724 err_msg = ("The name %s refers to an Operation, not a %s." % 3725 (repr(name), types_str)) 3726 else: 3727 err_msg = ("The name %s looks like an (invalid) Operation name, " 3728 "not a %s." % (repr(name), types_str)) 3729 err_msg += (" Tensor names must be of the form " 3730 "\"<op_name>:<output_index>\".") 3731 raise ValueError(err_msg) 3732 3733 elif isinstance(obj, Tensor) and allow_tensor: 3734 # Actually obj is just the object it's referring to. 3735 if obj.graph is not self: 3736 raise ValueError("Tensor %s is not an element of this graph." % obj) 3737 return obj 3738 elif isinstance(obj, Operation) and allow_operation: 3739 # Actually obj is just the object it's referring to. 3740 if obj.graph is not self: 3741 raise ValueError("Operation %s is not an element of this graph." % obj) 3742 return obj 3743 else: 3744 # We give up! 3745 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__, 3746 types_str)) 3747 3748 def get_operations(self): 3749 """Return the list of operations in the graph. 3750 3751 You can modify the operations in place, but modifications 3752 to the list such as inserts/delete have no effect on the 3753 list of operations known to the graph. 3754 3755 This method may be called concurrently from multiple threads. 3756 3757 Returns: 3758 A list of Operations. 3759 """ 3760 if self._finalized: 3761 return list(self._nodes_by_id.values()) 3762 3763 with self._lock: 3764 return list(self._nodes_by_id.values()) 3765 3766 def get_operation_by_name(self, name): 3767 """Returns the `Operation` with the given `name`. 3768 3769 This method may be called concurrently from multiple threads. 3770 3771 Args: 3772 name: The name of the `Operation` to return. 3773 3774 Returns: 3775 The `Operation` with the given `name`. 3776 3777 Raises: 3778 TypeError: If `name` is not a string. 3779 KeyError: If `name` does not correspond to an operation in this graph. 3780 """ 3781 3782 if not isinstance(name, six.string_types): 3783 raise TypeError("Operation names are strings (or similar), not %s." % 3784 type(name).__name__) 3785 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 3786 3787 def _get_operation_by_name_unsafe(self, name): 3788 """Returns the `Operation` with the given `name`. 3789 3790 This is a internal unsafe version of get_operation_by_name. It skips many 3791 checks and does not have user friedly error messages but runs considerably 3792 faster. This method may be called concurrently from multiple threads. 3793 3794 Args: 3795 name: The name of the `Operation` to return. 3796 3797 Returns: 3798 The `Operation` with the given `name`. 3799 3800 Raises: 3801 KeyError: If `name` does not correspond to an operation in this graph. 3802 """ 3803 3804 if self._finalized: 3805 return self._nodes_by_name[name] 3806 3807 with self._lock: 3808 return self._nodes_by_name[name] 3809 3810 def _get_operation_by_tf_operation(self, tf_oper): 3811 op_name = c_api.TF_OperationName(tf_oper) 3812 return self._get_operation_by_name_unsafe(op_name) 3813 3814 def get_tensor_by_name(self, name): 3815 """Returns the `Tensor` with the given `name`. 3816 3817 This method may be called concurrently from multiple threads. 3818 3819 Args: 3820 name: The name of the `Tensor` to return. 3821 3822 Returns: 3823 The `Tensor` with the given `name`. 3824 3825 Raises: 3826 TypeError: If `name` is not a string. 3827 KeyError: If `name` does not correspond to a tensor in this graph. 3828 """ 3829 # Names should be strings. 3830 if not isinstance(name, six.string_types): 3831 raise TypeError("Tensor names are strings (or similar), not %s." % 3832 type(name).__name__) 3833 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 3834 3835 def _get_tensor_by_tf_output(self, tf_output): 3836 """Returns the `Tensor` representing `tf_output`. 3837 3838 Note that there is only one such `Tensor`, i.e. multiple calls to this 3839 function with the same TF_Output value will always return the same `Tensor` 3840 object. 3841 3842 Args: 3843 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 3844 3845 Returns: 3846 The `Tensor` that represents `tf_output`. 3847 """ 3848 op = self._get_operation_by_tf_operation(tf_output.oper) 3849 return op.outputs[tf_output.index] 3850 3851 def _next_id(self): 3852 """Id for next Operation instance. Also increments the internal id.""" 3853 self._check_not_finalized() 3854 with self._lock: 3855 self._next_id_counter += 1 3856 return self._next_id_counter 3857 3858 @property 3859 def _last_id(self): 3860 return self._next_id_counter 3861 3862 def _get_op_def(self, type): # pylint: disable=redefined-builtin 3863 """Returns the `OpDef` proto for `type`. `type` is a string.""" 3864 with c_api_util.tf_buffer() as buf: 3865 # pylint: disable=protected-access 3866 c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) 3867 # pylint: enable=protected-access 3868 data = c_api.TF_GetBuffer(buf) 3869 op_def = op_def_pb2.OpDef() 3870 op_def.ParseFromString(compat.as_bytes(data)) 3871 return op_def 3872 3873 def as_default(self): 3874 """Returns a context manager that makes this `Graph` the default graph. 3875 3876 This method should be used if you want to create multiple graphs 3877 in the same process. For convenience, a global default graph is 3878 provided, and all ops will be added to this graph if you do not 3879 create a new graph explicitly. 3880 3881 Use this method with the `with` keyword to specify that ops created within 3882 the scope of a block should be added to this graph. In this case, once 3883 the scope of the `with` is exited, the previous default graph is set again 3884 as default. There is a stack, so it's ok to have multiple nested levels 3885 of `as_default` calls. 3886 3887 The default graph is a property of the current thread. If you 3888 create a new thread, and wish to use the default graph in that 3889 thread, you must explicitly add a `with g.as_default():` in that 3890 thread's function. 3891 3892 The following code examples are equivalent: 3893 3894 ```python 3895 # 1. Using Graph.as_default(): 3896 g = tf.Graph() 3897 with g.as_default(): 3898 c = tf.constant(5.0) 3899 assert c.graph is g 3900 3901 # 2. Constructing and making default: 3902 with tf.Graph().as_default() as g: 3903 c = tf.constant(5.0) 3904 assert c.graph is g 3905 ``` 3906 3907 If eager execution is enabled ops created under this context manager will be 3908 added to the graph instead of executed eagerly. 3909 3910 Returns: 3911 A context manager for using this graph as the default graph. 3912 """ 3913 return _default_graph_stack.get_controller(self) 3914 3915 @property 3916 def collections(self): 3917 """Returns the names of the collections known to this graph.""" 3918 return list(self._collections) 3919 3920 def add_to_collection(self, name, value): 3921 """Stores `value` in the collection with the given `name`. 3922 3923 Note that collections are not sets, so it is possible to add a value to 3924 a collection several times. 3925 3926 Args: 3927 name: The key for the collection. The `GraphKeys` class 3928 contains many standard names for collections. 3929 value: The value to add to the collection. 3930 """ # pylint: disable=g-doc-exception 3931 self._check_not_finalized() 3932 with self._lock: 3933 if name not in self._collections: 3934 self._collections[name] = [value] 3935 else: 3936 self._collections[name].append(value) 3937 3938 def add_to_collections(self, names, value): 3939 """Stores `value` in the collections given by `names`. 3940 3941 Note that collections are not sets, so it is possible to add a value to 3942 a collection several times. This function makes sure that duplicates in 3943 `names` are ignored, but it will not check for pre-existing membership of 3944 `value` in any of the collections in `names`. 3945 3946 `names` can be any iterable, but if `names` is a string, it is treated as a 3947 single collection name. 3948 3949 Args: 3950 names: The keys for the collections to add to. The `GraphKeys` class 3951 contains many standard names for collections. 3952 value: The value to add to the collections. 3953 """ 3954 # Make sure names are unique, but treat strings as a single collection name 3955 names = (names,) if isinstance(names, six.string_types) else set(names) 3956 for name in names: 3957 self.add_to_collection(name, value) 3958 3959 def get_collection_ref(self, name): 3960 """Returns a list of values in the collection with the given `name`. 3961 3962 If the collection exists, this returns the list itself, which can 3963 be modified in place to change the collection. If the collection does 3964 not exist, it is created as an empty list and the list is returned. 3965 3966 This is different from `get_collection()` which always returns a copy of 3967 the collection list if it exists and never creates an empty collection. 3968 3969 Args: 3970 name: The key for the collection. For example, the `GraphKeys` class 3971 contains many standard names for collections. 3972 3973 Returns: 3974 The list of values in the collection with the given `name`, or an empty 3975 list if no value has been added to that collection. 3976 """ # pylint: disable=g-doc-exception 3977 with self._lock: 3978 coll_list = self._collections.get(name, None) 3979 if coll_list is None: 3980 coll_list = [] 3981 self._collections[name] = coll_list 3982 return coll_list 3983 3984 def get_collection(self, name, scope=None): 3985 """Returns a list of values in the collection with the given `name`. 3986 3987 This is different from `get_collection_ref()` which always returns the 3988 actual collection list if it exists in that it returns a new list each time 3989 it is called. 3990 3991 Args: 3992 name: The key for the collection. For example, the `GraphKeys` class 3993 contains many standard names for collections. 3994 scope: (Optional.) A string. If supplied, the resulting list is filtered 3995 to include only items whose `name` attribute matches `scope` using 3996 `re.match`. Items without a `name` attribute are never returned if a 3997 scope is supplied. The choice of `re.match` means that a `scope` without 3998 special tokens filters by prefix. 3999 4000 Returns: 4001 The list of values in the collection with the given `name`, or 4002 an empty list if no value has been added to that collection. The 4003 list contains the values in the order under which they were 4004 collected. 4005 """ # pylint: disable=g-doc-exception 4006 with self._lock: 4007 collection = self._collections.get(name, None) 4008 if collection is None: 4009 return [] 4010 if scope is None: 4011 return list(collection) 4012 else: 4013 c = [] 4014 regex = re.compile(scope) 4015 for item in collection: 4016 if hasattr(item, "name") and regex.match(item.name): 4017 c.append(item) 4018 return c 4019 4020 def get_all_collection_keys(self): 4021 """Returns a list of collections used in this graph.""" 4022 with self._lock: 4023 return [x for x in self._collections if isinstance(x, six.string_types)] 4024 4025 def clear_collection(self, name): 4026 """Clears all values in a collection. 4027 4028 Args: 4029 name: The key for the collection. The `GraphKeys` class contains many 4030 standard names for collections. 4031 """ 4032 self._check_not_finalized() 4033 with self._lock: 4034 if name in self._collections: 4035 del self._collections[name] 4036 4037 @tf_contextlib.contextmanager 4038 def _original_op(self, op): 4039 """Python 'with' handler to help annotate ops with their originator. 4040 4041 An op may have an 'original_op' property that indicates the op on which 4042 it was based. For example a replica op is based on the op that was 4043 replicated and a gradient op is based on the op that was differentiated. 4044 4045 All ops created in the scope of this 'with' handler will have 4046 the given 'op' as their original op. 4047 4048 Args: 4049 op: The Operation that all ops created in this scope will have as their 4050 original op. 4051 4052 Yields: 4053 Nothing. 4054 """ 4055 old_original_op = self._default_original_op 4056 self._default_original_op = op 4057 try: 4058 yield 4059 finally: 4060 self._default_original_op = old_original_op 4061 4062 @property 4063 def _name_stack(self): 4064 # This may be called from a thread where name_stack doesn't yet exist. 4065 if not hasattr(self._thread_local, "_name_stack"): 4066 self._thread_local._name_stack = "" 4067 return self._thread_local._name_stack 4068 4069 @_name_stack.setter 4070 def _name_stack(self, name_stack): 4071 self._thread_local._name_stack = name_stack 4072 4073 # pylint: disable=g-doc-return-or-yield,line-too-long 4074 @tf_contextlib.contextmanager 4075 def name_scope(self, name): 4076 r"""Returns a context manager that creates hierarchical names for operations. 4077 4078 A graph maintains a stack of name scopes. A `with name_scope(...):` 4079 statement pushes a new name onto the stack for the lifetime of the context. 4080 4081 The `name` argument will be interpreted as follows: 4082 4083 * A string (not ending with '/') will create a new name scope, in which 4084 `name` is appended to the prefix of all operations created in the 4085 context. If `name` has been used before, it will be made unique by 4086 calling `self.unique_name(name)`. 4087 * A scope previously captured from a `with g.name_scope(...) as 4088 scope:` statement will be treated as an "absolute" name scope, which 4089 makes it possible to re-enter existing scopes. 4090 * A value of `None` or the empty string will reset the current name scope 4091 to the top-level (empty) name scope. 4092 4093 For example: 4094 4095 ```python 4096 with tf.Graph().as_default() as g: 4097 c = tf.constant(5.0, name="c") 4098 assert c.op.name == "c" 4099 c_1 = tf.constant(6.0, name="c") 4100 assert c_1.op.name == "c_1" 4101 4102 # Creates a scope called "nested" 4103 with g.name_scope("nested") as scope: 4104 nested_c = tf.constant(10.0, name="c") 4105 assert nested_c.op.name == "nested/c" 4106 4107 # Creates a nested scope called "inner". 4108 with g.name_scope("inner"): 4109 nested_inner_c = tf.constant(20.0, name="c") 4110 assert nested_inner_c.op.name == "nested/inner/c" 4111 4112 # Create a nested scope called "inner_1". 4113 with g.name_scope("inner"): 4114 nested_inner_1_c = tf.constant(30.0, name="c") 4115 assert nested_inner_1_c.op.name == "nested/inner_1/c" 4116 4117 # Treats `scope` as an absolute name scope, and 4118 # switches to the "nested/" scope. 4119 with g.name_scope(scope): 4120 nested_d = tf.constant(40.0, name="d") 4121 assert nested_d.op.name == "nested/d" 4122 4123 with g.name_scope(""): 4124 e = tf.constant(50.0, name="e") 4125 assert e.op.name == "e" 4126 ``` 4127 4128 The name of the scope itself can be captured by `with 4129 g.name_scope(...) as scope:`, which stores the name of the scope 4130 in the variable `scope`. This value can be used to name an 4131 operation that represents the overall result of executing the ops 4132 in a scope. For example: 4133 4134 ```python 4135 inputs = tf.constant(...) 4136 with g.name_scope('my_layer') as scope: 4137 weights = tf.Variable(..., name="weights") 4138 biases = tf.Variable(..., name="biases") 4139 affine = tf.matmul(inputs, weights) + biases 4140 output = tf.nn.relu(affine, name=scope) 4141 ``` 4142 4143 NOTE: This constructor validates the given `name`. Valid scope 4144 names match one of the following regular expressions: 4145 4146 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 4147 [A-Za-z0-9_.\\-/]* (for other scopes) 4148 4149 Args: 4150 name: A name for the scope. 4151 4152 Returns: 4153 A context manager that installs `name` as a new name scope. 4154 4155 Raises: 4156 ValueError: If `name` is not a valid scope name, according to the rules 4157 above. 4158 """ 4159 if name: 4160 if isinstance(name, compat.bytes_or_text_types): 4161 name = compat.as_str(name) 4162 4163 if self._name_stack: 4164 # Scopes created in a nested scope may have initial characters 4165 # that are illegal as the initial character of an op name 4166 # (viz. '-', '\', '/', and '_'). 4167 if not _VALID_SCOPE_NAME_REGEX.match(name): 4168 raise ValueError("'%s' is not a valid scope name" % name) 4169 else: 4170 # Scopes created in the root must match the more restrictive 4171 # op name regex, which constrains the initial character. 4172 if not _VALID_OP_NAME_REGEX.match(name): 4173 raise ValueError("'%s' is not a valid scope name" % name) 4174 old_stack = self._name_stack 4175 if not name: # Both for name=None and name="" we re-set to empty scope. 4176 new_stack = None 4177 elif name[-1] == "/": 4178 new_stack = _name_from_scope_name(name) 4179 else: 4180 new_stack = self.unique_name(name) 4181 self._name_stack = new_stack 4182 try: 4183 yield "" if new_stack is None else new_stack + "/" 4184 finally: 4185 self._name_stack = old_stack 4186 4187 # pylint: enable=g-doc-return-or-yield,line-too-long 4188 4189 def unique_name(self, name, mark_as_used=True): 4190 """Return a unique operation name for `name`. 4191 4192 Note: You rarely need to call `unique_name()` directly. Most of 4193 the time you just need to create `with g.name_scope()` blocks to 4194 generate structured names. 4195 4196 `unique_name` is used to generate structured names, separated by 4197 `"/"`, to help identify operations when debugging a graph. 4198 Operation names are displayed in error messages reported by the 4199 TensorFlow runtime, and in various visualization tools such as 4200 TensorBoard. 4201 4202 If `mark_as_used` is set to `True`, which is the default, a new 4203 unique name is created and marked as in use. If it's set to `False`, 4204 the unique name is returned without actually being marked as used. 4205 This is useful when the caller simply wants to know what the name 4206 to be created will be. 4207 4208 Args: 4209 name: The name for an operation. 4210 mark_as_used: Whether to mark this name as being used. 4211 4212 Returns: 4213 A string to be passed to `create_op()` that will be used 4214 to name the operation being created. 4215 """ 4216 if self._name_stack: 4217 name = self._name_stack + "/" + name 4218 4219 # For the sake of checking for names in use, we treat names as case 4220 # insensitive (e.g. foo = Foo). 4221 name_key = name.lower() 4222 i = self._names_in_use.get(name_key, 0) 4223 # Increment the number for "name_key". 4224 if mark_as_used: 4225 self._names_in_use[name_key] = i + 1 4226 if i > 0: 4227 base_name_key = name_key 4228 # Make sure the composed name key is not already used. 4229 while name_key in self._names_in_use: 4230 name_key = "%s_%d" % (base_name_key, i) 4231 i += 1 4232 # Mark the composed name_key as used in case someone wants 4233 # to call unique_name("name_1"). 4234 if mark_as_used: 4235 self._names_in_use[name_key] = 1 4236 4237 # Return the new name with the original capitalization of the given name. 4238 name = "%s_%d" % (name, i-1) 4239 return name 4240 4241 def get_name_scope(self): 4242 """Returns the current name scope. 4243 4244 For example: 4245 4246 ```python 4247 with tf.name_scope('scope1'): 4248 with tf.name_scope('scope2'): 4249 print(tf.get_default_graph().get_name_scope()) 4250 ``` 4251 would print the string `scope1/scope2`. 4252 4253 Returns: 4254 A string representing the current name scope. 4255 """ 4256 return self._name_stack 4257 4258 @tf_contextlib.contextmanager 4259 def _colocate_with_for_gradient(self, op, gradient_uid, 4260 ignore_existing=False): 4261 with self.colocate_with(op, ignore_existing): 4262 if gradient_uid is not None and self._control_flow_context is not None: 4263 self._control_flow_context.EnterGradientColocation(op, gradient_uid) 4264 try: 4265 yield 4266 finally: 4267 self._control_flow_context.ExitGradientColocation(op, gradient_uid) 4268 else: 4269 yield 4270 4271 @tf_contextlib.contextmanager 4272 def colocate_with(self, op, ignore_existing=False): 4273 """Returns a context manager that specifies an op to colocate with. 4274 4275 Note: this function is not for public use, only for internal libraries. 4276 4277 For example: 4278 4279 ```python 4280 a = tf.Variable([1.0]) 4281 with g.colocate_with(a): 4282 b = tf.constant(1.0) 4283 c = tf.add(a, b) 4284 ``` 4285 4286 `b` and `c` will always be colocated with `a`, no matter where `a` 4287 is eventually placed. 4288 4289 **NOTE** Using a colocation scope resets any existing device constraints. 4290 4291 If `op` is `None` then `ignore_existing` must be `True` and the new 4292 scope resets all colocation and device constraints. 4293 4294 Args: 4295 op: The op to colocate all created ops with, or `None`. 4296 ignore_existing: If true, only applies colocation of this op within 4297 the context, rather than applying all colocation properties 4298 on the stack. If `op` is `None`, this value must be `True`. 4299 4300 Raises: 4301 ValueError: if op is None but ignore_existing is False. 4302 4303 Yields: 4304 A context manager that specifies the op with which to colocate 4305 newly created ops. 4306 """ 4307 if op is None and not ignore_existing: 4308 raise ValueError("Trying to reset colocation (op is None) but " 4309 "ignore_existing is not True") 4310 op = _op_to_colocate_with(op) 4311 4312 # By default, colocate_with resets the device function stack, 4313 # since colocate_with is typically used in specific internal 4314 # library functions where colocation is intended to be "stronger" 4315 # than device functions. 4316 # 4317 # In the future, a caller may specify that device_functions win 4318 # over colocation, in which case we can add support. 4319 device_fn_tmp = self._device_function_stack 4320 self._device_function_stack = traceable_stack.TraceableStack() 4321 4322 if ignore_existing: 4323 current_stack = self._colocation_stack 4324 self._colocation_stack = traceable_stack.TraceableStack() 4325 4326 if op is not None: 4327 # offset refers to the stack frame used for storing code location. 4328 # We use 4, the sum of 1 to use our caller's stack frame and 3 4329 # to jump over layers of context managers above us. 4330 self._colocation_stack.push_obj(op, offset=4) 4331 4332 try: 4333 yield 4334 finally: 4335 # Restore device function stack 4336 self._device_function_stack = device_fn_tmp 4337 if op is not None: 4338 self._colocation_stack.pop_obj() 4339 4340 # Reset the colocation stack if requested. 4341 if ignore_existing: 4342 self._colocation_stack = current_stack 4343 4344 def _add_device_to_stack(self, device_name_or_function, offset=0): 4345 """Add device to stack manually, separate from a context manager.""" 4346 total_offset = 1 + offset 4347 spec = _UserDeviceSpec(device_name_or_function) 4348 self._device_function_stack.push_obj(spec, offset=total_offset) 4349 return spec 4350 4351 @tf_contextlib.contextmanager 4352 def device(self, device_name_or_function): 4353 # pylint: disable=line-too-long 4354 """Returns a context manager that specifies the default device to use. 4355 4356 The `device_name_or_function` argument may either be a device name 4357 string, a device function, or None: 4358 4359 * If it is a device name string, all operations constructed in 4360 this context will be assigned to the device with that name, unless 4361 overridden by a nested `device()` context. 4362 * If it is a function, it will be treated as a function from 4363 Operation objects to device name strings, and invoked each time 4364 a new Operation is created. The Operation will be assigned to 4365 the device with the returned name. 4366 * If it is None, all `device()` invocations from the enclosing context 4367 will be ignored. 4368 4369 For information about the valid syntax of device name strings, see 4370 the documentation in 4371 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 4372 4373 For example: 4374 4375 ```python 4376 with g.device('/device:GPU:0'): 4377 # All operations constructed in this context will be placed 4378 # on GPU 0. 4379 with g.device(None): 4380 # All operations constructed in this context will have no 4381 # assigned device. 4382 4383 # Defines a function from `Operation` to device string. 4384 def matmul_on_gpu(n): 4385 if n.type == "MatMul": 4386 return "/device:GPU:0" 4387 else: 4388 return "/cpu:0" 4389 4390 with g.device(matmul_on_gpu): 4391 # All operations of type "MatMul" constructed in this context 4392 # will be placed on GPU 0; all other operations will be placed 4393 # on CPU 0. 4394 ``` 4395 4396 **N.B.** The device scope may be overridden by op wrappers or 4397 other library code. For example, a variable assignment op 4398 `v.assign()` must be colocated with the `tf.Variable` `v`, and 4399 incompatible device scopes will be ignored. 4400 4401 Args: 4402 device_name_or_function: The device name or function to use in 4403 the context. 4404 4405 Yields: 4406 A context manager that specifies the default device to use for newly 4407 created ops. 4408 """ 4409 self._add_device_to_stack(device_name_or_function, offset=2) 4410 try: 4411 yield 4412 finally: 4413 self._device_function_stack.pop_obj() 4414 4415 def _apply_device_functions(self, op): 4416 """Applies the current device function stack to the given operation.""" 4417 # Apply any device functions in LIFO order, so that the most recently 4418 # pushed function has the first chance to apply a device to the op. 4419 # We apply here because the result can depend on the Operation's 4420 # signature, which is computed in the Operation constructor. 4421 # pylint: disable=protected-access 4422 for device_spec in self._device_function_stack.peek_objs(): 4423 if device_spec.function is None: 4424 break 4425 op._set_device(device_spec.function(op)) 4426 op._device_code_locations = self._snapshot_device_function_stack_metadata() 4427 # pylint: enable=protected-access 4428 4429 # pylint: disable=g-doc-return-or-yield 4430 @tf_contextlib.contextmanager 4431 def container(self, container_name): 4432 """Returns a context manager that specifies the resource container to use. 4433 4434 Stateful operations, such as variables and queues, can maintain their 4435 states on devices so that they can be shared by multiple processes. 4436 A resource container is a string name under which these stateful 4437 operations are tracked. These resources can be released or cleared 4438 with `tf.Session.reset()`. 4439 4440 For example: 4441 4442 ```python 4443 with g.container('experiment0'): 4444 # All stateful Operations constructed in this context will be placed 4445 # in resource container "experiment0". 4446 v1 = tf.Variable([1.0]) 4447 v2 = tf.Variable([2.0]) 4448 with g.container("experiment1"): 4449 # All stateful Operations constructed in this context will be 4450 # placed in resource container "experiment1". 4451 v3 = tf.Variable([3.0]) 4452 q1 = tf.FIFOQueue(10, tf.float32) 4453 # All stateful Operations constructed in this context will be 4454 # be created in the "experiment0". 4455 v4 = tf.Variable([4.0]) 4456 q1 = tf.FIFOQueue(20, tf.float32) 4457 with g.container(""): 4458 # All stateful Operations constructed in this context will be 4459 # be placed in the default resource container. 4460 v5 = tf.Variable([5.0]) 4461 q3 = tf.FIFOQueue(30, tf.float32) 4462 4463 # Resets container "experiment0", after which the state of v1, v2, v4, q1 4464 # will become undefined (such as uninitialized). 4465 tf.Session.reset(target, ["experiment0"]) 4466 ``` 4467 4468 Args: 4469 container_name: container name string. 4470 4471 Returns: 4472 A context manager for defining resource containers for stateful ops, 4473 yields the container name. 4474 """ 4475 original_container = self._container 4476 self._container = container_name 4477 try: 4478 yield self._container 4479 finally: 4480 self._container = original_container 4481 4482 # pylint: enable=g-doc-return-or-yield 4483 4484 class _ControlDependenciesController(object): 4485 """Context manager for `control_dependencies()`.""" 4486 4487 def __init__(self, graph, control_inputs): 4488 """Create a new `_ControlDependenciesController`. 4489 4490 A `_ControlDependenciesController` is the context manager for 4491 `with tf.control_dependencies()` blocks. These normally nest, 4492 as described in the documentation for `control_dependencies()`. 4493 4494 The `control_inputs` argument list control dependencies that must be 4495 added to the current set of control dependencies. Because of 4496 uniquification the set can be empty even if the caller passed a list of 4497 ops. The special value `None` indicates that we want to start a new 4498 empty set of control dependencies instead of extending the current set. 4499 4500 In that case we also clear the current control flow context, which is an 4501 additional mechanism to add control dependencies. 4502 4503 Args: 4504 graph: The graph that this controller is managing. 4505 control_inputs: List of ops to use as control inputs in addition 4506 to the current control dependencies. None to indicate that 4507 the dependencies should be cleared. 4508 """ 4509 self._graph = graph 4510 if control_inputs is None: 4511 self._control_inputs_val = [] 4512 self._new_stack = True 4513 else: 4514 self._control_inputs_val = control_inputs 4515 self._new_stack = False 4516 self._seen_nodes = set() 4517 self._old_stack = None 4518 self._old_control_flow_context = None 4519 4520# pylint: disable=protected-access 4521 4522 def __enter__(self): 4523 if self._new_stack: 4524 # Clear the control_dependencies graph. 4525 self._old_stack = self._graph._control_dependencies_stack 4526 self._graph._control_dependencies_stack = [] 4527 # Clear the control_flow_context too. 4528 self._old_control_flow_context = self._graph._get_control_flow_context() 4529 self._graph._set_control_flow_context(None) 4530 self._graph._push_control_dependencies_controller(self) 4531 4532 def __exit__(self, unused_type, unused_value, unused_traceback): 4533 self._graph._pop_control_dependencies_controller(self) 4534 if self._new_stack: 4535 self._graph._control_dependencies_stack = self._old_stack 4536 self._graph._set_control_flow_context(self._old_control_flow_context) 4537 4538# pylint: enable=protected-access 4539 4540 @property 4541 def control_inputs(self): 4542 return self._control_inputs_val 4543 4544 def add_op(self, op): 4545 self._seen_nodes.add(op) 4546 4547 def op_in_group(self, op): 4548 return op in self._seen_nodes 4549 4550 def _push_control_dependencies_controller(self, controller): 4551 self._control_dependencies_stack.append(controller) 4552 4553 def _pop_control_dependencies_controller(self, controller): 4554 assert self._control_dependencies_stack[-1] is controller 4555 self._control_dependencies_stack.pop() 4556 4557 def _current_control_dependencies(self): 4558 ret = set() 4559 for controller in self._control_dependencies_stack: 4560 for op in controller.control_inputs: 4561 ret.add(op) 4562 return ret 4563 4564 def _control_dependencies_for_inputs(self, input_ops): 4565 """For an op that takes `input_ops` as inputs, compute control inputs. 4566 4567 The returned control dependencies should yield an execution that 4568 is equivalent to adding all control inputs in 4569 self._control_dependencies_stack to a newly created op. However, 4570 this function attempts to prune the returned control dependencies 4571 by observing that nodes created within the same `with 4572 control_dependencies(...):` block may have data dependencies that make 4573 the explicit approach redundant. 4574 4575 Args: 4576 input_ops: The data input ops for an op to be created. 4577 4578 Returns: 4579 A list of control inputs for the op to be created. 4580 """ 4581 ret = [] 4582 for controller in self._control_dependencies_stack: 4583 # If any of the input_ops already depends on the inputs from controller, 4584 # we say that the new op is dominated (by that input), and we therefore 4585 # do not need to add control dependencies for this controller's inputs. 4586 dominated = False 4587 for op in input_ops: 4588 if controller.op_in_group(op): 4589 dominated = True 4590 break 4591 if not dominated: 4592 # Don't add a control input if we already have a data dependency on i. 4593 # NOTE(mrry): We do not currently track transitive data dependencies, 4594 # so we may add redundant control inputs. 4595 ret.extend([c for c in controller.control_inputs if c not in input_ops]) 4596 return ret 4597 4598 def _record_op_seen_by_control_dependencies(self, op): 4599 """Record that the given op depends on all registered control dependencies. 4600 4601 Args: 4602 op: An Operation. 4603 """ 4604 for controller in self._control_dependencies_stack: 4605 controller.add_op(op) 4606 4607 def control_dependencies(self, control_inputs): 4608 """Returns a context manager that specifies control dependencies. 4609 4610 Use with the `with` keyword to specify that all operations constructed 4611 within the context should have control dependencies on 4612 `control_inputs`. For example: 4613 4614 ```python 4615 with g.control_dependencies([a, b, c]): 4616 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 4617 d = ... 4618 e = ... 4619 ``` 4620 4621 Multiple calls to `control_dependencies()` can be nested, and in 4622 that case a new `Operation` will have control dependencies on the union 4623 of `control_inputs` from all active contexts. 4624 4625 ```python 4626 with g.control_dependencies([a, b]): 4627 # Ops constructed here run after `a` and `b`. 4628 with g.control_dependencies([c, d]): 4629 # Ops constructed here run after `a`, `b`, `c`, and `d`. 4630 ``` 4631 4632 You can pass None to clear the control dependencies: 4633 4634 ```python 4635 with g.control_dependencies([a, b]): 4636 # Ops constructed here run after `a` and `b`. 4637 with g.control_dependencies(None): 4638 # Ops constructed here run normally, not waiting for either `a` or `b`. 4639 with g.control_dependencies([c, d]): 4640 # Ops constructed here run after `c` and `d`, also not waiting 4641 # for either `a` or `b`. 4642 ``` 4643 4644 *N.B.* The control dependencies context applies *only* to ops that 4645 are constructed within the context. Merely using an op or tensor 4646 in the context does not add a control dependency. The following 4647 example illustrates this point: 4648 4649 ```python 4650 # WRONG 4651 def my_func(pred, tensor): 4652 t = tf.matmul(tensor, tensor) 4653 with tf.control_dependencies([pred]): 4654 # The matmul op is created outside the context, so no control 4655 # dependency will be added. 4656 return t 4657 4658 # RIGHT 4659 def my_func(pred, tensor): 4660 with tf.control_dependencies([pred]): 4661 # The matmul op is created in the context, so a control dependency 4662 # will be added. 4663 return tf.matmul(tensor, tensor) 4664 ``` 4665 4666 Also note that though execution of ops created under this scope will trigger 4667 execution of the dependencies, the ops created under this scope might still 4668 be pruned from a normal tensorflow graph. For example, in the following 4669 snippet of code the dependencies are never executed: 4670 4671 ```python 4672 loss = model.loss() 4673 with tf.control_dependencies(dependencies): 4674 loss = loss + tf.constant(1) # note: dependencies ignored in the 4675 # backward pass 4676 return tf.gradients(loss, model.variables) 4677 ``` 4678 4679 This is because evaluating the gradient graph does not require evaluating 4680 the constant(1) op created in the forward pass. 4681 4682 Args: 4683 control_inputs: A list of `Operation` or `Tensor` objects which 4684 must be executed or computed before running the operations 4685 defined in the context. Can also be `None` to clear the control 4686 dependencies. 4687 4688 Returns: 4689 A context manager that specifies control dependencies for all 4690 operations constructed within the context. 4691 4692 Raises: 4693 TypeError: If `control_inputs` is not a list of `Operation` or 4694 `Tensor` objects. 4695 """ 4696 if control_inputs is None: 4697 return self._ControlDependenciesController(self, None) 4698 # First convert the inputs to ops, and deduplicate them. 4699 # NOTE(mrry): Other than deduplication, we do not currently track direct 4700 # or indirect dependencies between control_inputs, which may result in 4701 # redundant control inputs. 4702 control_ops = [] 4703 current = self._current_control_dependencies() 4704 for c in control_inputs: 4705 # The hasattr(handle) is designed to match ResourceVariables. This is so 4706 # control dependencies on a variable or on an unread variable don't 4707 # trigger reads. 4708 if (isinstance(c, IndexedSlices) or 4709 (hasattr(c, "_handle") and hasattr(c, "op"))): 4710 c = c.op 4711 c = self.as_graph_element(c) 4712 if isinstance(c, Tensor): 4713 c = c.op 4714 elif not isinstance(c, Operation): 4715 raise TypeError("Control input must be Operation or Tensor: %s" % c) 4716 if c not in current: 4717 control_ops.append(c) 4718 current.add(c) 4719 return self._ControlDependenciesController(self, control_ops) 4720 4721 # pylint: disable=g-doc-return-or-yield 4722 @tf_contextlib.contextmanager 4723 def _attr_scope(self, attr_map): 4724 """EXPERIMENTAL: A context manager for setting attributes on operators. 4725 4726 This context manager can be used to add additional 4727 attributes to operators within the scope of the context. 4728 4729 For example: 4730 4731 with ops.Graph().as_default() as g: 4732 f_1 = Foo() # No extra attributes 4733 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 4734 f_2 = Foo() # Additional attribute _a=False 4735 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 4736 f_3 = Foo() # Additional attribute _a=False 4737 with g._attr_scope({"_a": None}): 4738 f_4 = Foo() # No additional attributes. 4739 4740 Args: 4741 attr_map: A dictionary mapping attr name strings to 4742 AttrValue protocol buffers or None. 4743 4744 Returns: 4745 A context manager that sets the kernel label to be used for one or more 4746 ops created in that context. 4747 4748 Raises: 4749 TypeError: If attr_map is not a dictionary mapping 4750 strings to AttrValue protobufs. 4751 """ 4752 if not isinstance(attr_map, dict): 4753 raise TypeError("attr_map must be a dictionary mapping " 4754 "strings to AttrValue protocol buffers") 4755 # The saved_attrs dictionary stores any currently-set labels that 4756 # will be overridden by this context manager. 4757 saved_attrs = {} 4758 # Install the given attribute 4759 for name, attr in attr_map.items(): 4760 if not (isinstance(name, six.string_types) and 4761 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 4762 callable(attr))): 4763 raise TypeError("attr_map must be a dictionary mapping " 4764 "strings to AttrValue protocol buffers or " 4765 "callables that emit AttrValue protocol buffers") 4766 try: 4767 saved_attrs[name] = self._attr_scope_map[name] 4768 except KeyError: 4769 pass 4770 if attr is None: 4771 del self._attr_scope_map[name] 4772 else: 4773 self._attr_scope_map[name] = attr 4774 try: 4775 yield # The code within the context runs here. 4776 finally: 4777 # Remove the attributes set for this context, and restore any saved 4778 # attributes. 4779 for name, attr in attr_map.items(): 4780 try: 4781 self._attr_scope_map[name] = saved_attrs[name] 4782 except KeyError: 4783 del self._attr_scope_map[name] 4784 4785 # pylint: enable=g-doc-return-or-yield 4786 4787 # pylint: disable=g-doc-return-or-yield 4788 @tf_contextlib.contextmanager 4789 def _kernel_label_map(self, op_to_kernel_label_map): 4790 """EXPERIMENTAL: A context manager for setting kernel labels. 4791 4792 This context manager can be used to select particular 4793 implementations of kernels within the scope of the context. 4794 4795 For example: 4796 4797 with ops.Graph().as_default() as g: 4798 f_1 = Foo() # Uses the default registered kernel for the Foo op. 4799 with g.kernel_label_map({"Foo": "v_2"}): 4800 f_2 = Foo() # Uses the registered kernel with label "v_2" 4801 # for the Foo op. 4802 with g.kernel_label_map({"Foo": "v_3"}): 4803 f_3 = Foo() # Uses the registered kernel with label "v_3" 4804 # for the Foo op. 4805 with g.kernel_label_map({"Foo": ""}): 4806 f_4 = Foo() # Uses the default registered kernel 4807 # for the Foo op. 4808 4809 Args: 4810 op_to_kernel_label_map: A dictionary mapping op type strings to 4811 kernel label strings. 4812 4813 Returns: 4814 A context manager that sets the kernel label to be used for one or more 4815 ops created in that context. 4816 4817 Raises: 4818 TypeError: If op_to_kernel_label_map is not a dictionary mapping 4819 strings to strings. 4820 """ 4821 if not isinstance(op_to_kernel_label_map, dict): 4822 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4823 "strings to strings") 4824 # The saved_labels dictionary stores any currently-set labels that 4825 # will be overridden by this context manager. 4826 saved_labels = {} 4827 # Install the given label 4828 for op_type, label in op_to_kernel_label_map.items(): 4829 if not (isinstance(op_type, six.string_types) and 4830 isinstance(label, six.string_types)): 4831 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 4832 "strings to strings") 4833 try: 4834 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 4835 except KeyError: 4836 pass 4837 self._op_to_kernel_label_map[op_type] = label 4838 try: 4839 yield # The code within the context runs here. 4840 finally: 4841 # Remove the labels set for this context, and restore any saved labels. 4842 for op_type, label in op_to_kernel_label_map.items(): 4843 try: 4844 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 4845 except KeyError: 4846 del self._op_to_kernel_label_map[op_type] 4847 4848 # pylint: enable=g-doc-return-or-yield 4849 4850 # pylint: disable=g-doc-return-or-yield 4851 @tf_contextlib.contextmanager 4852 def gradient_override_map(self, op_type_map): 4853 """EXPERIMENTAL: A context manager for overriding gradient functions. 4854 4855 This context manager can be used to override the gradient function 4856 that will be used for ops within the scope of the context. 4857 4858 For example: 4859 4860 ```python 4861 @tf.RegisterGradient("CustomSquare") 4862 def _custom_square_grad(op, grad): 4863 # ... 4864 4865 with tf.Graph().as_default() as g: 4866 c = tf.constant(5.0) 4867 s_1 = tf.square(c) # Uses the default gradient for tf.square. 4868 with g.gradient_override_map({"Square": "CustomSquare"}): 4869 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 4870 # gradient of s_2. 4871 ``` 4872 4873 Args: 4874 op_type_map: A dictionary mapping op type strings to alternative op 4875 type strings. 4876 4877 Returns: 4878 A context manager that sets the alternative op type to be used for one 4879 or more ops created in that context. 4880 4881 Raises: 4882 TypeError: If `op_type_map` is not a dictionary mapping strings to 4883 strings. 4884 """ 4885 if not isinstance(op_type_map, dict): 4886 raise TypeError("op_type_map must be a dictionary mapping " 4887 "strings to strings") 4888 # The saved_mappings dictionary stores any currently-set mappings that 4889 # will be overridden by this context manager. 4890 saved_mappings = {} 4891 # Install the given label 4892 for op_type, mapped_op_type in op_type_map.items(): 4893 if not (isinstance(op_type, six.string_types) and 4894 isinstance(mapped_op_type, six.string_types)): 4895 raise TypeError("op_type_map must be a dictionary mapping " 4896 "strings to strings") 4897 try: 4898 saved_mappings[op_type] = self._gradient_override_map[op_type] 4899 except KeyError: 4900 pass 4901 self._gradient_override_map[op_type] = mapped_op_type 4902 try: 4903 yield # The code within the context runs here. 4904 finally: 4905 # Remove the labels set for this context, and restore any saved labels. 4906 for op_type, mapped_op_type in op_type_map.items(): 4907 try: 4908 self._gradient_override_map[op_type] = saved_mappings[op_type] 4909 except KeyError: 4910 del self._gradient_override_map[op_type] 4911 4912 # pylint: enable=g-doc-return-or-yield 4913 4914 def prevent_feeding(self, tensor): 4915 """Marks the given `tensor` as unfeedable in this graph.""" 4916 self._unfeedable_tensors.add(tensor) 4917 4918 def is_feedable(self, tensor): 4919 """Returns `True` if and only if `tensor` is feedable.""" 4920 return tensor not in self._unfeedable_tensors 4921 4922 def prevent_fetching(self, op): 4923 """Marks the given `op` as unfetchable in this graph.""" 4924 self._unfetchable_ops.add(op) 4925 4926 def is_fetchable(self, tensor_or_op): 4927 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 4928 if isinstance(tensor_or_op, Tensor): 4929 return tensor_or_op.op not in self._unfetchable_ops 4930 else: 4931 return tensor_or_op not in self._unfetchable_ops 4932 4933 def switch_to_thread_local(self): 4934 """Make device, colocation and dependencies stacks thread-local. 4935 4936 Device, colocation and dependencies stacks are not thread-local be default. 4937 If multiple threads access them, then the state is shared. This means that 4938 one thread may affect the behavior of another thread. 4939 4940 After this method is called, the stacks become thread-local. If multiple 4941 threads access them, then the state is not shared. Each thread uses its own 4942 value; a thread doesn't affect other threads by mutating such a stack. 4943 4944 The initial value for every thread's stack is set to the current value 4945 of the stack when `switch_to_thread_local()` was first called. 4946 """ 4947 if not self._stack_state_is_thread_local: 4948 self._stack_state_is_thread_local = True 4949 4950 @property 4951 def _device_function_stack(self): 4952 if self._stack_state_is_thread_local: 4953 # This may be called from a thread where device_function_stack doesn't yet 4954 # exist. 4955 # pylint: disable=protected-access 4956 if not hasattr(self._thread_local, "_device_function_stack"): 4957 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 4958 self._thread_local._device_function_stack = stack_copy_for_this_thread 4959 return self._thread_local._device_function_stack 4960 # pylint: enable=protected-access 4961 else: 4962 return self._graph_device_function_stack 4963 4964 @property 4965 def _device_functions_outer_to_inner(self): 4966 user_device_specs = self._device_function_stack.peek_objs() 4967 device_functions = [spec.function for spec in user_device_specs] 4968 device_functions_outer_to_inner = list(reversed(device_functions)) 4969 return device_functions_outer_to_inner 4970 4971 def _snapshot_device_function_stack_metadata(self): 4972 """Return device function stack as a list of TraceableObjects. 4973 4974 Returns: 4975 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 4976 member is a displayable name for the user's argument to Graph.device, and 4977 the filename and lineno members point to the code location where 4978 Graph.device was called directly or indirectly by the user. 4979 """ 4980 traceable_objects = self._device_function_stack.peek_traceable_objs() 4981 snapshot = [] 4982 for obj in traceable_objects: 4983 obj_copy = obj.copy_metadata() 4984 obj_copy.obj = obj.obj.display_name 4985 snapshot.append(obj_copy) 4986 return snapshot 4987 4988 @_device_function_stack.setter 4989 def _device_function_stack(self, device_function_stack): 4990 if self._stack_state_is_thread_local: 4991 # pylint: disable=protected-access 4992 self._thread_local._device_function_stack = device_function_stack 4993 # pylint: enable=protected-access 4994 else: 4995 self._graph_device_function_stack = device_function_stack 4996 4997 @property 4998 def _colocation_stack(self): 4999 """Return thread-local copy of colocation stack.""" 5000 if self._stack_state_is_thread_local: 5001 # This may be called from a thread where colocation_stack doesn't yet 5002 # exist. 5003 # pylint: disable=protected-access 5004 if not hasattr(self._thread_local, "_colocation_stack"): 5005 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 5006 self._thread_local._colocation_stack = stack_copy_for_this_thread 5007 return self._thread_local._colocation_stack 5008 # pylint: enable=protected-access 5009 else: 5010 return self._graph_colocation_stack 5011 5012 def _snapshot_colocation_stack_metadata(self): 5013 """Return colocation stack metadata as a dictionary.""" 5014 traceable_objects = self._colocation_stack.peek_traceable_objs() 5015 return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} 5016 5017 @_colocation_stack.setter 5018 def _colocation_stack(self, colocation_stack): 5019 if self._stack_state_is_thread_local: 5020 # pylint: disable=protected-access 5021 self._thread_local._colocation_stack = colocation_stack 5022 # pylint: enable=protected-access 5023 else: 5024 self._graph_colocation_stack = colocation_stack 5025 5026 @property 5027 def _control_dependencies_stack(self): 5028 if self._stack_state_is_thread_local: 5029 # This may be called from a thread where control_dependencies_stack 5030 # doesn't yet exist. 5031 if not hasattr(self._thread_local, "_control_dependencies_stack"): 5032 self._thread_local._control_dependencies_stack = ( 5033 self._graph_control_dependencies_stack[:]) 5034 return self._thread_local._control_dependencies_stack 5035 else: 5036 return self._graph_control_dependencies_stack 5037 5038 @_control_dependencies_stack.setter 5039 def _control_dependencies_stack(self, control_dependencies): 5040 if self._stack_state_is_thread_local: 5041 self._thread_local._control_dependencies_stack = control_dependencies 5042 else: 5043 self._graph_control_dependencies_stack = control_dependencies 5044 5045 @property 5046 def _distribution_strategy_stack(self): 5047 """A stack to maintain distribution strategy context for each thread.""" 5048 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 5049 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 5050 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 5051 5052 @_distribution_strategy_stack.setter 5053 def _distribution_strategy_stack(self, _distribution_strategy_stack): 5054 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 5055 _distribution_strategy_stack) 5056 5057 @property 5058 def _auto_cast_variable_read_dtype(self): 5059 """The dtype that instances of `AutoCastVariable` will be casted to. 5060 5061 This is None if `AutoCastVariables` should not be casted. 5062 5063 See `AutoCastVariable` for more information. 5064 5065 Returns: 5066 The dtype that instances of `AutoCastVariable` will be casted to. 5067 """ 5068 if not hasattr(self._thread_local, "_auto_cast_variable_read_dtype"): 5069 self._thread_local._auto_cast_variable_read_dtype = None # pylint: disable=protected-access 5070 return self._thread_local._auto_cast_variable_read_dtype # pylint: disable=protected-access 5071 5072 @_auto_cast_variable_read_dtype.setter 5073 def _auto_cast_variable_read_dtype(self, _auto_cast_variable_read_dtype): 5074 self._thread_local._auto_cast_variable_read_dtype = ( # pylint: disable=protected-access 5075 _auto_cast_variable_read_dtype) 5076 5077 @tf_contextlib.contextmanager 5078 def _enable_auto_casting_variables(self, dtype): 5079 """Context manager to automatically cast AutoCastVariables. 5080 5081 If an AutoCastVariable `var` is used under this context manager, it will be 5082 casted to `dtype` before being used. 5083 5084 See `AutoCastVariable` for more information. 5085 5086 Args: 5087 dtype: The dtype that AutoCastVariables should be casted to. 5088 5089 Yields: 5090 Nothing. 5091 """ 5092 prev_read_dtype = self._auto_cast_variable_read_dtype 5093 try: 5094 self._auto_cast_variable_read_dtype = dtype 5095 yield 5096 finally: 5097 self._auto_cast_variable_read_dtype = prev_read_dtype 5098 5099 def _mutation_lock(self): 5100 """Returns a lock to guard code that creates & mutates ops. 5101 5102 See the comment for self._group_lock for more info. 5103 """ 5104 return self._group_lock.group(_MUTATION_LOCK_GROUP) 5105 5106 def _session_run_lock(self): 5107 """Returns a lock to guard code for Session.run. 5108 5109 See the comment for self._group_lock for more info. 5110 """ 5111 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 5112 5113 5114# TODO(agarwal): currently device directives in an outer eager scope will not 5115# apply to inner graph mode code. Fix that. 5116 5117 5118@tf_export(v1=["device"]) 5119def device(device_name_or_function): 5120 """Wrapper for `Graph.device()` using the default graph. 5121 5122 See 5123 `tf.Graph.device` 5124 for more details. 5125 5126 Args: 5127 device_name_or_function: The device name or function to use in 5128 the context. 5129 5130 Returns: 5131 A context manager that specifies the default device to use for newly 5132 created ops. 5133 5134 Raises: 5135 RuntimeError: If eager execution is enabled and a function is passed in. 5136 """ 5137 if context.executing_eagerly(): 5138 # TODO(agarwal): support device functions in EAGER mode. 5139 if callable(device_name_or_function): 5140 raise RuntimeError( 5141 "tf.device does not support functions when eager execution " 5142 "is enabled.") 5143 return context.device(device_name_or_function) 5144 else: 5145 return get_default_graph().device(device_name_or_function) 5146 5147 5148@tf_export("device", v1=[]) 5149def device_v2(device_name): 5150 """Specifies the device for ops created/executed in this context. 5151 5152 `device_name` can be fully specified, as in "/job:worker/task:1/device:cpu:0", 5153 or partially specified, containing only a subset of the "/"-separated 5154 fields. Any fields which are specified override device annotations from outer 5155 scopes. For example: 5156 5157 with tf.device('/job:foo'): 5158 # ops created here have devices with /job:foo 5159 with tf.device('/job:bar/task:0/device:gpu:2'): 5160 # ops created here have the fully specified device above 5161 with tf.device('/device:gpu:1'): 5162 # ops created here have the device '/job:foo/device:gpu:1' 5163 5164 Args: 5165 device_name: The device name to use in the context. 5166 5167 Returns: 5168 A context manager that specifies the default device to use for newly 5169 created ops. 5170 5171 Raises: 5172 RuntimeError: If a function is passed in. 5173 """ 5174 if callable(device_name): 5175 raise RuntimeError("tf.device does not support functions.") 5176 if context.executing_eagerly(): 5177 return context.device(device_name) 5178 else: 5179 return get_default_graph().device(device_name) 5180 5181 5182@tf_export(v1=["container"]) 5183def container(container_name): 5184 """Wrapper for `Graph.container()` using the default graph. 5185 5186 Args: 5187 container_name: The container string to use in the context. 5188 5189 Returns: 5190 A context manager that specifies the default container to use for newly 5191 created stateful ops. 5192 """ 5193 return get_default_graph().container(container_name) 5194 5195 5196def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 5197 if context.executing_eagerly(): 5198 if op is not None: 5199 if not hasattr(op, "device"): 5200 op = internal_convert_to_tensor_or_indexed_slices(op) 5201 return device(op.device) 5202 else: 5203 return NullContextmanager() 5204 else: 5205 default_graph = get_default_graph() 5206 if isinstance(op, EagerTensor): 5207 if default_graph.building_function: 5208 return default_graph.device(op.device) 5209 else: 5210 raise ValueError("Encountered an Eager-defined Tensor during graph " 5211 "construction, but a function was not being built.") 5212 return default_graph._colocate_with_for_gradient( 5213 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 5214 5215 5216# Internal interface to colocate_with. colocate_with has been deprecated from 5217# public API. There are still a few internal uses of colocate_with. Add internal 5218# only API for those uses to avoid deprecation warning. 5219def colocate_with(op, ignore_existing=False): 5220 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 5221 5222 5223@deprecation.deprecated( 5224 date=None, 5225 instructions="Colocations handled automatically by placer.") 5226@tf_export(v1=["colocate_with"]) 5227def _colocate_with(op, ignore_existing=False): 5228 return colocate_with(op, ignore_existing) 5229 5230 5231@tf_export("control_dependencies") 5232def control_dependencies(control_inputs): 5233 """Wrapper for `Graph.control_dependencies()` using the default graph. 5234 5235 See `tf.Graph.control_dependencies` 5236 for more details. 5237 5238 When eager execution is enabled, any callable object in the `control_inputs` 5239 list will be called. 5240 5241 Args: 5242 control_inputs: A list of `Operation` or `Tensor` objects which 5243 must be executed or computed before running the operations 5244 defined in the context. Can also be `None` to clear the control 5245 dependencies. If eager execution is enabled, any callable object in the 5246 `control_inputs` list will be called. 5247 5248 Returns: 5249 A context manager that specifies control dependencies for all 5250 operations constructed within the context. 5251 """ 5252 if context.executing_eagerly(): 5253 if control_inputs: 5254 # Excute any pending callables. 5255 for control in control_inputs: 5256 if callable(control): 5257 control() 5258 return NullContextmanager() 5259 else: 5260 return get_default_graph().control_dependencies(control_inputs) 5261 5262 5263class _DefaultStack(threading.local): 5264 """A thread-local stack of objects for providing implicit defaults.""" 5265 5266 def __init__(self): 5267 super(_DefaultStack, self).__init__() 5268 self._enforce_nesting = True 5269 self.stack = [] 5270 5271 def get_default(self): 5272 return self.stack[-1] if len(self.stack) >= 1 else None 5273 5274 def reset(self): 5275 self.stack = [] 5276 5277 def is_cleared(self): 5278 return not self.stack 5279 5280 @property 5281 def enforce_nesting(self): 5282 return self._enforce_nesting 5283 5284 @enforce_nesting.setter 5285 def enforce_nesting(self, value): 5286 self._enforce_nesting = value 5287 5288 @tf_contextlib.contextmanager 5289 def get_controller(self, default): 5290 """A context manager for manipulating a default stack.""" 5291 self.stack.append(default) 5292 try: 5293 yield default 5294 finally: 5295 # stack may be empty if reset() was called 5296 if self.stack: 5297 if self._enforce_nesting: 5298 if self.stack[-1] is not default: 5299 raise AssertionError( 5300 "Nesting violated for default stack of %s objects" % 5301 type(default)) 5302 self.stack.pop() 5303 else: 5304 self.stack.remove(default) 5305 5306 5307_default_session_stack = _DefaultStack() # pylint: disable=protected-access 5308 5309 5310def default_session(session): 5311 """Python "with" handler for defining a default session. 5312 5313 This function provides a means of registering a session for handling 5314 Tensor.eval() and Operation.run() calls. It is primarily intended for use 5315 by session.Session, but can be used with any object that implements 5316 the Session.run() interface. 5317 5318 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 5319 invocations within the scope of a block should be executed by a particular 5320 session. 5321 5322 The default session applies to the current thread only, so it is always 5323 possible to inspect the call stack and determine the scope of a default 5324 session. If you create a new thread, and wish to use the default session 5325 in that thread, you must explicitly add a "with ops.default_session(sess):" 5326 block in that thread's function. 5327 5328 Example: 5329 The following code examples are equivalent: 5330 5331 # 1. Using the Session object directly: 5332 sess = ... 5333 c = tf.constant(5.0) 5334 sess.run(c) 5335 5336 # 2. Using default_session(): 5337 sess = ... 5338 with ops.default_session(sess): 5339 c = tf.constant(5.0) 5340 result = c.eval() 5341 5342 # 3. Overriding default_session(): 5343 sess = ... 5344 with ops.default_session(sess): 5345 c = tf.constant(5.0) 5346 with ops.default_session(...): 5347 c.eval(session=sess) 5348 5349 Args: 5350 session: The session to be installed as the default session. 5351 5352 Returns: 5353 A context manager for the default session. 5354 """ 5355 return _default_session_stack.get_controller(session) 5356 5357 5358@tf_export(v1=["get_default_session"]) 5359def get_default_session(): 5360 """Returns the default session for the current thread. 5361 5362 The returned `Session` will be the innermost session on which a 5363 `Session` or `Session.as_default()` context has been entered. 5364 5365 NOTE: The default session is a property of the current thread. If you 5366 create a new thread, and wish to use the default session in that 5367 thread, you must explicitly add a `with sess.as_default():` in that 5368 thread's function. 5369 5370 Returns: 5371 The default `Session` being used in the current thread. 5372 """ 5373 return _default_session_stack.get_default() 5374 5375 5376def _eval_using_default_session(tensors, feed_dict, graph, session=None): 5377 """Uses the default session to evaluate one or more tensors. 5378 5379 Args: 5380 tensors: A single Tensor, or a list of Tensor objects. 5381 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5382 numpy ndarrays, TensorProtos, or strings. 5383 graph: The graph in which the tensors are defined. 5384 session: (Optional) A different session to use to evaluate "tensors". 5385 5386 Returns: 5387 Either a single numpy ndarray if "tensors" is a single tensor; or a list 5388 of numpy ndarrays that each correspond to the respective element in 5389 "tensors". 5390 5391 Raises: 5392 ValueError: If no default session is available; the default session 5393 does not have "graph" as its graph; or if "session" is specified, 5394 and it does not have "graph" as its graph. 5395 """ 5396 if session is None: 5397 session = get_default_session() 5398 if session is None: 5399 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 5400 "session is registered. Use `with " 5401 "sess.as_default()` or pass an explicit session to " 5402 "`eval(session=sess)`") 5403 if session.graph is not graph: 5404 raise ValueError("Cannot use the default session to evaluate tensor: " 5405 "the tensor's graph is different from the session's " 5406 "graph. Pass an explicit session to " 5407 "`eval(session=sess)`.") 5408 else: 5409 if session.graph is not graph: 5410 raise ValueError("Cannot use the given session to evaluate tensor: " 5411 "the tensor's graph is different from the session's " 5412 "graph.") 5413 return session.run(tensors, feed_dict) 5414 5415 5416def _run_using_default_session(operation, feed_dict, graph, session=None): 5417 """Uses the default session to run "operation". 5418 5419 Args: 5420 operation: The Operation to be run. 5421 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 5422 numpy ndarrays, TensorProtos, or strings. 5423 graph: The graph in which "operation" is defined. 5424 session: (Optional) A different session to use to run "operation". 5425 5426 Raises: 5427 ValueError: If no default session is available; the default session 5428 does not have "graph" as its graph; or if "session" is specified, 5429 and it does not have "graph" as its graph. 5430 """ 5431 if session is None: 5432 session = get_default_session() 5433 if session is None: 5434 raise ValueError("Cannot execute operation using `run()`: No default " 5435 "session is registered. Use `with " 5436 "sess.as_default():` or pass an explicit session to " 5437 "`run(session=sess)`") 5438 if session.graph is not graph: 5439 raise ValueError("Cannot use the default session to execute operation: " 5440 "the operation's graph is different from the " 5441 "session's graph. Pass an explicit session to " 5442 "run(session=sess).") 5443 else: 5444 if session.graph is not graph: 5445 raise ValueError("Cannot use the given session to execute operation: " 5446 "the operation's graph is different from the session's " 5447 "graph.") 5448 session.run(operation, feed_dict) 5449 5450 5451class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access 5452 """A thread-local stack of objects for providing an implicit default graph.""" 5453 5454 def __init__(self): 5455 super(_DefaultGraphStack, self).__init__() 5456 self._global_default_graph = None 5457 5458 def get_default(self): 5459 """Override that returns a global default if the stack is empty.""" 5460 ret = super(_DefaultGraphStack, self).get_default() 5461 if ret is None: 5462 ret = self._GetGlobalDefaultGraph() 5463 return ret 5464 5465 def _GetGlobalDefaultGraph(self): 5466 if self._global_default_graph is None: 5467 # TODO(mrry): Perhaps log that the default graph is being used, or set 5468 # provide some other feedback to prevent confusion when a mixture of 5469 # the global default graph and an explicit graph are combined in the 5470 # same process. 5471 self._global_default_graph = Graph() 5472 return self._global_default_graph 5473 5474 def reset(self): 5475 super(_DefaultGraphStack, self).reset() 5476 self._global_default_graph = None 5477 5478 @tf_contextlib.contextmanager 5479 def get_controller(self, default): 5480 context.context().context_switches.push( 5481 default.building_function, default.as_default, 5482 default._device_function_stack) 5483 try: 5484 with super(_DefaultGraphStack, self).get_controller( 5485 default) as g, context.graph_mode(): 5486 yield g 5487 finally: 5488 # If an exception is raised here it may be hiding a related exception in 5489 # the try-block (just above). 5490 context.context().context_switches.pop() 5491 5492 5493_default_graph_stack = _DefaultGraphStack() 5494 5495 5496# pylint: disable=g-doc-return-or-yield,line-too-long 5497@tf_export("init_scope") 5498@tf_contextlib.contextmanager 5499def init_scope(): 5500 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 5501 5502 There is often a need to lift variable initialization ops out of control-flow 5503 scopes, function-building graphs, and gradient tapes. Entering an 5504 `init_scope` is a mechanism for satisfying these desiderata. In particular, 5505 entering an `init_scope` has three effects: 5506 5507 (1) All control dependencies are cleared the moment the scope is entered; 5508 this is equivalent to entering the context manager returned from 5509 `control_dependencies(None)`, which has the side-effect of exiting 5510 control-flow scopes like `tf.cond` and `tf.while_loop`. 5511 5512 (2) All operations that are created while the scope is active are lifted 5513 into the lowest context on the `context_stack` that is not building a 5514 graph function. Here, a context is defined as either a graph or an eager 5515 context. Every context switch, i.e., every installation of a graph as 5516 the default graph and every switch into eager mode, is logged in a 5517 thread-local stack called `context_switches`; the log entry for a 5518 context switch is popped from the stack when the context is exited. 5519 Entering an `init_scope` is equivalent to crawling up 5520 `context_switches`, finding the first context that is not building a 5521 graph function, and entering it. A caveat is that if graph mode is 5522 enabled but the default graph stack is empty, then entering an 5523 `init_scope` will simply install a fresh graph as the default one. 5524 5525 (3) The gradient tape is paused while the scope is active. 5526 5527 When eager execution is enabled, code inside an init_scope block runs with 5528 eager execution enabled even when defining graph functions via 5529 tf.contrib.eager.defun. For example: 5530 5531 ```python 5532 tf.enable_eager_execution() 5533 5534 @tf.contrib.eager.defun 5535 def func(): 5536 # A defun-decorated function constructs TensorFlow graphs, 5537 # it does not execute eagerly. 5538 assert not tf.executing_eagerly() 5539 with tf.init_scope(): 5540 # Initialization runs with eager execution enabled 5541 assert tf.executing_eagerly() 5542 ``` 5543 5544 Raises: 5545 RuntimeError: if graph state is incompatible with this initialization. 5546 """ 5547 # pylint: enable=g-doc-return-or-yield,line-too-long 5548 5549 if context.executing_eagerly(): 5550 # Fastpath. 5551 with tape.stop_recording(): 5552 yield 5553 else: 5554 # Retrieve the active name scope: entering an `init_scope` preserves 5555 # the name scope of the current context. 5556 default_graph = get_default_graph() 5557 scope = default_graph.get_name_scope() 5558 if scope and scope[-1] != "/": 5559 # Names that end with trailing slashes are treated by `name_scope` as 5560 # absolute. 5561 scope = scope + "/" 5562 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 5563 5564 outer_context = None 5565 if not _default_graph_stack.stack: 5566 # If the default graph stack is empty, then we cannot be building a 5567 # function. Install the global graph (which, in this case, is also the 5568 # default graph) as the outer context. 5569 if default_graph.building_function: 5570 raise RuntimeError("The global graph is building a function.") 5571 outer_context = default_graph.as_default 5572 else: 5573 # Find a context that is not building a function. 5574 for stack_entry in reversed(context.context().context_switches.stack): 5575 if not innermost_nonempty_device_stack: 5576 innermost_nonempty_device_stack = stack_entry.device_stack 5577 if not stack_entry.is_building_function: 5578 outer_context = stack_entry.enter_context_fn 5579 break 5580 5581 if outer_context is None: 5582 # As a last resort, obtain the global default graph; this graph doesn't 5583 # necessarily live on the graph stack (and hence it doesn't necessarily 5584 # live on the context stack), but it is stored in the graph stack's 5585 # encapsulating object. 5586 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 5587 5588 if outer_context is None: 5589 # Sanity check; this shouldn't be triggered. 5590 raise RuntimeError("All graphs are building functions, and no " 5591 "eager context was previously active.") 5592 5593 outer_graph = None 5594 outer_device_stack = None 5595 try: 5596 with outer_context(), name_scope(scope), control_dependencies( 5597 None), tape.stop_recording(): 5598 context_manager = NullContextmanager 5599 context_manager_input = None 5600 if not context.executing_eagerly(): 5601 # The device stack is preserved when lifting into a graph. Eager 5602 # execution doesn't implement device stacks and in particular it 5603 # doesn't support device functions, so in general it's not possible 5604 # to do the same when lifting into the eager context. 5605 outer_graph = get_default_graph() 5606 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 5607 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 5608 elif innermost_nonempty_device_stack is not None: 5609 for device_spec in innermost_nonempty_device_stack.peek_objs(): 5610 if device_spec.function is None: 5611 break 5612 if device_spec.raw_string: 5613 context_manager = context.device 5614 context_manager_input = device_spec.raw_string 5615 break 5616 # It is currently not possible to have a device function in V2, 5617 # but in V1 we are unable to apply device functions in eager mode. 5618 # This means that we will silently skip some of the entries on the 5619 # device stack in V1 + eager mode. 5620 5621 with context_manager(context_manager_input): 5622 yield 5623 finally: 5624 # If an exception is raised here it may be hiding a related exception in 5625 # try-block (just above). 5626 if outer_graph is not None: 5627 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 5628 5629 5630def executing_eagerly_outside_functions(): 5631 """Returns True if executing eagerly, even if inside a graph function.""" 5632 # Fastpath for when this is called eagerly (its not necessary to init_scope). 5633 if context.executing_eagerly(): 5634 return True 5635 5636 with init_scope(): 5637 return context.executing_eagerly() 5638 5639 5640def inside_function(): 5641 return get_default_graph().building_function 5642 5643 5644@tf_export(v1=["enable_eager_execution"]) 5645def enable_eager_execution(config=None, 5646 device_policy=None, 5647 execution_mode=None): 5648 """Enables eager execution for the lifetime of this program. 5649 5650 Eager execution provides an imperative interface to TensorFlow. With eager 5651 execution enabled, TensorFlow functions execute operations immediately (as 5652 opposed to adding to a graph to be executed later in a `tf.Session`) and 5653 return concrete values (as opposed to symbolic references to a node in a 5654 computational graph). 5655 5656 For example: 5657 5658 ```python 5659 tf.enable_eager_execution() 5660 5661 # After eager execution is enabled, operations are executed as they are 5662 # defined and Tensor objects hold concrete values, which can be accessed as 5663 # numpy.ndarray`s through the numpy() method. 5664 assert tf.multiply(6, 7).numpy() == 42 5665 ``` 5666 5667 Eager execution cannot be enabled after TensorFlow APIs have been used to 5668 create or execute graphs. It is typically recommended to invoke this function 5669 at program startup and not in a library (as most libraries should be usable 5670 both with and without eager execution). 5671 5672 Args: 5673 config: (Optional.) A `tf.ConfigProto` to use to configure the environment 5674 in which operations are executed. Note that `tf.ConfigProto` is also 5675 used to configure graph execution (via `tf.Session`) and many options 5676 within `tf.ConfigProto` are not implemented (or are irrelevant) when 5677 eager execution is enabled. 5678 device_policy: (Optional.) Policy controlling how operations requiring 5679 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 5680 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will be 5681 picked automatically. The value picked may change between TensorFlow 5682 releases. 5683 Valid values: 5684 - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the 5685 placement is not correct. 5686 - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not 5687 on the right device but logs a warning. 5688 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. 5689 Note that this may hide performance problems as there is no notification 5690 provided when operations are blocked on the tensor being copied between 5691 devices. 5692 - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 5693 int32 tensors, raising errors on the other ones. 5694 execution_mode: (Optional.) Policy controlling how operations dispatched are 5695 actually executed. When set to None, an appropriate value will be picked 5696 automatically. The value picked may change between TensorFlow releases. 5697 Valid values: 5698 - tf.contrib.eager.SYNC: executes each operation synchronously. 5699 - tf.contrib.eager.ASYNC: executes each operation asynchronously. These 5700 operations may return "non-ready" handles. 5701 5702 Raises: 5703 ValueError: If eager execution is enabled after creating/executing a 5704 TensorFlow graph, or if options provided conflict with a previous call 5705 to this function. 5706 """ 5707 if context.default_execution_mode != context.EAGER_MODE: 5708 return enable_eager_execution_internal( 5709 config=config, 5710 device_policy=device_policy, 5711 execution_mode=execution_mode, 5712 server_def=None) 5713 5714 5715@tf_export(v1=["disable_eager_execution"]) 5716def disable_eager_execution(): 5717 """Disables eager execution. 5718 5719 This function can only be called before any Graphs, Ops, or Tensors have been 5720 created. It can be used at the beginning of the program for complex migration 5721 projects from TensorFlow 1.x to 2.x. 5722 """ 5723 context.default_execution_mode = context.GRAPH_MODE 5724 c = context.context_safe() 5725 if c is not None: 5726 c._thread_local_data.is_eager = False # pylint: disable=protected-access 5727 5728 5729def enable_eager_execution_internal(config=None, 5730 device_policy=None, 5731 execution_mode=None, 5732 server_def=None): 5733 """Enables eager execution for the lifetime of this program. 5734 5735 Most of the doc string for enable_eager_execution is relevant here as well. 5736 5737 Args: 5738 config: See enable_eager_execution doc string 5739 device_policy: See enable_eager_execution doc string 5740 execution_mode: See enable_eager_execution doc string 5741 server_def: (Optional.) A tensorflow::ServerDef proto. 5742 Enables execution on remote devices. GrpcServers need to be started by 5743 creating an identical server_def to this, and setting the appropriate 5744 task_indexes, so that the servers can communicate. It will then be 5745 possible to execute operations on remote devices. 5746 5747 Raises: 5748 ValueError 5749 5750 """ 5751 if config is not None and not isinstance(config, config_pb2.ConfigProto): 5752 raise TypeError( 5753 "config must be a tf.ConfigProto, but got %s" % type(config)) 5754 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 5755 context.DEVICE_PLACEMENT_WARN, 5756 context.DEVICE_PLACEMENT_SILENT, 5757 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 5758 raise ValueError( 5759 "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" 5760 ) 5761 if execution_mode not in (None, context.SYNC, context.ASYNC): 5762 raise ValueError( 5763 "execution_mode must be one of None, tf.contrib.eager.SYNC, " 5764 "tf.contrib.eager.ASYNC") 5765 if context.default_execution_mode == context.GRAPH_MODE: 5766 graph_mode_has_been_used = ( 5767 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 5768 if graph_mode_has_been_used: 5769 raise ValueError( 5770 "tf.enable_eager_execution must be called at program startup.") 5771 context.default_execution_mode = context.EAGER_MODE 5772 # pylint: disable=protected-access 5773 if context._context is None: 5774 context._context = context.Context( 5775 config=config, 5776 device_policy=device_policy, 5777 execution_mode=execution_mode, 5778 server_def=server_def) 5779 elif ((config is not None and config is not context._context._config) or 5780 (device_policy is not None and 5781 device_policy is not context._context._device_policy) or 5782 (execution_mode is not None and 5783 execution_mode is not context._context._execution_mode)): 5784 raise ValueError("Trying to change the options of an active eager" 5785 " execution. Context config: %s, specified config:" 5786 " %s. Context device policy: %s, specified device" 5787 " policy: %s. Context execution mode: %s, " 5788 " specified execution mode %s." % 5789 (context._context._config, config, 5790 context._context._device_policy, device_policy, 5791 context._context._execution_mode, execution_mode)) 5792 else: 5793 raise ValueError( 5794 "tf.enable_eager_execution must be called at program startup.") 5795 5796 # Monkey patch to get rid of an unnecessary conditional since the context is 5797 # now initialized. 5798 context.context = context.context_safe 5799 5800 5801def eager_run(main=None, argv=None): 5802 """Runs the program with an optional main function and argv list. 5803 5804 The program will run with eager execution enabled. 5805 5806 Example: 5807 ```python 5808 import tensorflow as tf 5809 # Import subject to future changes: 5810 from tensorflow.contrib.eager.python import tfe 5811 5812 def main(_): 5813 u = tf.constant(6.0) 5814 v = tf.constant(7.0) 5815 print(u * v) 5816 5817 if __name__ == "__main__": 5818 tfe.run() 5819 ``` 5820 5821 Args: 5822 main: the main function to run. 5823 argv: the arguments to pass to it. 5824 """ 5825 enable_eager_execution() 5826 app.run(main, argv) 5827 5828 5829@tf_export(v1=["reset_default_graph"]) 5830def reset_default_graph(): 5831 """Clears the default graph stack and resets the global default graph. 5832 5833 NOTE: The default graph is a property of the current thread. This 5834 function applies only to the current thread. Calling this function while 5835 a `tf.Session` or `tf.InteractiveSession` is active will result in undefined 5836 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 5837 after calling this function will result in undefined behavior. 5838 Raises: 5839 AssertionError: If this function is called within a nested graph. 5840 """ 5841 if not _default_graph_stack.is_cleared(): 5842 raise AssertionError("Do not use tf.reset_default_graph() to clear " 5843 "nested graphs. If you need a cleared graph, " 5844 "exit the nesting and create a new graph.") 5845 _default_graph_stack.reset() 5846 5847 5848@tf_export(v1=["get_default_graph"]) 5849def get_default_graph(): 5850 """Returns the default graph for the current thread. 5851 5852 The returned graph will be the innermost graph on which a 5853 `Graph.as_default()` context has been entered, or a global default 5854 graph if none has been explicitly created. 5855 5856 NOTE: The default graph is a property of the current thread. If you 5857 create a new thread, and wish to use the default graph in that 5858 thread, you must explicitly add a `with g.as_default():` in that 5859 thread's function. 5860 5861 Returns: 5862 The default `Graph` being used in the current thread. 5863 """ 5864 return _default_graph_stack.get_default() 5865 5866def has_default_graph(): 5867 """Returns True if there is a default graph.""" 5868 return len(_default_graph_stack.stack) >= 1 5869 5870 5871def get_name_scope(): 5872 """Returns the current name scope in the default_graph. 5873 5874 For example: 5875 5876 ```python 5877 with tf.name_scope('scope1'): 5878 with tf.name_scope('scope2'): 5879 print(tf.get_name_scope()) 5880 ``` 5881 would print the string `scope1/scope2`. 5882 5883 Returns: 5884 A string representing the current name scope. 5885 """ 5886 if context.executing_eagerly(): 5887 return context.context().scope_name.rstrip("/") 5888 return get_default_graph().get_name_scope() 5889 5890 5891def _assert_same_graph(original_item, item): 5892 """Fail if the 2 items are from different graphs. 5893 5894 Args: 5895 original_item: Original item to check against. 5896 item: Item to check. 5897 5898 Raises: 5899 ValueError: if graphs do not match. 5900 """ 5901 if original_item.graph is not item.graph: 5902 raise ValueError("%s must be from the same graph as %s." % (item, 5903 original_item)) 5904 5905 5906def _get_graph_from_inputs(op_input_list, graph=None): 5907 """Returns the appropriate graph to use for the given inputs. 5908 5909 This library method provides a consistent algorithm for choosing the graph 5910 in which an Operation should be constructed: 5911 5912 1. If the default graph is being used to construct a function, we 5913 use the default graph. 5914 2. If the "graph" is specified explicitly, we validate that all of the inputs 5915 in "op_input_list" are compatible with that graph. 5916 3. Otherwise, we attempt to select a graph from the first Operation- 5917 or Tensor-valued input in "op_input_list", and validate that all other 5918 such inputs are in the same graph. 5919 4. If the graph was not specified and it could not be inferred from 5920 "op_input_list", we attempt to use the default graph. 5921 5922 Args: 5923 op_input_list: A list of inputs to an operation, which may include `Tensor`, 5924 `Operation`, and other objects that may be converted to a graph element. 5925 graph: (Optional) The explicit graph to use. 5926 5927 Raises: 5928 TypeError: If op_input_list is not a list or tuple, or if graph is not a 5929 Graph. 5930 ValueError: If a graph is explicitly passed and not all inputs are from it, 5931 or if the inputs are from multiple graphs, or we could not find a graph 5932 and there was no default graph. 5933 5934 Returns: 5935 The appropriate graph to use for the given inputs. 5936 5937 """ 5938 if get_default_graph().building_function: 5939 return get_default_graph() 5940 5941 op_input_list = tuple(op_input_list) # Handle generators correctly 5942 if graph and not isinstance(graph, Graph): 5943 raise TypeError("Input graph needs to be a Graph: %s" % graph) 5944 5945 # 1. We validate that all of the inputs are from the same graph. This is 5946 # either the supplied graph parameter, or the first one selected from one 5947 # the graph-element-valued inputs. In the latter case, we hold onto 5948 # that input in original_graph_element so we can provide a more 5949 # informative error if a mismatch is found. 5950 original_graph_element = None 5951 for op_input in op_input_list: 5952 # Determine if this is a valid graph_element. 5953 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 5954 # up. 5955 graph_element = None 5956 if (isinstance(op_input, (Operation, _TensorLike)) and 5957 ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)): # pylint: disable=unidiomatic-typecheck 5958 graph_element = op_input 5959 else: 5960 graph_element = _as_graph_element(op_input) 5961 5962 if graph_element is not None: 5963 if not graph: 5964 original_graph_element = graph_element 5965 graph = graph_element.graph 5966 elif original_graph_element is not None: 5967 _assert_same_graph(original_graph_element, graph_element) 5968 elif graph_element.graph is not graph: 5969 raise ValueError("%s is not from the passed-in graph." % graph_element) 5970 5971 # 2. If all else fails, we use the default graph, which is always there. 5972 return graph or get_default_graph() 5973 5974 5975@tf_export(v1=["GraphKeys"]) 5976class GraphKeys(object): 5977 """Standard names to use for graph collections. 5978 5979 The standard library uses various well-known names to collect and 5980 retrieve values associated with a graph. For example, the 5981 `tf.Optimizer` subclasses default to optimizing the variables 5982 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 5983 specified, but it is also possible to pass an explicit list of 5984 variables. 5985 5986 The following standard keys are defined: 5987 5988 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 5989 across distributed environment (model variables are subset of these). See 5990 `tf.global_variables` 5991 for more details. 5992 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 5993 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 5994 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 5995 machine. Usually used for temporarily variables, like counters. 5996 Note: use `tf.contrib.framework.local_variable` to add to this collection. 5997 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 5998 model for inference (feed forward). Note: use 5999 `tf.contrib.framework.model_variable` to add to this collection. 6000 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 6001 be trained by an optimizer. See 6002 `tf.trainable_variables` 6003 for more details. 6004 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 6005 graph. See 6006 `tf.summary.merge_all` 6007 for more details. 6008 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 6009 produce input for a computation. See 6010 `tf.train.start_queue_runners` 6011 for more details. 6012 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 6013 keep moving averages. See 6014 `tf.moving_average_variables` 6015 for more details. 6016 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 6017 construction. 6018 6019 The following standard keys are _defined_, but their collections are **not** 6020 automatically populated as many of the others are: 6021 6022 * `WEIGHTS` 6023 * `BIASES` 6024 * `ACTIVATIONS` 6025 """ 6026 6027 # Key to collect Variable objects that are global (shared across machines). 6028 # Default collection for all variables, except local ones. 6029 GLOBAL_VARIABLES = "variables" 6030 # Key to collect local variables that are local to the machine and are not 6031 # saved/restored. 6032 LOCAL_VARIABLES = "local_variables" 6033 # Key to collect local variables which are used to accumulate interal state 6034 # to be used in tf.metrics.*. 6035 METRIC_VARIABLES = "metric_variables" 6036 # Key to collect model variables defined by layers. 6037 MODEL_VARIABLES = "model_variables" 6038 # Key to collect Variable objects that will be trained by the 6039 # optimizers. 6040 TRAINABLE_VARIABLES = "trainable_variables" 6041 # Key to collect summaries. 6042 SUMMARIES = "summaries" 6043 # Key to collect QueueRunners. 6044 QUEUE_RUNNERS = "queue_runners" 6045 # Key to collect table initializers. 6046 TABLE_INITIALIZERS = "table_initializer" 6047 # Key to collect asset filepaths. An asset represents an external resource 6048 # like a vocabulary file. 6049 ASSET_FILEPATHS = "asset_filepaths" 6050 # Key to collect Variable objects that keep moving averages. 6051 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 6052 # Key to collect regularization losses at graph construction. 6053 REGULARIZATION_LOSSES = "regularization_losses" 6054 # Key to collect concatenated sharded variables. 6055 CONCATENATED_VARIABLES = "concatenated_variables" 6056 # Key to collect savers. 6057 SAVERS = "savers" 6058 # Key to collect weights 6059 WEIGHTS = "weights" 6060 # Key to collect biases 6061 BIASES = "biases" 6062 # Key to collect activations 6063 ACTIVATIONS = "activations" 6064 # Key to collect update_ops 6065 UPDATE_OPS = "update_ops" 6066 # Key to collect losses 6067 LOSSES = "losses" 6068 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 6069 SAVEABLE_OBJECTS = "saveable_objects" 6070 # Key to collect all shared resources used by the graph which need to be 6071 # initialized once per cluster. 6072 RESOURCES = "resources" 6073 # Key to collect all shared resources used in this graph which need to be 6074 # initialized once per session. 6075 LOCAL_RESOURCES = "local_resources" 6076 # Trainable resource-style variables. 6077 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 6078 6079 # Key to indicate various ops. 6080 INIT_OP = "init_op" 6081 LOCAL_INIT_OP = "local_init_op" 6082 READY_OP = "ready_op" 6083 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 6084 SUMMARY_OP = "summary_op" 6085 GLOBAL_STEP = "global_step" 6086 6087 # Used to count the number of evaluations performed during a single evaluation 6088 # run. 6089 EVAL_STEP = "eval_step" 6090 TRAIN_OP = "train_op" 6091 6092 # Key for control flow context. 6093 COND_CONTEXT = "cond_context" 6094 WHILE_CONTEXT = "while_context" 6095 6096 # Used to store v2 summary names. 6097 _SUMMARY_COLLECTION = "_SUMMARY_V2" 6098 6099 # List of all collections that keep track of variables. 6100 _VARIABLE_COLLECTIONS = [ 6101 GLOBAL_VARIABLES, 6102 LOCAL_VARIABLES, 6103 METRIC_VARIABLES, 6104 MODEL_VARIABLES, 6105 TRAINABLE_VARIABLES, 6106 MOVING_AVERAGE_VARIABLES, 6107 CONCATENATED_VARIABLES, 6108 TRAINABLE_RESOURCE_VARIABLES, 6109 ] 6110 6111 # Key for streaming model ports. 6112 # NOTE(yuanbyu): internal and experimental. 6113 _STREAMING_MODEL_PORTS = "streaming_model_ports" 6114 6115 @decorator_utils.classproperty 6116 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 6117 def VARIABLES(cls): # pylint: disable=no-self-argument 6118 return cls.GLOBAL_VARIABLES 6119 6120 6121def dismantle_graph(graph): 6122 """Cleans up reference cycles from a `Graph`. 6123 6124 Helpful for making sure the garbage collector doesn't need to run after a 6125 temporary `Graph` is no longer needed. 6126 6127 Args: 6128 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 6129 after this function runs. 6130 """ 6131 memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access 6132 6133 # Now clean up Operation<->Graph reference cycles by clearing all of the 6134 # attributes for the Graph and its ops. 6135 graph_operations = graph.get_operations() 6136 for op in graph_operations: 6137 op.__dict__ = {} 6138 graph.__dict__ = {} 6139 6140 6141@tf_export(v1=["add_to_collection"]) 6142def add_to_collection(name, value): 6143 """Wrapper for `Graph.add_to_collection()` using the default graph. 6144 6145 See `tf.Graph.add_to_collection` 6146 for more details. 6147 6148 Args: 6149 name: The key for the collection. For example, the `GraphKeys` class 6150 contains many standard names for collections. 6151 value: The value to add to the collection. 6152 6153 @compatibility(eager) 6154 Collections are only supported in eager when variables are created inside an 6155 EagerVariableStore (e.g. as part of a layer or template). 6156 @end_compatibility 6157 """ 6158 get_default_graph().add_to_collection(name, value) 6159 6160 6161@tf_export(v1=["add_to_collections"]) 6162def add_to_collections(names, value): 6163 """Wrapper for `Graph.add_to_collections()` using the default graph. 6164 6165 See `tf.Graph.add_to_collections` 6166 for more details. 6167 6168 Args: 6169 names: The key for the collections. The `GraphKeys` class 6170 contains many standard names for collections. 6171 value: The value to add to the collections. 6172 6173 @compatibility(eager) 6174 Collections are only supported in eager when variables are created inside an 6175 EagerVariableStore (e.g. as part of a layer or template). 6176 @end_compatibility 6177 """ 6178 get_default_graph().add_to_collections(names, value) 6179 6180 6181@tf_export(v1=["get_collection_ref"]) 6182def get_collection_ref(key): 6183 """Wrapper for `Graph.get_collection_ref()` using the default graph. 6184 6185 See `tf.Graph.get_collection_ref` 6186 for more details. 6187 6188 Args: 6189 key: The key for the collection. For example, the `GraphKeys` class 6190 contains many standard names for collections. 6191 6192 Returns: 6193 The list of values in the collection with the given `name`, or an empty 6194 list if no value has been added to that collection. Note that this returns 6195 the collection list itself, which can be modified in place to change the 6196 collection. 6197 6198 @compatibility(eager) 6199 Collections are not supported when eager execution is enabled. 6200 @end_compatibility 6201 """ 6202 return get_default_graph().get_collection_ref(key) 6203 6204 6205@tf_export(v1=["get_collection"]) 6206def get_collection(key, scope=None): 6207 """Wrapper for `Graph.get_collection()` using the default graph. 6208 6209 See `tf.Graph.get_collection` 6210 for more details. 6211 6212 Args: 6213 key: The key for the collection. For example, the `GraphKeys` class 6214 contains many standard names for collections. 6215 scope: (Optional.) If supplied, the resulting list is filtered to include 6216 only items whose `name` attribute matches using `re.match`. Items 6217 without a `name` attribute are never returned if a scope is supplied and 6218 the choice or `re.match` means that a `scope` without special tokens 6219 filters by prefix. 6220 6221 Returns: 6222 The list of values in the collection with the given `name`, or 6223 an empty list if no value has been added to that collection. The 6224 list contains the values in the order under which they were 6225 collected. 6226 6227 @compatibility(eager) 6228 Collections are not supported when eager execution is enabled. 6229 @end_compatibility 6230 """ 6231 return get_default_graph().get_collection(key, scope) 6232 6233 6234def get_all_collection_keys(): 6235 """Returns a list of collections used in the default graph.""" 6236 return get_default_graph().get_all_collection_keys() 6237 6238 6239name_scope_cache = {} 6240 6241 6242# Named like a function for backwards compatibility with the 6243# @tf_contextlib.contextmanager version, which was switched to a class to avoid 6244# some object creation overhead. 6245@tf_export(v1=["name_scope"]) 6246class name_scope(object): # pylint: disable=invalid-name 6247 """A context manager for use when defining a Python op. 6248 6249 This context manager validates that the given `values` are from the 6250 same graph, makes that graph the default graph, and pushes a 6251 name scope in that graph (see 6252 `tf.Graph.name_scope` 6253 for more details on that). 6254 6255 For example, to define a new Python op called `my_op`: 6256 6257 ```python 6258 def my_op(a, b, c, name=None): 6259 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 6260 a = tf.convert_to_tensor(a, name="a") 6261 b = tf.convert_to_tensor(b, name="b") 6262 c = tf.convert_to_tensor(c, name="c") 6263 # Define some computation that uses `a`, `b`, and `c`. 6264 return foo_op(..., name=scope) 6265 ``` 6266 """ 6267 6268 @property 6269 def name(self): 6270 return self._name 6271 6272 def __init__(self, name, default_name=None, values=None): 6273 """Initialize the context manager. 6274 6275 Args: 6276 name: The name argument that is passed to the op function. 6277 default_name: The default name to use if the `name` argument is `None`. 6278 values: The list of `Tensor` arguments that are passed to the op function. 6279 6280 Raises: 6281 TypeError: if `default_name` is passed in but not a string. 6282 """ 6283 if not (default_name is None or isinstance(default_name, six.string_types)): 6284 raise TypeError( 6285 "`default_name` type (%s) is not a string type. You likely meant to " 6286 "pass this into the `values` kwarg." 6287 % type(default_name)) 6288 self._name = default_name if name is None else name 6289 self._default_name = default_name 6290 self._values = values 6291 self._ctx = context.context() 6292 self._in_eager_mode = self._ctx.executing_eagerly() 6293 self._has_symbolic_input_in_eager = False 6294 if self._values and self._in_eager_mode: 6295 # The presence of a graph tensor in `self._values` overrides the context. 6296 for value in self._values: 6297 if hasattr(value, "graph"): 6298 self._has_symbolic_input_in_eager = True 6299 self._name_scope = value.graph.name_scope(self._name) 6300 6301 def __enter__(self): 6302 """Start the scope block. 6303 6304 Returns: 6305 The scope name. 6306 6307 Raises: 6308 ValueError: if neither `name` nor `default_name` is provided 6309 but `values` are. 6310 """ 6311 if self._has_symbolic_input_in_eager: 6312 return self._name_scope.__enter__() 6313 6314 if self._in_eager_mode: 6315 self._old_name = self._ctx.scope_name 6316 if not self._name: 6317 scope_name = "" 6318 else: 6319 cache_key = self._name, self._old_name, self._default_name 6320 if cache_key in name_scope_cache: 6321 self._ctx.scope_name = name_scope_cache[cache_key] 6322 return self._ctx.scope_name 6323 elif self._name[-1] == "/": 6324 # A trailing slash breaks out of nested name scopes, indicating a 6325 # fully specified scope name, for compatibility with Graph.name_scope. 6326 scope_name = self._name 6327 else: 6328 name_with_trailing_slash = self._name + "/" 6329 scope_name = ( 6330 self._old_name + name_with_trailing_slash 6331 if self._old_name else name_with_trailing_slash) 6332 name_scope_cache[cache_key] = scope_name 6333 self._ctx.scope_name = scope_name 6334 return scope_name 6335 else: 6336 if self._name is None and self._values is not None: 6337 # We only raise an error if values is not None (provided) because 6338 # currently tf.name_scope(None) (values=None then) is sometimes used as 6339 # an idiom to reset to top scope. 6340 raise ValueError( 6341 "At least one of name (%s) and default_name (%s) must be provided." 6342 % (self._name, self._default_name)) 6343 if self._values is None: 6344 self._values = [] 6345 g = _get_graph_from_inputs(self._values) 6346 self._g_manager = g.as_default() 6347 self._g_manager.__enter__() 6348 try: 6349 self._name_scope = g.name_scope(self._name) 6350 return self._name_scope.__enter__() 6351 except: 6352 self._g_manager.__exit__(*sys.exc_info()) 6353 raise 6354 6355 def __exit__(self, type_arg, value_arg, traceback_arg): 6356 if self._has_symbolic_input_in_eager: 6357 self._name_scope.__exit__(type_arg, value_arg, traceback_arg) 6358 elif self._in_eager_mode: 6359 self._ctx.scope_name = self._old_name 6360 else: 6361 self._name_scope.__exit__(type_arg, value_arg, traceback_arg) 6362 self._g_manager.__exit__(type_arg, value_arg, traceback_arg) 6363 return False # False values do not suppress exceptions 6364 6365 6366@tf_export("name_scope", v1=[]) 6367class name_scope_v2(name_scope): 6368 """A context manager for use when defining a Python op. 6369 6370 This context manager pushes a name scope, which will make the name of all 6371 operations added within it have a prefix. 6372 6373 For example, to define a new Python op called `my_op`: 6374 6375 ```python 6376 def my_op(a, b, c, name=None): 6377 with tf.name_scope("MyOp") as scope: 6378 a = tf.convert_to_tensor(a, name="a") 6379 b = tf.convert_to_tensor(b, name="b") 6380 c = tf.convert_to_tensor(c, name="c") 6381 # Define some computation that uses `a`, `b`, and `c`. 6382 return foo_op(..., name=scope) 6383 ``` 6384 6385 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 6386 and `MyOp/c`. 6387 6388 If the scope name already exists, the name will be made unique by appending 6389 `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`, 6390 etc. 6391 """ 6392 6393 def __init__(self, name): 6394 """Initialize the context manager. 6395 6396 Args: 6397 name: The prefix to use on all names created within the name scope. 6398 6399 Raises: 6400 ValueError: If name is None, or not a string. 6401 """ 6402 if name is None or not isinstance(name, six.string_types): 6403 raise ValueError("name for name_scope must be a string.") 6404 super(name_scope_v2, self).__init__(name=None, default_name=name) 6405 6406 6407def strip_name_scope(name, export_scope): 6408 """Removes name scope from a name. 6409 6410 Args: 6411 name: A `string` name. 6412 export_scope: Optional `string`. Name scope to remove. 6413 6414 Returns: 6415 Name with name scope removed, or the original name if export_scope 6416 is None. 6417 """ 6418 if export_scope: 6419 if export_scope[-1] == "/": 6420 export_scope = export_scope[:-1] 6421 6422 try: 6423 # Strips export_scope/, export_scope///, 6424 # ^export_scope/, loc:@export_scope/. 6425 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 6426 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 6427 except TypeError as e: 6428 # If the name is not of a type we can process, simply return it. 6429 logging.warning(e) 6430 return name 6431 else: 6432 return name 6433 6434 6435def prepend_name_scope(name, import_scope): 6436 """Prepends name scope to a name. 6437 6438 Args: 6439 name: A `string` name. 6440 import_scope: Optional `string`. Name scope to add. 6441 6442 Returns: 6443 Name with name scope added, or the original name if import_scope 6444 is None. 6445 """ 6446 if import_scope: 6447 if import_scope[-1] == "/": 6448 import_scope = import_scope[:-1] 6449 6450 try: 6451 str_to_replace = r"([\^]|loc:@|^)(.*)" 6452 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 6453 compat.as_str(name)) 6454 except TypeError as e: 6455 # If the name is not of a type we can process, simply return it. 6456 logging.warning(e) 6457 return name 6458 else: 6459 return name 6460 6461 6462# pylint: disable=g-doc-return-or-yield 6463# pylint: disable=not-context-manager 6464@tf_export(v1=["op_scope"]) 6465@tf_contextlib.contextmanager 6466def op_scope(values, name, default_name=None): 6467 """DEPRECATED. Same as name_scope above, just different argument order.""" 6468 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 6469 " use tf.name_scope(name, default_name, values)") 6470 with name_scope(name, default_name=default_name, values=values) as scope: 6471 yield scope 6472 6473 6474_proto_function_registry = registry.Registry("proto functions") 6475 6476 6477def register_proto_function(collection_name, 6478 proto_type=None, 6479 to_proto=None, 6480 from_proto=None): 6481 """Registers `to_proto` and `from_proto` functions for collection_name. 6482 6483 `to_proto` function converts a Python object to the corresponding protocol 6484 buffer, and returns the protocol buffer. 6485 6486 `from_proto` function converts protocol buffer into a Python object, and 6487 returns the object.. 6488 6489 Args: 6490 collection_name: Name of the collection. 6491 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 6492 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 6493 to_proto: Function that implements Python object to protobuf conversion. 6494 from_proto: Function that implements protobuf to Python object conversion. 6495 """ 6496 if to_proto and not callable(to_proto): 6497 raise TypeError("to_proto must be callable.") 6498 if from_proto and not callable(from_proto): 6499 raise TypeError("from_proto must be callable.") 6500 6501 _proto_function_registry.register((proto_type, to_proto, from_proto), 6502 collection_name) 6503 6504 6505def get_collection_proto_type(collection_name): 6506 """Returns the proto_type for collection_name.""" 6507 try: 6508 return _proto_function_registry.lookup(collection_name)[0] 6509 except LookupError: 6510 return None 6511 6512 6513def get_to_proto_function(collection_name): 6514 """Returns the to_proto function for collection_name.""" 6515 try: 6516 return _proto_function_registry.lookup(collection_name)[1] 6517 except LookupError: 6518 return None 6519 6520 6521def get_from_proto_function(collection_name): 6522 """Returns the from_proto function for collection_name.""" 6523 try: 6524 return _proto_function_registry.lookup(collection_name)[2] 6525 except LookupError: 6526 return None 6527 6528 6529def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): 6530 """Produce a nice error if someone converts an Operation to a Tensor.""" 6531 raise TypeError(("Can't convert Operation '%s' to Tensor " 6532 "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype, 6533 name, as_ref)) 6534 6535 6536def _op_to_colocate_with(v): 6537 """Operation object corresponding to v to use for colocation constraints.""" 6538 if v is None: 6539 return None 6540 if isinstance(v, Operation): 6541 return v 6542 # We always want to colocate with the reference op. 6543 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 6544 # 6545 # What this should be is: 6546 # if isinstance(v, ResourceVariable): 6547 # return v.handle.op 6548 # However, that would require a circular import dependency. 6549 # As of October 2018, there were attempts underway to remove 6550 # colocation constraints altogether. Assuming that will 6551 # happen soon, perhaps this hack to work around the circular 6552 # import dependency is acceptable. 6553 if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance( 6554 v.handle.op, Operation): 6555 return v.handle.op 6556 return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op 6557 6558 6559def _is_keras_symbolic_tensor(x): 6560 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph" 6561 6562 6563register_tensor_conversion_function(Operation, _operation_conversion_error) 6564