• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as py_collections
22import itertools
23import weakref
24
25import numpy as np
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.eager import execute
30from tensorflow.python.eager import tape
31from tensorflow.python.eager.graph_only_ops import graph_placeholder
32from tensorflow.python.framework import auto_control_deps
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import type_spec
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import handle_data_util
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.ops import tensor_array_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.util import compat
47from tensorflow.python.util import memory
48from tensorflow.python.util import nest
49from tensorflow.python.util import object_identity
50from tensorflow.python.util import tf_contextlib
51from tensorflow.python.util import tf_decorator
52from tensorflow.python.util.tf_export import tf_export
53
54ALLOWLIST_COLLECTIONS = [
55    ops.GraphKeys.GLOBAL_VARIABLES,
56    ops.GraphKeys.LOCAL_VARIABLES,
57    ops.GraphKeys.TRAINABLE_VARIABLES,
58    variable_scope._VARSTORE_KEY,  # pylint: disable=protected-access
59    variable_scope._VARSCOPESTORE_KEY  # pylint: disable=protected-access
60]
61
62
63_EAGER_CONST_THRESHOLD = 128
64
65
66class UnknownArgument(object):
67  """Signifies an argument which is not currently handled."""
68  pass
69
70
71def convert_structure_to_signature(structure, arg_names=None):
72  """Convert a potentially nested structure to a signature.
73
74  Args:
75    structure: Structure to convert, where top level collection is a list or a
76      tuple.
77    arg_names: Optional list of arguments that has equal number of elements as
78      `structure` and is used for naming corresponding TensorSpecs.
79
80  Returns:
81    Identical structure that has TensorSpec objects instead of Tensors and
82    UnknownArgument instead of any unsupported types.
83  """
84  def encode_arg(arg, path):
85    """A representation for this argument, for converting into signatures."""
86    if isinstance(arg, ops.Tensor):
87      user_specified_name = None
88      try:
89        user_specified_name = compat.as_str(
90            arg.op.get_attr("_user_specified_name"))
91      except ValueError:
92        pass
93
94      if path and user_specified_name and user_specified_name != path[0]:
95        # The user has explicitly named the argument differently than the name
96        # of the function argument.
97        name = user_specified_name
98      else:
99        name = "/".join(str(p) for p in path)
100      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
101    if isinstance(arg, composite_tensor.CompositeTensor):
102      # TODO(b/133606651) Do we need to inject arg_name?
103      return arg._type_spec  # pylint: disable=protected-access
104    if isinstance(arg, resource_variable_ops.BaseResourceVariable):
105      name = "/".join(str(p) for p in path)
106      return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name,
107                                                trainable=arg.trainable)
108    if isinstance(arg, (
109        int,
110        float,
111        bool,
112        str,
113        type(None),
114        dtypes.DType,
115        tensor_spec.TensorSpec,
116        type_spec.TypeSpec,
117    )):
118      return arg
119    return UnknownArgument()
120
121  # We are using the flattened paths to name the TensorSpecs. We need an
122  # explicit name for them downstream.
123  flattened = nest.flatten_with_tuple_paths(structure)
124  if arg_names:
125    if len(arg_names) != len(structure):
126      raise ValueError(
127          "Passed in arg_names don't match actual signature (%s)." % arg_names)
128    # Replace all top-level names with their actual arg_names. If a path before
129    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
130    flattened = [
131        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
132    ]
133
134  mapped = [encode_arg(arg, path) for path, arg in flattened]
135  return nest.pack_sequence_as(structure, mapped)
136
137
138@tf_export("__internal__.FuncGraph", v1=[])
139class FuncGraph(ops.Graph):
140  """Graph representing a function body.
141
142  Attributes:
143    name: The name of the function.
144    inputs: Placeholder tensors representing the inputs to this function. The
145      tensors are in this FuncGraph. This represents "regular" inputs as well as
146      captured inputs (i.e. the values of self.captures), with the regular
147      inputs coming first.
148    outputs: Tensors that will be returned by this function. The tensors are in
149      this FuncGraph.
150    control_outputs: Operations that must be executed before the function
151      represented by this graph can be said to have been executed.
152    structured_input_signature: A tuple of (args, kwargs), which are both
153      possibly-nested python objects that were received by this function. Note
154      that these structures might contain Python `None`s.
155    structured_outputs: A possibly-nested python object which will be returned
156      by this function. The Tensors in this structure are the same as those of
157      self.outputs. Note that this structure might contain Python `None`s.
158    variables: Variables that should be watched during function execution.
159    outer_graph: The graph this function is defined in. May be another FuncGraph
160      or the global default Graph.
161    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
162      The entries are in the order they were captured.
163    control_captures: Set of external ops on which this graph has a control
164      dependency.
165    seed: The graph-level random seed.
166    capture_by_value: If True, the func graph will capture Variables by value
167      instead of reference.
168  """
169
170  def __init__(self, name, collections=None, capture_by_value=None):
171    """Construct a new FuncGraph.
172
173    The graph will inherit its graph key, collections, seed, and distribution
174    strategy stack from the current context or graph.
175
176    Args:
177      name: the name of the function.
178      collections: a dictionary of collections this FuncGraph should start
179        with. If not specified (None), the FuncGraph will read (but not write
180        to) the outer graph's collections that are not allowlisted, and both
181        read and write to the outer graph's collections that are allowlisted.
182        The current allowlisted collections are the global variables, the
183        local variables, and the trainable variables.
184        Defaults to None.
185      capture_by_value: An optional boolean. If True, the func graph will
186        capture Variables by value instead of reference. By default inherit
187        from outer graphs, and failing that will default to False.
188    """
189    super(FuncGraph, self).__init__()
190
191    self.name = name
192    self.inputs = []
193    self.outputs = []
194    self.control_outputs = []
195    self.control_captures = object_identity.ObjectIdentitySet()
196    self.structured_input_signature = None
197    self.structured_outputs = None
198    self._weak_variables = []
199    self._watched_variables = object_identity.ObjectIdentityWeakSet()
200    self.is_control_flow_graph = False
201
202    outer_graph = ops.get_default_graph()
203    self._weak_outer_graph = weakref.ref(outer_graph)
204    while outer_graph.building_function:
205      outer_graph = outer_graph.outer_graph
206    # If self._weak_outer_graph is deleted, we revert to the outermost Graph
207    # active when the FuncGraph was traced. This will not be a FuncGraph.
208    self._fallback_outer_graph = outer_graph
209    self._captures = py_collections.OrderedDict()
210    # If not None, records the names of output args of this function. Used to
211    # preserve the output names in the signature of a serialized+deserialized
212    # function. Private at the moment mostly because it's often out of date.
213    self._output_names = None
214    # Maps arbitrary key -> (closure, nest of placeholders), where at function
215    # call time the value of closure() will be used to feed the nest of
216    # placeholders.
217    self._deferred_captures = py_collections.OrderedDict()
218    # Inherit capture-by-value from outer graph.
219    if capture_by_value is not None:
220      self.capture_by_value = capture_by_value
221    elif self.outer_graph is not None and isinstance(
222        self.outer_graph, FuncGraph):
223      self.capture_by_value = self.outer_graph.capture_by_value
224    else:
225      self.capture_by_value = False
226
227    self._building_function = True
228    # Map from resource tensor name to last op (in program order) which uses
229    # this tensor. Used to enforce that execution order matches program order
230    # for resource tensors.
231    self._last_op_using_resource_tensor = {}
232
233    graph = self.outer_graph
234
235    if context.executing_eagerly():
236      self.seed = context.global_seed()
237      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
238      # any None op_seed for random_op in the function, in which case we end up
239      # using function seed, which could be unintended behavior for the op.
240      self._seed_used = False
241    else:
242      self.seed = graph.seed
243      self._seed_used = False
244      # TODO(allenl): Figure out if we can remove colocation stack
245      # specialization (currently used in cond_v2), here and in the cache key.
246      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access
247
248    if collections is None:
249      for collection_name in graph.get_all_collection_keys():
250        if collection_name not in ALLOWLIST_COLLECTIONS:
251          self._collections[collection_name] = graph.get_collection(
252              collection_name)
253      for collection_name in ALLOWLIST_COLLECTIONS:
254        self._collections[collection_name] = graph.get_collection_ref(
255            collection_name)
256    else:
257      self._collections = collections
258
259    # Keep track of whether this FuncGraph is exportable to SavedModel. Use
260    # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
261    # dependent functions as unsaveable.
262    self._saveable = True
263    self._saving_errors = set()
264
265    # Keep track of callbacks to run when this graph exits default scope
266    self._scope_exit_callbacks = None
267
268  def __str__(self):
269    return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
270
271  def watch_variable(self, v):
272    """Marks the variable v as accessed while building this graph."""
273    while self is not None and isinstance(self, FuncGraph):
274      self._watched_variables.add(v)
275      self = self.outer_graph
276
277  def capture_call_time_value(self, closure, spec, key=None):
278    """Creates a placeholder which at call time has the value closure().
279
280    Useful, for example, to respect TensorFlow context managers, which are often
281    dynamically scoped.
282
283    Args:
284      closure: function which takes no arguments, to be evaluated at function
285       call time, returning a nest of tensors compatible with `spec`.
286      spec: nest of TypeSpec for the value to capture.
287      key: optional. If not None, multiple calls to lazy_capture with the same
288       key in the same graph will return the same placeholder, and the
289       first closure will be used at function call time.
290
291    Returns:
292      Nest of placeholders which, at function call time, will be fed with the
293      result of calling closure().
294
295    Raises:
296      ValueError: at function call time, if the return value of closure() is
297       not compatible with `spec`.
298    """
299    if key is None:
300      key = object()
301    if key not in self._deferred_captures:
302
303      def convert_to_placeholder(s):
304        if not isinstance(s, tensor_spec.DenseSpec):
305          raise TypeError(
306              "Expected a nest of `TypeSpec` objects, found %s of type %s." %
307              (s, type(s)))
308        return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
309
310      placeholder = nest.map_structure(
311          convert_to_placeholder, spec, expand_composites=True)
312
313      def wrapped_closure():
314        ret_nest = closure()
315        nest.assert_same_structure(spec, ret_nest, expand_composites=True)
316        # This uses the tensor dtype defined in `spec` when converting values
317        # in `ret_nest` to tensors.
318        # pylint: disable=protected-access
319        y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest,
320                               expand_composites=False)
321        # pylint: enable=protected-access
322        return nest.flatten(y, expand_composites=True)
323
324      self._deferred_captures[key] = (wrapped_closure, placeholder)
325    return self._deferred_captures[key][1]
326
327  def control_dependencies(self, control_inputs):
328    """Handles control dependencies.
329
330    FuncGraph wraps Graph's control_dependencies logic by first filtering out
331    any external tensors / operations and storing them in the graph's
332    control_captures member. Any consumers of this function graph must then
333    decide how to handle the control captures.
334
335    Args:
336      control_inputs: A list of `Operation` or `Tensor` objects which
337        must be executed or computed before running the operations
338        defined in the context.  Can also be `None` to clear the control
339        dependencies.
340
341    Returns:
342     A context manager that specifies control dependencies for all
343     operations constructed within the context.
344
345    Raises:
346      TypeError: If `control_inputs` is not a list of `Operation` or
347        `Tensor` objects.
348    """
349    if control_inputs is None:
350      return super(FuncGraph, self).control_dependencies(control_inputs)
351
352    filtered_control_inputs = []
353    for c in control_inputs:
354      # Check for _UnreadVariable
355      if (isinstance(c, ops.IndexedSlices) or
356          (hasattr(c, "_handle") and hasattr(c, "op"))):
357        c = c.op
358      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
359      if graph_element is None:
360        graph_element = c
361      if graph_element is not None and getattr(
362          graph_element, "graph", None) is not self:
363        self.control_captures.add(graph_element)
364      else:
365        filtered_control_inputs.append(graph_element)
366    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
367
368  def as_default(self):
369    outer_cm = super(FuncGraph, self).as_default()
370
371    @tf_contextlib.contextmanager
372    def inner_cm():
373      """Context manager for copying distribute.Strategy scope information."""
374      # pylint: disable=protected-access
375      # TODO(b/112906995, nareshmodi): distribution strategy depends on
376      # inheriting this stack from the default graph even in eager mode. Maybe
377      # it should be part of the eager context? This would also allow us to
378      # remove a get_default_graph() call from the function cache lookup.
379      graph = ops.get_default_graph()
380      old_strategy_stack = self._distribution_strategy_stack
381      self._distribution_strategy_stack = list(
382          graph._distribution_strategy_stack)
383
384      # We ignore device placements from any outer scopes while tracing the
385      # function when possible, to avoid hard-coding them in the function
386      # graph. "Default" placements come from the PartitionedCallOp's placement,
387      # so that the same trace of the Python function may be placed on several
388      # different devices and saved functions may be placed on new devices when
389      # restored.
390      # However, we need to preserve the outer device stack in the following
391      # cases in non eager context:
392      # 1. device stack is callable
393      # 2. When using distribution strategy with legacy graph mode.
394      old_device_stack = self._device_function_stack
395      if (not context.executing_eagerly() and
396          (device_stack_has_callable(graph._device_function_stack) or
397           (self._distribution_strategy_stack and
398            not ops.executing_eagerly_outside_functions()))):
399        # Hard-code devices from device functions in the function body
400        self._device_function_stack = graph._device_function_stack.copy()
401
402      old_creator_stack = self._variable_creator_stack
403      self._variable_creator_stack = graph._variable_creator_stack
404      # Inherit the graph key, since this is used for matching variables in
405      # optimizers.
406      old_graph_key = self._graph_key
407      self._graph_key = graph._graph_key
408      # pylint: enable=protected-access
409
410      old_scope_exit_callbacks = self._scope_exit_callbacks
411      self._scope_exit_callbacks = []
412
413      with outer_cm as g:
414        try:
415          yield g
416        finally:
417          try:
418            for fn in self._scope_exit_callbacks:
419              fn()
420          finally:
421            self._scope_exit_callbacks = old_scope_exit_callbacks
422            self._distribution_strategy_stack = old_strategy_stack
423            self._device_function_stack = old_device_stack
424            self._variable_creator_stack = old_creator_stack
425            self._graph_key = old_graph_key
426    return inner_cm()
427
428  @property
429  def outer_graph(self):
430    """The Graph this FuncGraph is nested in.
431
432    Functions may capture Tensors from graphs they are nested in (transitive).
433
434    Returns:
435      A Graph object. Initially set to the current default graph when the
436      FuncGraph was created. If the previous `outer_graph` was deleted because
437      the function that owns it was deleted, `outer_graph` is reset to the
438      outermost default graph active when the FuncGraph was created. This
439      FuncGraph won't have captured anything from the new `outer_graph` (and
440      likely not from the previous setting, since that would have created a
441      strong reference), but it is returned so that FuncGraphs always have a
442      parent.
443    """
444    current = self._weak_outer_graph()
445    if current is None:
446      return self._fallback_outer_graph
447    return current
448
449  @outer_graph.setter
450  def outer_graph(self, new_outer_graph):
451    """Sets `outer_graph` to `new_outer_graph`."""
452    self._weak_outer_graph = weakref.ref(new_outer_graph)
453
454  @property
455  def output_types(self):
456    return [t.dtype for t in self.outputs]
457
458  @property
459  def output_shapes(self):
460    return [t.shape for t in self.outputs]
461
462  @property
463  def trainable_variables(self):
464    """A sequence of trainable variables accessed by this FuncGraph.
465
466    Note that functions keep only weak references to variables. Calling the
467    function after a variable it accesses has been deleted is an error.
468
469    Returns:
470      Sequence of trainable variables for this func graph.
471    """
472    return tuple(v for v in self.variables if v.trainable)
473
474  @property
475  def variables(self):
476    """A sequence of variables accessed by this FuncGraph.
477
478    Note that functions keep only weak references to variables. Calling the
479    function after a variable it accesses has been deleted is an error.
480
481    Returns:
482      Sequence of variables for this func graph.
483    """
484    def deref(weak_v):
485      v = weak_v()
486      if v is None:
487        raise AssertionError(
488            "Called a function referencing variables which have been deleted. "
489            "This likely means that function-local variables were created and "
490            "not referenced elsewhere in the program. This is generally a "
491            "mistake; consider storing variables in an object attribute on "
492            "first call.")
493      return v
494
495    return tuple(deref(v) for v in self._weak_variables)
496
497  @variables.setter
498  def variables(self, var_list):
499    self._weak_variables = [weakref.ref(v) for v in var_list]
500
501  def _capture_by_value(
502      self,
503      op_type,
504      inputs,
505      dtypes,  # pylint: disable=redefined-outer-name
506      input_types=None,
507      name=None,
508      attrs=None,
509      op_def=None,
510      compute_device=True):
511    # When capturing by value, do the read outside
512    reverse_captures = dict((id(v), k) for k, v in self.captures)
513    uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
514    with ops.init_scope():
515      if context.executing_eagerly():
516        attr_list = ("dtype", int(attrs["dtype"].type))
517        value, = execute.execute(
518            compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
519            context.context())
520      else:
521        op = ops.get_default_graph()._create_op_internal(  # pylint: disable=protected-access
522            op_type,
523            uncaptured_inputs,
524            dtypes,
525            input_types,
526            name,
527            attrs,
528            op_def,
529            compute_device)
530        value = op.outputs[0]
531    captured_value = self.capture(value)
532    return captured_value.op
533
534  def _create_op_internal(
535      self,
536      op_type,
537      inputs,
538      dtypes=None,  # pylint: disable=redefined-outer-name
539      input_types=None,
540      name=None,
541      attrs=None,
542      op_def=None,
543      compute_device=True):
544    """Like Graph.create_op, except handles external input tensors.
545
546    This overload adds functionality to create_op to "capture" any external
547    input tensors, i.e. tensors from the eager context or outer function graphs
548    if this is a nested function. See `capture` for more information.
549
550    Args:
551      op_type: The `Operation` type to create. This corresponds to the
552        `OpDef.name` field for the proto that defines the operation.
553      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
554      dtypes: (Optional) A list of `DType` objects that will be the types of the
555        tensors that the operation produces.
556      input_types: (Optional.) A list of `DType`s that will be the types of
557        the tensors that the operation consumes. By default, uses the base
558        `DType` of each input in `inputs`. Operations that expect
559        reference-typed inputs must specify `input_types` explicitly.
560      name: (Optional.) A string name for the operation. If not specified, a
561        name is generated based on `op_type`.
562      attrs: (Optional.) A dictionary where the key is the attribute name (a
563        string) and the value is the respective `attr` attribute of the
564        `NodeDef` proto that will represent the operation (an `AttrValue`
565        proto).
566      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
567        the operation will have.
568      compute_device: (Optional.) If True, device functions will be executed
569        to compute the device property of the Operation.
570
571    Returns:
572      An `Operation` object.
573    """
574    if self.capture_by_value and op_type in ["ReadVariableOp",
575                                             "ResourceGather"]:
576      return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
577                                    attrs, op_def, compute_device)
578
579    # This capturing logic interacts poorly with control flow contexts which
580    # want to replace inputs of ops far too late in the process. This can lead
581    # the context to get confused and try to create an Enter for an Enter. We
582    # can detect this here and skip the additional Enter which can confuse loop
583    # validation logic.
584    if op_type == "Enter" and inputs[0].op.type == "Enter":
585      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
586        return inputs[0].op
587    # Calling AddValue on the control flow contexts to force creation of the
588    # backward accumulators in the original graph before we create placeholders
589    # to capture the inputs.
590    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
591    # Use a different list to avoid modifying the original inputs list.
592    captured_inputs = []
593    for inp in inputs:
594      # TPU Estimator defines a control flow context with no AddValue method.
595      if ctxt is not None and hasattr(ctxt, "AddValue"):
596        inp = ctxt.AddValue(inp)
597      inp = self.capture(inp)
598      captured_inputs.append(inp)
599    return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
600        op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
601        compute_device)
602
603  def capture(self, tensor, name=None, shape=None):
604    """Captures `tensor` if it's external to this graph.
605
606    If `tensor` is from a different graph, returns a placeholder for it.
607    `tensor` and the placeholder will appear in self.captures, and the
608    placeholder will appear in self.inputs.  Multiple calls to this method with
609    the same `tensor` argument will return the same placeholder. If `tensor` is
610    from this graph, returns `tensor`.
611
612    Args:
613      tensor: Tensor. May be from this FuncGraph or a different graph.
614      name: Optional name if a placeholder is created.
615      shape: Optional shape if a placeholder is created.
616
617    Returns:
618      Tensor from this FuncGraph.
619
620    Raises:
621      InaccessibleTensorError: if any tensors are accessed in a manner that
622      bypasses the mechanisms required for the data dependencies to be correctly
623      wired.
624    """
625    if isinstance(tensor, ops.EagerTensor):
626      if name is None:
627        name = str(ops.uid())
628
629      # Small EagerTensors are captured with Const ops
630      if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
631          np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
632        return self.capture_eager_tensor(tensor, name)
633
634      # Large EagerTensors and resources are captured with Placeholder ops
635      return self._capture_helper(tensor, name, shape)
636    if tensor.graph is not self:
637      if name is None:
638        name = tensor.op.name
639      inner_graph = tensor.graph
640      while inner_graph is not None and isinstance(inner_graph, FuncGraph):
641        if inner_graph is self:
642          raise errors.InaccessibleTensorError(
643              "The tensor '%s' cannot be accessed here: it is defined"
644              " in another function or code block. Use return values,"
645              " explicit Python locals or TensorFlow collections to access"
646              " it. Defined in: %s; accessed from: %s.\n"
647              % (tensor, tensor.graph, self))
648        inner_graph = inner_graph.outer_graph
649      return self._capture_helper(tensor, name)
650    return tensor
651
652  def _capture_helper(self, tensor, name, shape=None):
653    capture = self._captures.get(id(tensor))
654    if capture is None:
655      placeholder = _create_substitute_placeholder(
656          tensor, name=name, dtype=tensor.dtype, shape=shape)
657      # Record the composite device as an attribute to the placeholder.
658      # This attribute would be propogated into the arg_attr of the FunctionDef.
659      # Currently, a packed eager tensor is always placed on a CompositeDevice.
660      if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
661        placeholder.op._set_attr(  # pylint: disable=protected-access
662            "_composite_device",
663            attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
664      self.add_capture(tensor, placeholder)
665    else:
666      placeholder = capture[1]
667    tape.record_operation("captured_value", [placeholder], [tensor],
668                          backward_function=lambda x: [x],
669                          forward_function=lambda x: [x])
670    return placeholder
671
672  @property
673  def captures(self):
674    """Order list of tuples containing external and internal captures."""
675    return self._captures.values()
676
677  def add_capture(self, tensor, placeholder):
678    """Capture a specific tensor and utilize the provided placeholder.
679
680    Args:
681      tensor: Tensor to captures.
682      placeholder: Provided placeholder for the tensor.
683    """
684    self._captures[id(tensor)] = (tensor, placeholder)
685    self.inputs.append(placeholder)
686
687  def replace_capture(self, tensor, placeholder):
688    """Replace already existing capture."""
689    self._captures[id(tensor)] = (tensor, placeholder)
690
691  def reset_captures(self, capture_list):
692    """Set the captures with the provided list of captures & placeholder."""
693    self._captures = py_collections.OrderedDict()
694    for tensor, placeholder in capture_list:
695      self._captures[id(tensor)] = (tensor, placeholder)
696
697  def pop_capture(self, tensor):
698    """Remove the capture and return the generated placeholder."""
699    capture = self._captures.pop(id(tensor), None)
700    if capture is None:
701      return None
702
703    return capture[1]
704
705  def clear_captures(self):
706    # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
707    # Clearing captures using clear() leaves some cycles around.
708    while self._captures:
709      self._captures.popitem()
710    memory.dismantle_ordered_dict(self._captures)
711    while self._deferred_captures:
712      self._deferred_captures.popitem()
713    memory.dismantle_ordered_dict(self._deferred_captures)
714
715  def capture_distributed_variable(self, variable, placeholder):
716    """Add given distributed variable to captures with given placeholder."""
717    self._captures[id(variable)] = (variable, placeholder)
718    tape.record_operation("captured_value", [placeholder], [variable],
719                          backward_function=lambda x: [x],
720                          forward_function=lambda x: [x])
721
722  def capture_eager_tensor(self, tensor, name):
723    capture = self._captures.get(id(tensor))
724    if capture is None:
725      # We clear all control dependencies and place the Const op on the same
726      # device as the source tensor. The device placement may be relaxed at
727      # a later date.
728      with ops.control_dependencies(None), self.device(tensor.device):
729        constant_value = tensor_util.constant_value(tensor)
730        if constant_value is None:
731          # Some eager tensors, e.g. parallel tensors, are not convertible to a
732          # single constant. We'll use a placeholder for this case.
733          return self._capture_helper(tensor, name)
734        graph_const = constant_op.constant(constant_value, dtype=tensor.dtype,
735                                           shape=tensor.shape, name=name)
736      self.add_capture(tensor, graph_const)
737    else:
738      graph_const = capture[1]
739    tape.record_operation("captured_value", [graph_const], [tensor],
740                          backward_function=lambda x: [x],
741                          forward_function=lambda x: [x])
742    return graph_const
743
744  def captured(self, tensor):
745    """Check if the specified tensor has been captured."""
746    return id(tensor) in self._captures
747
748  @property
749  def external_captures(self):
750    """External tensors captured by this function."""
751    return [c[0] for c in self._captures.values()]
752
753  @property
754  def internal_captures(self):
755    """Placeholders in this function corresponding captured tensors."""
756    return [c[1] for c in self._captures.values()]
757
758  @property
759  def deferred_external_captures(self):
760    """Ordered nest of tensors whose placeholders will be fed at call time."""
761    return [c[0] for c in self._deferred_captures.values()]
762
763  @property
764  def deferred_internal_captures(self):
765    """List of nest of placeholders which at call time will be fed."""
766    return [c[1] for c in self._deferred_captures.values()]
767
768  @property
769  def variable_captures(self):
770    """Map of python object ids of variables to variables which are captured."""
771    return {
772        id(self._captures[id(v)][1]): v
773        for v in self.variables
774        if id(v) in self._captures
775    }
776
777  def mark_as_unsaveable(self, error_message):
778    """Marks this FuncGraph as unsaveable.
779
780    Any attempts to export this FuncGraph will raise an error with the specified
781    message.
782
783    Args:
784      error_message: List or string containing the error message to be raised
785        when saving this FuncGraph to SavedModel.
786    """
787    self._saveable = False
788    if isinstance(error_message, str):
789      error_message = [error_message]
790    self._saving_errors.update(error_message)
791
792  @property
793  def saveable(self):
794    """Returns whether this FuncGraph is saveable."""
795    return self._saveable
796
797  @property
798  def saving_errors(self):
799    """Returns set of errors preventing this FuncGraph from being saved."""
800    return self._saving_errors
801
802  def _add_scope_exit_callback(self, fn):
803    """Add a function to call when this graph exits the default scope."""
804    if not callable(fn):
805      raise TypeError("fn is not callable: {}".format(fn))
806    if self._scope_exit_callbacks is None:
807      raise RuntimeError(
808          "Attempting to add a scope exit callback, but the default graph is "
809          "not the context scope graph.  Did you forget to call "
810          "'with graph.as_default(): ...'?")
811    self._scope_exit_callbacks.append(fn)
812
813
814# TODO(mdan): Too many threaded arguments. Accept an ACD ctx manager instead.
815def func_graph_from_py_func(name,
816                            python_func,
817                            args,
818                            kwargs,
819                            signature=None,
820                            func_graph=None,
821                            autograph=False,
822                            autograph_options=None,
823                            add_control_dependencies=True,
824                            arg_names=None,
825                            op_return_value=None,
826                            collections=None,
827                            capture_by_value=None,
828                            override_flat_arg_shapes=None,
829                            acd_record_initial_resource_uses=False):
830  """Returns a `FuncGraph` generated from `python_func`.
831
832  Args:
833    name: an identifier for the function.
834    python_func: the Python function to trace.
835    args: the positional args with which the Python function should be called;
836      ignored if a signature is provided.
837    kwargs: the keyword args with which the Python function should be called;
838      ignored if a signature is provided.
839    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
840      and dtypes of the arguments. When a signature is provided, `args` and
841      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
842      to `signature`. If `None`, the shapes and dtypes are inferred from the
843      inputs.
844    func_graph: Optional. An instance of FuncGraph. If provided, we will use
845      this graph else a new one is built and returned.
846    autograph: whether to use autograph to compile `python_func`.
847      See https://www.tensorflow.org/guide/autograph for more information.
848    autograph_options: additional knobs to control when `autograph=True`.
849      See https://www.tensorflow.org/guide/autograph for more information.
850    add_control_dependencies: If True, automatically adds control dependencies
851      to ensure program order matches execution order and stateful ops always
852      execute.
853    arg_names: Optional list of argument names, used to give input placeholders
854      recognizable names.
855    op_return_value: Optional. A Tensor. If set and `python_func` returns
856      Operations, those return values will be replaced with this value. If not
857      set, returning an Operation triggers an error.
858    collections: a dictionary of collections this FuncGraph should start
859      with. If not specified (None), the FuncGraph will read (but not write to)
860      the outer graph's collections that are not allowlisted, and both
861      read and write to the outer graph's collections that are allowlisted.
862      The current allowlisted collections are the global variables, the
863      local variables, and the trainable variables.
864      Defaults to None.
865    capture_by_value: An optional boolean. If True, the func graph will capture
866      Variables by value instead of reference. By default inherit from outer
867      graphs, and failing that will default to False.
868    override_flat_arg_shapes: An optional list of instances that are either
869      `None` or `TensorShape`.  The length must match that of
870      `nest.flatten((args, kwargs), expand_composites=True)`.  The entries
871      containing value `None` must match entries in flattened arguments
872      containing non-tensors, while entries containing a `TensorShape` must
873      match entries in the flattened arguments containing tensors.
874    acd_record_initial_resource_uses: If `True` and `add_control_dependencies`
875      is enabled, the results (those marked with
876      AutomaticControlDependencies.mark_result) will be annotated with a private
877      attribute, "_res_first_used_by", which points to the first nodes which
878      used the any of the resources that the result op is using.
879
880  Returns:
881    A FuncGraph.
882
883  Raises:
884    TypeError: If any of `python_func`'s return values is neither `None` nor a
885      `Tensor`.
886    ValueError: If both `signature` and `override_flat_arg_shapes` are
887      passed in.
888  """
889  if op_return_value is not None:
890    assert isinstance(op_return_value, ops.Tensor), op_return_value
891  if func_graph is None:
892    func_graph = FuncGraph(name, collections=collections,
893                           capture_by_value=capture_by_value)
894  assert isinstance(func_graph, FuncGraph)
895  if add_control_dependencies:
896    deps_control_manager = auto_control_deps.AutomaticControlDependencies(
897        record_initial_resource_uses=acd_record_initial_resource_uses)
898  else:
899    deps_control_manager = ops.NullContextmanager()
900
901  with func_graph.as_default(), deps_control_manager as deps_ctx:
902    current_scope = variable_scope.get_variable_scope()
903    default_use_resource = current_scope.use_resource
904    current_scope.set_use_resource(True)
905
906    if signature is not None and override_flat_arg_shapes is not None:
907      raise ValueError(
908          "Passed both signature and override_flat_arg_shapes: %s and %s."
909          % (signature, override_flat_arg_shapes))
910
911    if signature is not None:
912      args = signature
913      kwargs = {}
914
915    # Creates and names placeholders for all arguments.
916    if override_flat_arg_shapes is not None:
917      flat_args = nest.flatten(args, expand_composites=True)
918      arg_shapes = override_flat_arg_shapes[:len(flat_args)]
919      kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
920    else:
921      arg_shapes = None
922      kwarg_shapes = None
923    func_args = _get_defun_inputs_from_args(
924        args, arg_names, flat_shapes=arg_shapes)
925    func_kwargs = _get_defun_inputs_from_kwargs(
926        kwargs, flat_shapes=kwarg_shapes)
927
928    # Convert all Tensors into TensorSpecs before saving the structured inputs.
929    # If storing pure concrete functions that are not called through polymorphic
930    # functions, we don't have access to FunctionSpec, so we need to call the
931    # TensorSpecs by their `arg_names` for later binding.
932    func_graph.structured_input_signature = (
933        convert_structure_to_signature(func_args, arg_names),
934        convert_structure_to_signature(func_kwargs))
935
936    flat_func_args = nest.flatten(func_args, expand_composites=True)
937    flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
938    # Temporarily set inputs to allow graph building code to inspect
939    # them. Reassigned below.
940    func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
941                         if isinstance(arg, ops.Tensor)]
942
943    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
944    # Variables to help check whether mutation happens in calling the function
945    # Copy the recursive list, tuple and map structure, but not base objects
946    func_args_before = nest.pack_sequence_as(func_args, flat_func_args,
947                                             expand_composites=True)
948    func_kwargs_before = nest.pack_sequence_as(
949        func_kwargs, flat_func_kwargs, expand_composites=True)
950
951    def convert(x):
952      """Converts a function output to a Tensor."""
953      if x is None:
954        return None
955      if op_return_value is not None and isinstance(x, ops.Operation):
956        # TODO(b/79881896): we currently can't capture external control deps, so
957        # this won't work if x needs to be captured (i.e. if python_func returns
958        # captured Operations).
959        with ops.control_dependencies([x]):
960          x = array_ops.identity(op_return_value)
961      elif not isinstance(x, tensor_array_ops.TensorArray):
962        try:
963          x = ops.convert_to_tensor_or_composite(x)
964        except (ValueError, TypeError):
965          raise TypeError(
966              "To be compatible with tf.eager.defun, Python functions "
967              "must return zero or more Tensors; in compilation of %s, found "
968              "return value of type %s, which is not a Tensor." %
969              (str(python_func), type(x)))
970      if add_control_dependencies:
971        x = deps_ctx.mark_as_return(x)
972      return x
973
974    try:
975      if autograph:
976        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
977        _, original_func = tf_decorator.unwrap(python_func)
978
979        def autograph_handler(*args, **kwargs):
980          """Calls a converted version of original_func."""
981          # TODO(mdan): Push this block higher in tf.function's call stack.
982          try:
983            return autograph.converted_call(
984                original_func,
985                args,
986                kwargs,
987                options=autograph.ConversionOptions(
988                    recursive=True,
989                    optional_features=autograph_options,
990                    user_requested=True,
991                ))
992          except Exception as e:  # pylint:disable=broad-except
993            if hasattr(e, "ag_error_metadata"):
994              raise e.ag_error_metadata.to_exception(e)
995            else:
996              raise
997
998        # Wrapping around a decorator allows checks like tf_inspect.getargspec
999        # to be accurate.
1000        converted_func = tf_decorator.make_decorator(
1001            original_func, autograph_handler)
1002        python_func = tf_decorator.rewrap(python_func, original_func,
1003                                          converted_func)
1004
1005      else:
1006        _, original_func = tf_decorator.unwrap(python_func)
1007
1008      func_outputs = python_func(*func_args, **func_kwargs)
1009
1010      # invariant: `func_outputs` contains only Tensors, CompositeTensors,
1011      # TensorArrays and `None`s.
1012      func_outputs = nest.map_structure(convert, func_outputs,
1013                                        expand_composites=True)
1014
1015      check_mutation(func_args_before, func_args, original_func)
1016      check_mutation(func_kwargs_before, func_kwargs, original_func)
1017    finally:
1018      current_scope.set_use_resource(default_use_resource)
1019
1020    # Variables in `func_args`, `func_kwargs` should be explicit inputs
1021    # to the function, not captured inputs.
1022    graph_variables = list(func_graph._watched_variables)  # pylint: disable=protected-access
1023    arg_variables = object_identity.ObjectIdentitySet()
1024    inputs = []
1025    for arg in (nest.flatten(func_args, expand_composites=True) +
1026                nest.flatten(func_kwargs, expand_composites=True)):
1027      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1028        # Even if an argument variable was not used in the function, we've
1029        # already manually captured the resource Tensor when creating argument
1030        # placeholders.
1031        resource_placeholder = func_graph.pop_capture(arg.handle)
1032        if resource_placeholder is None:
1033          continue
1034        arg_variables.add(arg)
1035        inputs.append(resource_placeholder)
1036      elif isinstance(arg, ops.Tensor):
1037        inputs.append(arg)
1038    variables = [v for v in graph_variables if v not in arg_variables]
1039    func_graph.inputs = (
1040        inputs + func_graph.internal_captures + nest.flatten(
1041            func_graph.deferred_internal_captures, expand_composites=True))
1042    func_graph.structured_outputs = func_outputs
1043    # Returning a closed-over tensor does not trigger convert_to_tensor.
1044    func_graph.outputs.extend(
1045        func_graph.capture(x)
1046        for x in flatten(func_graph.structured_outputs)
1047        if x is not None)
1048
1049    func_graph.variables = variables
1050
1051  if add_control_dependencies:
1052    func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
1053    func_graph.collective_manager_ids_used = (
1054        deps_control_manager.collective_manager_ids_used)
1055
1056  return func_graph
1057
1058
1059def maybe_captured(tensor):
1060  """If t is a captured value placeholder, returns the original captured value.
1061
1062  Args:
1063    tensor: Tensor.
1064
1065  Returns:
1066    A tensor, potentially from a different Graph/FuncGraph.
1067  """
1068  if (not isinstance(tensor, ops.EagerTensor) and
1069      tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
1070    for input_t, placeholder_t in tensor.op.graph.captures:
1071      if tensor == placeholder_t:
1072        return maybe_captured(input_t)
1073  # pylint: enable=protected-access
1074  return tensor
1075
1076
1077def device_stack_has_callable(device_stack):
1078  """Checks whether a device stack contains a callable."""
1079  return any(callable(spec._device_name_or_function)  # pylint: disable=protected-access
1080             for spec in device_stack.peek_objs())
1081
1082
1083def check_mutation(n1, n2, func):
1084  """Check if two list of arguments are exactly the same."""
1085  func_name = getattr(func, "__name__", func)
1086
1087  errmsg = ("{}() should not modify its Python input arguments."
1088            " Check if it modifies any lists or dicts passed as"
1089            " arguments. Modifying a copy is allowed.".format(func_name))
1090  try:
1091    # TODO(mdan): Compare more robustly so that argument names can be reported.
1092    nest.assert_same_structure(n1, n2, expand_composites=True)
1093  except ValueError:
1094    raise ValueError(errmsg)
1095
1096  for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
1097                        nest.flatten(n2, expand_composites=True)):
1098    if arg1 is not arg2:
1099      raise ValueError(errmsg)
1100
1101
1102# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1103def flatten(sequence):
1104  """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
1105
1106  Args:
1107    sequence: A nested structure of Tensors, CompositeTensors, and
1108      TensorArrays.
1109
1110  Returns:
1111    A list of tensors.
1112  """
1113  flat_sequence = nest.flatten(sequence, expand_composites=True)
1114  return [
1115      item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
1116      for item in flat_sequence]
1117
1118
1119# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1120def pack_sequence_as(structure, flat_sequence):
1121  """Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
1122
1123  Args:
1124    structure: The structure to pack into. May contain Tensors,
1125      CompositeTensors, or TensorArrays.
1126    flat_sequence: An iterable containing tensors.
1127
1128  Returns:
1129    A nested structure.
1130
1131  Raises:
1132    AssertionError if `structure` and `flat_sequence` are not compatible.
1133  """
1134  flat_sequence = list(flat_sequence)
1135  flattened_structure = nest.flatten(structure, expand_composites=True)
1136  if len(flattened_structure) != len(flat_sequence):
1137    raise ValueError("Mismatch in element count")
1138  for i in range(len(flat_sequence)):
1139    if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
1140      flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
1141          old_ta=flattened_structure[i], flow=flat_sequence[i])
1142  return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
1143
1144
1145def _create_substitute_placeholder(value, name=None, dtype=None, shape=None):
1146  """Creates a placeholder for `value` and propagates shape info to it."""
1147  # Note: setting ops.control_dependencies(None) ensures we always put
1148  # capturing placeholders outside of any control flow context.
1149  if shape is None:
1150    shape = value.shape
1151  with ops.control_dependencies(None):
1152    placeholder = graph_placeholder(
1153        dtype=dtype or value.dtype, shape=shape, name=name)
1154  handle_data_util.copy_handle_data(value, placeholder)
1155  return placeholder
1156
1157
1158def _get_defun_inputs_from_args(args, names, flat_shapes=None):
1159  """Maps Python function positional args to graph-construction inputs."""
1160  return _get_defun_inputs(
1161      args, names, structure=args, flat_shapes=flat_shapes)
1162
1163
1164def _get_composite_tensor_spec(x):
1165  """Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
1166  return (x._type_spec  # pylint: disable=protected-access
1167          if isinstance(x, composite_tensor.CompositeTensor) else x)
1168
1169
1170def _get_defun_inputs(args, names, structure, flat_shapes=None):
1171  """Maps python function args to graph-construction inputs.
1172
1173  Args:
1174    args: A flat list of user-specified arguments.
1175    names: A list of strings with user-specified argument names, same length as
1176      `args`. May be `None`, in which case a generic name is used.
1177    structure: The original argument list or dictionary.
1178    flat_shapes: A flat list of values that are either `None` or
1179      instances of `TensorShape`.  If provided, then length must match
1180      that of `nest.flatten(args, expand_composites=True)`; and locations where
1181      `args` are instances of `Tensor` must have a corresponding `TensorShape`
1182      in `flat_shapes`.  May be `None`, in which case exact shapes are read
1183      directly from the args.
1184
1185  Returns:
1186    Placeholders with the same structure as `structure`.
1187
1188  Raises:
1189    RuntimeError: if `flat_shapes` is provided, but
1190     `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`.
1191    RuntimeError: if a shape from `flat_shapes` is not None
1192     for an argument that is not a `Tensor`, `TensorSpec`,
1193     or `ResourceVariable`.
1194  """
1195  func_graph = ops.get_default_graph()
1196  function_inputs = []
1197  if names is None:
1198    names = [None] * len(args)
1199  if flat_shapes is None:
1200    shapes_iter = itertools.repeat(None)
1201  else:
1202    len_flat_args = len(nest.flatten(args, expand_composites=True))
1203    if len_flat_args != len(flat_shapes):
1204      raise RuntimeError(
1205          "Length of fully flat shapes (%d) must match that of "
1206          "flatten(args) (%d).  args: %s, flat_shapes: %s"
1207          % (len(flat_shapes),
1208             len_flat_args,
1209             args,
1210             flat_shapes))
1211    shapes_iter = iter(flat_shapes)
1212  for arg_value, name in zip(args, names):
1213
1214    # Replace any composite tensors with their TypeSpecs.  This is important
1215    # for ensuring that shape information that's not preserved by the TypeSpec
1216    # (such as the number of values in a SparseTensor) gets properly masked.
1217    arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
1218
1219    flattened = nest.flatten(arg_value, expand_composites=True)
1220
1221    for arg in flattened:
1222      # We have a shape entry for each arg, regardless of whether it's a real
1223      # Tensor or not.  For non-tensor entries it should be None.
1224      shape = next(shapes_iter)
1225      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
1226        arg_is_spec = isinstance(arg, tensor_spec.TensorSpec)
1227        if arg_is_spec and arg.name:
1228          requested_name = arg.name
1229        else:
1230          requested_name = name
1231        placeholder_shape = shape if shape is not None else arg.shape
1232        try:
1233          placeholder = graph_placeholder(
1234              arg.dtype, placeholder_shape,
1235              name=requested_name)
1236        except ValueError:
1237          # Sometimes parameter names are not valid op names, so fall back to
1238          # unnamed placeholders.
1239          placeholder = graph_placeholder(arg.dtype, placeholder_shape)
1240        if not arg_is_spec:
1241          handle_data_util.copy_handle_data(arg, placeholder)
1242        if name is not None:
1243          # Record the requested/user-specified name in case it's different than
1244          # the uniquified name, for validation when exporting signatures.
1245          placeholder.op._set_attr(  # pylint: disable=protected-access
1246              "_user_specified_name",
1247              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
1248        function_inputs.append(placeholder)
1249      elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
1250                            resource_variable_ops.VariableSpec)):
1251        if isinstance(arg, resource_variable_ops.VariableSpec):
1252          name = arg.name or name
1253          with func_graph.outer_graph.as_default():
1254            placeholder = graph_placeholder(dtypes.resource, arg.shape,
1255                                            name=name)
1256
1257            arg = resource_variable_ops.BaseResourceVariable(
1258                name=name,
1259                shape=arg.shape,
1260                dtype=arg.dtype,
1261                handle=placeholder,
1262                handle_name=name,
1263                trainable=arg.trainable)
1264        # Capture arg variables to create placeholders for them. These will be
1265        # removed as captures after the function is traced (since otherwise we'd
1266        # just add it back with a new placeholder when the variable was
1267        # referenced).
1268        placeholder = func_graph.capture(arg.handle, name=name)
1269        placeholder.op._set_attr(  # pylint: disable=protected-access
1270            "_user_specified_name",
1271            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
1272        function_inputs.append(arg)
1273      else:
1274        if shape is not None:
1275          raise RuntimeError(
1276              "Expected provided shape override to be None for arg that isn't "
1277              "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s"
1278              % (arg, shape, args))
1279        function_inputs.append(arg)
1280  return nest.pack_sequence_as(structure, function_inputs,
1281                               expand_composites=True)
1282
1283
1284def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
1285  """Maps Python function keyword args to graph-construction inputs."""
1286  if kwargs:
1287    names, args = zip(*sorted(kwargs.items()))
1288  else:
1289    names = []
1290    args = []
1291  return _get_defun_inputs(
1292      args, names, structure=kwargs, flat_shapes=flat_shapes)
1293
1294
1295def dismantle_func_graph(func_graph):
1296  """Removes reference cycles in `func_graph` FuncGraph.
1297
1298  Helpful for making sure the garbage collector doesn't need to run when
1299  the FuncGraph goes out of scope, e.g. in tests using defun with
1300  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
1301
1302  Args:
1303    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
1304      after this function.
1305  """
1306  func_graph.clear_captures()
1307  ops.dismantle_graph(func_graph)
1308
1309
1310def override_func_graph_name_scope(func_graph, name_scope):
1311  func_graph._name_stack = name_scope  # pylint: disable=protected-access
1312