• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as py_collections
22import itertools
23import weakref
24
25import numpy as np
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import execute
30from tensorflow.python.eager import tape
31from tensorflow.python.eager.graph_only_ops import graph_placeholder
32from tensorflow.python.framework import auto_control_deps
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import type_spec
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import custom_gradient
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops import tensor_array_ops
44from tensorflow.python.ops import variable_scope
45from tensorflow.python.util import compat
46from tensorflow.python.util import memory
47from tensorflow.python.util import nest
48from tensorflow.python.util import object_identity
49from tensorflow.python.util import tf_contextlib
50from tensorflow.python.util import tf_decorator
51
52WHITELIST_COLLECTIONS = [
53    ops.GraphKeys.GLOBAL_VARIABLES,
54    ops.GraphKeys.LOCAL_VARIABLES,
55    ops.GraphKeys.TRAINABLE_VARIABLES,
56    variable_scope._VARSTORE_KEY,  # pylint: disable=protected-access
57    variable_scope._VARSCOPESTORE_KEY  # pylint: disable=protected-access
58]
59
60
61_EAGER_CONST_THRESHOLD = 128
62
63
64class UnknownArgument(object):
65  """Signifies an argument which is not currently handled."""
66  pass
67
68
69def convert_structure_to_signature(structure, arg_names=None):
70  """Convert a potentially nested structure to a signature.
71
72  Args:
73    structure: Structure to convert, where top level collection is a list or a
74      tuple.
75    arg_names: Optional list of arguments that has equal number of elements as
76      `structure` and is used for naming corresponding TensorSpecs.
77
78  Returns:
79    Identical structure that has TensorSpec objects instead of Tensors and
80    UknownArgument instead of any unsupported types.
81  """
82  def encode_arg(arg, path):
83    """A representation for this argument, for converting into signatures."""
84    if isinstance(arg, ops.Tensor):
85      user_specified_name = None
86      try:
87        user_specified_name = compat.as_str(
88            arg.op.get_attr("_user_specified_name"))
89      except ValueError:
90        pass
91
92      if path and user_specified_name and user_specified_name != path[0]:
93        # The user has explicitly named the argument differently than the name
94        # of the function argument.
95        name = user_specified_name
96      else:
97        name = "/".join(str(p) for p in path)
98      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
99    if isinstance(arg, composite_tensor.CompositeTensor):
100      # TODO(b/133606651) Do we need to inject arg_name?
101      return arg._type_spec  # pylint: disable=protected-access
102    if isinstance(arg, resource_variable_ops.BaseResourceVariable):
103      name = "/".join(str(p) for p in path)
104      return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name)
105    if isinstance(arg, (
106        int,
107        float,
108        bool,
109        type(None),
110        dtypes.DType,
111        tensor_spec.TensorSpec,
112        type_spec.TypeSpec,
113    )):
114      return arg
115    return UnknownArgument()
116
117  # We are using the flattened paths to name the TensorSpecs. We need an
118  # explicit name for them downstream.
119  flattened = nest.flatten_with_tuple_paths(structure)
120  if arg_names:
121    if len(arg_names) != len(structure):
122      raise ValueError(
123          "Passed in arg_names don't match actual signature (%s)." % arg_names)
124    # Replace all top-level names with their actual arg_names. If a path before
125    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
126    flattened = [
127        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
128    ]
129
130  mapped = [encode_arg(arg, path) for path, arg in flattened]
131  return nest.pack_sequence_as(structure, mapped)
132
133
134class FuncGraph(ops.Graph):
135  """Graph representing a function body.
136
137  Attributes:
138    name: The name of the function.
139    inputs: Placeholder tensors representing the inputs to this function. The
140      tensors are in this FuncGraph. This represents "regular" inputs as well as
141      captured inputs (i.e. the values of self.captures), with the regular
142      inputs coming first.
143    outputs: Tensors that will be returned by this function. The tensors are in
144      this FuncGraph.
145    control_outputs: Operations that must be executed before the function
146      represented by this graph can be said to have been executed.
147    structured_input_signature: A tuple of (args, kwargs), which are both
148      possibly-nested python objects that were received by this function. Note
149      that these structures might contain Python `None`s.
150    structured_outputs: A possibly-nested python object which will be returned
151      by this function. The Tensors in this structure are the same as those of
152      self.outputs. Note that this structure might contain Python `None`s.
153    variables: Variables that should be watched during function execution.
154    outer_graph: The graph this function is defined in. May be another FuncGraph
155      or the global default Graph.
156    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
157      The entries are in the order they were captured.
158    control_captures: Set of external ops on which this graph has a control
159      dependency.
160    seed: The graph-level random seed.
161    capture_by_value: If True, the func graph will capture Variables by value
162      instead of reference.
163  """
164
165  def __init__(self, name, collections=None, capture_by_value=None):
166    """Construct a new FuncGraph.
167
168    The graph will inherit its graph key, collections, seed, and distribution
169    strategy stack from the current context or graph.
170
171    Args:
172      name: the name of the function.
173      collections: a dictionary of collections this FuncGraph should start
174        with. If not specified (None), the FuncGraph will read (but not write
175        to) the outer graph's collections that are not whitelisted, and both
176        read and write to the outer graph's collections that are whitelisted.
177        The current whitelisted collections are the global variables, the
178        local variables, and the trainable variables.
179        Defaults to None.
180      capture_by_value: An optional boolean. If True, the func graph will
181        capture Variables by value instead of reference. By default inherit
182        from outer graphs, and failing that will default to False.
183    """
184    super(FuncGraph, self).__init__()
185
186    self.name = name
187    self.inputs = []
188    self.outputs = []
189    self.control_outputs = []
190    self.control_captures = set()
191    self.structured_input_signature = None
192    self.structured_outputs = None
193    self._weak_variables = []
194    self._watched_variables = object_identity.ObjectIdentityWeakSet()
195
196    outer_graph = ops.get_default_graph()
197    self._weak_outer_graph = weakref.ref(outer_graph)
198    while outer_graph.building_function:
199      outer_graph = outer_graph.outer_graph
200    # If self._weak_outer_graph is deleted, we revert to the outermost Graph
201    # active when the FuncGraph was traced. This will not be a FuncGraph.
202    self._fallback_outer_graph = outer_graph
203    self._captures = py_collections.OrderedDict()
204    # If not None, records the names of output args of this function. Used to
205    # preserve the output names in the signature of a serialized+deserialized
206    # function. Private at the moment mostly because it's often out of date.
207    self._output_names = None
208    # Maps arbitrary key -> (closure, nest of placeholders), where at function
209    # call time the value of closure() will be used to feed the nest of
210    # placeholders.
211    self._deferred_captures = py_collections.OrderedDict()
212    # Inherit capture-by-value from outer graph.
213    if capture_by_value is not None:
214      self.capture_by_value = capture_by_value
215    elif self.outer_graph is not None and isinstance(
216        self.outer_graph, FuncGraph):
217      self.capture_by_value = self.outer_graph.capture_by_value
218    else:
219      self.capture_by_value = False
220
221    self._building_function = True
222    # Map from resource tensor name to last op (in program order) which uses
223    # this tensor. Used to enforce that execution order matches program order
224    # for resource tensors.
225    self._last_op_using_resource_tensor = {}
226
227    graph = self.outer_graph
228
229    if context.executing_eagerly():
230      self.seed = context.global_seed()
231      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
232      # any None op_seed for random_op in the function, in which case we end up
233      # using function seed, which could be unintended behavior for the op.
234      self._seed_used = False
235    else:
236      self.seed = graph.seed
237      self._seed_used = False
238      # TODO(allenl): Figure out if we can remove colocation stack
239      # specialization (currently used in cond_v2), here and in the cache key.
240      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access
241
242    if collections is None:
243      for collection_name in graph.get_all_collection_keys():
244        if collection_name not in WHITELIST_COLLECTIONS:
245          self._collections[collection_name] = graph.get_collection(
246              collection_name)
247      for collection_name in WHITELIST_COLLECTIONS:
248        self._collections[collection_name] = graph.get_collection_ref(
249            collection_name)
250    else:
251      self._collections = collections
252
253    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
254    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
255    # dependent functions as unsaveable.
256    self._saveable = True
257    self._saving_errors = set()
258
259    # Keep track of callbacks to run when this graph exits default scope
260    self._scope_exit_callbacks = None
261
262  def __str__(self):
263    return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
264
265  def watch_variable(self, v):
266    """Marks the variable v as accessed while building this graph."""
267    while self is not None and isinstance(self, FuncGraph):
268      self._watched_variables.add(v)
269      self = self.outer_graph
270
271  def capture_call_time_value(self, closure, spec, key=None):
272    """Creates a placeholder which at call time has the value closure().
273
274    Useful, for example, to respect TensorFlow context managers, which are often
275    dynamically scoped.
276
277    Args:
278      closure: function which takes no arguments, to be evaluated at function
279       call time, returning a nest of tensors compatible with `spec`.
280      spec: nest of TypeSpec for the value to capture.
281      key: optional. If not None, multiple calls to lazy_capture with the same
282       key in the same graph will return the same placeholder, and the
283       first closure will be used at function call time.
284
285    Returns:
286      Nest of placeholders which, at function call time, will be fed with the
287      result of calling closure().
288
289    Raises:
290      ValueError: at function call time, if the return value of closure() is
291       not compatible with `spec`.
292    """
293    if key is None:
294      key = object()
295    if key not in self._deferred_captures:
296
297      def convert_to_placeholder(s):
298        if not isinstance(s, tensor_spec.DenseSpec):
299          raise TypeError(
300              "Expected a nest of `TypeSpec` objects, found %s of type %s." %
301              (s, type(s)))
302        return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
303
304      placeholder = nest.map_structure(
305          convert_to_placeholder, spec, expand_composites=True)
306
307      def wrapped_closure():
308        ret_nest = closure()
309        nest.assert_same_structure(spec, ret_nest, expand_composites=True)
310        # This uses the tensor dtype defined in `spec` when converting values
311        # in `ret_nest` to tensors.
312        # pylint: disable=protected-access
313        y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest,
314                               expand_composites=False)
315        # pylint: enable=protected-access
316        return nest.flatten(y, expand_composites=True)
317
318      self._deferred_captures[key] = (wrapped_closure, placeholder)
319    return self._deferred_captures[key][1]
320
321  def control_dependencies(self, control_inputs):
322    """Handles control dependencies.
323
324    FuncGraph wraps Graph's control_dependencies logic by first filtering out
325    any external tensors / operations and storing them in the graph's
326    control_captures member. Any consumers of this function graph must then
327    decide how to handle the control captures.
328
329    Args:
330      control_inputs: A list of `Operation` or `Tensor` objects which
331        must be executed or computed before running the operations
332        defined in the context.  Can also be `None` to clear the control
333        dependencies.
334
335    Returns:
336     A context manager that specifies control dependencies for all
337     operations constructed within the context.
338
339    Raises:
340      TypeError: If `control_inputs` is not a list of `Operation` or
341        `Tensor` objects.
342    """
343    if control_inputs is None:
344      return super(FuncGraph, self).control_dependencies(control_inputs)
345
346    filtered_control_inputs = []
347    for c in control_inputs:
348      # Check for _UnreadVariable
349      if (isinstance(c, ops.IndexedSlices) or
350          (hasattr(c, "_handle") and hasattr(c, "op"))):
351        c = c.op
352      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
353      if graph_element is None:
354        graph_element = c
355      if graph_element is not None and getattr(
356          graph_element, "graph", None) is not self:
357        self.control_captures.add(graph_element)
358      else:
359        filtered_control_inputs.append(graph_element)
360    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
361
362  def as_default(self):
363    outer_cm = super(FuncGraph, self).as_default()
364
365    @tf_contextlib.contextmanager
366    def inner_cm():
367      """Context manager for copying distribute.Strategy scope information."""
368      graph = ops.get_default_graph()
369      # pylint: disable=protected-access
370      # TODO(b/112906995, nareshmodi): distribution strategy depends on
371      # inheriting this stack from the default graph even in eager mode. Maybe
372      # it should be part of the eager context? This would also allow us to
373      # remove a get_default_graph() call from the function cache lookup.
374      old_strategy_stack = self._distribution_strategy_stack
375      self._distribution_strategy_stack = list(
376          graph._distribution_strategy_stack)
377      uses_distribution_strategy = (
378          self._distribution_strategy_stack and
379          self._distribution_strategy_stack[-1].strategy.extended
380          ._retrace_functions_for_each_device)
381      # We ignore device placements from any outer scopes while tracing the
382      # function when possible, to avoid hard-coding them in the function
383      # graph. "Default" placements come from the PartitionedCallOp's placement,
384      # so that the same trace of the Python function may be placed on several
385      # different devices and saved functions may be placed on new devices when
386      # restored.
387      old_device_stack = self._device_function_stack
388      if context.executing_eagerly():
389        if uses_distribution_strategy:
390          self._device_function_stack = self._device_function_stack.copy()
391          self._add_device_to_stack(context.context().device_name)
392      else:
393        if (uses_distribution_strategy or
394            device_stack_has_callable(graph._device_function_stack)):
395          # Hard-code devices from device functions in the function body
396          self._device_function_stack = graph._device_function_stack.copy()
397
398      old_creator_stack = self._variable_creator_stack
399      self._variable_creator_stack = graph._variable_creator_stack
400      # Inherit the graph key, since this is used for matching variables in
401      # optimizers.
402      old_graph_key = self._graph_key
403      self._graph_key = graph._graph_key
404      # Inherit the auto_cast_variable_read_dtype, since this should not change
405      # inside a function.
406      old_auto_cast_var_read_dtype = self._auto_cast_variable_read_dtype
407      self._auto_cast_variable_read_dtype = graph._auto_cast_variable_read_dtype
408      # pylint: enable=protected-access
409
410      old_scope_exit_callbacks = self._scope_exit_callbacks
411      self._scope_exit_callbacks = []
412
413      with outer_cm as g:
414        try:
415          yield g
416        finally:
417          try:
418            for fn in self._scope_exit_callbacks:
419              fn()
420          finally:
421            self._scope_exit_callbacks = old_scope_exit_callbacks
422            self._distribution_strategy_stack = old_strategy_stack
423            self._device_function_stack = old_device_stack
424            self._variable_creator_stack = old_creator_stack
425            self._graph_key = old_graph_key
426            self._auto_cast_variable_read_dtype = old_auto_cast_var_read_dtype
427    return inner_cm()
428
429  @property
430  def outer_graph(self):
431    """The Graph this FuncGraph is nested in.
432
433    Functions may capture Tensors from graphs they are nested in (transitive).
434
435    Returns:
436      A Graph object. Initially set to the current default graph when the
437      FuncGraph was created. If the previous `outer_graph` was deleted because
438      the function that owns it was deleted, `outer_graph` is reset to the
439      outermost default graph active when the FuncGraph was created. This
440      FuncGraph won't have captured anything from the new `outer_graph` (and
441      likely not from the previous setting, since that would have created a
442      strong reference), but it is returned so that FuncGraphs always have a
443      parent.
444    """
445    current = self._weak_outer_graph()
446    if current is None:
447      return self._fallback_outer_graph
448    return current
449
450  @property
451  def output_types(self):
452    return [t.dtype for t in self.outputs]
453
454  @property
455  def output_shapes(self):
456    return [t.shape for t in self.outputs]
457
458  @property
459  def trainable_variables(self):
460    """A sequence of trainable variables accessed by this FuncGraph.
461
462    Note that functions keep only weak references to variables. Calling the
463    function after a variable it accesses has been deleted is an error.
464
465    Returns:
466      Sequence of trainable variables for this func graph.
467    """
468    return tuple(v for v in self.variables if v.trainable)
469
470  @property
471  def variables(self):
472    """A sequence of variables accessed by this FuncGraph.
473
474    Note that functions keep only weak references to variables. Calling the
475    function after a variable it accesses has been deleted is an error.
476
477    Returns:
478      Sequence of variables for this func graph.
479    """
480    def deref(weak_v):
481      v = weak_v()
482      if v is None:
483        raise AssertionError(
484            "Called a function referencing variables which have been deleted. "
485            "This likely means that function-local variables were created and "
486            "not referenced elsewhere in the program. This is generally a "
487            "mistake; consider storing variables in an object attribute on "
488            "first call.")
489      return v
490
491    return tuple(deref(v) for v in self._weak_variables)
492
493  @variables.setter
494  def variables(self, var_list):
495    self._weak_variables = [weakref.ref(v) for v in var_list]
496
497  def _capture_by_value(
498      self,
499      op_type,
500      inputs,
501      dtypes,  # pylint: disable=redefined-outer-name
502      input_types=None,
503      name=None,
504      attrs=None,
505      op_def=None,
506      compute_device=True):
507    # When capturing by value, do the read outside
508    reverse_captures = dict((id(v), k) for k, v in self.captures)
509    uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
510    with ops.init_scope():
511      if context.executing_eagerly():
512        attr_list = ("dtype", int(attrs["dtype"].type))
513        value, = execute.execute(
514            compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
515            context.context())
516      else:
517        op = ops.get_default_graph()._create_op_internal(  # pylint: disable=protected-access
518            op_type,
519            uncaptured_inputs,
520            dtypes,
521            input_types,
522            name,
523            attrs,
524            op_def,
525            compute_device)
526        value = op.outputs[0]
527    captured_value = self.capture(value)
528    return captured_value.op
529
530  def _create_op_internal(
531      self,
532      op_type,
533      inputs,
534      dtypes=None,  # pylint: disable=redefined-outer-name
535      input_types=None,
536      name=None,
537      attrs=None,
538      op_def=None,
539      compute_device=True):
540    """Like Graph.create_op, except handles external input tensors.
541
542    This overload adds functionality to create_op to "capture" any external
543    input tensors, i.e. tensors from the eager context or outer function graphs
544    if this is a nested function. See `capture` for more information.
545
546    Args:
547      op_type: The `Operation` type to create. This corresponds to the
548        `OpDef.name` field for the proto that defines the operation.
549      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
550      dtypes: (Optional) A list of `DType` objects that will be the types of the
551        tensors that the operation produces.
552      input_types: (Optional.) A list of `DType`s that will be the types of
553        the tensors that the operation consumes. By default, uses the base
554        `DType` of each input in `inputs`. Operations that expect
555        reference-typed inputs must specify `input_types` explicitly.
556      name: (Optional.) A string name for the operation. If not specified, a
557        name is generated based on `op_type`.
558      attrs: (Optional.) A dictionary where the key is the attribute name (a
559        string) and the value is the respective `attr` attribute of the
560        `NodeDef` proto that will represent the operation (an `AttrValue`
561        proto).
562      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
563        the operation will have.
564      compute_device: (Optional.) If True, device functions will be executed
565        to compute the device property of the Operation.
566
567    Returns:
568      An `Operation` object.
569    """
570    if self.capture_by_value and op_type in ["ReadVariableOp",
571                                             "ResourceGather"]:
572      return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
573                                    attrs, op_def, compute_device)
574
575    # This capturing logic interacts poorly with control flow contexts which
576    # want to replace inputs of ops far too late in the process. This can lead
577    # the context to get confused and try to create an Enter for an Enter. We
578    # can detect this here and skip the additional Enter which can confuse loop
579    # validation logic.
580    if op_type == "Enter" and inputs[0].op.type == "Enter":
581      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
582        return inputs[0].op
583    # Calling AddValue on the control flow contexts to force creation of the
584    # backward accumulators in the original graph before we create placeholders
585    # to capture the inputs.
586    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
587    for i, inp in enumerate(inputs):
588      # TPU Estimator defines a control flow context with no AddValue method.
589      if ctxt is not None and hasattr(ctxt, "AddValue"):
590        inp = ctxt.AddValue(inp)
591      inp = self.capture(inp)
592      inputs[i] = inp
593    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
594        op_type, inputs, dtypes, input_types, name, attrs, op_def,
595        compute_device)
596
597  def capture(self, tensor, name=None, shape=None):
598    """Captures `tensor` if it's external to this graph.
599
600    If `tensor` is from a different graph, returns a placeholder for it.
601    `tensor` and the placeholder will appear in self.captures, and the
602    placeholder will appear in self.inputs.  Multiple calls to this method with
603    the same `tensor` argument will return the same placeholder. If `tensor` is
604    from this graph, returns `tensor`.
605
606    Args:
607      tensor: Tensor. May be from this FuncGraph or a different graph.
608      name: Optional name if a placeholder is created.
609      shape: Optional shape if a placeholder is created.
610
611    Returns:
612      Tensor from this FuncGraph.
613
614    Raises:
615      InaccessibleTensorError: if any tensors are accessed in a manner that
616      bypasses the mechanisms required for the data dependencies to be correctly
617      wired.
618    """
619    if isinstance(tensor, ops.EagerTensor):
620      if name is None:
621        name = str(ops.uid())
622
623      # Small EagerTensors are captured with Const ops
624      if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
625          np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
626        return self.capture_eager_tensor(tensor, name)
627
628      # Large EagerTensors and resources are captured with Placeholder ops
629      return self._capture_helper(tensor, name, shape)
630    if tensor.graph is not self:
631      if name is None:
632        name = tensor.op.name
633      inner_graph = tensor.graph
634      while inner_graph is not None and isinstance(inner_graph, FuncGraph):
635        if inner_graph is self:
636          raise errors.InaccessibleTensorError(
637              "The tensor '%s' cannot be accessed here: it is defined"
638              " in another function or code block. Use return values,"
639              " explicit Python locals or TensorFlow collections to access"
640              " it. Defined in: %s; accessed from: %s.\n"
641              % (tensor, tensor.graph, self))
642        inner_graph = inner_graph.outer_graph
643      return self._capture_helper(tensor, name)
644    return tensor
645
646  def _capture_helper(self, tensor, name, shape=None):
647    capture = self._captures.get(id(tensor))
648    if capture is None:
649      placeholder = _create_substitute_placeholder(
650          tensor, name=name, dtype=tensor.dtype, shape=shape)
651      self.add_capture(tensor, placeholder)
652    else:
653      placeholder = capture[1]
654    tape.record_operation("captured_value", [placeholder], [tensor],
655                          backward_function=lambda x: [x],
656                          forward_function=lambda x: [x])
657    return placeholder
658
659  @property
660  def captures(self):
661    """Order list of tuples containing external and internal captures."""
662    return self._captures.values()
663
664  def add_capture(self, tensor, placeholder):
665    """Capture a specific tensor and utilize the provided placeholder.
666
667    Args:
668      tensor: Tensor to captures.
669      placeholder: Provided placeholder for the tensor.
670    """
671    self._captures[id(tensor)] = (tensor, placeholder)
672    self.inputs.append(placeholder)
673
674  def replace_capture(self, tensor, placeholder):
675    """Replace already existing capture."""
676    self._captures[id(tensor)] = (tensor, placeholder)
677
678  def reset_captures(self, capture_list):
679    """Set the captures with the provided list of captures & placeholder."""
680    self._captures = py_collections.OrderedDict()
681    for tensor, placeholder in capture_list:
682      self._captures[id(tensor)] = (tensor, placeholder)
683
684  def pop_capture(self, tensor):
685    """Remove the capture and return the generated placeholder."""
686    capture = self._captures.pop(id(tensor), None)
687    if capture is None:
688      return None
689
690    return capture[1]
691
692  def clear_captures(self):
693    # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
694    # Clearing captures using clear() leaves some cycles around.
695    while self._captures:
696      self._captures.popitem()
697    memory.dismantle_ordered_dict(self._captures)
698    while self._deferred_captures:
699      self._deferred_captures.popitem()
700    memory.dismantle_ordered_dict(self._deferred_captures)
701
702  def capture_distributed_variable(self, variable, placeholder):
703    """Add given distributed variable to captures with given placeholder."""
704    self._captures[id(variable)] = (variable, placeholder)
705    tape.record_operation("captured_value", [placeholder], [variable],
706                          backward_function=lambda x: [x],
707                          forward_function=lambda x: [x])
708
709  def capture_eager_tensor(self, tensor, name):
710    capture = self._captures.get(id(tensor))
711    if capture is None:
712      # We clear all control dependencies and place the Const op on the same
713      # device as the source tensor. The device placement may be relaxed at
714      # a later date.
715      with ops.control_dependencies(None), self.device(tensor.device):
716        graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
717                                           shape=tensor.shape, name=name)
718      self.add_capture(tensor, graph_const)
719    else:
720      graph_const = capture[1]
721    tape.record_operation("captured_value", [graph_const], [tensor],
722                          backward_function=lambda x: [x],
723                          forward_function=lambda x: [x])
724    return graph_const
725
726  def captured(self, tensor):
727    """Check if the specified tensor has been captured."""
728    return id(tensor) in self._captures
729
730  @property
731  def external_captures(self):
732    """External tensors captured by this function."""
733    return [c[0] for c in self._captures.values()]
734
735  @property
736  def internal_captures(self):
737    """Placeholders in this function corresponding captured tensors."""
738    return [c[1] for c in self._captures.values()]
739
740  @property
741  def deferred_external_captures(self):
742    """Ordered nest of tensors whose placeholders will be fed at call time."""
743    return [c[0] for c in self._deferred_captures.values()]
744
745  @property
746  def deferred_internal_captures(self):
747    """List of nest of placeholders which at call time will be fed."""
748    return [c[1] for c in self._deferred_captures.values()]
749
750  @property
751  def variable_captures(self):
752    """Map of python object ids of variables to variables which are captured."""
753    return {
754        id(self._captures[id(v)][1]): v
755        for v in self.variables
756        if id(v) in self._captures
757    }
758
759  def mark_as_unsaveable(self, error_message):
760    """Marks this FuncGraph as unsaveable.
761
762    Any attempts to export this FuncGraph will raise an error with the specified
763    message.
764
765    Args:
766      error_message: List or string containing the error message to be raised
767        when saving this FuncGraph to SavedModel.
768    """
769    self._saveable = False
770    if isinstance(error_message, str):
771      error_message = [error_message]
772    self._saving_errors.update(error_message)
773
774  @property
775  def saveable(self):
776    """Returns whether this FuncGraph is saveable."""
777    return self._saveable
778
779  @property
780  def saving_errors(self):
781    """Returns set of errors preventing this FuncGraph from being saved."""
782    return self._saving_errors
783
784  def _add_scope_exit_callback(self, fn):
785    """Add a function to call when this graph exits the default scope."""
786    if not callable(fn):
787      raise TypeError("fn is not callable: {}".format(fn))
788    if self._scope_exit_callbacks is None:
789      raise RuntimeError(
790          "Attempting to add a scope exit callback, but the default graph is "
791          "not the context scope graph.  Did you forget to call "
792          "'with graph.as_default(): ...'?")
793    self._scope_exit_callbacks.append(fn)
794
795
796def func_graph_from_py_func(name,
797                            python_func,
798                            args,
799                            kwargs,
800                            signature=None,
801                            func_graph=None,
802                            autograph=False,
803                            autograph_options=None,
804                            add_control_dependencies=True,
805                            arg_names=None,
806                            op_return_value=None,
807                            collections=None,
808                            capture_by_value=None,
809                            override_flat_arg_shapes=None):
810  """Returns a `FuncGraph` generated from `python_func`.
811
812  Args:
813    name: an identifier for the function.
814    python_func: the Python function to trace.
815    args: the positional args with which the Python function should be called;
816      ignored if a signature is provided.
817    kwargs: the keyword args with which the Python function should be called;
818      ignored if a signature is provided.
819    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
820      and dtypes of the arguments. When a signature is provided, `args` and
821      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
822      to `signature`. If `None`, the shapes and dtypes are inferred from the
823      inputs.
824    func_graph: Optional. An instance of FuncGraph. If provided, we will use
825      this graph else a new one is built and returned.
826    autograph: whether to use autograph to compile `python_func`.
827      See https://www.tensorflow.org/guide/autograph for more information.
828    autograph_options: additional knobs to control when `autograph=True`.
829      See https://www.tensorflow.org/guide/autograph for more information.
830    add_control_dependencies: If True, automatically adds control dependencies
831      to ensure program order matches execution order and stateful ops always
832      execute.
833    arg_names: Optional list of argument names, used to give input placeholders
834      recognizable names.
835    op_return_value: Optional. A Tensor. If set and `python_func` returns
836      Operations, those return values will be replaced with this value. If not
837      set, returning an Operation triggers an error.
838    collections: a dictionary of collections this FuncGraph should start
839      with. If not specified (None), the FuncGraph will read (but not write to)
840      the outer graph's collections that are not whitelisted, and both
841      read and write to the outer graph's collections that are whitelisted.
842      The current whitelisted collections are the global variables, the
843      local variables, and the trainable variables.
844      Defaults to None.
845    capture_by_value: An optional boolean. If True, the func graph will capture
846      Variables by value instead of reference. By default inherit from outer
847      graphs, and failing that will default to False.
848    override_flat_arg_shapes: An optional list of instances that are either
849      `None` or `TensorShape`.  The length must match that of
850      `nest.flatten((args, kwargs), expand_composites=True)`.  The entries
851      containing value `None` must match entries in flattened arguments
852      containing non-tensors, while entries containing a `TensorShape` must
853      match entries in the flattened arguments containing tensors.
854
855  Returns:
856    A FuncGraph.
857
858  Raises:
859    TypeError: If any of `python_func`'s return values is neither `None` nor a
860      `Tensor`.
861    ValueError: If both `signature` and `override_flat_arg_shapes` are
862      passed in.
863  """
864  if op_return_value is not None:
865    assert isinstance(op_return_value, ops.Tensor), op_return_value
866  if func_graph is None:
867    func_graph = FuncGraph(name, collections=collections,
868                           capture_by_value=capture_by_value)
869  assert isinstance(func_graph, FuncGraph)
870  if add_control_dependencies:
871    deps_control_manager = auto_control_deps.AutomaticControlDependencies()
872  else:
873    deps_control_manager = ops.NullContextmanager()
874
875  with func_graph.as_default(), deps_control_manager as deps_ctx:
876    current_scope = variable_scope.get_variable_scope()
877    default_use_recource = current_scope.use_resource
878    current_scope.set_use_resource(True)
879
880    if signature is not None and override_flat_arg_shapes is not None:
881      raise ValueError(
882          "Passed both signature and override_flat_arg_shapes: %s and %s."
883          % (signature, override_flat_arg_shapes))
884
885    if signature is not None:
886      args = signature
887      kwargs = {}
888
889    # Creates and names placeholders for all arguments.
890    if override_flat_arg_shapes is not None:
891      flat_args = nest.flatten(args, expand_composites=True)
892      arg_shapes = override_flat_arg_shapes[:len(flat_args)]
893      kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
894    else:
895      arg_shapes = None
896      kwarg_shapes = None
897    func_args = _get_defun_inputs_from_args(
898        args, arg_names, flat_shapes=arg_shapes)
899    func_kwargs = _get_defun_inputs_from_kwargs(
900        kwargs, flat_shapes=kwarg_shapes)
901
902    # Convert all Tensors into TensorSpecs before saving the structured inputs.
903    # If storing pure concrete functions that are not called through polymorphic
904    # functions, we don't have access to FunctionSpec, so we need to call the
905    # TensorSpecs by their `arg_names` for later binding.
906    func_graph.structured_input_signature = (
907        convert_structure_to_signature(func_args, arg_names),
908        convert_structure_to_signature(func_kwargs))
909
910    flat_func_args = nest.flatten(func_args, expand_composites=True)
911    flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
912    # Temporarily set inputs to allow graph building code to inspect
913    # them. Reassigned below.
914    func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
915                         if isinstance(arg, ops.Tensor)]
916
917    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
918    # Variables to help check whether mutation happens in calling the function
919    # Copy the recursive list, tuple and map structure, but not base objects
920    func_args_before = nest.pack_sequence_as(func_args, flat_func_args,
921                                             expand_composites=True)
922    func_kwargs_before = nest.pack_sequence_as(
923        func_kwargs, flat_func_kwargs, expand_composites=True)
924
925    def convert(x):
926      """Converts a function output to a Tensor."""
927      if x is None:
928        return None
929      if op_return_value is not None and isinstance(x, ops.Operation):
930        # TODO(b/79881896): we currently can't capture external control deps, so
931        # this won't work if x needs to be captured (i.e. if python_func returns
932        # captured Operations).
933        with ops.control_dependencies([x]):
934          x = array_ops.identity(op_return_value)
935      elif not isinstance(x, tensor_array_ops.TensorArray):
936        try:
937          x = ops.convert_to_tensor_or_composite(x)
938        except (ValueError, TypeError):
939          raise TypeError(
940              "To be compatible with tf.contrib.eager.defun, Python functions "
941              "must return zero or more Tensors; in compilation of %s, found "
942              "return value of type %s, which is not a Tensor." %
943              (str(python_func), type(x)))
944      if add_control_dependencies:
945        x = deps_ctx.mark_as_return(x)
946      return x
947
948    try:
949      if autograph:
950        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
951        _, original_func = tf_decorator.unwrap(python_func)
952
953        def wrapper(*args, **kwargs):
954          """Calls a converted version of original_func."""
955          # TODO(mdan): Push this block higher in tf.function's call stack.
956          try:
957            return autograph.converted_call(
958                original_func,
959                args,
960                kwargs,
961                options=autograph.ConversionOptions(
962                    recursive=True,
963                    optional_features=autograph_options,
964                    user_requested=True,
965                ))
966          except Exception as e:  # pylint:disable=broad-except
967            if hasattr(e, "ag_error_metadata"):
968              raise e.ag_error_metadata.to_exception(e)
969            else:
970              raise
971
972        # Wrapping around a decorator allows checks like tf_inspect.getargspec
973        # to be accurate.
974        converted_func = tf_decorator.make_decorator(original_func, wrapper)
975        python_func = tf_decorator.rewrap(python_func, original_func,
976                                          converted_func)
977
978      else:
979        _, original_func = tf_decorator.unwrap(python_func)
980
981      func_outputs = python_func(*func_args, **func_kwargs)
982
983      # invariant: `func_outputs` contains only Tensors, CompositeTensors,
984      # TensorArrays and `None`s.
985      func_outputs = nest.map_structure(convert, func_outputs,
986                                        expand_composites=True)
987
988      check_mutation(func_args_before, func_args, original_func)
989      check_mutation(func_kwargs_before, func_kwargs, original_func)
990    finally:
991      current_scope.set_use_resource(default_use_recource)
992
993    # Variables in `func_args`, `func_kwargs` should be explicit inputs
994    # to the function, not captured inputs.
995    graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
996    arg_variables = object_identity.ObjectIdentitySet()
997    inputs = []
998    for arg in (nest.flatten(func_args, expand_composites=True) +
999                nest.flatten(func_kwargs, expand_composites=True)):
1000      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1001        # Even if an argument variable was not used in the function, we've
1002        # already manually captured the resource Tensor when creating argument
1003        # placeholders.
1004        resource_placeholder = func_graph.pop_capture(arg.handle)
1005        if resource_placeholder is None:
1006          continue
1007        arg_variables.add(arg)
1008        inputs.append(resource_placeholder)
1009      elif isinstance(arg, ops.Tensor):
1010        inputs.append(arg)
1011    variables = [v for v in graph_variables if v not in arg_variables]
1012    func_graph.inputs = (
1013        inputs + func_graph.internal_captures + nest.flatten(
1014            func_graph.deferred_internal_captures, expand_composites=True))
1015    func_graph.structured_outputs = func_outputs
1016    # Returning a closed-over tensor does not trigger convert_to_tensor.
1017    func_graph.outputs.extend(
1018        func_graph.capture(x)
1019        for x in flatten(func_graph.structured_outputs)
1020        if x is not None)
1021
1022    func_graph.variables = variables
1023
1024  if add_control_dependencies:
1025    func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
1026
1027  return func_graph
1028
1029
1030def maybe_captured(tensor):
1031  """If t is a captured value placeholder, returns the original captured value.
1032
1033  Args:
1034    tensor: Tensor.
1035
1036  Returns:
1037    A tensor, potentially from a different Graph/FuncGraph.
1038  """
1039  if (not isinstance(tensor, ops.EagerTensor) and
1040      tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
1041    for input_t, placeholder_t in tensor.op.graph.captures:
1042      if tensor == placeholder_t:
1043        return maybe_captured(input_t)
1044  # pylint: enable=protected-access
1045  return tensor
1046
1047
1048def device_stack_has_callable(device_stack):
1049  """Checks whether a device stack contains a callable."""
1050  return any(callable(spec._device_name_or_function)  # pylint: disable=protected-access
1051             for spec in device_stack.peek_objs())
1052
1053
1054def check_mutation(n1, n2, func):
1055  """Check if two list of arguments are exactly the same."""
1056  func_name = getattr(func, "__name__", func)
1057
1058  errmsg = ("{}() should not modify its Python input arguments."
1059            " Check if it modifies any lists or dicts passed as"
1060            " arguments. Modifying a copy is allowed.".format(func_name))
1061  try:
1062    # TODO(mdan): Compare more robustly so that argument names can be reported.
1063    nest.assert_same_structure(n1, n2, expand_composites=True)
1064  except ValueError:
1065    raise ValueError(errmsg)
1066
1067  for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
1068                        nest.flatten(n2, expand_composites=True)):
1069    if arg1 is not arg2:
1070      raise ValueError(errmsg)
1071
1072
1073# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1074def flatten(sequence):
1075  """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
1076
1077  Args:
1078    sequence: A nested structure of Tensors, CompositeTensors, and
1079      TensorArrays.
1080
1081  Returns:
1082    A list of tensors.
1083  """
1084  flat_sequence = nest.flatten(sequence, expand_composites=True)
1085  return [
1086      item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
1087      for item in flat_sequence]
1088
1089
1090# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1091def pack_sequence_as(structure, flat_sequence):
1092  """Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
1093
1094  Args:
1095    structure: The structure to pack into. May contain Tensors,
1096      CompositeTensors, or TensorArrays.
1097    flat_sequence: An iterable containing tensors.
1098
1099  Returns:
1100    A nested structure.
1101
1102  Raises:
1103    AssertionError if `structure` and `flat_sequence` are not compatible.
1104  """
1105  flat_sequence = list(flat_sequence)
1106  flattened_structure = nest.flatten(structure, expand_composites=True)
1107  if len(flattened_structure) != len(flat_sequence):
1108    raise ValueError("Mismatch in element count")
1109  for i in range(len(flat_sequence)):
1110    if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
1111      flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
1112          old_ta=flattened_structure[i], flow=flat_sequence[i])
1113  return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
1114
1115
1116def _create_substitute_placeholder(value, name=None, dtype=None, shape=None):
1117  """Creates a placeholder for `value` and propagates shape info to it."""
1118  # Note: setting ops.control_dependencies(None) ensures we always put
1119  # capturing placeholders outside of any control flow context.
1120  if shape is None:
1121    shape = value.shape
1122  with ops.control_dependencies(None):
1123    placeholder = graph_placeholder(
1124        dtype=dtype or value.dtype, shape=shape, name=name)
1125  custom_gradient.copy_handle_data(value, placeholder)
1126  return placeholder
1127
1128
1129def _get_defun_inputs_from_args(args, names, flat_shapes=None):
1130  """Maps Python function positional args to graph-construction inputs."""
1131  return _get_defun_inputs(
1132      args, names, structure=args, flat_shapes=flat_shapes)
1133
1134
1135def _get_composite_tensor_spec(x):
1136  """Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
1137  return (x._type_spec  # pylint: disable=protected-access
1138          if isinstance(x, composite_tensor.CompositeTensor) else x)
1139
1140
1141def _get_defun_inputs(args, names, structure, flat_shapes=None):
1142  """Maps python function args to graph-construction inputs.
1143
1144  Args:
1145    args: A flat list of user-specified arguments.
1146    names: A list of strings with user-specified argument names, same length as
1147      `args`. May be `None`, in which case a generic name is used.
1148    structure: The original argument list or dictionary.
1149    flat_shapes: A flat list of values that are either `None` or
1150      instances of `TensorShape`.  If provided, then length must match
1151      that of `nest.flatten(args, expand_composites=True)`; and locations where
1152      `args` are instances of `Tensor` must have a corresponding `TensorShape`
1153      in `flat_shapes`.  May be `None`, in which case exact shapes are read
1154      directly from the args.
1155
1156  Returns:
1157    Placeholders with the same structure as `structure`.
1158
1159  Raises:
1160    RuntimeError: if `flat_shapes` is provided, but
1161     `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`.
1162    RuntimeError: if a shape from `flat_shapes` is not None
1163     for an argument that is not a `Tensor`, `TensorSpec`,
1164     or `ResourceVariable`.
1165  """
1166  func_graph = ops.get_default_graph()
1167  function_inputs = []
1168  if names is None:
1169    names = [None] * len(args)
1170  if flat_shapes is None:
1171    shapes_iter = itertools.repeat(None)
1172  else:
1173    len_flat_args = len(nest.flatten(args, expand_composites=True))
1174    if len_flat_args != len(flat_shapes):
1175      raise RuntimeError(
1176          "Length of fully flat shapes (%d) must match that of "
1177          "flatten(args) (%d).  args: %s, flat_shapes: %s"
1178          % (len(flat_shapes),
1179             len_flat_args,
1180             args,
1181             flat_shapes))
1182    shapes_iter = iter(flat_shapes)
1183  for arg_value, name in zip(args, names):
1184
1185    # Replace any composite tensors with their TypeSpecs.  This is important
1186    # for ensuring that shape information that's not preserved by the TypeSpec
1187    # (such as the number of values in a SparseTensor) gets properly masked.
1188    arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
1189
1190    flattened = nest.flatten(arg_value, expand_composites=True)
1191    tensor_specs = [
1192        arg for arg in flattened if isinstance(arg, tensor_spec.DenseSpec)
1193    ]
1194    specified_names = [arg.name for arg in tensor_specs if arg.name]
1195    if specified_names and len(specified_names) < len(tensor_specs):
1196      raise ValueError("If specifying TensorSpec names for nested structures, "
1197                       "either zero or all names have to be specified.")
1198
1199    for arg in flattened:
1200      # We have a shape entry for each arg, regadless of whether it's a real
1201      # Tensor or not.  For non-tensor entries it should be None.
1202      shape = next(shapes_iter)
1203      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
1204        if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
1205          requested_name = arg.name
1206        else:
1207          requested_name = name
1208        placeholder_shape = shape if shape is not None else arg.shape
1209        try:
1210          placeholder = graph_placeholder(
1211              arg.dtype, placeholder_shape,
1212              name=requested_name)
1213        except ValueError:
1214          # Sometimes parameter names are not valid op names, so fall back to
1215          # unnamed placeholders.
1216          placeholder = graph_placeholder(arg.dtype, placeholder_shape)
1217        if name is not None:
1218          # Record the requested/user-specified name in case it's different than
1219          # the uniquified name, for validation when exporting signatures.
1220          placeholder.op._set_attr(  # pylint: disable=protected-access
1221              "_user_specified_name",
1222              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
1223        function_inputs.append(placeholder)
1224      elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
1225                            resource_variable_ops.VariableSpec)):
1226        if isinstance(arg, resource_variable_ops.VariableSpec):
1227          name = arg.name or name
1228          with func_graph.outer_graph.as_default():
1229            placeholder = graph_placeholder(dtypes.resource, arg.shape,
1230                                            name=name)
1231
1232            arg = resource_variable_ops.BaseResourceVariable(
1233                name=name,
1234                shape=arg.shape,
1235                dtype=arg.dtype,
1236                handle=placeholder,
1237                handle_name=name)
1238        # Capture arg variables to create placeholders for them. These will be
1239        # removed as captures after the function is traced (since otherwise we'd
1240        # just add it back with a new placeholder when the variable was
1241        # referenced).
1242        placeholder = func_graph.capture(arg.handle, name=name)
1243        placeholder.op._set_attr(  # pylint: disable=protected-access
1244            "_user_specified_name",
1245            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
1246        function_inputs.append(arg)
1247      else:
1248        if shape is not None:
1249          raise RuntimeError(
1250              "Expected provided shape override to be None for arg that isn't "
1251              "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s"
1252              % (arg, shape, args))
1253        function_inputs.append(arg)
1254  return nest.pack_sequence_as(structure, function_inputs,
1255                               expand_composites=True)
1256
1257
1258def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
1259  """Maps Python function keyword args to graph-construction inputs."""
1260  if kwargs:
1261    names, args = zip(*sorted(kwargs.items()))
1262  else:
1263    names = []
1264    args = []
1265  return _get_defun_inputs(
1266      args, names, structure=kwargs, flat_shapes=flat_shapes)
1267
1268
1269def dismantle_func_graph(func_graph):
1270  """Removes reference cycles in `func_graph` FuncGraph.
1271
1272  Helpful for making sure the garbage collector doesn't need to run when
1273  the FuncGraph goes out of scope, e.g. in tests using defun with
1274  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
1275
1276  Args:
1277    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
1278      after this function.
1279  """
1280  func_graph.clear_captures()
1281  ops.dismantle_graph(func_graph)
1282