1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FuncGraph and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as py_collections 22import itertools 23import weakref 24 25import numpy as np 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import execute 30from tensorflow.python.eager import tape 31from tensorflow.python.eager.graph_only_ops import graph_placeholder 32from tensorflow.python.framework import auto_control_deps 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import type_spec 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import handle_data_util 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.ops import tensor_array_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.util import compat 47from tensorflow.python.util import memory 48from tensorflow.python.util import nest 49from tensorflow.python.util import object_identity 50from tensorflow.python.util import tf_contextlib 51from tensorflow.python.util import tf_decorator 52from tensorflow.python.util.tf_export import tf_export 53 54ALLOWLIST_COLLECTIONS = [ 55 ops.GraphKeys.GLOBAL_VARIABLES, 56 ops.GraphKeys.LOCAL_VARIABLES, 57 ops.GraphKeys.TRAINABLE_VARIABLES, 58 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 59 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 60] 61 62 63_EAGER_CONST_THRESHOLD = 128 64 65 66class UnknownArgument(object): 67 """Signifies an argument which is not currently handled.""" 68 pass 69 70 71def convert_structure_to_signature(structure, arg_names=None): 72 """Convert a potentially nested structure to a signature. 73 74 Args: 75 structure: Structure to convert, where top level collection is a list or a 76 tuple. 77 arg_names: Optional list of arguments that has equal number of elements as 78 `structure` and is used for naming corresponding TensorSpecs. 79 80 Returns: 81 Identical structure that has TensorSpec objects instead of Tensors and 82 UnknownArgument instead of any unsupported types. 83 """ 84 def encode_arg(arg, path): 85 """A representation for this argument, for converting into signatures.""" 86 if isinstance(arg, ops.Tensor): 87 user_specified_name = None 88 try: 89 user_specified_name = compat.as_str( 90 arg.op.get_attr("_user_specified_name")) 91 except ValueError: 92 pass 93 94 if path and user_specified_name and user_specified_name != path[0]: 95 # The user has explicitly named the argument differently than the name 96 # of the function argument. 97 name = user_specified_name 98 else: 99 name = "/".join(str(p) for p in path) 100 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 101 if isinstance(arg, composite_tensor.CompositeTensor): 102 # TODO(b/133606651) Do we need to inject arg_name? 103 return arg._type_spec # pylint: disable=protected-access 104 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 105 name = "/".join(str(p) for p in path) 106 return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name, 107 trainable=arg.trainable) 108 if isinstance(arg, ( 109 int, 110 float, 111 bool, 112 str, 113 type(None), 114 dtypes.DType, 115 tensor_spec.TensorSpec, 116 type_spec.TypeSpec, 117 )): 118 return arg 119 return UnknownArgument() 120 121 # We are using the flattened paths to name the TensorSpecs. We need an 122 # explicit name for them downstream. 123 flattened = nest.flatten_with_tuple_paths(structure) 124 if arg_names: 125 if len(arg_names) != len(structure): 126 raise ValueError( 127 "Passed in arg_names don't match actual signature (%s)." % arg_names) 128 # Replace all top-level names with their actual arg_names. If a path before 129 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 130 flattened = [ 131 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 132 ] 133 134 mapped = [encode_arg(arg, path) for path, arg in flattened] 135 return nest.pack_sequence_as(structure, mapped) 136 137 138@tf_export("__internal__.FuncGraph", v1=[]) 139class FuncGraph(ops.Graph): 140 """Graph representing a function body. 141 142 Attributes: 143 name: The name of the function. 144 inputs: Placeholder tensors representing the inputs to this function. The 145 tensors are in this FuncGraph. This represents "regular" inputs as well as 146 captured inputs (i.e. the values of self.captures), with the regular 147 inputs coming first. 148 outputs: Tensors that will be returned by this function. The tensors are in 149 this FuncGraph. 150 control_outputs: Operations that must be executed before the function 151 represented by this graph can be said to have been executed. 152 structured_input_signature: A tuple of (args, kwargs), which are both 153 possibly-nested python objects that were received by this function. Note 154 that these structures might contain Python `None`s. 155 structured_outputs: A possibly-nested python object which will be returned 156 by this function. The Tensors in this structure are the same as those of 157 self.outputs. Note that this structure might contain Python `None`s. 158 variables: Variables that should be watched during function execution. 159 outer_graph: The graph this function is defined in. May be another FuncGraph 160 or the global default Graph. 161 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 162 The entries are in the order they were captured. 163 control_captures: Set of external ops on which this graph has a control 164 dependency. 165 seed: The graph-level random seed. 166 capture_by_value: If True, the func graph will capture Variables by value 167 instead of reference. 168 """ 169 170 def __init__(self, name, collections=None, capture_by_value=None): 171 """Construct a new FuncGraph. 172 173 The graph will inherit its graph key, collections, seed, and distribution 174 strategy stack from the current context or graph. 175 176 Args: 177 name: the name of the function. 178 collections: a dictionary of collections this FuncGraph should start 179 with. If not specified (None), the FuncGraph will read (but not write 180 to) the outer graph's collections that are not allowlisted, and both 181 read and write to the outer graph's collections that are allowlisted. 182 The current allowlisted collections are the global variables, the 183 local variables, and the trainable variables. 184 Defaults to None. 185 capture_by_value: An optional boolean. If True, the func graph will 186 capture Variables by value instead of reference. By default inherit 187 from outer graphs, and failing that will default to False. 188 """ 189 super(FuncGraph, self).__init__() 190 191 self.name = name 192 self.inputs = [] 193 self.outputs = [] 194 self.control_outputs = [] 195 self.control_captures = object_identity.ObjectIdentitySet() 196 self.structured_input_signature = None 197 self.structured_outputs = None 198 self._weak_variables = [] 199 self._watched_variables = object_identity.ObjectIdentityWeakSet() 200 self.is_control_flow_graph = False 201 202 outer_graph = ops.get_default_graph() 203 self._weak_outer_graph = weakref.ref(outer_graph) 204 while outer_graph.building_function: 205 outer_graph = outer_graph.outer_graph 206 # If self._weak_outer_graph is deleted, we revert to the outermost Graph 207 # active when the FuncGraph was traced. This will not be a FuncGraph. 208 self._fallback_outer_graph = outer_graph 209 self._captures = py_collections.OrderedDict() 210 # If not None, records the names of output args of this function. Used to 211 # preserve the output names in the signature of a serialized+deserialized 212 # function. Private at the moment mostly because it's often out of date. 213 self._output_names = None 214 # Maps arbitrary key -> (closure, nest of placeholders), where at function 215 # call time the value of closure() will be used to feed the nest of 216 # placeholders. 217 self._deferred_captures = py_collections.OrderedDict() 218 # Inherit capture-by-value from outer graph. 219 if capture_by_value is not None: 220 self.capture_by_value = capture_by_value 221 elif self.outer_graph is not None and isinstance( 222 self.outer_graph, FuncGraph): 223 self.capture_by_value = self.outer_graph.capture_by_value 224 else: 225 self.capture_by_value = False 226 227 self._building_function = True 228 # Map from resource tensor name to last op (in program order) which uses 229 # this tensor. Used to enforce that execution order matches program order 230 # for resource tensors. 231 self._last_op_using_resource_tensor = {} 232 233 graph = self.outer_graph 234 235 if context.executing_eagerly(): 236 self.seed = context.global_seed() 237 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 238 # any None op_seed for random_op in the function, in which case we end up 239 # using function seed, which could be unintended behavior for the op. 240 self._seed_used = False 241 else: 242 self.seed = graph.seed 243 self._seed_used = False 244 # TODO(allenl): Figure out if we can remove colocation stack 245 # specialization (currently used in cond_v2), here and in the cache key. 246 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 247 248 if collections is None: 249 for collection_name in graph.get_all_collection_keys(): 250 if collection_name not in ALLOWLIST_COLLECTIONS: 251 self._collections[collection_name] = graph.get_collection( 252 collection_name) 253 for collection_name in ALLOWLIST_COLLECTIONS: 254 self._collections[collection_name] = graph.get_collection_ref( 255 collection_name) 256 else: 257 self._collections = collections 258 259 # Keep track of whether this FuncGraph is exportable to SavedModel. Use 260 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any 261 # dependent functions as unsaveable. 262 self._saveable = True 263 self._saving_errors = set() 264 265 # Keep track of callbacks to run when this graph exits default scope 266 self._scope_exit_callbacks = None 267 268 def __str__(self): 269 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 270 271 def watch_variable(self, v): 272 """Marks the variable v as accessed while building this graph.""" 273 while self is not None and isinstance(self, FuncGraph): 274 self._watched_variables.add(v) 275 self = self.outer_graph 276 277 def capture_call_time_value(self, closure, spec, key=None): 278 """Creates a placeholder which at call time has the value closure(). 279 280 Useful, for example, to respect TensorFlow context managers, which are often 281 dynamically scoped. 282 283 Args: 284 closure: function which takes no arguments, to be evaluated at function 285 call time, returning a nest of tensors compatible with `spec`. 286 spec: nest of TypeSpec for the value to capture. 287 key: optional. If not None, multiple calls to lazy_capture with the same 288 key in the same graph will return the same placeholder, and the 289 first closure will be used at function call time. 290 291 Returns: 292 Nest of placeholders which, at function call time, will be fed with the 293 result of calling closure(). 294 295 Raises: 296 ValueError: at function call time, if the return value of closure() is 297 not compatible with `spec`. 298 """ 299 if key is None: 300 key = object() 301 if key not in self._deferred_captures: 302 303 def convert_to_placeholder(s): 304 if not isinstance(s, tensor_spec.DenseSpec): 305 raise TypeError( 306 "Expected a nest of `TypeSpec` objects, found %s of type %s." % 307 (s, type(s))) 308 return array_ops.placeholder(dtype=s.dtype, shape=s.shape) 309 310 placeholder = nest.map_structure( 311 convert_to_placeholder, spec, expand_composites=True) 312 313 def wrapped_closure(): 314 ret_nest = closure() 315 nest.assert_same_structure(spec, ret_nest, expand_composites=True) 316 # This uses the tensor dtype defined in `spec` when converting values 317 # in `ret_nest` to tensors. 318 # pylint: disable=protected-access 319 y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest, 320 expand_composites=False) 321 # pylint: enable=protected-access 322 return nest.flatten(y, expand_composites=True) 323 324 self._deferred_captures[key] = (wrapped_closure, placeholder) 325 return self._deferred_captures[key][1] 326 327 def control_dependencies(self, control_inputs): 328 """Handles control dependencies. 329 330 FuncGraph wraps Graph's control_dependencies logic by first filtering out 331 any external tensors / operations and storing them in the graph's 332 control_captures member. Any consumers of this function graph must then 333 decide how to handle the control captures. 334 335 Args: 336 control_inputs: A list of `Operation` or `Tensor` objects which 337 must be executed or computed before running the operations 338 defined in the context. Can also be `None` to clear the control 339 dependencies. 340 341 Returns: 342 A context manager that specifies control dependencies for all 343 operations constructed within the context. 344 345 Raises: 346 TypeError: If `control_inputs` is not a list of `Operation` or 347 `Tensor` objects. 348 """ 349 if control_inputs is None: 350 return super(FuncGraph, self).control_dependencies(control_inputs) 351 352 filtered_control_inputs = [] 353 for c in control_inputs: 354 # Check for _UnreadVariable 355 if (isinstance(c, ops.IndexedSlices) or 356 (hasattr(c, "_handle") and hasattr(c, "op"))): 357 c = c.op 358 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 359 if graph_element is None: 360 graph_element = c 361 if graph_element is not None and getattr( 362 graph_element, "graph", None) is not self: 363 self.control_captures.add(graph_element) 364 else: 365 filtered_control_inputs.append(graph_element) 366 return super(FuncGraph, self).control_dependencies(filtered_control_inputs) 367 368 def as_default(self): 369 outer_cm = super(FuncGraph, self).as_default() 370 371 @tf_contextlib.contextmanager 372 def inner_cm(): 373 """Context manager for copying distribute.Strategy scope information.""" 374 # pylint: disable=protected-access 375 # TODO(b/112906995, nareshmodi): distribution strategy depends on 376 # inheriting this stack from the default graph even in eager mode. Maybe 377 # it should be part of the eager context? This would also allow us to 378 # remove a get_default_graph() call from the function cache lookup. 379 graph = ops.get_default_graph() 380 old_strategy_stack = self._distribution_strategy_stack 381 self._distribution_strategy_stack = list( 382 graph._distribution_strategy_stack) 383 384 # We ignore device placements from any outer scopes while tracing the 385 # function when possible, to avoid hard-coding them in the function 386 # graph. "Default" placements come from the PartitionedCallOp's placement, 387 # so that the same trace of the Python function may be placed on several 388 # different devices and saved functions may be placed on new devices when 389 # restored. 390 # However, we need to preserve the outer device stack in the following 391 # cases in non eager context: 392 # 1. device stack is callable 393 # 2. When using distribution strategy with legacy graph mode. 394 old_device_stack = self._device_function_stack 395 if (not context.executing_eagerly() and 396 (device_stack_has_callable(graph._device_function_stack) or 397 (self._distribution_strategy_stack and 398 not ops.executing_eagerly_outside_functions()))): 399 # Hard-code devices from device functions in the function body 400 self._device_function_stack = graph._device_function_stack.copy() 401 402 old_creator_stack = self._variable_creator_stack 403 self._variable_creator_stack = graph._variable_creator_stack 404 # Inherit the graph key, since this is used for matching variables in 405 # optimizers. 406 old_graph_key = self._graph_key 407 self._graph_key = graph._graph_key 408 # pylint: enable=protected-access 409 410 old_scope_exit_callbacks = self._scope_exit_callbacks 411 self._scope_exit_callbacks = [] 412 413 with outer_cm as g: 414 try: 415 yield g 416 finally: 417 try: 418 for fn in self._scope_exit_callbacks: 419 fn() 420 finally: 421 self._scope_exit_callbacks = old_scope_exit_callbacks 422 self._distribution_strategy_stack = old_strategy_stack 423 self._device_function_stack = old_device_stack 424 self._variable_creator_stack = old_creator_stack 425 self._graph_key = old_graph_key 426 return inner_cm() 427 428 @property 429 def outer_graph(self): 430 """The Graph this FuncGraph is nested in. 431 432 Functions may capture Tensors from graphs they are nested in (transitive). 433 434 Returns: 435 A Graph object. Initially set to the current default graph when the 436 FuncGraph was created. If the previous `outer_graph` was deleted because 437 the function that owns it was deleted, `outer_graph` is reset to the 438 outermost default graph active when the FuncGraph was created. This 439 FuncGraph won't have captured anything from the new `outer_graph` (and 440 likely not from the previous setting, since that would have created a 441 strong reference), but it is returned so that FuncGraphs always have a 442 parent. 443 """ 444 current = self._weak_outer_graph() 445 if current is None: 446 return self._fallback_outer_graph 447 return current 448 449 @outer_graph.setter 450 def outer_graph(self, new_outer_graph): 451 """Sets `outer_graph` to `new_outer_graph`.""" 452 self._weak_outer_graph = weakref.ref(new_outer_graph) 453 454 @property 455 def output_types(self): 456 return [t.dtype for t in self.outputs] 457 458 @property 459 def output_shapes(self): 460 return [t.shape for t in self.outputs] 461 462 @property 463 def trainable_variables(self): 464 """A sequence of trainable variables accessed by this FuncGraph. 465 466 Note that functions keep only weak references to variables. Calling the 467 function after a variable it accesses has been deleted is an error. 468 469 Returns: 470 Sequence of trainable variables for this func graph. 471 """ 472 return tuple(v for v in self.variables if v.trainable) 473 474 @property 475 def variables(self): 476 """A sequence of variables accessed by this FuncGraph. 477 478 Note that functions keep only weak references to variables. Calling the 479 function after a variable it accesses has been deleted is an error. 480 481 Returns: 482 Sequence of variables for this func graph. 483 """ 484 def deref(weak_v): 485 v = weak_v() 486 if v is None: 487 raise AssertionError( 488 "Called a function referencing variables which have been deleted. " 489 "This likely means that function-local variables were created and " 490 "not referenced elsewhere in the program. This is generally a " 491 "mistake; consider storing variables in an object attribute on " 492 "first call.") 493 return v 494 495 return tuple(deref(v) for v in self._weak_variables) 496 497 @variables.setter 498 def variables(self, var_list): 499 self._weak_variables = [weakref.ref(v) for v in var_list] 500 501 def _capture_by_value( 502 self, 503 op_type, 504 inputs, 505 dtypes, # pylint: disable=redefined-outer-name 506 input_types=None, 507 name=None, 508 attrs=None, 509 op_def=None, 510 compute_device=True): 511 # When capturing by value, do the read outside 512 reverse_captures = dict((id(v), k) for k, v in self.captures) 513 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] 514 with ops.init_scope(): 515 if context.executing_eagerly(): 516 attr_list = ("dtype", int(attrs["dtype"].type)) 517 value, = execute.execute( 518 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 519 context.context()) 520 else: 521 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access 522 op_type, 523 uncaptured_inputs, 524 dtypes, 525 input_types, 526 name, 527 attrs, 528 op_def, 529 compute_device) 530 value = op.outputs[0] 531 captured_value = self.capture(value) 532 return captured_value.op 533 534 def _create_op_internal( 535 self, 536 op_type, 537 inputs, 538 dtypes=None, # pylint: disable=redefined-outer-name 539 input_types=None, 540 name=None, 541 attrs=None, 542 op_def=None, 543 compute_device=True): 544 """Like Graph.create_op, except handles external input tensors. 545 546 This overload adds functionality to create_op to "capture" any external 547 input tensors, i.e. tensors from the eager context or outer function graphs 548 if this is a nested function. See `capture` for more information. 549 550 Args: 551 op_type: The `Operation` type to create. This corresponds to the 552 `OpDef.name` field for the proto that defines the operation. 553 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 554 dtypes: (Optional) A list of `DType` objects that will be the types of the 555 tensors that the operation produces. 556 input_types: (Optional.) A list of `DType`s that will be the types of 557 the tensors that the operation consumes. By default, uses the base 558 `DType` of each input in `inputs`. Operations that expect 559 reference-typed inputs must specify `input_types` explicitly. 560 name: (Optional.) A string name for the operation. If not specified, a 561 name is generated based on `op_type`. 562 attrs: (Optional.) A dictionary where the key is the attribute name (a 563 string) and the value is the respective `attr` attribute of the 564 `NodeDef` proto that will represent the operation (an `AttrValue` 565 proto). 566 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 567 the operation will have. 568 compute_device: (Optional.) If True, device functions will be executed 569 to compute the device property of the Operation. 570 571 Returns: 572 An `Operation` object. 573 """ 574 if self.capture_by_value and op_type in ["ReadVariableOp", 575 "ResourceGather"]: 576 return self._capture_by_value(op_type, inputs, dtypes, input_types, name, 577 attrs, op_def, compute_device) 578 579 # This capturing logic interacts poorly with control flow contexts which 580 # want to replace inputs of ops far too late in the process. This can lead 581 # the context to get confused and try to create an Enter for an Enter. We 582 # can detect this here and skip the additional Enter which can confuse loop 583 # validation logic. 584 if op_type == "Enter" and inputs[0].op.type == "Enter": 585 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 586 return inputs[0].op 587 # Calling AddValue on the control flow contexts to force creation of the 588 # backward accumulators in the original graph before we create placeholders 589 # to capture the inputs. 590 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 591 # Use a different list to avoid modifying the original inputs list. 592 captured_inputs = [] 593 for inp in inputs: 594 # TPU Estimator defines a control flow context with no AddValue method. 595 if ctxt is not None and hasattr(ctxt, "AddValue"): 596 inp = ctxt.AddValue(inp) 597 inp = self.capture(inp) 598 captured_inputs.append(inp) 599 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access 600 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def, 601 compute_device) 602 603 def capture(self, tensor, name=None, shape=None): 604 """Captures `tensor` if it's external to this graph. 605 606 If `tensor` is from a different graph, returns a placeholder for it. 607 `tensor` and the placeholder will appear in self.captures, and the 608 placeholder will appear in self.inputs. Multiple calls to this method with 609 the same `tensor` argument will return the same placeholder. If `tensor` is 610 from this graph, returns `tensor`. 611 612 Args: 613 tensor: Tensor. May be from this FuncGraph or a different graph. 614 name: Optional name if a placeholder is created. 615 shape: Optional shape if a placeholder is created. 616 617 Returns: 618 Tensor from this FuncGraph. 619 620 Raises: 621 InaccessibleTensorError: if any tensors are accessed in a manner that 622 bypasses the mechanisms required for the data dependencies to be correctly 623 wired. 624 """ 625 if isinstance(tensor, ops.EagerTensor): 626 if name is None: 627 name = str(ops.uid()) 628 629 # Small EagerTensors are captured with Const ops 630 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and 631 np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): 632 return self.capture_eager_tensor(tensor, name) 633 634 # Large EagerTensors and resources are captured with Placeholder ops 635 return self._capture_helper(tensor, name, shape) 636 if tensor.graph is not self: 637 if name is None: 638 name = tensor.op.name 639 inner_graph = tensor.graph 640 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 641 if inner_graph is self: 642 raise errors.InaccessibleTensorError( 643 "The tensor '%s' cannot be accessed here: it is defined" 644 " in another function or code block. Use return values," 645 " explicit Python locals or TensorFlow collections to access" 646 " it. Defined in: %s; accessed from: %s.\n" 647 % (tensor, tensor.graph, self)) 648 inner_graph = inner_graph.outer_graph 649 return self._capture_helper(tensor, name) 650 return tensor 651 652 def _capture_helper(self, tensor, name, shape=None): 653 capture = self._captures.get(id(tensor)) 654 if capture is None: 655 placeholder = _create_substitute_placeholder( 656 tensor, name=name, dtype=tensor.dtype, shape=shape) 657 # Record the composite device as an attribute to the placeholder. 658 # This attribute would be propogated into the arg_attr of the FunctionDef. 659 # Currently, a packed eager tensor is always placed on a CompositeDevice. 660 if isinstance(tensor, ops.EagerTensor) and tensor.is_packed: 661 placeholder.op._set_attr( # pylint: disable=protected-access 662 "_composite_device", 663 attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device))) 664 self.add_capture(tensor, placeholder) 665 else: 666 placeholder = capture[1] 667 tape.record_operation("captured_value", [placeholder], [tensor], 668 backward_function=lambda x: [x], 669 forward_function=lambda x: [x]) 670 return placeholder 671 672 @property 673 def captures(self): 674 """Order list of tuples containing external and internal captures.""" 675 return self._captures.values() 676 677 def add_capture(self, tensor, placeholder): 678 """Capture a specific tensor and utilize the provided placeholder. 679 680 Args: 681 tensor: Tensor to captures. 682 placeholder: Provided placeholder for the tensor. 683 """ 684 self._captures[id(tensor)] = (tensor, placeholder) 685 self.inputs.append(placeholder) 686 687 def replace_capture(self, tensor, placeholder): 688 """Replace already existing capture.""" 689 self._captures[id(tensor)] = (tensor, placeholder) 690 691 def reset_captures(self, capture_list): 692 """Set the captures with the provided list of captures & placeholder.""" 693 self._captures = py_collections.OrderedDict() 694 for tensor, placeholder in capture_list: 695 self._captures[id(tensor)] = (tensor, placeholder) 696 697 def pop_capture(self, tensor): 698 """Remove the capture and return the generated placeholder.""" 699 capture = self._captures.pop(id(tensor), None) 700 if capture is None: 701 return None 702 703 return capture[1] 704 705 def clear_captures(self): 706 # TODO(b/115366440): Delete this method when a custom OrderedDict is added. 707 # Clearing captures using clear() leaves some cycles around. 708 while self._captures: 709 self._captures.popitem() 710 memory.dismantle_ordered_dict(self._captures) 711 while self._deferred_captures: 712 self._deferred_captures.popitem() 713 memory.dismantle_ordered_dict(self._deferred_captures) 714 715 def capture_distributed_variable(self, variable, placeholder): 716 """Add given distributed variable to captures with given placeholder.""" 717 self._captures[id(variable)] = (variable, placeholder) 718 tape.record_operation("captured_value", [placeholder], [variable], 719 backward_function=lambda x: [x], 720 forward_function=lambda x: [x]) 721 722 def capture_eager_tensor(self, tensor, name): 723 capture = self._captures.get(id(tensor)) 724 if capture is None: 725 # We clear all control dependencies and place the Const op on the same 726 # device as the source tensor. The device placement may be relaxed at 727 # a later date. 728 with ops.control_dependencies(None), self.device(tensor.device): 729 constant_value = tensor_util.constant_value(tensor) 730 if constant_value is None: 731 # Some eager tensors, e.g. parallel tensors, are not convertible to a 732 # single constant. We'll use a placeholder for this case. 733 return self._capture_helper(tensor, name) 734 graph_const = constant_op.constant(constant_value, dtype=tensor.dtype, 735 shape=tensor.shape, name=name) 736 self.add_capture(tensor, graph_const) 737 else: 738 graph_const = capture[1] 739 tape.record_operation("captured_value", [graph_const], [tensor], 740 backward_function=lambda x: [x], 741 forward_function=lambda x: [x]) 742 return graph_const 743 744 def captured(self, tensor): 745 """Check if the specified tensor has been captured.""" 746 return id(tensor) in self._captures 747 748 @property 749 def external_captures(self): 750 """External tensors captured by this function.""" 751 return [c[0] for c in self._captures.values()] 752 753 @property 754 def internal_captures(self): 755 """Placeholders in this function corresponding captured tensors.""" 756 return [c[1] for c in self._captures.values()] 757 758 @property 759 def deferred_external_captures(self): 760 """Ordered nest of tensors whose placeholders will be fed at call time.""" 761 return [c[0] for c in self._deferred_captures.values()] 762 763 @property 764 def deferred_internal_captures(self): 765 """List of nest of placeholders which at call time will be fed.""" 766 return [c[1] for c in self._deferred_captures.values()] 767 768 @property 769 def variable_captures(self): 770 """Map of python object ids of variables to variables which are captured.""" 771 return { 772 id(self._captures[id(v)][1]): v 773 for v in self.variables 774 if id(v) in self._captures 775 } 776 777 def mark_as_unsaveable(self, error_message): 778 """Marks this FuncGraph as unsaveable. 779 780 Any attempts to export this FuncGraph will raise an error with the specified 781 message. 782 783 Args: 784 error_message: List or string containing the error message to be raised 785 when saving this FuncGraph to SavedModel. 786 """ 787 self._saveable = False 788 if isinstance(error_message, str): 789 error_message = [error_message] 790 self._saving_errors.update(error_message) 791 792 @property 793 def saveable(self): 794 """Returns whether this FuncGraph is saveable.""" 795 return self._saveable 796 797 @property 798 def saving_errors(self): 799 """Returns set of errors preventing this FuncGraph from being saved.""" 800 return self._saving_errors 801 802 def _add_scope_exit_callback(self, fn): 803 """Add a function to call when this graph exits the default scope.""" 804 if not callable(fn): 805 raise TypeError("fn is not callable: {}".format(fn)) 806 if self._scope_exit_callbacks is None: 807 raise RuntimeError( 808 "Attempting to add a scope exit callback, but the default graph is " 809 "not the context scope graph. Did you forget to call " 810 "'with graph.as_default(): ...'?") 811 self._scope_exit_callbacks.append(fn) 812 813 814# TODO(mdan): Too many threaded arguments. Accept an ACD ctx manager instead. 815def func_graph_from_py_func(name, 816 python_func, 817 args, 818 kwargs, 819 signature=None, 820 func_graph=None, 821 autograph=False, 822 autograph_options=None, 823 add_control_dependencies=True, 824 arg_names=None, 825 op_return_value=None, 826 collections=None, 827 capture_by_value=None, 828 override_flat_arg_shapes=None, 829 acd_record_initial_resource_uses=False): 830 """Returns a `FuncGraph` generated from `python_func`. 831 832 Args: 833 name: an identifier for the function. 834 python_func: the Python function to trace. 835 args: the positional args with which the Python function should be called; 836 ignored if a signature is provided. 837 kwargs: the keyword args with which the Python function should be called; 838 ignored if a signature is provided. 839 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 840 and dtypes of the arguments. When a signature is provided, `args` and 841 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 842 to `signature`. If `None`, the shapes and dtypes are inferred from the 843 inputs. 844 func_graph: Optional. An instance of FuncGraph. If provided, we will use 845 this graph else a new one is built and returned. 846 autograph: whether to use autograph to compile `python_func`. 847 See https://www.tensorflow.org/guide/autograph for more information. 848 autograph_options: additional knobs to control when `autograph=True`. 849 See https://www.tensorflow.org/guide/autograph for more information. 850 add_control_dependencies: If True, automatically adds control dependencies 851 to ensure program order matches execution order and stateful ops always 852 execute. 853 arg_names: Optional list of argument names, used to give input placeholders 854 recognizable names. 855 op_return_value: Optional. A Tensor. If set and `python_func` returns 856 Operations, those return values will be replaced with this value. If not 857 set, returning an Operation triggers an error. 858 collections: a dictionary of collections this FuncGraph should start 859 with. If not specified (None), the FuncGraph will read (but not write to) 860 the outer graph's collections that are not allowlisted, and both 861 read and write to the outer graph's collections that are allowlisted. 862 The current allowlisted collections are the global variables, the 863 local variables, and the trainable variables. 864 Defaults to None. 865 capture_by_value: An optional boolean. If True, the func graph will capture 866 Variables by value instead of reference. By default inherit from outer 867 graphs, and failing that will default to False. 868 override_flat_arg_shapes: An optional list of instances that are either 869 `None` or `TensorShape`. The length must match that of 870 `nest.flatten((args, kwargs), expand_composites=True)`. The entries 871 containing value `None` must match entries in flattened arguments 872 containing non-tensors, while entries containing a `TensorShape` must 873 match entries in the flattened arguments containing tensors. 874 acd_record_initial_resource_uses: If `True` and `add_control_dependencies` 875 is enabled, the results (those marked with 876 AutomaticControlDependencies.mark_result) will be annotated with a private 877 attribute, "_res_first_used_by", which points to the first nodes which 878 used the any of the resources that the result op is using. 879 880 Returns: 881 A FuncGraph. 882 883 Raises: 884 TypeError: If any of `python_func`'s return values is neither `None` nor a 885 `Tensor`. 886 ValueError: If both `signature` and `override_flat_arg_shapes` are 887 passed in. 888 """ 889 if op_return_value is not None: 890 assert isinstance(op_return_value, ops.Tensor), op_return_value 891 if func_graph is None: 892 func_graph = FuncGraph(name, collections=collections, 893 capture_by_value=capture_by_value) 894 assert isinstance(func_graph, FuncGraph) 895 if add_control_dependencies: 896 deps_control_manager = auto_control_deps.AutomaticControlDependencies( 897 record_initial_resource_uses=acd_record_initial_resource_uses) 898 else: 899 deps_control_manager = ops.NullContextmanager() 900 901 with func_graph.as_default(), deps_control_manager as deps_ctx: 902 current_scope = variable_scope.get_variable_scope() 903 default_use_resource = current_scope.use_resource 904 current_scope.set_use_resource(True) 905 906 if signature is not None and override_flat_arg_shapes is not None: 907 raise ValueError( 908 "Passed both signature and override_flat_arg_shapes: %s and %s." 909 % (signature, override_flat_arg_shapes)) 910 911 if signature is not None: 912 args = signature 913 kwargs = {} 914 915 # Creates and names placeholders for all arguments. 916 if override_flat_arg_shapes is not None: 917 flat_args = nest.flatten(args, expand_composites=True) 918 arg_shapes = override_flat_arg_shapes[:len(flat_args)] 919 kwarg_shapes = override_flat_arg_shapes[len(flat_args):] 920 else: 921 arg_shapes = None 922 kwarg_shapes = None 923 func_args = _get_defun_inputs_from_args( 924 args, arg_names, flat_shapes=arg_shapes) 925 func_kwargs = _get_defun_inputs_from_kwargs( 926 kwargs, flat_shapes=kwarg_shapes) 927 928 # Convert all Tensors into TensorSpecs before saving the structured inputs. 929 # If storing pure concrete functions that are not called through polymorphic 930 # functions, we don't have access to FunctionSpec, so we need to call the 931 # TensorSpecs by their `arg_names` for later binding. 932 func_graph.structured_input_signature = ( 933 convert_structure_to_signature(func_args, arg_names), 934 convert_structure_to_signature(func_kwargs)) 935 936 flat_func_args = nest.flatten(func_args, expand_composites=True) 937 flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) 938 # Temporarily set inputs to allow graph building code to inspect 939 # them. Reassigned below. 940 func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs 941 if isinstance(arg, ops.Tensor)] 942 943 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 944 # Variables to help check whether mutation happens in calling the function 945 # Copy the recursive list, tuple and map structure, but not base objects 946 func_args_before = nest.pack_sequence_as(func_args, flat_func_args, 947 expand_composites=True) 948 func_kwargs_before = nest.pack_sequence_as( 949 func_kwargs, flat_func_kwargs, expand_composites=True) 950 951 def convert(x): 952 """Converts a function output to a Tensor.""" 953 if x is None: 954 return None 955 if op_return_value is not None and isinstance(x, ops.Operation): 956 # TODO(b/79881896): we currently can't capture external control deps, so 957 # this won't work if x needs to be captured (i.e. if python_func returns 958 # captured Operations). 959 with ops.control_dependencies([x]): 960 x = array_ops.identity(op_return_value) 961 elif not isinstance(x, tensor_array_ops.TensorArray): 962 try: 963 x = ops.convert_to_tensor_or_composite(x) 964 except (ValueError, TypeError): 965 raise TypeError( 966 "To be compatible with tf.eager.defun, Python functions " 967 "must return zero or more Tensors; in compilation of %s, found " 968 "return value of type %s, which is not a Tensor." % 969 (str(python_func), type(x))) 970 if add_control_dependencies: 971 x = deps_ctx.mark_as_return(x) 972 return x 973 974 try: 975 if autograph: 976 from tensorflow.python import autograph # pylint: disable=g-import-not-at-top 977 _, original_func = tf_decorator.unwrap(python_func) 978 979 def autograph_handler(*args, **kwargs): 980 """Calls a converted version of original_func.""" 981 # TODO(mdan): Push this block higher in tf.function's call stack. 982 try: 983 return autograph.converted_call( 984 original_func, 985 args, 986 kwargs, 987 options=autograph.ConversionOptions( 988 recursive=True, 989 optional_features=autograph_options, 990 user_requested=True, 991 )) 992 except Exception as e: # pylint:disable=broad-except 993 if hasattr(e, "ag_error_metadata"): 994 raise e.ag_error_metadata.to_exception(e) 995 else: 996 raise 997 998 # Wrapping around a decorator allows checks like tf_inspect.getargspec 999 # to be accurate. 1000 converted_func = tf_decorator.make_decorator( 1001 original_func, autograph_handler) 1002 python_func = tf_decorator.rewrap(python_func, original_func, 1003 converted_func) 1004 1005 else: 1006 _, original_func = tf_decorator.unwrap(python_func) 1007 1008 func_outputs = python_func(*func_args, **func_kwargs) 1009 1010 # invariant: `func_outputs` contains only Tensors, CompositeTensors, 1011 # TensorArrays and `None`s. 1012 func_outputs = nest.map_structure(convert, func_outputs, 1013 expand_composites=True) 1014 1015 check_mutation(func_args_before, func_args, original_func) 1016 check_mutation(func_kwargs_before, func_kwargs, original_func) 1017 finally: 1018 current_scope.set_use_resource(default_use_resource) 1019 1020 # Variables in `func_args`, `func_kwargs` should be explicit inputs 1021 # to the function, not captured inputs. 1022 graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access 1023 arg_variables = object_identity.ObjectIdentitySet() 1024 inputs = [] 1025 for arg in (nest.flatten(func_args, expand_composites=True) + 1026 nest.flatten(func_kwargs, expand_composites=True)): 1027 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1028 # Even if an argument variable was not used in the function, we've 1029 # already manually captured the resource Tensor when creating argument 1030 # placeholders. 1031 resource_placeholder = func_graph.pop_capture(arg.handle) 1032 if resource_placeholder is None: 1033 continue 1034 arg_variables.add(arg) 1035 inputs.append(resource_placeholder) 1036 elif isinstance(arg, ops.Tensor): 1037 inputs.append(arg) 1038 variables = [v for v in graph_variables if v not in arg_variables] 1039 func_graph.inputs = ( 1040 inputs + func_graph.internal_captures + nest.flatten( 1041 func_graph.deferred_internal_captures, expand_composites=True)) 1042 func_graph.structured_outputs = func_outputs 1043 # Returning a closed-over tensor does not trigger convert_to_tensor. 1044 func_graph.outputs.extend( 1045 func_graph.capture(x) 1046 for x in flatten(func_graph.structured_outputs) 1047 if x is not None) 1048 1049 func_graph.variables = variables 1050 1051 if add_control_dependencies: 1052 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) 1053 func_graph.collective_manager_ids_used = ( 1054 deps_control_manager.collective_manager_ids_used) 1055 1056 return func_graph 1057 1058 1059def maybe_captured(tensor): 1060 """If t is a captured value placeholder, returns the original captured value. 1061 1062 Args: 1063 tensor: Tensor. 1064 1065 Returns: 1066 A tensor, potentially from a different Graph/FuncGraph. 1067 """ 1068 if (not isinstance(tensor, ops.EagerTensor) and 1069 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 1070 for input_t, placeholder_t in tensor.op.graph.captures: 1071 if tensor == placeholder_t: 1072 return maybe_captured(input_t) 1073 # pylint: enable=protected-access 1074 return tensor 1075 1076 1077def device_stack_has_callable(device_stack): 1078 """Checks whether a device stack contains a callable.""" 1079 return any(callable(spec._device_name_or_function) # pylint: disable=protected-access 1080 for spec in device_stack.peek_objs()) 1081 1082 1083def check_mutation(n1, n2, func): 1084 """Check if two list of arguments are exactly the same.""" 1085 func_name = getattr(func, "__name__", func) 1086 1087 errmsg = ("{}() should not modify its Python input arguments." 1088 " Check if it modifies any lists or dicts passed as" 1089 " arguments. Modifying a copy is allowed.".format(func_name)) 1090 try: 1091 # TODO(mdan): Compare more robustly so that argument names can be reported. 1092 nest.assert_same_structure(n1, n2, expand_composites=True) 1093 except ValueError: 1094 raise ValueError(errmsg) 1095 1096 for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True), 1097 nest.flatten(n2, expand_composites=True)): 1098 if arg1 is not arg2: 1099 raise ValueError(errmsg) 1100 1101 1102# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1103def flatten(sequence): 1104 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. 1105 1106 Args: 1107 sequence: A nested structure of Tensors, CompositeTensors, and 1108 TensorArrays. 1109 1110 Returns: 1111 A list of tensors. 1112 """ 1113 flat_sequence = nest.flatten(sequence, expand_composites=True) 1114 return [ 1115 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 1116 for item in flat_sequence] 1117 1118 1119# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1120def pack_sequence_as(structure, flat_sequence): 1121 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. 1122 1123 Args: 1124 structure: The structure to pack into. May contain Tensors, 1125 CompositeTensors, or TensorArrays. 1126 flat_sequence: An iterable containing tensors. 1127 1128 Returns: 1129 A nested structure. 1130 1131 Raises: 1132 AssertionError if `structure` and `flat_sequence` are not compatible. 1133 """ 1134 flat_sequence = list(flat_sequence) 1135 flattened_structure = nest.flatten(structure, expand_composites=True) 1136 if len(flattened_structure) != len(flat_sequence): 1137 raise ValueError("Mismatch in element count") 1138 for i in range(len(flat_sequence)): 1139 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 1140 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 1141 old_ta=flattened_structure[i], flow=flat_sequence[i]) 1142 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 1143 1144 1145def _create_substitute_placeholder(value, name=None, dtype=None, shape=None): 1146 """Creates a placeholder for `value` and propagates shape info to it.""" 1147 # Note: setting ops.control_dependencies(None) ensures we always put 1148 # capturing placeholders outside of any control flow context. 1149 if shape is None: 1150 shape = value.shape 1151 with ops.control_dependencies(None): 1152 placeholder = graph_placeholder( 1153 dtype=dtype or value.dtype, shape=shape, name=name) 1154 handle_data_util.copy_handle_data(value, placeholder) 1155 return placeholder 1156 1157 1158def _get_defun_inputs_from_args(args, names, flat_shapes=None): 1159 """Maps Python function positional args to graph-construction inputs.""" 1160 return _get_defun_inputs( 1161 args, names, structure=args, flat_shapes=flat_shapes) 1162 1163 1164def _get_composite_tensor_spec(x): 1165 """Returns the TypeSpec for x if it's a composite tensor, or x otherwise.""" 1166 return (x._type_spec # pylint: disable=protected-access 1167 if isinstance(x, composite_tensor.CompositeTensor) else x) 1168 1169 1170def _get_defun_inputs(args, names, structure, flat_shapes=None): 1171 """Maps python function args to graph-construction inputs. 1172 1173 Args: 1174 args: A flat list of user-specified arguments. 1175 names: A list of strings with user-specified argument names, same length as 1176 `args`. May be `None`, in which case a generic name is used. 1177 structure: The original argument list or dictionary. 1178 flat_shapes: A flat list of values that are either `None` or 1179 instances of `TensorShape`. If provided, then length must match 1180 that of `nest.flatten(args, expand_composites=True)`; and locations where 1181 `args` are instances of `Tensor` must have a corresponding `TensorShape` 1182 in `flat_shapes`. May be `None`, in which case exact shapes are read 1183 directly from the args. 1184 1185 Returns: 1186 Placeholders with the same structure as `structure`. 1187 1188 Raises: 1189 RuntimeError: if `flat_shapes` is provided, but 1190 `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`. 1191 RuntimeError: if a shape from `flat_shapes` is not None 1192 for an argument that is not a `Tensor`, `TensorSpec`, 1193 or `ResourceVariable`. 1194 """ 1195 func_graph = ops.get_default_graph() 1196 function_inputs = [] 1197 if names is None: 1198 names = [None] * len(args) 1199 if flat_shapes is None: 1200 shapes_iter = itertools.repeat(None) 1201 else: 1202 len_flat_args = len(nest.flatten(args, expand_composites=True)) 1203 if len_flat_args != len(flat_shapes): 1204 raise RuntimeError( 1205 "Length of fully flat shapes (%d) must match that of " 1206 "flatten(args) (%d). args: %s, flat_shapes: %s" 1207 % (len(flat_shapes), 1208 len_flat_args, 1209 args, 1210 flat_shapes)) 1211 shapes_iter = iter(flat_shapes) 1212 for arg_value, name in zip(args, names): 1213 1214 # Replace any composite tensors with their TypeSpecs. This is important 1215 # for ensuring that shape information that's not preserved by the TypeSpec 1216 # (such as the number of values in a SparseTensor) gets properly masked. 1217 arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value) 1218 1219 flattened = nest.flatten(arg_value, expand_composites=True) 1220 1221 for arg in flattened: 1222 # We have a shape entry for each arg, regardless of whether it's a real 1223 # Tensor or not. For non-tensor entries it should be None. 1224 shape = next(shapes_iter) 1225 if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): 1226 arg_is_spec = isinstance(arg, tensor_spec.TensorSpec) 1227 if arg_is_spec and arg.name: 1228 requested_name = arg.name 1229 else: 1230 requested_name = name 1231 placeholder_shape = shape if shape is not None else arg.shape 1232 try: 1233 placeholder = graph_placeholder( 1234 arg.dtype, placeholder_shape, 1235 name=requested_name) 1236 except ValueError: 1237 # Sometimes parameter names are not valid op names, so fall back to 1238 # unnamed placeholders. 1239 placeholder = graph_placeholder(arg.dtype, placeholder_shape) 1240 if not arg_is_spec: 1241 handle_data_util.copy_handle_data(arg, placeholder) 1242 if name is not None: 1243 # Record the requested/user-specified name in case it's different than 1244 # the uniquified name, for validation when exporting signatures. 1245 placeholder.op._set_attr( # pylint: disable=protected-access 1246 "_user_specified_name", 1247 attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) 1248 function_inputs.append(placeholder) 1249 elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, 1250 resource_variable_ops.VariableSpec)): 1251 if isinstance(arg, resource_variable_ops.VariableSpec): 1252 name = arg.name or name 1253 with func_graph.outer_graph.as_default(): 1254 placeholder = graph_placeholder(dtypes.resource, arg.shape, 1255 name=name) 1256 1257 arg = resource_variable_ops.BaseResourceVariable( 1258 name=name, 1259 shape=arg.shape, 1260 dtype=arg.dtype, 1261 handle=placeholder, 1262 handle_name=name, 1263 trainable=arg.trainable) 1264 # Capture arg variables to create placeholders for them. These will be 1265 # removed as captures after the function is traced (since otherwise we'd 1266 # just add it back with a new placeholder when the variable was 1267 # referenced). 1268 placeholder = func_graph.capture(arg.handle, name=name) 1269 placeholder.op._set_attr( # pylint: disable=protected-access 1270 "_user_specified_name", 1271 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 1272 function_inputs.append(arg) 1273 else: 1274 if shape is not None: 1275 raise RuntimeError( 1276 "Expected provided shape override to be None for arg that isn't " 1277 "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" 1278 % (arg, shape, args)) 1279 function_inputs.append(arg) 1280 return nest.pack_sequence_as(structure, function_inputs, 1281 expand_composites=True) 1282 1283 1284def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): 1285 """Maps Python function keyword args to graph-construction inputs.""" 1286 if kwargs: 1287 names, args = zip(*sorted(kwargs.items())) 1288 else: 1289 names = [] 1290 args = [] 1291 return _get_defun_inputs( 1292 args, names, structure=kwargs, flat_shapes=flat_shapes) 1293 1294 1295def dismantle_func_graph(func_graph): 1296 """Removes reference cycles in `func_graph` FuncGraph. 1297 1298 Helpful for making sure the garbage collector doesn't need to run when 1299 the FuncGraph goes out of scope, e.g. in tests using defun with 1300 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 1301 1302 Args: 1303 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable 1304 after this function. 1305 """ 1306 func_graph.clear_captures() 1307 ops.dismantle_graph(func_graph) 1308 1309 1310def override_func_graph_name_scope(func_graph, name_scope): 1311 func_graph._name_stack = name_scope # pylint: disable=protected-access 1312