1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FuncGraph and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as py_collections 22import itertools 23import weakref 24 25import numpy as np 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import execute 30from tensorflow.python.eager import tape 31from tensorflow.python.eager.graph_only_ops import graph_placeholder 32from tensorflow.python.framework import auto_control_deps 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import type_spec 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import custom_gradient 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.ops import tensor_array_ops 45from tensorflow.python.ops import variable_scope 46from tensorflow.python.util import compat 47from tensorflow.python.util import memory 48from tensorflow.python.util import nest 49from tensorflow.python.util import object_identity 50from tensorflow.python.util import tf_contextlib 51from tensorflow.python.util import tf_decorator 52from tensorflow.python.util.tf_export import tf_export 53 54ALLOWLIST_COLLECTIONS = [ 55 ops.GraphKeys.GLOBAL_VARIABLES, 56 ops.GraphKeys.LOCAL_VARIABLES, 57 ops.GraphKeys.TRAINABLE_VARIABLES, 58 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 59 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 60] 61 62 63_EAGER_CONST_THRESHOLD = 128 64 65 66class UnknownArgument(object): 67 """Signifies an argument which is not currently handled.""" 68 pass 69 70 71def convert_structure_to_signature(structure, arg_names=None): 72 """Convert a potentially nested structure to a signature. 73 74 Args: 75 structure: Structure to convert, where top level collection is a list or a 76 tuple. 77 arg_names: Optional list of arguments that has equal number of elements as 78 `structure` and is used for naming corresponding TensorSpecs. 79 80 Returns: 81 Identical structure that has TensorSpec objects instead of Tensors and 82 UnknownArgument instead of any unsupported types. 83 """ 84 def encode_arg(arg, path): 85 """A representation for this argument, for converting into signatures.""" 86 if isinstance(arg, ops.Tensor): 87 user_specified_name = None 88 try: 89 user_specified_name = compat.as_str( 90 arg.op.get_attr("_user_specified_name")) 91 except ValueError: 92 pass 93 94 if path and user_specified_name and user_specified_name != path[0]: 95 # The user has explicitly named the argument differently than the name 96 # of the function argument. 97 name = user_specified_name 98 else: 99 name = "/".join(str(p) for p in path) 100 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 101 if isinstance(arg, composite_tensor.CompositeTensor): 102 # TODO(b/133606651) Do we need to inject arg_name? 103 return arg._type_spec # pylint: disable=protected-access 104 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 105 name = "/".join(str(p) for p in path) 106 return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name) 107 if isinstance(arg, ( 108 int, 109 float, 110 bool, 111 str, 112 type(None), 113 dtypes.DType, 114 tensor_spec.TensorSpec, 115 type_spec.TypeSpec, 116 )): 117 return arg 118 return UnknownArgument() 119 120 # We are using the flattened paths to name the TensorSpecs. We need an 121 # explicit name for them downstream. 122 flattened = nest.flatten_with_tuple_paths(structure) 123 if arg_names: 124 if len(arg_names) != len(structure): 125 raise ValueError( 126 "Passed in arg_names don't match actual signature (%s)." % arg_names) 127 # Replace all top-level names with their actual arg_names. If a path before 128 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 129 flattened = [ 130 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 131 ] 132 133 mapped = [encode_arg(arg, path) for path, arg in flattened] 134 return nest.pack_sequence_as(structure, mapped) 135 136 137@tf_export("__internal__.FuncGraph", v1=[]) 138class FuncGraph(ops.Graph): 139 """Graph representing a function body. 140 141 Attributes: 142 name: The name of the function. 143 inputs: Placeholder tensors representing the inputs to this function. The 144 tensors are in this FuncGraph. This represents "regular" inputs as well as 145 captured inputs (i.e. the values of self.captures), with the regular 146 inputs coming first. 147 outputs: Tensors that will be returned by this function. The tensors are in 148 this FuncGraph. 149 control_outputs: Operations that must be executed before the function 150 represented by this graph can be said to have been executed. 151 structured_input_signature: A tuple of (args, kwargs), which are both 152 possibly-nested python objects that were received by this function. Note 153 that these structures might contain Python `None`s. 154 structured_outputs: A possibly-nested python object which will be returned 155 by this function. The Tensors in this structure are the same as those of 156 self.outputs. Note that this structure might contain Python `None`s. 157 variables: Variables that should be watched during function execution. 158 outer_graph: The graph this function is defined in. May be another FuncGraph 159 or the global default Graph. 160 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 161 The entries are in the order they were captured. 162 control_captures: Set of external ops on which this graph has a control 163 dependency. 164 seed: The graph-level random seed. 165 capture_by_value: If True, the func graph will capture Variables by value 166 instead of reference. 167 """ 168 169 def __init__(self, name, collections=None, capture_by_value=None): 170 """Construct a new FuncGraph. 171 172 The graph will inherit its graph key, collections, seed, and distribution 173 strategy stack from the current context or graph. 174 175 Args: 176 name: the name of the function. 177 collections: a dictionary of collections this FuncGraph should start 178 with. If not specified (None), the FuncGraph will read (but not write 179 to) the outer graph's collections that are not allowlisted, and both 180 read and write to the outer graph's collections that are allowlisted. 181 The current allowlisted collections are the global variables, the 182 local variables, and the trainable variables. 183 Defaults to None. 184 capture_by_value: An optional boolean. If True, the func graph will 185 capture Variables by value instead of reference. By default inherit 186 from outer graphs, and failing that will default to False. 187 """ 188 super(FuncGraph, self).__init__() 189 190 self.name = name 191 self.inputs = [] 192 self.outputs = [] 193 self.control_outputs = [] 194 self.control_captures = set() 195 self.structured_input_signature = None 196 self.structured_outputs = None 197 self._weak_variables = [] 198 self._watched_variables = object_identity.ObjectIdentityWeakSet() 199 self.is_control_flow_graph = False 200 201 outer_graph = ops.get_default_graph() 202 self._weak_outer_graph = weakref.ref(outer_graph) 203 while outer_graph.building_function: 204 outer_graph = outer_graph.outer_graph 205 # If self._weak_outer_graph is deleted, we revert to the outermost Graph 206 # active when the FuncGraph was traced. This will not be a FuncGraph. 207 self._fallback_outer_graph = outer_graph 208 self._captures = py_collections.OrderedDict() 209 # If not None, records the names of output args of this function. Used to 210 # preserve the output names in the signature of a serialized+deserialized 211 # function. Private at the moment mostly because it's often out of date. 212 self._output_names = None 213 # Maps arbitrary key -> (closure, nest of placeholders), where at function 214 # call time the value of closure() will be used to feed the nest of 215 # placeholders. 216 self._deferred_captures = py_collections.OrderedDict() 217 # Inherit capture-by-value from outer graph. 218 if capture_by_value is not None: 219 self.capture_by_value = capture_by_value 220 elif self.outer_graph is not None and isinstance( 221 self.outer_graph, FuncGraph): 222 self.capture_by_value = self.outer_graph.capture_by_value 223 else: 224 self.capture_by_value = False 225 226 self._building_function = True 227 # Map from resource tensor name to last op (in program order) which uses 228 # this tensor. Used to enforce that execution order matches program order 229 # for resource tensors. 230 self._last_op_using_resource_tensor = {} 231 232 graph = self.outer_graph 233 234 if context.executing_eagerly(): 235 self.seed = context.global_seed() 236 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 237 # any None op_seed for random_op in the function, in which case we end up 238 # using function seed, which could be unintended behavior for the op. 239 self._seed_used = False 240 else: 241 self.seed = graph.seed 242 self._seed_used = False 243 # TODO(allenl): Figure out if we can remove colocation stack 244 # specialization (currently used in cond_v2), here and in the cache key. 245 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 246 247 if collections is None: 248 for collection_name in graph.get_all_collection_keys(): 249 if collection_name not in ALLOWLIST_COLLECTIONS: 250 self._collections[collection_name] = graph.get_collection( 251 collection_name) 252 for collection_name in ALLOWLIST_COLLECTIONS: 253 self._collections[collection_name] = graph.get_collection_ref( 254 collection_name) 255 else: 256 self._collections = collections 257 258 # Keep track of whether this FuncGraph is exportable to SavedModel. Use 259 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any 260 # dependent functions as unsaveable. 261 self._saveable = True 262 self._saving_errors = set() 263 264 # Keep track of callbacks to run when this graph exits default scope 265 self._scope_exit_callbacks = None 266 267 def __str__(self): 268 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 269 270 def watch_variable(self, v): 271 """Marks the variable v as accessed while building this graph.""" 272 while self is not None and isinstance(self, FuncGraph): 273 self._watched_variables.add(v) 274 self = self.outer_graph 275 276 def capture_call_time_value(self, closure, spec, key=None): 277 """Creates a placeholder which at call time has the value closure(). 278 279 Useful, for example, to respect TensorFlow context managers, which are often 280 dynamically scoped. 281 282 Args: 283 closure: function which takes no arguments, to be evaluated at function 284 call time, returning a nest of tensors compatible with `spec`. 285 spec: nest of TypeSpec for the value to capture. 286 key: optional. If not None, multiple calls to lazy_capture with the same 287 key in the same graph will return the same placeholder, and the 288 first closure will be used at function call time. 289 290 Returns: 291 Nest of placeholders which, at function call time, will be fed with the 292 result of calling closure(). 293 294 Raises: 295 ValueError: at function call time, if the return value of closure() is 296 not compatible with `spec`. 297 """ 298 if key is None: 299 key = object() 300 if key not in self._deferred_captures: 301 302 def convert_to_placeholder(s): 303 if not isinstance(s, tensor_spec.DenseSpec): 304 raise TypeError( 305 "Expected a nest of `TypeSpec` objects, found %s of type %s." % 306 (s, type(s))) 307 return array_ops.placeholder(dtype=s.dtype, shape=s.shape) 308 309 placeholder = nest.map_structure( 310 convert_to_placeholder, spec, expand_composites=True) 311 312 def wrapped_closure(): 313 ret_nest = closure() 314 nest.assert_same_structure(spec, ret_nest, expand_composites=True) 315 # This uses the tensor dtype defined in `spec` when converting values 316 # in `ret_nest` to tensors. 317 # pylint: disable=protected-access 318 y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest, 319 expand_composites=False) 320 # pylint: enable=protected-access 321 return nest.flatten(y, expand_composites=True) 322 323 self._deferred_captures[key] = (wrapped_closure, placeholder) 324 return self._deferred_captures[key][1] 325 326 def control_dependencies(self, control_inputs): 327 """Handles control dependencies. 328 329 FuncGraph wraps Graph's control_dependencies logic by first filtering out 330 any external tensors / operations and storing them in the graph's 331 control_captures member. Any consumers of this function graph must then 332 decide how to handle the control captures. 333 334 Args: 335 control_inputs: A list of `Operation` or `Tensor` objects which 336 must be executed or computed before running the operations 337 defined in the context. Can also be `None` to clear the control 338 dependencies. 339 340 Returns: 341 A context manager that specifies control dependencies for all 342 operations constructed within the context. 343 344 Raises: 345 TypeError: If `control_inputs` is not a list of `Operation` or 346 `Tensor` objects. 347 """ 348 if control_inputs is None: 349 return super(FuncGraph, self).control_dependencies(control_inputs) 350 351 filtered_control_inputs = [] 352 for c in control_inputs: 353 # Check for _UnreadVariable 354 if (isinstance(c, ops.IndexedSlices) or 355 (hasattr(c, "_handle") and hasattr(c, "op"))): 356 c = c.op 357 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 358 if graph_element is None: 359 graph_element = c 360 if graph_element is not None and getattr( 361 graph_element, "graph", None) is not self: 362 self.control_captures.add(graph_element) 363 else: 364 filtered_control_inputs.append(graph_element) 365 return super(FuncGraph, self).control_dependencies(filtered_control_inputs) 366 367 def as_default(self): 368 outer_cm = super(FuncGraph, self).as_default() 369 370 @tf_contextlib.contextmanager 371 def inner_cm(): 372 """Context manager for copying distribute.Strategy scope information.""" 373 # pylint: disable=protected-access 374 # TODO(b/112906995, nareshmodi): distribution strategy depends on 375 # inheriting this stack from the default graph even in eager mode. Maybe 376 # it should be part of the eager context? This would also allow us to 377 # remove a get_default_graph() call from the function cache lookup. 378 graph = ops.get_default_graph() 379 old_strategy_stack = self._distribution_strategy_stack 380 self._distribution_strategy_stack = list( 381 graph._distribution_strategy_stack) 382 383 # We ignore device placements from any outer scopes while tracing the 384 # function when possible, to avoid hard-coding them in the function 385 # graph. "Default" placements come from the PartitionedCallOp's placement, 386 # so that the same trace of the Python function may be placed on several 387 # different devices and saved functions may be placed on new devices when 388 # restored. 389 # However, we need to preserve the outer device stack in the following 390 # cases in non eager context: 391 # 1. device stack is callable 392 # 2. When using distribution strategy with legacy graph mode. 393 old_device_stack = self._device_function_stack 394 if (not context.executing_eagerly() and 395 (device_stack_has_callable(graph._device_function_stack) or 396 (self._distribution_strategy_stack and 397 not ops.executing_eagerly_outside_functions()))): 398 # Hard-code devices from device functions in the function body 399 self._device_function_stack = graph._device_function_stack.copy() 400 401 old_creator_stack = self._variable_creator_stack 402 self._variable_creator_stack = graph._variable_creator_stack 403 # Inherit the graph key, since this is used for matching variables in 404 # optimizers. 405 old_graph_key = self._graph_key 406 self._graph_key = graph._graph_key 407 # pylint: enable=protected-access 408 409 old_scope_exit_callbacks = self._scope_exit_callbacks 410 self._scope_exit_callbacks = [] 411 412 with outer_cm as g: 413 try: 414 yield g 415 finally: 416 try: 417 for fn in self._scope_exit_callbacks: 418 fn() 419 finally: 420 self._scope_exit_callbacks = old_scope_exit_callbacks 421 self._distribution_strategy_stack = old_strategy_stack 422 self._device_function_stack = old_device_stack 423 self._variable_creator_stack = old_creator_stack 424 self._graph_key = old_graph_key 425 return inner_cm() 426 427 @property 428 def outer_graph(self): 429 """The Graph this FuncGraph is nested in. 430 431 Functions may capture Tensors from graphs they are nested in (transitive). 432 433 Returns: 434 A Graph object. Initially set to the current default graph when the 435 FuncGraph was created. If the previous `outer_graph` was deleted because 436 the function that owns it was deleted, `outer_graph` is reset to the 437 outermost default graph active when the FuncGraph was created. This 438 FuncGraph won't have captured anything from the new `outer_graph` (and 439 likely not from the previous setting, since that would have created a 440 strong reference), but it is returned so that FuncGraphs always have a 441 parent. 442 """ 443 current = self._weak_outer_graph() 444 if current is None: 445 return self._fallback_outer_graph 446 return current 447 448 @outer_graph.setter 449 def outer_graph(self, new_outer_graph): 450 """Sets `outer_graph` to `new_outer_graph`.""" 451 self._weak_outer_graph = weakref.ref(new_outer_graph) 452 453 @property 454 def output_types(self): 455 return [t.dtype for t in self.outputs] 456 457 @property 458 def output_shapes(self): 459 return [t.shape for t in self.outputs] 460 461 @property 462 def trainable_variables(self): 463 """A sequence of trainable variables accessed by this FuncGraph. 464 465 Note that functions keep only weak references to variables. Calling the 466 function after a variable it accesses has been deleted is an error. 467 468 Returns: 469 Sequence of trainable variables for this func graph. 470 """ 471 return tuple(v for v in self.variables if v.trainable) 472 473 @property 474 def variables(self): 475 """A sequence of variables accessed by this FuncGraph. 476 477 Note that functions keep only weak references to variables. Calling the 478 function after a variable it accesses has been deleted is an error. 479 480 Returns: 481 Sequence of variables for this func graph. 482 """ 483 def deref(weak_v): 484 v = weak_v() 485 if v is None: 486 raise AssertionError( 487 "Called a function referencing variables which have been deleted. " 488 "This likely means that function-local variables were created and " 489 "not referenced elsewhere in the program. This is generally a " 490 "mistake; consider storing variables in an object attribute on " 491 "first call.") 492 return v 493 494 return tuple(deref(v) for v in self._weak_variables) 495 496 @variables.setter 497 def variables(self, var_list): 498 self._weak_variables = [weakref.ref(v) for v in var_list] 499 500 def _capture_by_value( 501 self, 502 op_type, 503 inputs, 504 dtypes, # pylint: disable=redefined-outer-name 505 input_types=None, 506 name=None, 507 attrs=None, 508 op_def=None, 509 compute_device=True): 510 # When capturing by value, do the read outside 511 reverse_captures = dict((id(v), k) for k, v in self.captures) 512 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] 513 with ops.init_scope(): 514 if context.executing_eagerly(): 515 attr_list = ("dtype", int(attrs["dtype"].type)) 516 value, = execute.execute( 517 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 518 context.context()) 519 else: 520 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access 521 op_type, 522 uncaptured_inputs, 523 dtypes, 524 input_types, 525 name, 526 attrs, 527 op_def, 528 compute_device) 529 value = op.outputs[0] 530 captured_value = self.capture(value) 531 return captured_value.op 532 533 def _create_op_internal( 534 self, 535 op_type, 536 inputs, 537 dtypes=None, # pylint: disable=redefined-outer-name 538 input_types=None, 539 name=None, 540 attrs=None, 541 op_def=None, 542 compute_device=True): 543 """Like Graph.create_op, except handles external input tensors. 544 545 This overload adds functionality to create_op to "capture" any external 546 input tensors, i.e. tensors from the eager context or outer function graphs 547 if this is a nested function. See `capture` for more information. 548 549 Args: 550 op_type: The `Operation` type to create. This corresponds to the 551 `OpDef.name` field for the proto that defines the operation. 552 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 553 dtypes: (Optional) A list of `DType` objects that will be the types of the 554 tensors that the operation produces. 555 input_types: (Optional.) A list of `DType`s that will be the types of 556 the tensors that the operation consumes. By default, uses the base 557 `DType` of each input in `inputs`. Operations that expect 558 reference-typed inputs must specify `input_types` explicitly. 559 name: (Optional.) A string name for the operation. If not specified, a 560 name is generated based on `op_type`. 561 attrs: (Optional.) A dictionary where the key is the attribute name (a 562 string) and the value is the respective `attr` attribute of the 563 `NodeDef` proto that will represent the operation (an `AttrValue` 564 proto). 565 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 566 the operation will have. 567 compute_device: (Optional.) If True, device functions will be executed 568 to compute the device property of the Operation. 569 570 Returns: 571 An `Operation` object. 572 """ 573 if self.capture_by_value and op_type in ["ReadVariableOp", 574 "ResourceGather"]: 575 return self._capture_by_value(op_type, inputs, dtypes, input_types, name, 576 attrs, op_def, compute_device) 577 578 # This capturing logic interacts poorly with control flow contexts which 579 # want to replace inputs of ops far too late in the process. This can lead 580 # the context to get confused and try to create an Enter for an Enter. We 581 # can detect this here and skip the additional Enter which can confuse loop 582 # validation logic. 583 if op_type == "Enter" and inputs[0].op.type == "Enter": 584 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 585 return inputs[0].op 586 # Calling AddValue on the control flow contexts to force creation of the 587 # backward accumulators in the original graph before we create placeholders 588 # to capture the inputs. 589 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 590 # Use a different list to avoid modifying the original inputs list. 591 captured_inputs = [] 592 for inp in inputs: 593 # TPU Estimator defines a control flow context with no AddValue method. 594 if ctxt is not None and hasattr(ctxt, "AddValue"): 595 inp = ctxt.AddValue(inp) 596 inp = self.capture(inp) 597 captured_inputs.append(inp) 598 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access 599 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def, 600 compute_device) 601 602 def capture(self, tensor, name=None, shape=None): 603 """Captures `tensor` if it's external to this graph. 604 605 If `tensor` is from a different graph, returns a placeholder for it. 606 `tensor` and the placeholder will appear in self.captures, and the 607 placeholder will appear in self.inputs. Multiple calls to this method with 608 the same `tensor` argument will return the same placeholder. If `tensor` is 609 from this graph, returns `tensor`. 610 611 Args: 612 tensor: Tensor. May be from this FuncGraph or a different graph. 613 name: Optional name if a placeholder is created. 614 shape: Optional shape if a placeholder is created. 615 616 Returns: 617 Tensor from this FuncGraph. 618 619 Raises: 620 InaccessibleTensorError: if any tensors are accessed in a manner that 621 bypasses the mechanisms required for the data dependencies to be correctly 622 wired. 623 """ 624 if isinstance(tensor, ops.EagerTensor): 625 if name is None: 626 name = str(ops.uid()) 627 628 # Small EagerTensors are captured with Const ops 629 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and 630 np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): 631 return self.capture_eager_tensor(tensor, name) 632 633 # Large EagerTensors and resources are captured with Placeholder ops 634 return self._capture_helper(tensor, name, shape) 635 if tensor.graph is not self: 636 if name is None: 637 name = tensor.op.name 638 inner_graph = tensor.graph 639 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 640 if inner_graph is self: 641 raise errors.InaccessibleTensorError( 642 "The tensor '%s' cannot be accessed here: it is defined" 643 " in another function or code block. Use return values," 644 " explicit Python locals or TensorFlow collections to access" 645 " it. Defined in: %s; accessed from: %s.\n" 646 % (tensor, tensor.graph, self)) 647 inner_graph = inner_graph.outer_graph 648 return self._capture_helper(tensor, name) 649 return tensor 650 651 def _capture_helper(self, tensor, name, shape=None): 652 capture = self._captures.get(id(tensor)) 653 if capture is None: 654 placeholder = _create_substitute_placeholder( 655 tensor, name=name, dtype=tensor.dtype, shape=shape) 656 # Record the composite device as an attribute to the placeholder. 657 # This attribute would be propogated into the arg_attr of the FunctionDef. 658 # Currently, a packed eager tensor is always placed on a CompositeDevice. 659 if isinstance(tensor, ops.EagerTensor) and tensor.is_packed: 660 placeholder.op._set_attr( # pylint: disable=protected-access 661 "_composite_device", 662 attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device))) 663 self.add_capture(tensor, placeholder) 664 else: 665 placeholder = capture[1] 666 tape.record_operation("captured_value", [placeholder], [tensor], 667 backward_function=lambda x: [x], 668 forward_function=lambda x: [x]) 669 return placeholder 670 671 @property 672 def captures(self): 673 """Order list of tuples containing external and internal captures.""" 674 return self._captures.values() 675 676 def add_capture(self, tensor, placeholder): 677 """Capture a specific tensor and utilize the provided placeholder. 678 679 Args: 680 tensor: Tensor to captures. 681 placeholder: Provided placeholder for the tensor. 682 """ 683 self._captures[id(tensor)] = (tensor, placeholder) 684 self.inputs.append(placeholder) 685 686 def replace_capture(self, tensor, placeholder): 687 """Replace already existing capture.""" 688 self._captures[id(tensor)] = (tensor, placeholder) 689 690 def reset_captures(self, capture_list): 691 """Set the captures with the provided list of captures & placeholder.""" 692 self._captures = py_collections.OrderedDict() 693 for tensor, placeholder in capture_list: 694 self._captures[id(tensor)] = (tensor, placeholder) 695 696 def pop_capture(self, tensor): 697 """Remove the capture and return the generated placeholder.""" 698 capture = self._captures.pop(id(tensor), None) 699 if capture is None: 700 return None 701 702 return capture[1] 703 704 def clear_captures(self): 705 # TODO(b/115366440): Delete this method when a custom OrderedDict is added. 706 # Clearing captures using clear() leaves some cycles around. 707 while self._captures: 708 self._captures.popitem() 709 memory.dismantle_ordered_dict(self._captures) 710 while self._deferred_captures: 711 self._deferred_captures.popitem() 712 memory.dismantle_ordered_dict(self._deferred_captures) 713 714 def capture_distributed_variable(self, variable, placeholder): 715 """Add given distributed variable to captures with given placeholder.""" 716 self._captures[id(variable)] = (variable, placeholder) 717 tape.record_operation("captured_value", [placeholder], [variable], 718 backward_function=lambda x: [x], 719 forward_function=lambda x: [x]) 720 721 def capture_eager_tensor(self, tensor, name): 722 capture = self._captures.get(id(tensor)) 723 if capture is None: 724 # We clear all control dependencies and place the Const op on the same 725 # device as the source tensor. The device placement may be relaxed at 726 # a later date. 727 with ops.control_dependencies(None), self.device(tensor.device): 728 constant_value = tensor_util.constant_value(tensor) 729 if constant_value is None: 730 # Some eager tensors, e.g. parallel tensors, are not convertible to a 731 # single constant. We'll use a placeholder for this case. 732 return self._capture_helper(tensor, name) 733 graph_const = constant_op.constant(constant_value, dtype=tensor.dtype, 734 shape=tensor.shape, name=name) 735 self.add_capture(tensor, graph_const) 736 else: 737 graph_const = capture[1] 738 tape.record_operation("captured_value", [graph_const], [tensor], 739 backward_function=lambda x: [x], 740 forward_function=lambda x: [x]) 741 return graph_const 742 743 def captured(self, tensor): 744 """Check if the specified tensor has been captured.""" 745 return id(tensor) in self._captures 746 747 @property 748 def external_captures(self): 749 """External tensors captured by this function.""" 750 return [c[0] for c in self._captures.values()] 751 752 @property 753 def internal_captures(self): 754 """Placeholders in this function corresponding captured tensors.""" 755 return [c[1] for c in self._captures.values()] 756 757 @property 758 def deferred_external_captures(self): 759 """Ordered nest of tensors whose placeholders will be fed at call time.""" 760 return [c[0] for c in self._deferred_captures.values()] 761 762 @property 763 def deferred_internal_captures(self): 764 """List of nest of placeholders which at call time will be fed.""" 765 return [c[1] for c in self._deferred_captures.values()] 766 767 @property 768 def variable_captures(self): 769 """Map of python object ids of variables to variables which are captured.""" 770 return { 771 id(self._captures[id(v)][1]): v 772 for v in self.variables 773 if id(v) in self._captures 774 } 775 776 def mark_as_unsaveable(self, error_message): 777 """Marks this FuncGraph as unsaveable. 778 779 Any attempts to export this FuncGraph will raise an error with the specified 780 message. 781 782 Args: 783 error_message: List or string containing the error message to be raised 784 when saving this FuncGraph to SavedModel. 785 """ 786 self._saveable = False 787 if isinstance(error_message, str): 788 error_message = [error_message] 789 self._saving_errors.update(error_message) 790 791 @property 792 def saveable(self): 793 """Returns whether this FuncGraph is saveable.""" 794 return self._saveable 795 796 @property 797 def saving_errors(self): 798 """Returns set of errors preventing this FuncGraph from being saved.""" 799 return self._saving_errors 800 801 def _add_scope_exit_callback(self, fn): 802 """Add a function to call when this graph exits the default scope.""" 803 if not callable(fn): 804 raise TypeError("fn is not callable: {}".format(fn)) 805 if self._scope_exit_callbacks is None: 806 raise RuntimeError( 807 "Attempting to add a scope exit callback, but the default graph is " 808 "not the context scope graph. Did you forget to call " 809 "'with graph.as_default(): ...'?") 810 self._scope_exit_callbacks.append(fn) 811 812 813def func_graph_from_py_func(name, 814 python_func, 815 args, 816 kwargs, 817 signature=None, 818 func_graph=None, 819 autograph=False, 820 autograph_options=None, 821 add_control_dependencies=True, 822 arg_names=None, 823 op_return_value=None, 824 collections=None, 825 capture_by_value=None, 826 override_flat_arg_shapes=None): 827 """Returns a `FuncGraph` generated from `python_func`. 828 829 Args: 830 name: an identifier for the function. 831 python_func: the Python function to trace. 832 args: the positional args with which the Python function should be called; 833 ignored if a signature is provided. 834 kwargs: the keyword args with which the Python function should be called; 835 ignored if a signature is provided. 836 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 837 and dtypes of the arguments. When a signature is provided, `args` and 838 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 839 to `signature`. If `None`, the shapes and dtypes are inferred from the 840 inputs. 841 func_graph: Optional. An instance of FuncGraph. If provided, we will use 842 this graph else a new one is built and returned. 843 autograph: whether to use autograph to compile `python_func`. 844 See https://www.tensorflow.org/guide/autograph for more information. 845 autograph_options: additional knobs to control when `autograph=True`. 846 See https://www.tensorflow.org/guide/autograph for more information. 847 add_control_dependencies: If True, automatically adds control dependencies 848 to ensure program order matches execution order and stateful ops always 849 execute. 850 arg_names: Optional list of argument names, used to give input placeholders 851 recognizable names. 852 op_return_value: Optional. A Tensor. If set and `python_func` returns 853 Operations, those return values will be replaced with this value. If not 854 set, returning an Operation triggers an error. 855 collections: a dictionary of collections this FuncGraph should start 856 with. If not specified (None), the FuncGraph will read (but not write to) 857 the outer graph's collections that are not allowlisted, and both 858 read and write to the outer graph's collections that are allowlisted. 859 The current allowlisted collections are the global variables, the 860 local variables, and the trainable variables. 861 Defaults to None. 862 capture_by_value: An optional boolean. If True, the func graph will capture 863 Variables by value instead of reference. By default inherit from outer 864 graphs, and failing that will default to False. 865 override_flat_arg_shapes: An optional list of instances that are either 866 `None` or `TensorShape`. The length must match that of 867 `nest.flatten((args, kwargs), expand_composites=True)`. The entries 868 containing value `None` must match entries in flattened arguments 869 containing non-tensors, while entries containing a `TensorShape` must 870 match entries in the flattened arguments containing tensors. 871 872 Returns: 873 A FuncGraph. 874 875 Raises: 876 TypeError: If any of `python_func`'s return values is neither `None` nor a 877 `Tensor`. 878 ValueError: If both `signature` and `override_flat_arg_shapes` are 879 passed in. 880 """ 881 if op_return_value is not None: 882 assert isinstance(op_return_value, ops.Tensor), op_return_value 883 if func_graph is None: 884 func_graph = FuncGraph(name, collections=collections, 885 capture_by_value=capture_by_value) 886 assert isinstance(func_graph, FuncGraph) 887 if add_control_dependencies: 888 deps_control_manager = auto_control_deps.AutomaticControlDependencies() 889 else: 890 deps_control_manager = ops.NullContextmanager() 891 892 with func_graph.as_default(), deps_control_manager as deps_ctx: 893 current_scope = variable_scope.get_variable_scope() 894 default_use_resource = current_scope.use_resource 895 current_scope.set_use_resource(True) 896 897 if signature is not None and override_flat_arg_shapes is not None: 898 raise ValueError( 899 "Passed both signature and override_flat_arg_shapes: %s and %s." 900 % (signature, override_flat_arg_shapes)) 901 902 if signature is not None: 903 args = signature 904 kwargs = {} 905 906 # Creates and names placeholders for all arguments. 907 if override_flat_arg_shapes is not None: 908 flat_args = nest.flatten(args, expand_composites=True) 909 arg_shapes = override_flat_arg_shapes[:len(flat_args)] 910 kwarg_shapes = override_flat_arg_shapes[len(flat_args):] 911 else: 912 arg_shapes = None 913 kwarg_shapes = None 914 func_args = _get_defun_inputs_from_args( 915 args, arg_names, flat_shapes=arg_shapes) 916 func_kwargs = _get_defun_inputs_from_kwargs( 917 kwargs, flat_shapes=kwarg_shapes) 918 919 # Convert all Tensors into TensorSpecs before saving the structured inputs. 920 # If storing pure concrete functions that are not called through polymorphic 921 # functions, we don't have access to FunctionSpec, so we need to call the 922 # TensorSpecs by their `arg_names` for later binding. 923 func_graph.structured_input_signature = ( 924 convert_structure_to_signature(func_args, arg_names), 925 convert_structure_to_signature(func_kwargs)) 926 927 flat_func_args = nest.flatten(func_args, expand_composites=True) 928 flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) 929 # Temporarily set inputs to allow graph building code to inspect 930 # them. Reassigned below. 931 func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs 932 if isinstance(arg, ops.Tensor)] 933 934 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 935 # Variables to help check whether mutation happens in calling the function 936 # Copy the recursive list, tuple and map structure, but not base objects 937 func_args_before = nest.pack_sequence_as(func_args, flat_func_args, 938 expand_composites=True) 939 func_kwargs_before = nest.pack_sequence_as( 940 func_kwargs, flat_func_kwargs, expand_composites=True) 941 942 def convert(x): 943 """Converts a function output to a Tensor.""" 944 if x is None: 945 return None 946 if op_return_value is not None and isinstance(x, ops.Operation): 947 # TODO(b/79881896): we currently can't capture external control deps, so 948 # this won't work if x needs to be captured (i.e. if python_func returns 949 # captured Operations). 950 with ops.control_dependencies([x]): 951 x = array_ops.identity(op_return_value) 952 elif not isinstance(x, tensor_array_ops.TensorArray): 953 try: 954 x = ops.convert_to_tensor_or_composite(x) 955 except (ValueError, TypeError): 956 raise TypeError( 957 "To be compatible with tf.eager.defun, Python functions " 958 "must return zero or more Tensors; in compilation of %s, found " 959 "return value of type %s, which is not a Tensor." % 960 (str(python_func), type(x))) 961 if add_control_dependencies: 962 x = deps_ctx.mark_as_return(x) 963 return x 964 965 try: 966 if autograph: 967 from tensorflow.python import autograph # pylint: disable=g-import-not-at-top 968 _, original_func = tf_decorator.unwrap(python_func) 969 970 def wrapper(*args, **kwargs): 971 """Calls a converted version of original_func.""" 972 # TODO(mdan): Push this block higher in tf.function's call stack. 973 try: 974 return autograph.converted_call( 975 original_func, 976 args, 977 kwargs, 978 options=autograph.ConversionOptions( 979 recursive=True, 980 optional_features=autograph_options, 981 user_requested=True, 982 )) 983 except Exception as e: # pylint:disable=broad-except 984 if hasattr(e, "ag_error_metadata"): 985 raise e.ag_error_metadata.to_exception(e) 986 else: 987 raise 988 989 # Wrapping around a decorator allows checks like tf_inspect.getargspec 990 # to be accurate. 991 converted_func = tf_decorator.make_decorator(original_func, wrapper) 992 python_func = tf_decorator.rewrap(python_func, original_func, 993 converted_func) 994 995 else: 996 _, original_func = tf_decorator.unwrap(python_func) 997 998 func_outputs = python_func(*func_args, **func_kwargs) 999 1000 # invariant: `func_outputs` contains only Tensors, CompositeTensors, 1001 # TensorArrays and `None`s. 1002 func_outputs = nest.map_structure(convert, func_outputs, 1003 expand_composites=True) 1004 1005 check_mutation(func_args_before, func_args, original_func) 1006 check_mutation(func_kwargs_before, func_kwargs, original_func) 1007 finally: 1008 current_scope.set_use_resource(default_use_resource) 1009 1010 # Variables in `func_args`, `func_kwargs` should be explicit inputs 1011 # to the function, not captured inputs. 1012 graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access 1013 arg_variables = object_identity.ObjectIdentitySet() 1014 inputs = [] 1015 for arg in (nest.flatten(func_args, expand_composites=True) + 1016 nest.flatten(func_kwargs, expand_composites=True)): 1017 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1018 # Even if an argument variable was not used in the function, we've 1019 # already manually captured the resource Tensor when creating argument 1020 # placeholders. 1021 resource_placeholder = func_graph.pop_capture(arg.handle) 1022 if resource_placeholder is None: 1023 continue 1024 arg_variables.add(arg) 1025 inputs.append(resource_placeholder) 1026 elif isinstance(arg, ops.Tensor): 1027 inputs.append(arg) 1028 variables = [v for v in graph_variables if v not in arg_variables] 1029 func_graph.inputs = ( 1030 inputs + func_graph.internal_captures + nest.flatten( 1031 func_graph.deferred_internal_captures, expand_composites=True)) 1032 func_graph.structured_outputs = func_outputs 1033 # Returning a closed-over tensor does not trigger convert_to_tensor. 1034 func_graph.outputs.extend( 1035 func_graph.capture(x) 1036 for x in flatten(func_graph.structured_outputs) 1037 if x is not None) 1038 1039 func_graph.variables = variables 1040 1041 if add_control_dependencies: 1042 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) 1043 func_graph.collective_manager_ids_used = ( 1044 deps_control_manager.collective_manager_ids_used) 1045 1046 return func_graph 1047 1048 1049def maybe_captured(tensor): 1050 """If t is a captured value placeholder, returns the original captured value. 1051 1052 Args: 1053 tensor: Tensor. 1054 1055 Returns: 1056 A tensor, potentially from a different Graph/FuncGraph. 1057 """ 1058 if (not isinstance(tensor, ops.EagerTensor) and 1059 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 1060 for input_t, placeholder_t in tensor.op.graph.captures: 1061 if tensor == placeholder_t: 1062 return maybe_captured(input_t) 1063 # pylint: enable=protected-access 1064 return tensor 1065 1066 1067def device_stack_has_callable(device_stack): 1068 """Checks whether a device stack contains a callable.""" 1069 return any(callable(spec._device_name_or_function) # pylint: disable=protected-access 1070 for spec in device_stack.peek_objs()) 1071 1072 1073def check_mutation(n1, n2, func): 1074 """Check if two list of arguments are exactly the same.""" 1075 func_name = getattr(func, "__name__", func) 1076 1077 errmsg = ("{}() should not modify its Python input arguments." 1078 " Check if it modifies any lists or dicts passed as" 1079 " arguments. Modifying a copy is allowed.".format(func_name)) 1080 try: 1081 # TODO(mdan): Compare more robustly so that argument names can be reported. 1082 nest.assert_same_structure(n1, n2, expand_composites=True) 1083 except ValueError: 1084 raise ValueError(errmsg) 1085 1086 for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True), 1087 nest.flatten(n2, expand_composites=True)): 1088 if arg1 is not arg2: 1089 raise ValueError(errmsg) 1090 1091 1092# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1093def flatten(sequence): 1094 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. 1095 1096 Args: 1097 sequence: A nested structure of Tensors, CompositeTensors, and 1098 TensorArrays. 1099 1100 Returns: 1101 A list of tensors. 1102 """ 1103 flat_sequence = nest.flatten(sequence, expand_composites=True) 1104 return [ 1105 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 1106 for item in flat_sequence] 1107 1108 1109# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1110def pack_sequence_as(structure, flat_sequence): 1111 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. 1112 1113 Args: 1114 structure: The structure to pack into. May contain Tensors, 1115 CompositeTensors, or TensorArrays. 1116 flat_sequence: An iterable containing tensors. 1117 1118 Returns: 1119 A nested structure. 1120 1121 Raises: 1122 AssertionError if `structure` and `flat_sequence` are not compatible. 1123 """ 1124 flat_sequence = list(flat_sequence) 1125 flattened_structure = nest.flatten(structure, expand_composites=True) 1126 if len(flattened_structure) != len(flat_sequence): 1127 raise ValueError("Mismatch in element count") 1128 for i in range(len(flat_sequence)): 1129 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 1130 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 1131 old_ta=flattened_structure[i], flow=flat_sequence[i]) 1132 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 1133 1134 1135def _create_substitute_placeholder(value, name=None, dtype=None, shape=None): 1136 """Creates a placeholder for `value` and propagates shape info to it.""" 1137 # Note: setting ops.control_dependencies(None) ensures we always put 1138 # capturing placeholders outside of any control flow context. 1139 if shape is None: 1140 shape = value.shape 1141 with ops.control_dependencies(None): 1142 placeholder = graph_placeholder( 1143 dtype=dtype or value.dtype, shape=shape, name=name) 1144 custom_gradient.copy_handle_data(value, placeholder) 1145 return placeholder 1146 1147 1148def _get_defun_inputs_from_args(args, names, flat_shapes=None): 1149 """Maps Python function positional args to graph-construction inputs.""" 1150 return _get_defun_inputs( 1151 args, names, structure=args, flat_shapes=flat_shapes) 1152 1153 1154def _get_composite_tensor_spec(x): 1155 """Returns the TypeSpec for x if it's a composite tensor, or x otherwise.""" 1156 return (x._type_spec # pylint: disable=protected-access 1157 if isinstance(x, composite_tensor.CompositeTensor) else x) 1158 1159 1160def _get_defun_inputs(args, names, structure, flat_shapes=None): 1161 """Maps python function args to graph-construction inputs. 1162 1163 Args: 1164 args: A flat list of user-specified arguments. 1165 names: A list of strings with user-specified argument names, same length as 1166 `args`. May be `None`, in which case a generic name is used. 1167 structure: The original argument list or dictionary. 1168 flat_shapes: A flat list of values that are either `None` or 1169 instances of `TensorShape`. If provided, then length must match 1170 that of `nest.flatten(args, expand_composites=True)`; and locations where 1171 `args` are instances of `Tensor` must have a corresponding `TensorShape` 1172 in `flat_shapes`. May be `None`, in which case exact shapes are read 1173 directly from the args. 1174 1175 Returns: 1176 Placeholders with the same structure as `structure`. 1177 1178 Raises: 1179 RuntimeError: if `flat_shapes` is provided, but 1180 `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`. 1181 RuntimeError: if a shape from `flat_shapes` is not None 1182 for an argument that is not a `Tensor`, `TensorSpec`, 1183 or `ResourceVariable`. 1184 """ 1185 func_graph = ops.get_default_graph() 1186 function_inputs = [] 1187 if names is None: 1188 names = [None] * len(args) 1189 if flat_shapes is None: 1190 shapes_iter = itertools.repeat(None) 1191 else: 1192 len_flat_args = len(nest.flatten(args, expand_composites=True)) 1193 if len_flat_args != len(flat_shapes): 1194 raise RuntimeError( 1195 "Length of fully flat shapes (%d) must match that of " 1196 "flatten(args) (%d). args: %s, flat_shapes: %s" 1197 % (len(flat_shapes), 1198 len_flat_args, 1199 args, 1200 flat_shapes)) 1201 shapes_iter = iter(flat_shapes) 1202 for arg_value, name in zip(args, names): 1203 1204 # Replace any composite tensors with their TypeSpecs. This is important 1205 # for ensuring that shape information that's not preserved by the TypeSpec 1206 # (such as the number of values in a SparseTensor) gets properly masked. 1207 arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value) 1208 1209 flattened = nest.flatten(arg_value, expand_composites=True) 1210 1211 for arg in flattened: 1212 # We have a shape entry for each arg, regardless of whether it's a real 1213 # Tensor or not. For non-tensor entries it should be None. 1214 shape = next(shapes_iter) 1215 if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): 1216 arg_is_spec = isinstance(arg, tensor_spec.TensorSpec) 1217 if arg_is_spec and arg.name: 1218 requested_name = arg.name 1219 else: 1220 requested_name = name 1221 placeholder_shape = shape if shape is not None else arg.shape 1222 try: 1223 placeholder = graph_placeholder( 1224 arg.dtype, placeholder_shape, 1225 name=requested_name) 1226 except ValueError: 1227 # Sometimes parameter names are not valid op names, so fall back to 1228 # unnamed placeholders. 1229 placeholder = graph_placeholder(arg.dtype, placeholder_shape) 1230 if not arg_is_spec: 1231 custom_gradient.copy_handle_data(arg, placeholder) 1232 if name is not None: 1233 # Record the requested/user-specified name in case it's different than 1234 # the uniquified name, for validation when exporting signatures. 1235 placeholder.op._set_attr( # pylint: disable=protected-access 1236 "_user_specified_name", 1237 attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) 1238 function_inputs.append(placeholder) 1239 elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, 1240 resource_variable_ops.VariableSpec)): 1241 if isinstance(arg, resource_variable_ops.VariableSpec): 1242 name = arg.name or name 1243 with func_graph.outer_graph.as_default(): 1244 placeholder = graph_placeholder(dtypes.resource, arg.shape, 1245 name=name) 1246 1247 arg = resource_variable_ops.BaseResourceVariable( 1248 name=name, 1249 shape=arg.shape, 1250 dtype=arg.dtype, 1251 handle=placeholder, 1252 handle_name=name) 1253 # Capture arg variables to create placeholders for them. These will be 1254 # removed as captures after the function is traced (since otherwise we'd 1255 # just add it back with a new placeholder when the variable was 1256 # referenced). 1257 placeholder = func_graph.capture(arg.handle, name=name) 1258 placeholder.op._set_attr( # pylint: disable=protected-access 1259 "_user_specified_name", 1260 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 1261 function_inputs.append(arg) 1262 else: 1263 if shape is not None: 1264 raise RuntimeError( 1265 "Expected provided shape override to be None for arg that isn't " 1266 "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" 1267 % (arg, shape, args)) 1268 function_inputs.append(arg) 1269 return nest.pack_sequence_as(structure, function_inputs, 1270 expand_composites=True) 1271 1272 1273def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): 1274 """Maps Python function keyword args to graph-construction inputs.""" 1275 if kwargs: 1276 names, args = zip(*sorted(kwargs.items())) 1277 else: 1278 names = [] 1279 args = [] 1280 return _get_defun_inputs( 1281 args, names, structure=kwargs, flat_shapes=flat_shapes) 1282 1283 1284def dismantle_func_graph(func_graph): 1285 """Removes reference cycles in `func_graph` FuncGraph. 1286 1287 Helpful for making sure the garbage collector doesn't need to run when 1288 the FuncGraph goes out of scope, e.g. in tests using defun with 1289 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 1290 1291 Args: 1292 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable 1293 after this function. 1294 """ 1295 func_graph.clear_captures() 1296 ops.dismantle_graph(func_graph) 1297 1298 1299def override_func_graph_name_scope(func_graph, name_scope): 1300 func_graph._name_stack = name_scope # pylint: disable=protected-access 1301