1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FuncGraph and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as py_collections 22import itertools 23import weakref 24 25import numpy as np 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import execute 30from tensorflow.python.eager import tape 31from tensorflow.python.eager.graph_only_ops import graph_placeholder 32from tensorflow.python.framework import auto_control_deps 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import type_spec 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import custom_gradient 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.ops import tensor_array_ops 44from tensorflow.python.ops import variable_scope 45from tensorflow.python.util import compat 46from tensorflow.python.util import memory 47from tensorflow.python.util import nest 48from tensorflow.python.util import object_identity 49from tensorflow.python.util import tf_contextlib 50from tensorflow.python.util import tf_decorator 51 52WHITELIST_COLLECTIONS = [ 53 ops.GraphKeys.GLOBAL_VARIABLES, 54 ops.GraphKeys.LOCAL_VARIABLES, 55 ops.GraphKeys.TRAINABLE_VARIABLES, 56 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 57 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 58] 59 60 61_EAGER_CONST_THRESHOLD = 128 62 63 64class UnknownArgument(object): 65 """Signifies an argument which is not currently handled.""" 66 pass 67 68 69def convert_structure_to_signature(structure, arg_names=None): 70 """Convert a potentially nested structure to a signature. 71 72 Args: 73 structure: Structure to convert, where top level collection is a list or a 74 tuple. 75 arg_names: Optional list of arguments that has equal number of elements as 76 `structure` and is used for naming corresponding TensorSpecs. 77 78 Returns: 79 Identical structure that has TensorSpec objects instead of Tensors and 80 UknownArgument instead of any unsupported types. 81 """ 82 def encode_arg(arg, path): 83 """A representation for this argument, for converting into signatures.""" 84 if isinstance(arg, ops.Tensor): 85 user_specified_name = None 86 try: 87 user_specified_name = compat.as_str( 88 arg.op.get_attr("_user_specified_name")) 89 except ValueError: 90 pass 91 92 if path and user_specified_name and user_specified_name != path[0]: 93 # The user has explicitly named the argument differently than the name 94 # of the function argument. 95 name = user_specified_name 96 else: 97 name = "/".join(str(p) for p in path) 98 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 99 if isinstance(arg, composite_tensor.CompositeTensor): 100 # TODO(b/133606651) Do we need to inject arg_name? 101 return arg._type_spec # pylint: disable=protected-access 102 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 103 name = "/".join(str(p) for p in path) 104 return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name) 105 if isinstance(arg, ( 106 int, 107 float, 108 bool, 109 type(None), 110 dtypes.DType, 111 tensor_spec.TensorSpec, 112 type_spec.TypeSpec, 113 )): 114 return arg 115 return UnknownArgument() 116 117 # We are using the flattened paths to name the TensorSpecs. We need an 118 # explicit name for them downstream. 119 flattened = nest.flatten_with_tuple_paths(structure) 120 if arg_names: 121 if len(arg_names) != len(structure): 122 raise ValueError( 123 "Passed in arg_names don't match actual signature (%s)." % arg_names) 124 # Replace all top-level names with their actual arg_names. If a path before 125 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 126 flattened = [ 127 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 128 ] 129 130 mapped = [encode_arg(arg, path) for path, arg in flattened] 131 return nest.pack_sequence_as(structure, mapped) 132 133 134class FuncGraph(ops.Graph): 135 """Graph representing a function body. 136 137 Attributes: 138 name: The name of the function. 139 inputs: Placeholder tensors representing the inputs to this function. The 140 tensors are in this FuncGraph. This represents "regular" inputs as well as 141 captured inputs (i.e. the values of self.captures), with the regular 142 inputs coming first. 143 outputs: Tensors that will be returned by this function. The tensors are in 144 this FuncGraph. 145 control_outputs: Operations that must be executed before the function 146 represented by this graph can be said to have been executed. 147 structured_input_signature: A tuple of (args, kwargs), which are both 148 possibly-nested python objects that were received by this function. Note 149 that these structures might contain Python `None`s. 150 structured_outputs: A possibly-nested python object which will be returned 151 by this function. The Tensors in this structure are the same as those of 152 self.outputs. Note that this structure might contain Python `None`s. 153 variables: Variables that should be watched during function execution. 154 outer_graph: The graph this function is defined in. May be another FuncGraph 155 or the global default Graph. 156 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 157 The entries are in the order they were captured. 158 control_captures: Set of external ops on which this graph has a control 159 dependency. 160 seed: The graph-level random seed. 161 capture_by_value: If True, the func graph will capture Variables by value 162 instead of reference. 163 """ 164 165 def __init__(self, name, collections=None, capture_by_value=None): 166 """Construct a new FuncGraph. 167 168 The graph will inherit its graph key, collections, seed, and distribution 169 strategy stack from the current context or graph. 170 171 Args: 172 name: the name of the function. 173 collections: a dictionary of collections this FuncGraph should start 174 with. If not specified (None), the FuncGraph will read (but not write 175 to) the outer graph's collections that are not whitelisted, and both 176 read and write to the outer graph's collections that are whitelisted. 177 The current whitelisted collections are the global variables, the 178 local variables, and the trainable variables. 179 Defaults to None. 180 capture_by_value: An optional boolean. If True, the func graph will 181 capture Variables by value instead of reference. By default inherit 182 from outer graphs, and failing that will default to False. 183 """ 184 super(FuncGraph, self).__init__() 185 186 self.name = name 187 self.inputs = [] 188 self.outputs = [] 189 self.control_outputs = [] 190 self.control_captures = set() 191 self.structured_input_signature = None 192 self.structured_outputs = None 193 self._weak_variables = [] 194 self._watched_variables = object_identity.ObjectIdentityWeakSet() 195 196 outer_graph = ops.get_default_graph() 197 self._weak_outer_graph = weakref.ref(outer_graph) 198 while outer_graph.building_function: 199 outer_graph = outer_graph.outer_graph 200 # If self._weak_outer_graph is deleted, we revert to the outermost Graph 201 # active when the FuncGraph was traced. This will not be a FuncGraph. 202 self._fallback_outer_graph = outer_graph 203 self._captures = py_collections.OrderedDict() 204 # If not None, records the names of output args of this function. Used to 205 # preserve the output names in the signature of a serialized+deserialized 206 # function. Private at the moment mostly because it's often out of date. 207 self._output_names = None 208 # Maps arbitrary key -> (closure, nest of placeholders), where at function 209 # call time the value of closure() will be used to feed the nest of 210 # placeholders. 211 self._deferred_captures = py_collections.OrderedDict() 212 # Inherit capture-by-value from outer graph. 213 if capture_by_value is not None: 214 self.capture_by_value = capture_by_value 215 elif self.outer_graph is not None and isinstance( 216 self.outer_graph, FuncGraph): 217 self.capture_by_value = self.outer_graph.capture_by_value 218 else: 219 self.capture_by_value = False 220 221 self._building_function = True 222 # Map from resource tensor name to last op (in program order) which uses 223 # this tensor. Used to enforce that execution order matches program order 224 # for resource tensors. 225 self._last_op_using_resource_tensor = {} 226 227 graph = self.outer_graph 228 229 if context.executing_eagerly(): 230 self.seed = context.global_seed() 231 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 232 # any None op_seed for random_op in the function, in which case we end up 233 # using function seed, which could be unintended behavior for the op. 234 self._seed_used = False 235 else: 236 self.seed = graph.seed 237 self._seed_used = False 238 # TODO(allenl): Figure out if we can remove colocation stack 239 # specialization (currently used in cond_v2), here and in the cache key. 240 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 241 242 if collections is None: 243 for collection_name in graph.get_all_collection_keys(): 244 if collection_name not in WHITELIST_COLLECTIONS: 245 self._collections[collection_name] = graph.get_collection( 246 collection_name) 247 for collection_name in WHITELIST_COLLECTIONS: 248 self._collections[collection_name] = graph.get_collection_ref( 249 collection_name) 250 else: 251 self._collections = collections 252 253 # Keep track of whether this FuncGraph is exportable to SavedModel. Use 254 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any 255 # dependent functions as unsaveable. 256 self._saveable = True 257 self._saving_errors = set() 258 259 # Keep track of callbacks to run when this graph exits default scope 260 self._scope_exit_callbacks = None 261 262 def __str__(self): 263 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 264 265 def watch_variable(self, v): 266 """Marks the variable v as accessed while building this graph.""" 267 while self is not None and isinstance(self, FuncGraph): 268 self._watched_variables.add(v) 269 self = self.outer_graph 270 271 def capture_call_time_value(self, closure, spec, key=None): 272 """Creates a placeholder which at call time has the value closure(). 273 274 Useful, for example, to respect TensorFlow context managers, which are often 275 dynamically scoped. 276 277 Args: 278 closure: function which takes no arguments, to be evaluated at function 279 call time, returning a nest of tensors compatible with `spec`. 280 spec: nest of TypeSpec for the value to capture. 281 key: optional. If not None, multiple calls to lazy_capture with the same 282 key in the same graph will return the same placeholder, and the 283 first closure will be used at function call time. 284 285 Returns: 286 Nest of placeholders which, at function call time, will be fed with the 287 result of calling closure(). 288 289 Raises: 290 ValueError: at function call time, if the return value of closure() is 291 not compatible with `spec`. 292 """ 293 if key is None: 294 key = object() 295 if key not in self._deferred_captures: 296 297 def convert_to_placeholder(s): 298 if not isinstance(s, tensor_spec.DenseSpec): 299 raise TypeError( 300 "Expected a nest of `TypeSpec` objects, found %s of type %s." % 301 (s, type(s))) 302 return array_ops.placeholder(dtype=s.dtype, shape=s.shape) 303 304 placeholder = nest.map_structure( 305 convert_to_placeholder, spec, expand_composites=True) 306 307 def wrapped_closure(): 308 ret_nest = closure() 309 nest.assert_same_structure(spec, ret_nest, expand_composites=True) 310 # This uses the tensor dtype defined in `spec` when converting values 311 # in `ret_nest` to tensors. 312 # pylint: disable=protected-access 313 y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest, 314 expand_composites=False) 315 # pylint: enable=protected-access 316 return nest.flatten(y, expand_composites=True) 317 318 self._deferred_captures[key] = (wrapped_closure, placeholder) 319 return self._deferred_captures[key][1] 320 321 def control_dependencies(self, control_inputs): 322 """Handles control dependencies. 323 324 FuncGraph wraps Graph's control_dependencies logic by first filtering out 325 any external tensors / operations and storing them in the graph's 326 control_captures member. Any consumers of this function graph must then 327 decide how to handle the control captures. 328 329 Args: 330 control_inputs: A list of `Operation` or `Tensor` objects which 331 must be executed or computed before running the operations 332 defined in the context. Can also be `None` to clear the control 333 dependencies. 334 335 Returns: 336 A context manager that specifies control dependencies for all 337 operations constructed within the context. 338 339 Raises: 340 TypeError: If `control_inputs` is not a list of `Operation` or 341 `Tensor` objects. 342 """ 343 if control_inputs is None: 344 return super(FuncGraph, self).control_dependencies(control_inputs) 345 346 filtered_control_inputs = [] 347 for c in control_inputs: 348 # Check for _UnreadVariable 349 if (isinstance(c, ops.IndexedSlices) or 350 (hasattr(c, "_handle") and hasattr(c, "op"))): 351 c = c.op 352 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 353 if graph_element is None: 354 graph_element = c 355 if graph_element is not None and getattr( 356 graph_element, "graph", None) is not self: 357 self.control_captures.add(graph_element) 358 else: 359 filtered_control_inputs.append(graph_element) 360 return super(FuncGraph, self).control_dependencies(filtered_control_inputs) 361 362 def as_default(self): 363 outer_cm = super(FuncGraph, self).as_default() 364 365 @tf_contextlib.contextmanager 366 def inner_cm(): 367 """Context manager for copying distribute.Strategy scope information.""" 368 graph = ops.get_default_graph() 369 # pylint: disable=protected-access 370 # TODO(b/112906995, nareshmodi): distribution strategy depends on 371 # inheriting this stack from the default graph even in eager mode. Maybe 372 # it should be part of the eager context? This would also allow us to 373 # remove a get_default_graph() call from the function cache lookup. 374 old_strategy_stack = self._distribution_strategy_stack 375 self._distribution_strategy_stack = list( 376 graph._distribution_strategy_stack) 377 uses_distribution_strategy = ( 378 self._distribution_strategy_stack and 379 self._distribution_strategy_stack[-1].strategy.extended 380 ._retrace_functions_for_each_device) 381 # We ignore device placements from any outer scopes while tracing the 382 # function when possible, to avoid hard-coding them in the function 383 # graph. "Default" placements come from the PartitionedCallOp's placement, 384 # so that the same trace of the Python function may be placed on several 385 # different devices and saved functions may be placed on new devices when 386 # restored. 387 old_device_stack = self._device_function_stack 388 if context.executing_eagerly(): 389 if uses_distribution_strategy: 390 self._device_function_stack = self._device_function_stack.copy() 391 self._add_device_to_stack(context.context().device_name) 392 else: 393 if (uses_distribution_strategy or 394 device_stack_has_callable(graph._device_function_stack)): 395 # Hard-code devices from device functions in the function body 396 self._device_function_stack = graph._device_function_stack.copy() 397 398 old_creator_stack = self._variable_creator_stack 399 self._variable_creator_stack = graph._variable_creator_stack 400 # Inherit the graph key, since this is used for matching variables in 401 # optimizers. 402 old_graph_key = self._graph_key 403 self._graph_key = graph._graph_key 404 # Inherit the auto_cast_variable_read_dtype, since this should not change 405 # inside a function. 406 old_auto_cast_var_read_dtype = self._auto_cast_variable_read_dtype 407 self._auto_cast_variable_read_dtype = graph._auto_cast_variable_read_dtype 408 # pylint: enable=protected-access 409 410 old_scope_exit_callbacks = self._scope_exit_callbacks 411 self._scope_exit_callbacks = [] 412 413 with outer_cm as g: 414 try: 415 yield g 416 finally: 417 try: 418 for fn in self._scope_exit_callbacks: 419 fn() 420 finally: 421 self._scope_exit_callbacks = old_scope_exit_callbacks 422 self._distribution_strategy_stack = old_strategy_stack 423 self._device_function_stack = old_device_stack 424 self._variable_creator_stack = old_creator_stack 425 self._graph_key = old_graph_key 426 self._auto_cast_variable_read_dtype = old_auto_cast_var_read_dtype 427 return inner_cm() 428 429 @property 430 def outer_graph(self): 431 """The Graph this FuncGraph is nested in. 432 433 Functions may capture Tensors from graphs they are nested in (transitive). 434 435 Returns: 436 A Graph object. Initially set to the current default graph when the 437 FuncGraph was created. If the previous `outer_graph` was deleted because 438 the function that owns it was deleted, `outer_graph` is reset to the 439 outermost default graph active when the FuncGraph was created. This 440 FuncGraph won't have captured anything from the new `outer_graph` (and 441 likely not from the previous setting, since that would have created a 442 strong reference), but it is returned so that FuncGraphs always have a 443 parent. 444 """ 445 current = self._weak_outer_graph() 446 if current is None: 447 return self._fallback_outer_graph 448 return current 449 450 @property 451 def output_types(self): 452 return [t.dtype for t in self.outputs] 453 454 @property 455 def output_shapes(self): 456 return [t.shape for t in self.outputs] 457 458 @property 459 def trainable_variables(self): 460 """A sequence of trainable variables accessed by this FuncGraph. 461 462 Note that functions keep only weak references to variables. Calling the 463 function after a variable it accesses has been deleted is an error. 464 465 Returns: 466 Sequence of trainable variables for this func graph. 467 """ 468 return tuple(v for v in self.variables if v.trainable) 469 470 @property 471 def variables(self): 472 """A sequence of variables accessed by this FuncGraph. 473 474 Note that functions keep only weak references to variables. Calling the 475 function after a variable it accesses has been deleted is an error. 476 477 Returns: 478 Sequence of variables for this func graph. 479 """ 480 def deref(weak_v): 481 v = weak_v() 482 if v is None: 483 raise AssertionError( 484 "Called a function referencing variables which have been deleted. " 485 "This likely means that function-local variables were created and " 486 "not referenced elsewhere in the program. This is generally a " 487 "mistake; consider storing variables in an object attribute on " 488 "first call.") 489 return v 490 491 return tuple(deref(v) for v in self._weak_variables) 492 493 @variables.setter 494 def variables(self, var_list): 495 self._weak_variables = [weakref.ref(v) for v in var_list] 496 497 def _capture_by_value( 498 self, 499 op_type, 500 inputs, 501 dtypes, # pylint: disable=redefined-outer-name 502 input_types=None, 503 name=None, 504 attrs=None, 505 op_def=None, 506 compute_device=True): 507 # When capturing by value, do the read outside 508 reverse_captures = dict((id(v), k) for k, v in self.captures) 509 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] 510 with ops.init_scope(): 511 if context.executing_eagerly(): 512 attr_list = ("dtype", int(attrs["dtype"].type)) 513 value, = execute.execute( 514 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 515 context.context()) 516 else: 517 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access 518 op_type, 519 uncaptured_inputs, 520 dtypes, 521 input_types, 522 name, 523 attrs, 524 op_def, 525 compute_device) 526 value = op.outputs[0] 527 captured_value = self.capture(value) 528 return captured_value.op 529 530 def _create_op_internal( 531 self, 532 op_type, 533 inputs, 534 dtypes=None, # pylint: disable=redefined-outer-name 535 input_types=None, 536 name=None, 537 attrs=None, 538 op_def=None, 539 compute_device=True): 540 """Like Graph.create_op, except handles external input tensors. 541 542 This overload adds functionality to create_op to "capture" any external 543 input tensors, i.e. tensors from the eager context or outer function graphs 544 if this is a nested function. See `capture` for more information. 545 546 Args: 547 op_type: The `Operation` type to create. This corresponds to the 548 `OpDef.name` field for the proto that defines the operation. 549 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 550 dtypes: (Optional) A list of `DType` objects that will be the types of the 551 tensors that the operation produces. 552 input_types: (Optional.) A list of `DType`s that will be the types of 553 the tensors that the operation consumes. By default, uses the base 554 `DType` of each input in `inputs`. Operations that expect 555 reference-typed inputs must specify `input_types` explicitly. 556 name: (Optional.) A string name for the operation. If not specified, a 557 name is generated based on `op_type`. 558 attrs: (Optional.) A dictionary where the key is the attribute name (a 559 string) and the value is the respective `attr` attribute of the 560 `NodeDef` proto that will represent the operation (an `AttrValue` 561 proto). 562 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 563 the operation will have. 564 compute_device: (Optional.) If True, device functions will be executed 565 to compute the device property of the Operation. 566 567 Returns: 568 An `Operation` object. 569 """ 570 if self.capture_by_value and op_type in ["ReadVariableOp", 571 "ResourceGather"]: 572 return self._capture_by_value(op_type, inputs, dtypes, input_types, name, 573 attrs, op_def, compute_device) 574 575 # This capturing logic interacts poorly with control flow contexts which 576 # want to replace inputs of ops far too late in the process. This can lead 577 # the context to get confused and try to create an Enter for an Enter. We 578 # can detect this here and skip the additional Enter which can confuse loop 579 # validation logic. 580 if op_type == "Enter" and inputs[0].op.type == "Enter": 581 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 582 return inputs[0].op 583 # Calling AddValue on the control flow contexts to force creation of the 584 # backward accumulators in the original graph before we create placeholders 585 # to capture the inputs. 586 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 587 for i, inp in enumerate(inputs): 588 # TPU Estimator defines a control flow context with no AddValue method. 589 if ctxt is not None and hasattr(ctxt, "AddValue"): 590 inp = ctxt.AddValue(inp) 591 inp = self.capture(inp) 592 inputs[i] = inp 593 return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access 594 op_type, inputs, dtypes, input_types, name, attrs, op_def, 595 compute_device) 596 597 def capture(self, tensor, name=None, shape=None): 598 """Captures `tensor` if it's external to this graph. 599 600 If `tensor` is from a different graph, returns a placeholder for it. 601 `tensor` and the placeholder will appear in self.captures, and the 602 placeholder will appear in self.inputs. Multiple calls to this method with 603 the same `tensor` argument will return the same placeholder. If `tensor` is 604 from this graph, returns `tensor`. 605 606 Args: 607 tensor: Tensor. May be from this FuncGraph or a different graph. 608 name: Optional name if a placeholder is created. 609 shape: Optional shape if a placeholder is created. 610 611 Returns: 612 Tensor from this FuncGraph. 613 614 Raises: 615 InaccessibleTensorError: if any tensors are accessed in a manner that 616 bypasses the mechanisms required for the data dependencies to be correctly 617 wired. 618 """ 619 if isinstance(tensor, ops.EagerTensor): 620 if name is None: 621 name = str(ops.uid()) 622 623 # Small EagerTensors are captured with Const ops 624 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and 625 np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD): 626 return self.capture_eager_tensor(tensor, name) 627 628 # Large EagerTensors and resources are captured with Placeholder ops 629 return self._capture_helper(tensor, name, shape) 630 if tensor.graph is not self: 631 if name is None: 632 name = tensor.op.name 633 inner_graph = tensor.graph 634 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 635 if inner_graph is self: 636 raise errors.InaccessibleTensorError( 637 "The tensor '%s' cannot be accessed here: it is defined" 638 " in another function or code block. Use return values," 639 " explicit Python locals or TensorFlow collections to access" 640 " it. Defined in: %s; accessed from: %s.\n" 641 % (tensor, tensor.graph, self)) 642 inner_graph = inner_graph.outer_graph 643 return self._capture_helper(tensor, name) 644 return tensor 645 646 def _capture_helper(self, tensor, name, shape=None): 647 capture = self._captures.get(id(tensor)) 648 if capture is None: 649 placeholder = _create_substitute_placeholder( 650 tensor, name=name, dtype=tensor.dtype, shape=shape) 651 self.add_capture(tensor, placeholder) 652 else: 653 placeholder = capture[1] 654 tape.record_operation("captured_value", [placeholder], [tensor], 655 backward_function=lambda x: [x], 656 forward_function=lambda x: [x]) 657 return placeholder 658 659 @property 660 def captures(self): 661 """Order list of tuples containing external and internal captures.""" 662 return self._captures.values() 663 664 def add_capture(self, tensor, placeholder): 665 """Capture a specific tensor and utilize the provided placeholder. 666 667 Args: 668 tensor: Tensor to captures. 669 placeholder: Provided placeholder for the tensor. 670 """ 671 self._captures[id(tensor)] = (tensor, placeholder) 672 self.inputs.append(placeholder) 673 674 def replace_capture(self, tensor, placeholder): 675 """Replace already existing capture.""" 676 self._captures[id(tensor)] = (tensor, placeholder) 677 678 def reset_captures(self, capture_list): 679 """Set the captures with the provided list of captures & placeholder.""" 680 self._captures = py_collections.OrderedDict() 681 for tensor, placeholder in capture_list: 682 self._captures[id(tensor)] = (tensor, placeholder) 683 684 def pop_capture(self, tensor): 685 """Remove the capture and return the generated placeholder.""" 686 capture = self._captures.pop(id(tensor), None) 687 if capture is None: 688 return None 689 690 return capture[1] 691 692 def clear_captures(self): 693 # TODO(b/115366440): Delete this method when a custom OrderedDict is added. 694 # Clearing captures using clear() leaves some cycles around. 695 while self._captures: 696 self._captures.popitem() 697 memory.dismantle_ordered_dict(self._captures) 698 while self._deferred_captures: 699 self._deferred_captures.popitem() 700 memory.dismantle_ordered_dict(self._deferred_captures) 701 702 def capture_distributed_variable(self, variable, placeholder): 703 """Add given distributed variable to captures with given placeholder.""" 704 self._captures[id(variable)] = (variable, placeholder) 705 tape.record_operation("captured_value", [placeholder], [variable], 706 backward_function=lambda x: [x], 707 forward_function=lambda x: [x]) 708 709 def capture_eager_tensor(self, tensor, name): 710 capture = self._captures.get(id(tensor)) 711 if capture is None: 712 # We clear all control dependencies and place the Const op on the same 713 # device as the source tensor. The device placement may be relaxed at 714 # a later date. 715 with ops.control_dependencies(None), self.device(tensor.device): 716 graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype, 717 shape=tensor.shape, name=name) 718 self.add_capture(tensor, graph_const) 719 else: 720 graph_const = capture[1] 721 tape.record_operation("captured_value", [graph_const], [tensor], 722 backward_function=lambda x: [x], 723 forward_function=lambda x: [x]) 724 return graph_const 725 726 def captured(self, tensor): 727 """Check if the specified tensor has been captured.""" 728 return id(tensor) in self._captures 729 730 @property 731 def external_captures(self): 732 """External tensors captured by this function.""" 733 return [c[0] for c in self._captures.values()] 734 735 @property 736 def internal_captures(self): 737 """Placeholders in this function corresponding captured tensors.""" 738 return [c[1] for c in self._captures.values()] 739 740 @property 741 def deferred_external_captures(self): 742 """Ordered nest of tensors whose placeholders will be fed at call time.""" 743 return [c[0] for c in self._deferred_captures.values()] 744 745 @property 746 def deferred_internal_captures(self): 747 """List of nest of placeholders which at call time will be fed.""" 748 return [c[1] for c in self._deferred_captures.values()] 749 750 @property 751 def variable_captures(self): 752 """Map of python object ids of variables to variables which are captured.""" 753 return { 754 id(self._captures[id(v)][1]): v 755 for v in self.variables 756 if id(v) in self._captures 757 } 758 759 def mark_as_unsaveable(self, error_message): 760 """Marks this FuncGraph as unsaveable. 761 762 Any attempts to export this FuncGraph will raise an error with the specified 763 message. 764 765 Args: 766 error_message: List or string containing the error message to be raised 767 when saving this FuncGraph to SavedModel. 768 """ 769 self._saveable = False 770 if isinstance(error_message, str): 771 error_message = [error_message] 772 self._saving_errors.update(error_message) 773 774 @property 775 def saveable(self): 776 """Returns whether this FuncGraph is saveable.""" 777 return self._saveable 778 779 @property 780 def saving_errors(self): 781 """Returns set of errors preventing this FuncGraph from being saved.""" 782 return self._saving_errors 783 784 def _add_scope_exit_callback(self, fn): 785 """Add a function to call when this graph exits the default scope.""" 786 if not callable(fn): 787 raise TypeError("fn is not callable: {}".format(fn)) 788 if self._scope_exit_callbacks is None: 789 raise RuntimeError( 790 "Attempting to add a scope exit callback, but the default graph is " 791 "not the context scope graph. Did you forget to call " 792 "'with graph.as_default(): ...'?") 793 self._scope_exit_callbacks.append(fn) 794 795 796def func_graph_from_py_func(name, 797 python_func, 798 args, 799 kwargs, 800 signature=None, 801 func_graph=None, 802 autograph=False, 803 autograph_options=None, 804 add_control_dependencies=True, 805 arg_names=None, 806 op_return_value=None, 807 collections=None, 808 capture_by_value=None, 809 override_flat_arg_shapes=None): 810 """Returns a `FuncGraph` generated from `python_func`. 811 812 Args: 813 name: an identifier for the function. 814 python_func: the Python function to trace. 815 args: the positional args with which the Python function should be called; 816 ignored if a signature is provided. 817 kwargs: the keyword args with which the Python function should be called; 818 ignored if a signature is provided. 819 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 820 and dtypes of the arguments. When a signature is provided, `args` and 821 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 822 to `signature`. If `None`, the shapes and dtypes are inferred from the 823 inputs. 824 func_graph: Optional. An instance of FuncGraph. If provided, we will use 825 this graph else a new one is built and returned. 826 autograph: whether to use autograph to compile `python_func`. 827 See https://www.tensorflow.org/guide/autograph for more information. 828 autograph_options: additional knobs to control when `autograph=True`. 829 See https://www.tensorflow.org/guide/autograph for more information. 830 add_control_dependencies: If True, automatically adds control dependencies 831 to ensure program order matches execution order and stateful ops always 832 execute. 833 arg_names: Optional list of argument names, used to give input placeholders 834 recognizable names. 835 op_return_value: Optional. A Tensor. If set and `python_func` returns 836 Operations, those return values will be replaced with this value. If not 837 set, returning an Operation triggers an error. 838 collections: a dictionary of collections this FuncGraph should start 839 with. If not specified (None), the FuncGraph will read (but not write to) 840 the outer graph's collections that are not whitelisted, and both 841 read and write to the outer graph's collections that are whitelisted. 842 The current whitelisted collections are the global variables, the 843 local variables, and the trainable variables. 844 Defaults to None. 845 capture_by_value: An optional boolean. If True, the func graph will capture 846 Variables by value instead of reference. By default inherit from outer 847 graphs, and failing that will default to False. 848 override_flat_arg_shapes: An optional list of instances that are either 849 `None` or `TensorShape`. The length must match that of 850 `nest.flatten((args, kwargs), expand_composites=True)`. The entries 851 containing value `None` must match entries in flattened arguments 852 containing non-tensors, while entries containing a `TensorShape` must 853 match entries in the flattened arguments containing tensors. 854 855 Returns: 856 A FuncGraph. 857 858 Raises: 859 TypeError: If any of `python_func`'s return values is neither `None` nor a 860 `Tensor`. 861 ValueError: If both `signature` and `override_flat_arg_shapes` are 862 passed in. 863 """ 864 if op_return_value is not None: 865 assert isinstance(op_return_value, ops.Tensor), op_return_value 866 if func_graph is None: 867 func_graph = FuncGraph(name, collections=collections, 868 capture_by_value=capture_by_value) 869 assert isinstance(func_graph, FuncGraph) 870 if add_control_dependencies: 871 deps_control_manager = auto_control_deps.AutomaticControlDependencies() 872 else: 873 deps_control_manager = ops.NullContextmanager() 874 875 with func_graph.as_default(), deps_control_manager as deps_ctx: 876 current_scope = variable_scope.get_variable_scope() 877 default_use_recource = current_scope.use_resource 878 current_scope.set_use_resource(True) 879 880 if signature is not None and override_flat_arg_shapes is not None: 881 raise ValueError( 882 "Passed both signature and override_flat_arg_shapes: %s and %s." 883 % (signature, override_flat_arg_shapes)) 884 885 if signature is not None: 886 args = signature 887 kwargs = {} 888 889 # Creates and names placeholders for all arguments. 890 if override_flat_arg_shapes is not None: 891 flat_args = nest.flatten(args, expand_composites=True) 892 arg_shapes = override_flat_arg_shapes[:len(flat_args)] 893 kwarg_shapes = override_flat_arg_shapes[len(flat_args):] 894 else: 895 arg_shapes = None 896 kwarg_shapes = None 897 func_args = _get_defun_inputs_from_args( 898 args, arg_names, flat_shapes=arg_shapes) 899 func_kwargs = _get_defun_inputs_from_kwargs( 900 kwargs, flat_shapes=kwarg_shapes) 901 902 # Convert all Tensors into TensorSpecs before saving the structured inputs. 903 # If storing pure concrete functions that are not called through polymorphic 904 # functions, we don't have access to FunctionSpec, so we need to call the 905 # TensorSpecs by their `arg_names` for later binding. 906 func_graph.structured_input_signature = ( 907 convert_structure_to_signature(func_args, arg_names), 908 convert_structure_to_signature(func_kwargs)) 909 910 flat_func_args = nest.flatten(func_args, expand_composites=True) 911 flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) 912 # Temporarily set inputs to allow graph building code to inspect 913 # them. Reassigned below. 914 func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs 915 if isinstance(arg, ops.Tensor)] 916 917 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 918 # Variables to help check whether mutation happens in calling the function 919 # Copy the recursive list, tuple and map structure, but not base objects 920 func_args_before = nest.pack_sequence_as(func_args, flat_func_args, 921 expand_composites=True) 922 func_kwargs_before = nest.pack_sequence_as( 923 func_kwargs, flat_func_kwargs, expand_composites=True) 924 925 def convert(x): 926 """Converts a function output to a Tensor.""" 927 if x is None: 928 return None 929 if op_return_value is not None and isinstance(x, ops.Operation): 930 # TODO(b/79881896): we currently can't capture external control deps, so 931 # this won't work if x needs to be captured (i.e. if python_func returns 932 # captured Operations). 933 with ops.control_dependencies([x]): 934 x = array_ops.identity(op_return_value) 935 elif not isinstance(x, tensor_array_ops.TensorArray): 936 try: 937 x = ops.convert_to_tensor_or_composite(x) 938 except (ValueError, TypeError): 939 raise TypeError( 940 "To be compatible with tf.contrib.eager.defun, Python functions " 941 "must return zero or more Tensors; in compilation of %s, found " 942 "return value of type %s, which is not a Tensor." % 943 (str(python_func), type(x))) 944 if add_control_dependencies: 945 x = deps_ctx.mark_as_return(x) 946 return x 947 948 try: 949 if autograph: 950 from tensorflow.python import autograph # pylint: disable=g-import-not-at-top 951 _, original_func = tf_decorator.unwrap(python_func) 952 953 def wrapper(*args, **kwargs): 954 """Calls a converted version of original_func.""" 955 # TODO(mdan): Push this block higher in tf.function's call stack. 956 try: 957 return autograph.converted_call( 958 original_func, 959 args, 960 kwargs, 961 options=autograph.ConversionOptions( 962 recursive=True, 963 optional_features=autograph_options, 964 user_requested=True, 965 )) 966 except Exception as e: # pylint:disable=broad-except 967 if hasattr(e, "ag_error_metadata"): 968 raise e.ag_error_metadata.to_exception(e) 969 else: 970 raise 971 972 # Wrapping around a decorator allows checks like tf_inspect.getargspec 973 # to be accurate. 974 converted_func = tf_decorator.make_decorator(original_func, wrapper) 975 python_func = tf_decorator.rewrap(python_func, original_func, 976 converted_func) 977 978 else: 979 _, original_func = tf_decorator.unwrap(python_func) 980 981 func_outputs = python_func(*func_args, **func_kwargs) 982 983 # invariant: `func_outputs` contains only Tensors, CompositeTensors, 984 # TensorArrays and `None`s. 985 func_outputs = nest.map_structure(convert, func_outputs, 986 expand_composites=True) 987 988 check_mutation(func_args_before, func_args, original_func) 989 check_mutation(func_kwargs_before, func_kwargs, original_func) 990 finally: 991 current_scope.set_use_resource(default_use_recource) 992 993 # Variables in `func_args`, `func_kwargs` should be explicit inputs 994 # to the function, not captured inputs. 995 graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access 996 arg_variables = object_identity.ObjectIdentitySet() 997 inputs = [] 998 for arg in (nest.flatten(func_args, expand_composites=True) + 999 nest.flatten(func_kwargs, expand_composites=True)): 1000 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1001 # Even if an argument variable was not used in the function, we've 1002 # already manually captured the resource Tensor when creating argument 1003 # placeholders. 1004 resource_placeholder = func_graph.pop_capture(arg.handle) 1005 if resource_placeholder is None: 1006 continue 1007 arg_variables.add(arg) 1008 inputs.append(resource_placeholder) 1009 elif isinstance(arg, ops.Tensor): 1010 inputs.append(arg) 1011 variables = [v for v in graph_variables if v not in arg_variables] 1012 func_graph.inputs = ( 1013 inputs + func_graph.internal_captures + nest.flatten( 1014 func_graph.deferred_internal_captures, expand_composites=True)) 1015 func_graph.structured_outputs = func_outputs 1016 # Returning a closed-over tensor does not trigger convert_to_tensor. 1017 func_graph.outputs.extend( 1018 func_graph.capture(x) 1019 for x in flatten(func_graph.structured_outputs) 1020 if x is not None) 1021 1022 func_graph.variables = variables 1023 1024 if add_control_dependencies: 1025 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) 1026 1027 return func_graph 1028 1029 1030def maybe_captured(tensor): 1031 """If t is a captured value placeholder, returns the original captured value. 1032 1033 Args: 1034 tensor: Tensor. 1035 1036 Returns: 1037 A tensor, potentially from a different Graph/FuncGraph. 1038 """ 1039 if (not isinstance(tensor, ops.EagerTensor) and 1040 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 1041 for input_t, placeholder_t in tensor.op.graph.captures: 1042 if tensor == placeholder_t: 1043 return maybe_captured(input_t) 1044 # pylint: enable=protected-access 1045 return tensor 1046 1047 1048def device_stack_has_callable(device_stack): 1049 """Checks whether a device stack contains a callable.""" 1050 return any(callable(spec._device_name_or_function) # pylint: disable=protected-access 1051 for spec in device_stack.peek_objs()) 1052 1053 1054def check_mutation(n1, n2, func): 1055 """Check if two list of arguments are exactly the same.""" 1056 func_name = getattr(func, "__name__", func) 1057 1058 errmsg = ("{}() should not modify its Python input arguments." 1059 " Check if it modifies any lists or dicts passed as" 1060 " arguments. Modifying a copy is allowed.".format(func_name)) 1061 try: 1062 # TODO(mdan): Compare more robustly so that argument names can be reported. 1063 nest.assert_same_structure(n1, n2, expand_composites=True) 1064 except ValueError: 1065 raise ValueError(errmsg) 1066 1067 for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True), 1068 nest.flatten(n2, expand_composites=True)): 1069 if arg1 is not arg2: 1070 raise ValueError(errmsg) 1071 1072 1073# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1074def flatten(sequence): 1075 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. 1076 1077 Args: 1078 sequence: A nested structure of Tensors, CompositeTensors, and 1079 TensorArrays. 1080 1081 Returns: 1082 A list of tensors. 1083 """ 1084 flat_sequence = nest.flatten(sequence, expand_composites=True) 1085 return [ 1086 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 1087 for item in flat_sequence] 1088 1089 1090# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 1091def pack_sequence_as(structure, flat_sequence): 1092 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. 1093 1094 Args: 1095 structure: The structure to pack into. May contain Tensors, 1096 CompositeTensors, or TensorArrays. 1097 flat_sequence: An iterable containing tensors. 1098 1099 Returns: 1100 A nested structure. 1101 1102 Raises: 1103 AssertionError if `structure` and `flat_sequence` are not compatible. 1104 """ 1105 flat_sequence = list(flat_sequence) 1106 flattened_structure = nest.flatten(structure, expand_composites=True) 1107 if len(flattened_structure) != len(flat_sequence): 1108 raise ValueError("Mismatch in element count") 1109 for i in range(len(flat_sequence)): 1110 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 1111 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 1112 old_ta=flattened_structure[i], flow=flat_sequence[i]) 1113 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 1114 1115 1116def _create_substitute_placeholder(value, name=None, dtype=None, shape=None): 1117 """Creates a placeholder for `value` and propagates shape info to it.""" 1118 # Note: setting ops.control_dependencies(None) ensures we always put 1119 # capturing placeholders outside of any control flow context. 1120 if shape is None: 1121 shape = value.shape 1122 with ops.control_dependencies(None): 1123 placeholder = graph_placeholder( 1124 dtype=dtype or value.dtype, shape=shape, name=name) 1125 custom_gradient.copy_handle_data(value, placeholder) 1126 return placeholder 1127 1128 1129def _get_defun_inputs_from_args(args, names, flat_shapes=None): 1130 """Maps Python function positional args to graph-construction inputs.""" 1131 return _get_defun_inputs( 1132 args, names, structure=args, flat_shapes=flat_shapes) 1133 1134 1135def _get_composite_tensor_spec(x): 1136 """Returns the TypeSpec for x if it's a composite tensor, or x otherwise.""" 1137 return (x._type_spec # pylint: disable=protected-access 1138 if isinstance(x, composite_tensor.CompositeTensor) else x) 1139 1140 1141def _get_defun_inputs(args, names, structure, flat_shapes=None): 1142 """Maps python function args to graph-construction inputs. 1143 1144 Args: 1145 args: A flat list of user-specified arguments. 1146 names: A list of strings with user-specified argument names, same length as 1147 `args`. May be `None`, in which case a generic name is used. 1148 structure: The original argument list or dictionary. 1149 flat_shapes: A flat list of values that are either `None` or 1150 instances of `TensorShape`. If provided, then length must match 1151 that of `nest.flatten(args, expand_composites=True)`; and locations where 1152 `args` are instances of `Tensor` must have a corresponding `TensorShape` 1153 in `flat_shapes`. May be `None`, in which case exact shapes are read 1154 directly from the args. 1155 1156 Returns: 1157 Placeholders with the same structure as `structure`. 1158 1159 Raises: 1160 RuntimeError: if `flat_shapes` is provided, but 1161 `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`. 1162 RuntimeError: if a shape from `flat_shapes` is not None 1163 for an argument that is not a `Tensor`, `TensorSpec`, 1164 or `ResourceVariable`. 1165 """ 1166 func_graph = ops.get_default_graph() 1167 function_inputs = [] 1168 if names is None: 1169 names = [None] * len(args) 1170 if flat_shapes is None: 1171 shapes_iter = itertools.repeat(None) 1172 else: 1173 len_flat_args = len(nest.flatten(args, expand_composites=True)) 1174 if len_flat_args != len(flat_shapes): 1175 raise RuntimeError( 1176 "Length of fully flat shapes (%d) must match that of " 1177 "flatten(args) (%d). args: %s, flat_shapes: %s" 1178 % (len(flat_shapes), 1179 len_flat_args, 1180 args, 1181 flat_shapes)) 1182 shapes_iter = iter(flat_shapes) 1183 for arg_value, name in zip(args, names): 1184 1185 # Replace any composite tensors with their TypeSpecs. This is important 1186 # for ensuring that shape information that's not preserved by the TypeSpec 1187 # (such as the number of values in a SparseTensor) gets properly masked. 1188 arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value) 1189 1190 flattened = nest.flatten(arg_value, expand_composites=True) 1191 tensor_specs = [ 1192 arg for arg in flattened if isinstance(arg, tensor_spec.DenseSpec) 1193 ] 1194 specified_names = [arg.name for arg in tensor_specs if arg.name] 1195 if specified_names and len(specified_names) < len(tensor_specs): 1196 raise ValueError("If specifying TensorSpec names for nested structures, " 1197 "either zero or all names have to be specified.") 1198 1199 for arg in flattened: 1200 # We have a shape entry for each arg, regadless of whether it's a real 1201 # Tensor or not. For non-tensor entries it should be None. 1202 shape = next(shapes_iter) 1203 if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): 1204 if isinstance(arg, tensor_spec.TensorSpec) and arg.name: 1205 requested_name = arg.name 1206 else: 1207 requested_name = name 1208 placeholder_shape = shape if shape is not None else arg.shape 1209 try: 1210 placeholder = graph_placeholder( 1211 arg.dtype, placeholder_shape, 1212 name=requested_name) 1213 except ValueError: 1214 # Sometimes parameter names are not valid op names, so fall back to 1215 # unnamed placeholders. 1216 placeholder = graph_placeholder(arg.dtype, placeholder_shape) 1217 if name is not None: 1218 # Record the requested/user-specified name in case it's different than 1219 # the uniquified name, for validation when exporting signatures. 1220 placeholder.op._set_attr( # pylint: disable=protected-access 1221 "_user_specified_name", 1222 attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) 1223 function_inputs.append(placeholder) 1224 elif isinstance(arg, (resource_variable_ops.BaseResourceVariable, 1225 resource_variable_ops.VariableSpec)): 1226 if isinstance(arg, resource_variable_ops.VariableSpec): 1227 name = arg.name or name 1228 with func_graph.outer_graph.as_default(): 1229 placeholder = graph_placeholder(dtypes.resource, arg.shape, 1230 name=name) 1231 1232 arg = resource_variable_ops.BaseResourceVariable( 1233 name=name, 1234 shape=arg.shape, 1235 dtype=arg.dtype, 1236 handle=placeholder, 1237 handle_name=name) 1238 # Capture arg variables to create placeholders for them. These will be 1239 # removed as captures after the function is traced (since otherwise we'd 1240 # just add it back with a new placeholder when the variable was 1241 # referenced). 1242 placeholder = func_graph.capture(arg.handle, name=name) 1243 placeholder.op._set_attr( # pylint: disable=protected-access 1244 "_user_specified_name", 1245 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 1246 function_inputs.append(arg) 1247 else: 1248 if shape is not None: 1249 raise RuntimeError( 1250 "Expected provided shape override to be None for arg that isn't " 1251 "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" 1252 % (arg, shape, args)) 1253 function_inputs.append(arg) 1254 return nest.pack_sequence_as(structure, function_inputs, 1255 expand_composites=True) 1256 1257 1258def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): 1259 """Maps Python function keyword args to graph-construction inputs.""" 1260 if kwargs: 1261 names, args = zip(*sorted(kwargs.items())) 1262 else: 1263 names = [] 1264 args = [] 1265 return _get_defun_inputs( 1266 args, names, structure=kwargs, flat_shapes=flat_shapes) 1267 1268 1269def dismantle_func_graph(func_graph): 1270 """Removes reference cycles in `func_graph` FuncGraph. 1271 1272 Helpful for making sure the garbage collector doesn't need to run when 1273 the FuncGraph goes out of scope, e.g. in tests using defun with 1274 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 1275 1276 Args: 1277 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable 1278 after this function. 1279 """ 1280 func_graph.clear_captures() 1281 ops.dismantle_graph(func_graph) 1282