• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as py_collections
22import itertools
23import weakref
24
25import numpy as np
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import execute
30from tensorflow.python.eager import tape
31from tensorflow.python.eager.graph_only_ops import graph_placeholder
32from tensorflow.python.framework import auto_control_deps
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import type_spec
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import custom_gradient
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.ops import tensor_array_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.util import compat
47from tensorflow.python.util import memory
48from tensorflow.python.util import nest
49from tensorflow.python.util import object_identity
50from tensorflow.python.util import tf_contextlib
51from tensorflow.python.util import tf_decorator
52from tensorflow.python.util.tf_export import tf_export
53
54ALLOWLIST_COLLECTIONS = [
55    ops.GraphKeys.GLOBAL_VARIABLES,
56    ops.GraphKeys.LOCAL_VARIABLES,
57    ops.GraphKeys.TRAINABLE_VARIABLES,
58    variable_scope._VARSTORE_KEY,  # pylint: disable=protected-access
59    variable_scope._VARSCOPESTORE_KEY  # pylint: disable=protected-access
60]
61
62
63_EAGER_CONST_THRESHOLD = 128
64
65
66class UnknownArgument(object):
67  """Signifies an argument which is not currently handled."""
68  pass
69
70
71def convert_structure_to_signature(structure, arg_names=None):
72  """Convert a potentially nested structure to a signature.
73
74  Args:
75    structure: Structure to convert, where top level collection is a list or a
76      tuple.
77    arg_names: Optional list of arguments that has equal number of elements as
78      `structure` and is used for naming corresponding TensorSpecs.
79
80  Returns:
81    Identical structure that has TensorSpec objects instead of Tensors and
82    UnknownArgument instead of any unsupported types.
83  """
84  def encode_arg(arg, path):
85    """A representation for this argument, for converting into signatures."""
86    if isinstance(arg, ops.Tensor):
87      user_specified_name = None
88      try:
89        user_specified_name = compat.as_str(
90            arg.op.get_attr("_user_specified_name"))
91      except ValueError:
92        pass
93
94      if path and user_specified_name and user_specified_name != path[0]:
95        # The user has explicitly named the argument differently than the name
96        # of the function argument.
97        name = user_specified_name
98      else:
99        name = "/".join(str(p) for p in path)
100      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
101    if isinstance(arg, composite_tensor.CompositeTensor):
102      # TODO(b/133606651) Do we need to inject arg_name?
103      return arg._type_spec  # pylint: disable=protected-access
104    if isinstance(arg, resource_variable_ops.BaseResourceVariable):
105      name = "/".join(str(p) for p in path)
106      return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name)
107    if isinstance(arg, (
108        int,
109        float,
110        bool,
111        str,
112        type(None),
113        dtypes.DType,
114        tensor_spec.TensorSpec,
115        type_spec.TypeSpec,
116    )):
117      return arg
118    return UnknownArgument()
119
120  # We are using the flattened paths to name the TensorSpecs. We need an
121  # explicit name for them downstream.
122  flattened = nest.flatten_with_tuple_paths(structure)
123  if arg_names:
124    if len(arg_names) != len(structure):
125      raise ValueError(
126          "Passed in arg_names don't match actual signature (%s)." % arg_names)
127    # Replace all top-level names with their actual arg_names. If a path before
128    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
129    flattened = [
130        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
131    ]
132
133  mapped = [encode_arg(arg, path) for path, arg in flattened]
134  return nest.pack_sequence_as(structure, mapped)
135
136
137@tf_export("__internal__.FuncGraph", v1=[])
138class FuncGraph(ops.Graph):
139  """Graph representing a function body.
140
141  Attributes:
142    name: The name of the function.
143    inputs: Placeholder tensors representing the inputs to this function. The
144      tensors are in this FuncGraph. This represents "regular" inputs as well as
145      captured inputs (i.e. the values of self.captures), with the regular
146      inputs coming first.
147    outputs: Tensors that will be returned by this function. The tensors are in
148      this FuncGraph.
149    control_outputs: Operations that must be executed before the function
150      represented by this graph can be said to have been executed.
151    structured_input_signature: A tuple of (args, kwargs), which are both
152      possibly-nested python objects that were received by this function. Note
153      that these structures might contain Python `None`s.
154    structured_outputs: A possibly-nested python object which will be returned
155      by this function. The Tensors in this structure are the same as those of
156      self.outputs. Note that this structure might contain Python `None`s.
157    variables: Variables that should be watched during function execution.
158    outer_graph: The graph this function is defined in. May be another FuncGraph
159      or the global default Graph.
160    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
161      The entries are in the order they were captured.
162    control_captures: Set of external ops on which this graph has a control
163      dependency.
164    seed: The graph-level random seed.
165    capture_by_value: If True, the func graph will capture Variables by value
166      instead of reference.
167  """
168
169  def __init__(self, name, collections=None, capture_by_value=None):
170    """Construct a new FuncGraph.
171
172    The graph will inherit its graph key, collections, seed, and distribution
173    strategy stack from the current context or graph.
174
175    Args:
176      name: the name of the function.
177      collections: a dictionary of collections this FuncGraph should start
178        with. If not specified (None), the FuncGraph will read (but not write
179        to) the outer graph's collections that are not allowlisted, and both
180        read and write to the outer graph's collections that are allowlisted.
181        The current allowlisted collections are the global variables, the
182        local variables, and the trainable variables.
183        Defaults to None.
184      capture_by_value: An optional boolean. If True, the func graph will
185        capture Variables by value instead of reference. By default inherit
186        from outer graphs, and failing that will default to False.
187    """
188    super(FuncGraph, self).__init__()
189
190    self.name = name
191    self.inputs = []
192    self.outputs = []
193    self.control_outputs = []
194    self.control_captures = set()
195    self.structured_input_signature = None
196    self.structured_outputs = None
197    self._weak_variables = []
198    self._watched_variables = object_identity.ObjectIdentityWeakSet()
199    self.is_control_flow_graph = False
200
201    outer_graph = ops.get_default_graph()
202    self._weak_outer_graph = weakref.ref(outer_graph)
203    while outer_graph.building_function:
204      outer_graph = outer_graph.outer_graph
205    # If self._weak_outer_graph is deleted, we revert to the outermost Graph
206    # active when the FuncGraph was traced. This will not be a FuncGraph.
207    self._fallback_outer_graph = outer_graph
208    self._captures = py_collections.OrderedDict()
209    # If not None, records the names of output args of this function. Used to
210    # preserve the output names in the signature of a serialized+deserialized
211    # function. Private at the moment mostly because it's often out of date.
212    self._output_names = None
213    # Maps arbitrary key -> (closure, nest of placeholders), where at function
214    # call time the value of closure() will be used to feed the nest of
215    # placeholders.
216    self._deferred_captures = py_collections.OrderedDict()
217    # Inherit capture-by-value from outer graph.
218    if capture_by_value is not None:
219      self.capture_by_value = capture_by_value
220    elif self.outer_graph is not None and isinstance(
221        self.outer_graph, FuncGraph):
222      self.capture_by_value = self.outer_graph.capture_by_value
223    else:
224      self.capture_by_value = False
225
226    self._building_function = True
227    # Map from resource tensor name to last op (in program order) which uses
228    # this tensor. Used to enforce that execution order matches program order
229    # for resource tensors.
230    self._last_op_using_resource_tensor = {}
231
232    graph = self.outer_graph
233
234    if context.executing_eagerly():
235      self.seed = context.global_seed()
236      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
237      # any None op_seed for random_op in the function, in which case we end up
238      # using function seed, which could be unintended behavior for the op.
239      self._seed_used = False
240    else:
241      self.seed = graph.seed
242      self._seed_used = False
243      # TODO(allenl): Figure out if we can remove colocation stack
244      # specialization (currently used in cond_v2), here and in the cache key.
245      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access
246
247    if collections is None:
248      for collection_name in graph.get_all_collection_keys():
249        if collection_name not in ALLOWLIST_COLLECTIONS:
250          self._collections[collection_name] = graph.get_collection(
251              collection_name)
252      for collection_name in ALLOWLIST_COLLECTIONS:
253        self._collections[collection_name] = graph.get_collection_ref(
254            collection_name)
255    else:
256      self._collections = collections
257
258    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
259    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
260    # dependent functions as unsaveable.
261    self._saveable = True
262    self._saving_errors = set()
263
264    # Keep track of callbacks to run when this graph exits default scope
265    self._scope_exit_callbacks = None
266
267  def __str__(self):
268    return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
269
270  def watch_variable(self, v):
271    """Marks the variable v as accessed while building this graph."""
272    while self is not None and isinstance(self, FuncGraph):
273      self._watched_variables.add(v)
274      self = self.outer_graph
275
276  def capture_call_time_value(self, closure, spec, key=None):
277    """Creates a placeholder which at call time has the value closure().
278
279    Useful, for example, to respect TensorFlow context managers, which are often
280    dynamically scoped.
281
282    Args:
283      closure: function which takes no arguments, to be evaluated at function
284       call time, returning a nest of tensors compatible with `spec`.
285      spec: nest of TypeSpec for the value to capture.
286      key: optional. If not None, multiple calls to lazy_capture with the same
287       key in the same graph will return the same placeholder, and the
288       first closure will be used at function call time.
289
290    Returns:
291      Nest of placeholders which, at function call time, will be fed with the
292      result of calling closure().
293
294    Raises:
295      ValueError: at function call time, if the return value of closure() is
296       not compatible with `spec`.
297    """
298    if key is None:
299      key = object()
300    if key not in self._deferred_captures:
301
302      def convert_to_placeholder(s):
303        if not isinstance(s, tensor_spec.DenseSpec):
304          raise TypeError(
305              "Expected a nest of `TypeSpec` objects, found %s of type %s." %
306              (s, type(s)))
307        return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
308
309      placeholder = nest.map_structure(
310          convert_to_placeholder, spec, expand_composites=True)
311
312      def wrapped_closure():
313        ret_nest = closure()
314        nest.assert_same_structure(spec, ret_nest, expand_composites=True)
315        # This uses the tensor dtype defined in `spec` when converting values
316        # in `ret_nest` to tensors.
317        # pylint: disable=protected-access
318        y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest,
319                               expand_composites=False)
320        # pylint: enable=protected-access
321        return nest.flatten(y, expand_composites=True)
322
323      self._deferred_captures[key] = (wrapped_closure, placeholder)
324    return self._deferred_captures[key][1]
325
326  def control_dependencies(self, control_inputs):
327    """Handles control dependencies.
328
329    FuncGraph wraps Graph's control_dependencies logic by first filtering out
330    any external tensors / operations and storing them in the graph's
331    control_captures member. Any consumers of this function graph must then
332    decide how to handle the control captures.
333
334    Args:
335      control_inputs: A list of `Operation` or `Tensor` objects which
336        must be executed or computed before running the operations
337        defined in the context.  Can also be `None` to clear the control
338        dependencies.
339
340    Returns:
341     A context manager that specifies control dependencies for all
342     operations constructed within the context.
343
344    Raises:
345      TypeError: If `control_inputs` is not a list of `Operation` or
346        `Tensor` objects.
347    """
348    if control_inputs is None:
349      return super(FuncGraph, self).control_dependencies(control_inputs)
350
351    filtered_control_inputs = []
352    for c in control_inputs:
353      # Check for _UnreadVariable
354      if (isinstance(c, ops.IndexedSlices) or
355          (hasattr(c, "_handle") and hasattr(c, "op"))):
356        c = c.op
357      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
358      if graph_element is None:
359        graph_element = c
360      if graph_element is not None and getattr(
361          graph_element, "graph", None) is not self:
362        self.control_captures.add(graph_element)
363      else:
364        filtered_control_inputs.append(graph_element)
365    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
366
367  def as_default(self):
368    outer_cm = super(FuncGraph, self).as_default()
369
370    @tf_contextlib.contextmanager
371    def inner_cm():
372      """Context manager for copying distribute.Strategy scope information."""
373      # pylint: disable=protected-access
374      # TODO(b/112906995, nareshmodi): distribution strategy depends on
375      # inheriting this stack from the default graph even in eager mode. Maybe
376      # it should be part of the eager context? This would also allow us to
377      # remove a get_default_graph() call from the function cache lookup.
378      graph = ops.get_default_graph()
379      old_strategy_stack = self._distribution_strategy_stack
380      self._distribution_strategy_stack = list(
381          graph._distribution_strategy_stack)
382
383      # We ignore device placements from any outer scopes while tracing the
384      # function when possible, to avoid hard-coding them in the function
385      # graph. "Default" placements come from the PartitionedCallOp's placement,
386      # so that the same trace of the Python function may be placed on several
387      # different devices and saved functions may be placed on new devices when
388      # restored.
389      # However, we need to preserve the outer device stack in the following
390      # cases in non eager context:
391      # 1. device stack is callable
392      # 2. When using distribution strategy with legacy graph mode.
393      old_device_stack = self._device_function_stack
394      if (not context.executing_eagerly() and
395          (device_stack_has_callable(graph._device_function_stack) or
396           (self._distribution_strategy_stack and
397            not ops.executing_eagerly_outside_functions()))):
398        # Hard-code devices from device functions in the function body
399        self._device_function_stack = graph._device_function_stack.copy()
400
401      old_creator_stack = self._variable_creator_stack
402      self._variable_creator_stack = graph._variable_creator_stack
403      # Inherit the graph key, since this is used for matching variables in
404      # optimizers.
405      old_graph_key = self._graph_key
406      self._graph_key = graph._graph_key
407      # pylint: enable=protected-access
408
409      old_scope_exit_callbacks = self._scope_exit_callbacks
410      self._scope_exit_callbacks = []
411
412      with outer_cm as g:
413        try:
414          yield g
415        finally:
416          try:
417            for fn in self._scope_exit_callbacks:
418              fn()
419          finally:
420            self._scope_exit_callbacks = old_scope_exit_callbacks
421            self._distribution_strategy_stack = old_strategy_stack
422            self._device_function_stack = old_device_stack
423            self._variable_creator_stack = old_creator_stack
424            self._graph_key = old_graph_key
425    return inner_cm()
426
427  @property
428  def outer_graph(self):
429    """The Graph this FuncGraph is nested in.
430
431    Functions may capture Tensors from graphs they are nested in (transitive).
432
433    Returns:
434      A Graph object. Initially set to the current default graph when the
435      FuncGraph was created. If the previous `outer_graph` was deleted because
436      the function that owns it was deleted, `outer_graph` is reset to the
437      outermost default graph active when the FuncGraph was created. This
438      FuncGraph won't have captured anything from the new `outer_graph` (and
439      likely not from the previous setting, since that would have created a
440      strong reference), but it is returned so that FuncGraphs always have a
441      parent.
442    """
443    current = self._weak_outer_graph()
444    if current is None:
445      return self._fallback_outer_graph
446    return current
447
448  @outer_graph.setter
449  def outer_graph(self, new_outer_graph):
450    """Sets `outer_graph` to `new_outer_graph`."""
451    self._weak_outer_graph = weakref.ref(new_outer_graph)
452
453  @property
454  def output_types(self):
455    return [t.dtype for t in self.outputs]
456
457  @property
458  def output_shapes(self):
459    return [t.shape for t in self.outputs]
460
461  @property
462  def trainable_variables(self):
463    """A sequence of trainable variables accessed by this FuncGraph.
464
465    Note that functions keep only weak references to variables. Calling the
466    function after a variable it accesses has been deleted is an error.
467
468    Returns:
469      Sequence of trainable variables for this func graph.
470    """
471    return tuple(v for v in self.variables if v.trainable)
472
473  @property
474  def variables(self):
475    """A sequence of variables accessed by this FuncGraph.
476
477    Note that functions keep only weak references to variables. Calling the
478    function after a variable it accesses has been deleted is an error.
479
480    Returns:
481      Sequence of variables for this func graph.
482    """
483    def deref(weak_v):
484      v = weak_v()
485      if v is None:
486        raise AssertionError(
487            "Called a function referencing variables which have been deleted. "
488            "This likely means that function-local variables were created and "
489            "not referenced elsewhere in the program. This is generally a "
490            "mistake; consider storing variables in an object attribute on "
491            "first call.")
492      return v
493
494    return tuple(deref(v) for v in self._weak_variables)
495
496  @variables.setter
497  def variables(self, var_list):
498    self._weak_variables = [weakref.ref(v) for v in var_list]
499
500  def _capture_by_value(
501      self,
502      op_type,
503      inputs,
504      dtypes,  # pylint: disable=redefined-outer-name
505      input_types=None,
506      name=None,
507      attrs=None,
508      op_def=None,
509      compute_device=True):
510    # When capturing by value, do the read outside
511    reverse_captures = dict((id(v), k) for k, v in self.captures)
512    uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
513    with ops.init_scope():
514      if context.executing_eagerly():
515        attr_list = ("dtype", int(attrs["dtype"].type))
516        value, = execute.execute(
517            compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
518            context.context())
519      else:
520        op = ops.get_default_graph()._create_op_internal(  # pylint: disable=protected-access
521            op_type,
522            uncaptured_inputs,
523            dtypes,
524            input_types,
525            name,
526            attrs,
527            op_def,
528            compute_device)
529        value = op.outputs[0]
530    captured_value = self.capture(value)
531    return captured_value.op
532
533  def _create_op_internal(
534      self,
535      op_type,
536      inputs,
537      dtypes=None,  # pylint: disable=redefined-outer-name
538      input_types=None,
539      name=None,
540      attrs=None,
541      op_def=None,
542      compute_device=True):
543    """Like Graph.create_op, except handles external input tensors.
544
545    This overload adds functionality to create_op to "capture" any external
546    input tensors, i.e. tensors from the eager context or outer function graphs
547    if this is a nested function. See `capture` for more information.
548
549    Args:
550      op_type: The `Operation` type to create. This corresponds to the
551        `OpDef.name` field for the proto that defines the operation.
552      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
553      dtypes: (Optional) A list of `DType` objects that will be the types of the
554        tensors that the operation produces.
555      input_types: (Optional.) A list of `DType`s that will be the types of
556        the tensors that the operation consumes. By default, uses the base
557        `DType` of each input in `inputs`. Operations that expect
558        reference-typed inputs must specify `input_types` explicitly.
559      name: (Optional.) A string name for the operation. If not specified, a
560        name is generated based on `op_type`.
561      attrs: (Optional.) A dictionary where the key is the attribute name (a
562        string) and the value is the respective `attr` attribute of the
563        `NodeDef` proto that will represent the operation (an `AttrValue`
564        proto).
565      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
566        the operation will have.
567      compute_device: (Optional.) If True, device functions will be executed
568        to compute the device property of the Operation.
569
570    Returns:
571      An `Operation` object.
572    """
573    if self.capture_by_value and op_type in ["ReadVariableOp",
574                                             "ResourceGather"]:
575      return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
576                                    attrs, op_def, compute_device)
577
578    # This capturing logic interacts poorly with control flow contexts which
579    # want to replace inputs of ops far too late in the process. This can lead
580    # the context to get confused and try to create an Enter for an Enter. We
581    # can detect this here and skip the additional Enter which can confuse loop
582    # validation logic.
583    if op_type == "Enter" and inputs[0].op.type == "Enter":
584      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
585        return inputs[0].op
586    # Calling AddValue on the control flow contexts to force creation of the
587    # backward accumulators in the original graph before we create placeholders
588    # to capture the inputs.
589    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
590    # Use a different list to avoid modifying the original inputs list.
591    captured_inputs = []
592    for inp in inputs:
593      # TPU Estimator defines a control flow context with no AddValue method.
594      if ctxt is not None and hasattr(ctxt, "AddValue"):
595        inp = ctxt.AddValue(inp)
596      inp = self.capture(inp)
597      captured_inputs.append(inp)
598    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
599        op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
600        compute_device)
601
602  def capture(self, tensor, name=None, shape=None):
603    """Captures `tensor` if it's external to this graph.
604
605    If `tensor` is from a different graph, returns a placeholder for it.
606    `tensor` and the placeholder will appear in self.captures, and the
607    placeholder will appear in self.inputs.  Multiple calls to this method with
608    the same `tensor` argument will return the same placeholder. If `tensor` is
609    from this graph, returns `tensor`.
610
611    Args:
612      tensor: Tensor. May be from this FuncGraph or a different graph.
613      name: Optional name if a placeholder is created.
614      shape: Optional shape if a placeholder is created.
615
616    Returns:
617      Tensor from this FuncGraph.
618
619    Raises:
620      InaccessibleTensorError: if any tensors are accessed in a manner that
621      bypasses the mechanisms required for the data dependencies to be correctly
622      wired.
623    """
624    if isinstance(tensor, ops.EagerTensor):
625      if name is None:
626        name = str(ops.uid())
627
628      # Small EagerTensors are captured with Const ops
629      if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
630          np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
631        return self.capture_eager_tensor(tensor, name)
632
633      # Large EagerTensors and resources are captured with Placeholder ops
634      return self._capture_helper(tensor, name, shape)
635    if tensor.graph is not self:
636      if name is None:
637        name = tensor.op.name
638      inner_graph = tensor.graph
639      while inner_graph is not None and isinstance(inner_graph, FuncGraph):
640        if inner_graph is self:
641          raise errors.InaccessibleTensorError(
642              "The tensor '%s' cannot be accessed here: it is defined"
643              " in another function or code block. Use return values,"
644              " explicit Python locals or TensorFlow collections to access"
645              " it. Defined in: %s; accessed from: %s.\n"
646              % (tensor, tensor.graph, self))
647        inner_graph = inner_graph.outer_graph
648      return self._capture_helper(tensor, name)
649    return tensor
650
651  def _capture_helper(self, tensor, name, shape=None):
652    capture = self._captures.get(id(tensor))
653    if capture is None:
654      placeholder = _create_substitute_placeholder(
655          tensor, name=name, dtype=tensor.dtype, shape=shape)
656      # Record the composite device as an attribute to the placeholder.
657      # This attribute would be propogated into the arg_attr of the FunctionDef.
658      # Currently, a packed eager tensor is always placed on a CompositeDevice.
659      if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
660        placeholder.op._set_attr(  # pylint: disable=protected-access
661            "_composite_device",
662            attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
663      self.add_capture(tensor, placeholder)
664    else:
665      placeholder = capture[1]
666    tape.record_operation("captured_value", [placeholder], [tensor],
667                          backward_function=lambda x: [x],
668                          forward_function=lambda x: [x])
669    return placeholder
670
671  @property
672  def captures(self):
673    """Order list of tuples containing external and internal captures."""
674    return self._captures.values()
675
676  def add_capture(self, tensor, placeholder):
677    """Capture a specific tensor and utilize the provided placeholder.
678
679    Args:
680      tensor: Tensor to captures.
681      placeholder: Provided placeholder for the tensor.
682    """
683    self._captures[id(tensor)] = (tensor, placeholder)
684    self.inputs.append(placeholder)
685
686  def replace_capture(self, tensor, placeholder):
687    """Replace already existing capture."""
688    self._captures[id(tensor)] = (tensor, placeholder)
689
690  def reset_captures(self, capture_list):
691    """Set the captures with the provided list of captures & placeholder."""
692    self._captures = py_collections.OrderedDict()
693    for tensor, placeholder in capture_list:
694      self._captures[id(tensor)] = (tensor, placeholder)
695
696  def pop_capture(self, tensor):
697    """Remove the capture and return the generated placeholder."""
698    capture = self._captures.pop(id(tensor), None)
699    if capture is None:
700      return None
701
702    return capture[1]
703
704  def clear_captures(self):
705    # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
706    # Clearing captures using clear() leaves some cycles around.
707    while self._captures:
708      self._captures.popitem()
709    memory.dismantle_ordered_dict(self._captures)
710    while self._deferred_captures:
711      self._deferred_captures.popitem()
712    memory.dismantle_ordered_dict(self._deferred_captures)
713
714  def capture_distributed_variable(self, variable, placeholder):
715    """Add given distributed variable to captures with given placeholder."""
716    self._captures[id(variable)] = (variable, placeholder)
717    tape.record_operation("captured_value", [placeholder], [variable],
718                          backward_function=lambda x: [x],
719                          forward_function=lambda x: [x])
720
721  def capture_eager_tensor(self, tensor, name):
722    capture = self._captures.get(id(tensor))
723    if capture is None:
724      # We clear all control dependencies and place the Const op on the same
725      # device as the source tensor. The device placement may be relaxed at
726      # a later date.
727      with ops.control_dependencies(None), self.device(tensor.device):
728        constant_value = tensor_util.constant_value(tensor)
729        if constant_value is None:
730          # Some eager tensors, e.g. parallel tensors, are not convertible to a
731          # single constant. We'll use a placeholder for this case.
732          return self._capture_helper(tensor, name)
733        graph_const = constant_op.constant(constant_value, dtype=tensor.dtype,
734                                           shape=tensor.shape, name=name)
735      self.add_capture(tensor, graph_const)
736    else:
737      graph_const = capture[1]
738    tape.record_operation("captured_value", [graph_const], [tensor],
739                          backward_function=lambda x: [x],
740                          forward_function=lambda x: [x])
741    return graph_const
742
743  def captured(self, tensor):
744    """Check if the specified tensor has been captured."""
745    return id(tensor) in self._captures
746
747  @property
748  def external_captures(self):
749    """External tensors captured by this function."""
750    return [c[0] for c in self._captures.values()]
751
752  @property
753  def internal_captures(self):
754    """Placeholders in this function corresponding captured tensors."""
755    return [c[1] for c in self._captures.values()]
756
757  @property
758  def deferred_external_captures(self):
759    """Ordered nest of tensors whose placeholders will be fed at call time."""
760    return [c[0] for c in self._deferred_captures.values()]
761
762  @property
763  def deferred_internal_captures(self):
764    """List of nest of placeholders which at call time will be fed."""
765    return [c[1] for c in self._deferred_captures.values()]
766
767  @property
768  def variable_captures(self):
769    """Map of python object ids of variables to variables which are captured."""
770    return {
771        id(self._captures[id(v)][1]): v
772        for v in self.variables
773        if id(v) in self._captures
774    }
775
776  def mark_as_unsaveable(self, error_message):
777    """Marks this FuncGraph as unsaveable.
778
779    Any attempts to export this FuncGraph will raise an error with the specified
780    message.
781
782    Args:
783      error_message: List or string containing the error message to be raised
784        when saving this FuncGraph to SavedModel.
785    """
786    self._saveable = False
787    if isinstance(error_message, str):
788      error_message = [error_message]
789    self._saving_errors.update(error_message)
790
791  @property
792  def saveable(self):
793    """Returns whether this FuncGraph is saveable."""
794    return self._saveable
795
796  @property
797  def saving_errors(self):
798    """Returns set of errors preventing this FuncGraph from being saved."""
799    return self._saving_errors
800
801  def _add_scope_exit_callback(self, fn):
802    """Add a function to call when this graph exits the default scope."""
803    if not callable(fn):
804      raise TypeError("fn is not callable: {}".format(fn))
805    if self._scope_exit_callbacks is None:
806      raise RuntimeError(
807          "Attempting to add a scope exit callback, but the default graph is "
808          "not the context scope graph.  Did you forget to call "
809          "'with graph.as_default(): ...'?")
810    self._scope_exit_callbacks.append(fn)
811
812
813def func_graph_from_py_func(name,
814                            python_func,
815                            args,
816                            kwargs,
817                            signature=None,
818                            func_graph=None,
819                            autograph=False,
820                            autograph_options=None,
821                            add_control_dependencies=True,
822                            arg_names=None,
823                            op_return_value=None,
824                            collections=None,
825                            capture_by_value=None,
826                            override_flat_arg_shapes=None):
827  """Returns a `FuncGraph` generated from `python_func`.
828
829  Args:
830    name: an identifier for the function.
831    python_func: the Python function to trace.
832    args: the positional args with which the Python function should be called;
833      ignored if a signature is provided.
834    kwargs: the keyword args with which the Python function should be called;
835      ignored if a signature is provided.
836    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
837      and dtypes of the arguments. When a signature is provided, `args` and
838      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
839      to `signature`. If `None`, the shapes and dtypes are inferred from the
840      inputs.
841    func_graph: Optional. An instance of FuncGraph. If provided, we will use
842      this graph else a new one is built and returned.
843    autograph: whether to use autograph to compile `python_func`.
844      See https://www.tensorflow.org/guide/autograph for more information.
845    autograph_options: additional knobs to control when `autograph=True`.
846      See https://www.tensorflow.org/guide/autograph for more information.
847    add_control_dependencies: If True, automatically adds control dependencies
848      to ensure program order matches execution order and stateful ops always
849      execute.
850    arg_names: Optional list of argument names, used to give input placeholders
851      recognizable names.
852    op_return_value: Optional. A Tensor. If set and `python_func` returns
853      Operations, those return values will be replaced with this value. If not
854      set, returning an Operation triggers an error.
855    collections: a dictionary of collections this FuncGraph should start
856      with. If not specified (None), the FuncGraph will read (but not write to)
857      the outer graph's collections that are not allowlisted, and both
858      read and write to the outer graph's collections that are allowlisted.
859      The current allowlisted collections are the global variables, the
860      local variables, and the trainable variables.
861      Defaults to None.
862    capture_by_value: An optional boolean. If True, the func graph will capture
863      Variables by value instead of reference. By default inherit from outer
864      graphs, and failing that will default to False.
865    override_flat_arg_shapes: An optional list of instances that are either
866      `None` or `TensorShape`.  The length must match that of
867      `nest.flatten((args, kwargs), expand_composites=True)`.  The entries
868      containing value `None` must match entries in flattened arguments
869      containing non-tensors, while entries containing a `TensorShape` must
870      match entries in the flattened arguments containing tensors.
871
872  Returns:
873    A FuncGraph.
874
875  Raises:
876    TypeError: If any of `python_func`'s return values is neither `None` nor a
877      `Tensor`.
878    ValueError: If both `signature` and `override_flat_arg_shapes` are
879      passed in.
880  """
881  if op_return_value is not None:
882    assert isinstance(op_return_value, ops.Tensor), op_return_value
883  if func_graph is None:
884    func_graph = FuncGraph(name, collections=collections,
885                           capture_by_value=capture_by_value)
886  assert isinstance(func_graph, FuncGraph)
887  if add_control_dependencies:
888    deps_control_manager = auto_control_deps.AutomaticControlDependencies()
889  else:
890    deps_control_manager = ops.NullContextmanager()
891
892  with func_graph.as_default(), deps_control_manager as deps_ctx:
893    current_scope = variable_scope.get_variable_scope()
894    default_use_resource = current_scope.use_resource
895    current_scope.set_use_resource(True)
896
897    if signature is not None and override_flat_arg_shapes is not None:
898      raise ValueError(
899          "Passed both signature and override_flat_arg_shapes: %s and %s."
900          % (signature, override_flat_arg_shapes))
901
902    if signature is not None:
903      args = signature
904      kwargs = {}
905
906    # Creates and names placeholders for all arguments.
907    if override_flat_arg_shapes is not None:
908      flat_args = nest.flatten(args, expand_composites=True)
909      arg_shapes = override_flat_arg_shapes[:len(flat_args)]
910      kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
911    else:
912      arg_shapes = None
913      kwarg_shapes = None
914    func_args = _get_defun_inputs_from_args(
915        args, arg_names, flat_shapes=arg_shapes)
916    func_kwargs = _get_defun_inputs_from_kwargs(
917        kwargs, flat_shapes=kwarg_shapes)
918
919    # Convert all Tensors into TensorSpecs before saving the structured inputs.
920    # If storing pure concrete functions that are not called through polymorphic
921    # functions, we don't have access to FunctionSpec, so we need to call the
922    # TensorSpecs by their `arg_names` for later binding.
923    func_graph.structured_input_signature = (
924        convert_structure_to_signature(func_args, arg_names),
925        convert_structure_to_signature(func_kwargs))
926
927    flat_func_args = nest.flatten(func_args, expand_composites=True)
928    flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
929    # Temporarily set inputs to allow graph building code to inspect
930    # them. Reassigned below.
931    func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
932                         if isinstance(arg, ops.Tensor)]
933
934    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
935    # Variables to help check whether mutation happens in calling the function
936    # Copy the recursive list, tuple and map structure, but not base objects
937    func_args_before = nest.pack_sequence_as(func_args, flat_func_args,
938                                             expand_composites=True)
939    func_kwargs_before = nest.pack_sequence_as(
940        func_kwargs, flat_func_kwargs, expand_composites=True)
941
942    def convert(x):
943      """Converts a function output to a Tensor."""
944      if x is None:
945        return None
946      if op_return_value is not None and isinstance(x, ops.Operation):
947        # TODO(b/79881896): we currently can't capture external control deps, so
948        # this won't work if x needs to be captured (i.e. if python_func returns
949        # captured Operations).
950        with ops.control_dependencies([x]):
951          x = array_ops.identity(op_return_value)
952      elif not isinstance(x, tensor_array_ops.TensorArray):
953        try:
954          x = ops.convert_to_tensor_or_composite(x)
955        except (ValueError, TypeError):
956          raise TypeError(
957              "To be compatible with tf.eager.defun, Python functions "
958              "must return zero or more Tensors; in compilation of %s, found "
959              "return value of type %s, which is not a Tensor." %
960              (str(python_func), type(x)))
961      if add_control_dependencies:
962        x = deps_ctx.mark_as_return(x)
963      return x
964
965    try:
966      if autograph:
967        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
968        _, original_func = tf_decorator.unwrap(python_func)
969
970        def wrapper(*args, **kwargs):
971          """Calls a converted version of original_func."""
972          # TODO(mdan): Push this block higher in tf.function's call stack.
973          try:
974            return autograph.converted_call(
975                original_func,
976                args,
977                kwargs,
978                options=autograph.ConversionOptions(
979                    recursive=True,
980                    optional_features=autograph_options,
981                    user_requested=True,
982                ))
983          except Exception as e:  # pylint:disable=broad-except
984            if hasattr(e, "ag_error_metadata"):
985              raise e.ag_error_metadata.to_exception(e)
986            else:
987              raise
988
989        # Wrapping around a decorator allows checks like tf_inspect.getargspec
990        # to be accurate.
991        converted_func = tf_decorator.make_decorator(original_func, wrapper)
992        python_func = tf_decorator.rewrap(python_func, original_func,
993                                          converted_func)
994
995      else:
996        _, original_func = tf_decorator.unwrap(python_func)
997
998      func_outputs = python_func(*func_args, **func_kwargs)
999
1000      # invariant: `func_outputs` contains only Tensors, CompositeTensors,
1001      # TensorArrays and `None`s.
1002      func_outputs = nest.map_structure(convert, func_outputs,
1003                                        expand_composites=True)
1004
1005      check_mutation(func_args_before, func_args, original_func)
1006      check_mutation(func_kwargs_before, func_kwargs, original_func)
1007    finally:
1008      current_scope.set_use_resource(default_use_resource)
1009
1010    # Variables in `func_args`, `func_kwargs` should be explicit inputs
1011    # to the function, not captured inputs.
1012    graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
1013    arg_variables = object_identity.ObjectIdentitySet()
1014    inputs = []
1015    for arg in (nest.flatten(func_args, expand_composites=True) +
1016                nest.flatten(func_kwargs, expand_composites=True)):
1017      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1018        # Even if an argument variable was not used in the function, we've
1019        # already manually captured the resource Tensor when creating argument
1020        # placeholders.
1021        resource_placeholder = func_graph.pop_capture(arg.handle)
1022        if resource_placeholder is None:
1023          continue
1024        arg_variables.add(arg)
1025        inputs.append(resource_placeholder)
1026      elif isinstance(arg, ops.Tensor):
1027        inputs.append(arg)
1028    variables = [v for v in graph_variables if v not in arg_variables]
1029    func_graph.inputs = (
1030        inputs + func_graph.internal_captures + nest.flatten(
1031            func_graph.deferred_internal_captures, expand_composites=True))
1032    func_graph.structured_outputs = func_outputs
1033    # Returning a closed-over tensor does not trigger convert_to_tensor.
1034    func_graph.outputs.extend(
1035        func_graph.capture(x)
1036        for x in flatten(func_graph.structured_outputs)
1037        if x is not None)
1038
1039    func_graph.variables = variables
1040
1041  if add_control_dependencies:
1042    func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
1043    func_graph.collective_manager_ids_used = (
1044        deps_control_manager.collective_manager_ids_used)
1045
1046  return func_graph
1047
1048
1049def maybe_captured(tensor):
1050  """If t is a captured value placeholder, returns the original captured value.
1051
1052  Args:
1053    tensor: Tensor.
1054
1055  Returns:
1056    A tensor, potentially from a different Graph/FuncGraph.
1057  """
1058  if (not isinstance(tensor, ops.EagerTensor) and
1059      tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
1060    for input_t, placeholder_t in tensor.op.graph.captures:
1061      if tensor == placeholder_t:
1062        return maybe_captured(input_t)
1063  # pylint: enable=protected-access
1064  return tensor
1065
1066
1067def device_stack_has_callable(device_stack):
1068  """Checks whether a device stack contains a callable."""
1069  return any(callable(spec._device_name_or_function)  # pylint: disable=protected-access
1070             for spec in device_stack.peek_objs())
1071
1072
1073def check_mutation(n1, n2, func):
1074  """Check if two list of arguments are exactly the same."""
1075  func_name = getattr(func, "__name__", func)
1076
1077  errmsg = ("{}() should not modify its Python input arguments."
1078            " Check if it modifies any lists or dicts passed as"
1079            " arguments. Modifying a copy is allowed.".format(func_name))
1080  try:
1081    # TODO(mdan): Compare more robustly so that argument names can be reported.
1082    nest.assert_same_structure(n1, n2, expand_composites=True)
1083  except ValueError:
1084    raise ValueError(errmsg)
1085
1086  for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
1087                        nest.flatten(n2, expand_composites=True)):
1088    if arg1 is not arg2:
1089      raise ValueError(errmsg)
1090
1091
1092# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1093def flatten(sequence):
1094  """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
1095
1096  Args:
1097    sequence: A nested structure of Tensors, CompositeTensors, and
1098      TensorArrays.
1099
1100  Returns:
1101    A list of tensors.
1102  """
1103  flat_sequence = nest.flatten(sequence, expand_composites=True)
1104  return [
1105      item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
1106      for item in flat_sequence]
1107
1108
1109# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1110def pack_sequence_as(structure, flat_sequence):
1111  """Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
1112
1113  Args:
1114    structure: The structure to pack into. May contain Tensors,
1115      CompositeTensors, or TensorArrays.
1116    flat_sequence: An iterable containing tensors.
1117
1118  Returns:
1119    A nested structure.
1120
1121  Raises:
1122    AssertionError if `structure` and `flat_sequence` are not compatible.
1123  """
1124  flat_sequence = list(flat_sequence)
1125  flattened_structure = nest.flatten(structure, expand_composites=True)
1126  if len(flattened_structure) != len(flat_sequence):
1127    raise ValueError("Mismatch in element count")
1128  for i in range(len(flat_sequence)):
1129    if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
1130      flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
1131          old_ta=flattened_structure[i], flow=flat_sequence[i])
1132  return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
1133
1134
1135def _create_substitute_placeholder(value, name=None, dtype=None, shape=None):
1136  """Creates a placeholder for `value` and propagates shape info to it."""
1137  # Note: setting ops.control_dependencies(None) ensures we always put
1138  # capturing placeholders outside of any control flow context.
1139  if shape is None:
1140    shape = value.shape
1141  with ops.control_dependencies(None):
1142    placeholder = graph_placeholder(
1143        dtype=dtype or value.dtype, shape=shape, name=name)
1144  custom_gradient.copy_handle_data(value, placeholder)
1145  return placeholder
1146
1147
1148def _get_defun_inputs_from_args(args, names, flat_shapes=None):
1149  """Maps Python function positional args to graph-construction inputs."""
1150  return _get_defun_inputs(
1151      args, names, structure=args, flat_shapes=flat_shapes)
1152
1153
1154def _get_composite_tensor_spec(x):
1155  """Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
1156  return (x._type_spec  # pylint: disable=protected-access
1157          if isinstance(x, composite_tensor.CompositeTensor) else x)
1158
1159
1160def _get_defun_inputs(args, names, structure, flat_shapes=None):
1161  """Maps python function args to graph-construction inputs.
1162
1163  Args:
1164    args: A flat list of user-specified arguments.
1165    names: A list of strings with user-specified argument names, same length as
1166      `args`. May be `None`, in which case a generic name is used.
1167    structure: The original argument list or dictionary.
1168    flat_shapes: A flat list of values that are either `None` or
1169      instances of `TensorShape`.  If provided, then length must match
1170      that of `nest.flatten(args, expand_composites=True)`; and locations where
1171      `args` are instances of `Tensor` must have a corresponding `TensorShape`
1172      in `flat_shapes`.  May be `None`, in which case exact shapes are read
1173      directly from the args.
1174
1175  Returns:
1176    Placeholders with the same structure as `structure`.
1177
1178  Raises:
1179    RuntimeError: if `flat_shapes` is provided, but
1180     `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`.
1181    RuntimeError: if a shape from `flat_shapes` is not None
1182     for an argument that is not a `Tensor`, `TensorSpec`,
1183     or `ResourceVariable`.
1184  """
1185  func_graph = ops.get_default_graph()
1186  function_inputs = []
1187  if names is None:
1188    names = [None] * len(args)
1189  if flat_shapes is None:
1190    shapes_iter = itertools.repeat(None)
1191  else:
1192    len_flat_args = len(nest.flatten(args, expand_composites=True))
1193    if len_flat_args != len(flat_shapes):
1194      raise RuntimeError(
1195          "Length of fully flat shapes (%d) must match that of "
1196          "flatten(args) (%d).  args: %s, flat_shapes: %s"
1197          % (len(flat_shapes),
1198             len_flat_args,
1199             args,
1200             flat_shapes))
1201    shapes_iter = iter(flat_shapes)
1202  for arg_value, name in zip(args, names):
1203
1204    # Replace any composite tensors with their TypeSpecs.  This is important
1205    # for ensuring that shape information that's not preserved by the TypeSpec
1206    # (such as the number of values in a SparseTensor) gets properly masked.
1207    arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
1208
1209    flattened = nest.flatten(arg_value, expand_composites=True)
1210
1211    for arg in flattened:
1212      # We have a shape entry for each arg, regardless of whether it's a real
1213      # Tensor or not.  For non-tensor entries it should be None.
1214      shape = next(shapes_iter)
1215      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
1216        arg_is_spec = isinstance(arg, tensor_spec.TensorSpec)
1217        if arg_is_spec and arg.name:
1218          requested_name = arg.name
1219        else:
1220          requested_name = name
1221        placeholder_shape = shape if shape is not None else arg.shape
1222        try:
1223          placeholder = graph_placeholder(
1224              arg.dtype, placeholder_shape,
1225              name=requested_name)
1226        except ValueError:
1227          # Sometimes parameter names are not valid op names, so fall back to
1228          # unnamed placeholders.
1229          placeholder = graph_placeholder(arg.dtype, placeholder_shape)
1230        if not arg_is_spec:
1231          custom_gradient.copy_handle_data(arg, placeholder)
1232        if name is not None:
1233          # Record the requested/user-specified name in case it's different than
1234          # the uniquified name, for validation when exporting signatures.
1235          placeholder.op._set_attr(  # pylint: disable=protected-access
1236              "_user_specified_name",
1237              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
1238        function_inputs.append(placeholder)
1239      elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
1240                            resource_variable_ops.VariableSpec)):
1241        if isinstance(arg, resource_variable_ops.VariableSpec):
1242          name = arg.name or name
1243          with func_graph.outer_graph.as_default():
1244            placeholder = graph_placeholder(dtypes.resource, arg.shape,
1245                                            name=name)
1246
1247            arg = resource_variable_ops.BaseResourceVariable(
1248                name=name,
1249                shape=arg.shape,
1250                dtype=arg.dtype,
1251                handle=placeholder,
1252                handle_name=name)
1253        # Capture arg variables to create placeholders for them. These will be
1254        # removed as captures after the function is traced (since otherwise we'd
1255        # just add it back with a new placeholder when the variable was
1256        # referenced).
1257        placeholder = func_graph.capture(arg.handle, name=name)
1258        placeholder.op._set_attr(  # pylint: disable=protected-access
1259            "_user_specified_name",
1260            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
1261        function_inputs.append(arg)
1262      else:
1263        if shape is not None:
1264          raise RuntimeError(
1265              "Expected provided shape override to be None for arg that isn't "
1266              "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s"
1267              % (arg, shape, args))
1268        function_inputs.append(arg)
1269  return nest.pack_sequence_as(structure, function_inputs,
1270                               expand_composites=True)
1271
1272
1273def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
1274  """Maps Python function keyword args to graph-construction inputs."""
1275  if kwargs:
1276    names, args = zip(*sorted(kwargs.items()))
1277  else:
1278    names = []
1279    args = []
1280  return _get_defun_inputs(
1281      args, names, structure=kwargs, flat_shapes=flat_shapes)
1282
1283
1284def dismantle_func_graph(func_graph):
1285  """Removes reference cycles in `func_graph` FuncGraph.
1286
1287  Helpful for making sure the garbage collector doesn't need to run when
1288  the FuncGraph goes out of scope, e.g. in tests using defun with
1289  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
1290
1291  Args:
1292    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
1293      after this function.
1294  """
1295  func_graph.clear_captures()
1296  ops.dismantle_graph(func_graph)
1297
1298
1299def override_func_graph_name_scope(func_graph, name_scope):
1300  func_graph._name_stack = name_scope  # pylint: disable=protected-access
1301