• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as py_collections
22import itertools
23import weakref
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.eager import execute
28from tensorflow.python.eager import tape
29from tensorflow.python.eager.graph_only_ops import graph_placeholder
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework.auto_control_deps import AutomaticControlDependencies
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import custom_gradient
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.util import compat
40from tensorflow.python.util import memory
41from tensorflow.python.util import nest
42from tensorflow.python.util import tf_contextlib
43from tensorflow.python.util import tf_decorator
44from tensorflow.python.util.lazy_loader import LazyLoader
45
46# This is to avoid a circular dependency:
47# function -> func_graph
48function = LazyLoader("function", globals(),
49                      "tensorflow.python.eager.function")
50def_function = LazyLoader(
51    "def_function", globals(),
52    "tensorflow.python.eager.def_function")
53
54WHITELIST_COLLECTIONS = [
55    ops.GraphKeys.GLOBAL_VARIABLES,
56    ops.GraphKeys.LOCAL_VARIABLES,
57    ops.GraphKeys.TRAINABLE_VARIABLES,
58    variable_scope._VARSTORE_KEY,  # pylint: disable=protected-access
59    variable_scope._VARSCOPESTORE_KEY  # pylint: disable=protected-access
60]
61
62
63class UnknownArgument(object):
64  """Signifies an argument which is not currently handled."""
65  pass
66
67
68def convert_structure_to_signature(structure, arg_names=None):
69  """Convert a potentially nested structure to a signature.
70
71  Args:
72    structure: Structure to convert, where top level collection is a list or a
73      tuple.
74    arg_names: Optional list of arguments that has equal number of elements as
75      `structure` and is used for naming corresponding TensorSpecs.
76
77  Returns:
78    Identical structure that has TensorSpec objects instead of Tensors and
79    UknownArgument instead of any unsupported types.
80  """
81  def encode_arg(arg, path):
82    """A representation for this argument, for converting into signatures."""
83    if isinstance(arg, ops.Tensor):
84      user_specified_name = None
85      try:
86        user_specified_name = compat.as_str(
87            arg.op.get_attr("_user_specified_name"))
88      except ValueError:
89        pass
90
91      if path and user_specified_name and user_specified_name != path[0]:
92        # The user has explicitly named the argument differently than the name
93        # of the function argument.
94        name = user_specified_name
95      else:
96        name = "/".join([str(p) for p in path])
97      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
98    if isinstance(arg, (
99        int,
100        float,
101        bool,
102        type(None),
103        dtypes.DType,
104        tensor_spec.TensorSpec,
105    )):
106      return arg
107    return UnknownArgument()
108
109  # We are using the flattened paths to name the TensorSpecs. We need an
110  # explicit name for them downstream.
111  flattened = nest.flatten_with_tuple_paths(structure)
112  if arg_names:
113    if len(arg_names) != len(structure):
114      raise ValueError(
115          "Passed in arg_names don't match actual signature (%s)." % arg_names)
116    # Replace all top-level names with their actual arg_names. If a path before
117    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
118    flattened = [
119        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
120    ]
121
122  mapped = [encode_arg(arg, path) for path, arg in flattened]
123  return nest.pack_sequence_as(structure, mapped)
124
125
126class FuncGraph(ops.Graph):
127  """Graph representing a function body.
128
129  Attributes:
130    name: The name of the function.
131    inputs: Placeholder tensors representing the inputs to this function. The
132      tensors are in this FuncGraph. This represents "regular" inputs as well as
133      captured inputs (i.e. the values of self.captures), with the regular
134      inputs coming first.
135    outputs: Tensors that will be returned by this function. The tensors are in
136      this FuncGraph.
137    control_outputs: Operations that must be executed before the function
138      represented by this graph can be said to have been executed.
139    structured_input_signature: A tuple of (args, kwargs), which are both
140      possibly-nested python objects that were received by this function. Note
141      that these structures might contain Python `None`s.
142    structured_outputs: A possibly-nested python object which will be returned
143      by this function. The Tensors in this structure are the same as those of
144      self.outputs. Note that this structure might contain Python `None`s.
145    variables: Variables that should be watched during function execution.
146    outer_graph: The graph this function is defined in. May be another FuncGraph
147      or the global default Graph.
148    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
149      The entries are in the order they were captured.
150    control_captures: Set of external ops on which this graph has a control
151      dependency.
152    seed: The graph-level random seed.
153    capture_by_value: If True, the func graph will capture Variables by value
154      instead of reference.
155  """
156
157  def __init__(self, name, collections=None, capture_by_value=None):
158    """Construct a new FuncGraph.
159
160    The graph will inherit its graph key, collections, seed, and distribution
161    strategy stack from the current context or graph.
162
163    Args:
164      name: the name of the function.
165      collections: a dictionary of collections this FuncGraph should start
166        with. If not specified (None), the FuncGraph will read (but not write
167        to) the outer graph's collections that are not whitelisted, and both
168        read and write to the outer graph's collections that are whitelisted.
169        The current whitelisted collections are the global variables, the
170        local variables, and the trainable variables.
171        Defaults to None.
172      capture_by_value: An optional boolean. If True, the func graph will
173        capture Variables by value instead of reference. By default inherit
174        from outer graphs, and failing that will default to False.
175    """
176    super(FuncGraph, self).__init__()
177
178    self.name = name
179    self.inputs = []
180    self.outputs = []
181    self.control_outputs = []
182    self.control_captures = set()
183    self.structured_input_signature = None
184    self.structured_outputs = None
185    self._weak_variables = []
186    self.outer_graph = ops.get_default_graph()
187    self.captures = py_collections.OrderedDict()
188    # Inherit capture-by-value from outer graph.
189    if capture_by_value is not None:
190      self.capture_by_value = capture_by_value
191    elif self.outer_graph is not None and isinstance(
192        self.outer_graph, FuncGraph):
193      self.capture_by_value = self.outer_graph.capture_by_value
194    else:
195      self.capture_by_value = False
196
197    self._building_function = True
198    # Map from resource tensor name to last op (in program order) which uses
199    # this tensor. Used to enforce that execution order matches program order
200    # for resource tensors.
201    self._last_op_using_resource_tensor = {}
202
203    graph = self.outer_graph
204
205    if context.executing_eagerly():
206      self.seed = context.global_seed()
207      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
208      # any None op_seed for random_op in the function, in which case we end up
209      # using function seed, which could be unintended behavior for the op.
210      self._seed_used = False
211    else:
212      self.seed = graph.seed
213      self._seed_used = False
214      # TODO(allenl): Figure out if we can remove colocation stack
215      # specialization (currently used in cond_v2), here and in the cache key.
216      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access
217
218    if collections is None:
219      for collection_name in graph.get_all_collection_keys():
220        if collection_name not in WHITELIST_COLLECTIONS:
221          self._collections[collection_name] = graph.get_collection(
222              collection_name)
223      for collection_name in WHITELIST_COLLECTIONS:
224        self._collections[collection_name] = graph.get_collection_ref(
225            collection_name)
226    else:
227      self._collections = collections
228
229  def __str__(self):
230    return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
231
232  def control_dependencies(self, control_inputs):
233    """Handles control dependencies.
234
235    FuncGraph wraps Graph's control_dependencies logic by first filtering out
236    any external tensors / operations and storing them in the graph's
237    control_captures member. Any consumers of this function graph must then
238    decide how to handle the control captures.
239
240    Args:
241      control_inputs: A list of `Operation` or `Tensor` objects which
242        must be executed or computed before running the operations
243        defined in the context.  Can also be `None` to clear the control
244        dependencies.
245
246    Returns:
247     A context manager that specifies control dependencies for all
248     operations constructed within the context.
249
250    Raises:
251      TypeError: If `control_inputs` is not a list of `Operation` or
252        `Tensor` objects.
253    """
254    if control_inputs is None:
255      return super(FuncGraph, self).control_dependencies(control_inputs)
256
257    filtered_control_inputs = []
258    for c in control_inputs:
259      # Check for _UnreadVariable
260      if (isinstance(c, ops.IndexedSlices) or
261          (hasattr(c, "_handle") and hasattr(c, "op"))):
262        c = c.op
263      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
264      if graph_element is None:
265        graph_element = c
266      if graph_element is not None and getattr(
267          graph_element, "graph", None) is not self:
268        self.control_captures.add(graph_element)
269      else:
270        filtered_control_inputs.append(graph_element)
271    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
272
273  def as_default(self):
274    outer_cm = super(FuncGraph, self).as_default()
275
276    @tf_contextlib.contextmanager
277    def inner_cm():
278      """Context manager for copying distribute.Strategy scope information."""
279      graph = ops.get_default_graph()
280      # pylint: disable=protected-access
281      # TODO(b/112906995, nareshmodi): distribution strategy depends on
282      # inheriting this stack from the default graph even in eager mode. Maybe
283      # it should be part of the eager context? This would also allow us to
284      # remove a get_default_graph() call from the function cache lookup.
285      old_strategy_stack = self._distribution_strategy_stack
286      self._distribution_strategy_stack = list(
287          graph._distribution_strategy_stack)
288      # We ignore device placements from any outer scopes while tracing the
289      # function when possible, to avoid hard-coding them in the function
290      # graph. "Default" placements come from the PartitionedCallOp's placement,
291      # so that the same trace of the Python function may be placed on several
292      # different devices and saved functions may be placed on new devices when
293      # restored.
294      old_device_stack = self._device_function_stack
295      if context.executing_eagerly():
296        if self._distribution_strategy_stack:
297          self._add_device_to_stack(context.context().device_name)
298      else:
299        if (self._distribution_strategy_stack
300            or device_stack_has_callable(graph._device_function_stack)):
301          # Hard-code devices from device functions in the function body
302          self._device_function_stack = graph._device_function_stack.copy()
303
304      old_creator_stack = self._variable_creator_stack
305      self._variable_creator_stack = graph._variable_creator_stack
306      # Inherit the graph key, since this is used for matching variables in
307      # optimizers.
308      old_graph_key = self._graph_key
309      self._graph_key = graph._graph_key
310      # pylint: enable=protected-access
311
312      with outer_cm as g:
313        try:
314          yield g
315        finally:
316          self._distribution_strategy_stack = old_strategy_stack
317          self._device_function_stack = old_device_stack
318          self._variable_creator_stack = old_creator_stack
319          self._graph_key = old_graph_key
320    return inner_cm()
321
322  @property
323  def output_types(self):
324    return [t.dtype for t in self.outputs]
325
326  @property
327  def output_shapes(self):
328    return [t.shape for t in self.outputs]
329
330  @property
331  def variables(self):
332    """A list of variables accessed by this FuncGraph.
333
334    Note that functions keep only weak references to variables. Calling the
335    function after a variable it accesses has been deleted is an error.
336
337    Yields:
338      Strong references to variables accessed by this FuncGraph.
339    """
340    for weak_v in self._weak_variables:
341      v = weak_v()
342      if v is None:
343        raise AssertionError(
344            "Called a function referencing variables which have been deleted. "
345            "This likely means that function-local variables were created and "
346            "not referenced elsewhere in the program. This is generally a "
347            "mistake; consider storing variables in an object attribute on "
348            "first call.")
349      yield v
350
351  @variables.setter
352  def variables(self, var_list):
353    self._weak_variables = [weakref.ref(v) for v in var_list]
354
355  def _capture_by_value(
356      self,
357      op_type,
358      inputs,
359      dtypes,  # pylint: disable=redefined-outer-name
360      input_types=None,
361      name=None,
362      attrs=None,
363      op_def=None,
364      compute_shapes=True,
365      compute_device=True):
366    # When capturing by value, do the read outside
367    reverse_captures = dict((v, k) for k, v in self.captures.items())
368    uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
369    with ops.init_scope():
370      if context.executing_eagerly():
371        attr_list = ("dtype", int(attrs["dtype"].type))
372        value, = execute.execute(
373            compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
374            context.context())
375      else:
376        op = ops.get_default_graph().create_op(
377            op_type, uncaptured_inputs, dtypes, input_types, name, attrs,
378            op_def, compute_shapes, compute_device)
379        value = op.outputs[0]
380    captured_value = self.capture(value)
381    return captured_value.op
382
383  def create_op(
384      self,
385      op_type,
386      inputs,
387      dtypes=None,  # pylint: disable=redefined-outer-name
388      input_types=None,
389      name=None,
390      attrs=None,
391      op_def=None,
392      compute_shapes=True,
393      compute_device=True):
394    """Like Graph.create_op, except handles external input tensors.
395
396    This overload adds functionality to create_op to "capture" any external
397    input tensors, i.e. tensors from the eager context or outer function graphs
398    if this is a nested function. See `capture` for more information.
399
400    Args:
401      op_type: The `Operation` type to create. This corresponds to the
402        `OpDef.name` field for the proto that defines the operation.
403      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
404      dtypes: (Optional) A list of `DType` objects that will be the types of the
405        tensors that the operation produces.
406      input_types: (Optional.) A list of `DType`s that will be the types of
407        the tensors that the operation consumes. By default, uses the base
408        `DType` of each input in `inputs`. Operations that expect
409        reference-typed inputs must specify `input_types` explicitly.
410      name: (Optional.) A string name for the operation. If not specified, a
411        name is generated based on `op_type`.
412      attrs: (Optional.) A dictionary where the key is the attribute name (a
413        string) and the value is the respective `attr` attribute of the
414        `NodeDef` proto that will represent the operation (an `AttrValue`
415        proto).
416      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
417        the operation will have.
418      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
419        computed).
420      compute_device: (Optional.) If True, device functions will be executed
421        to compute the device property of the Operation.
422
423    Returns:
424      An `Operation` object.
425    """
426    if self.capture_by_value and op_type in ["ReadVariableOp",
427                                             "ResourceGather"]:
428      return self._capture_by_value(
429          op_type, inputs, dtypes, input_types, name, attrs, op_def,
430          compute_shapes, compute_device)
431
432    # This capturing logic interacts poorly with control flow contexts which
433    # want to replace inputs of ops far too late in the process. This can lead
434    # the context to get confused and try to create an Enter for an Enter. We
435    # can detect this here and skip the additional Enter which can confuse loop
436    # validation logic.
437    if op_type == "Enter" and inputs[0].op.type == "Enter":
438      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
439        return inputs[0].op
440    # Calling AddValue on the control flow contexts to force creation of the
441    # backward accumulators in the original graph before we create placeholders
442    # to capture the inputs.
443    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
444    for i, inp in enumerate(inputs):
445      # TPU Estimator defines a control flow context with no AddValue method.
446      if ctxt is not None and hasattr(ctxt, "AddValue"):
447        inp = ctxt.AddValue(inp)
448      inp = self.capture(inp)
449      inputs[i] = inp
450    return super(FuncGraph, self).create_op(
451        op_type, inputs, dtypes, input_types, name, attrs, op_def,
452        compute_device=compute_device)
453
454  def capture(self, tensor, name=None):
455    """Captures `tensor` if it's external to this graph.
456
457    If `tensor` is from a different graph, returns a placeholder for it.
458    `tensor` and the placeholder will appear in self.captures, and the
459    placeholder will appear in self.inputs.  Multiple calls to this method with
460    the same `tensor` argument will return the same placeholder. If `tensor` is
461    from this graph, returns `tensor`.
462
463    Args:
464      tensor: Tensor. May be from this FuncGraph or a different graph.
465      name: Optional name if a placeholder is created.
466
467    Returns:
468      Tensor from this FuncGraph.
469    """
470    if isinstance(tensor, ops.EagerTensor):
471      if name is None:
472        name = str(ops.uid())
473      return self._capture_helper(tensor, name)
474    if tensor.graph is not self:
475      if name is None:
476        name = tensor.op.name
477      inner_graph = tensor.graph
478      while inner_graph is not None and isinstance(inner_graph, FuncGraph):
479        if inner_graph is self:
480          raise ValueError(
481              "Trying to capture a tensor from an inner function. This can be "
482              "caused by accessing a tensor defined inside a loop or "
483              "conditional body, or a subfunction, from a calling function, "
484              "without going through the proper return value mechanism. "
485              "Consider using TensorFlow mechanisms such as TensorArrays "
486              "to return tensors from inner functions or loop / conditional "
487              "bodies. Tensor: %s; tensor graph: %s; this graph: %s"
488              % (tensor, tensor.graph, self))
489        inner_graph = inner_graph.outer_graph
490      return self._capture_helper(tensor, name)
491    return tensor
492
493  def _capture_helper(self, tensor, name):
494    captured_tensor = self.captures.get(tensor, None)
495    if captured_tensor is None:
496      captured_tensor = _create_substitute_placeholder(tensor, name=name,
497                                                       dtype=tensor.dtype)
498      self.captures[tensor] = captured_tensor
499      self.inputs.append(captured_tensor)
500    tape.record_operation("captured_value", [captured_tensor], [tensor],
501                          lambda x: [x])
502    return captured_tensor
503
504  @property
505  def external_captures(self):
506    """External tensors captured by this function."""
507    return list(self.captures.keys())
508
509  @property
510  def internal_captures(self):
511    """Placeholders in this function corresponding captured tensors."""
512    return list(self.captures.values())
513
514
515def func_graph_from_py_func(name,
516                            python_func,
517                            args,
518                            kwargs,
519                            signature=None,
520                            func_graph=None,
521                            autograph=False,
522                            autograph_options=None,
523                            add_control_dependencies=True,
524                            arg_names=None,
525                            op_return_value=None,
526                            collections=None,
527                            capture_by_value=None,
528                            override_flat_arg_shapes=None):
529  """Returns a `FuncGraph` generated from `python_func`.
530
531  Args:
532    name: an identifier for the function.
533    python_func: the Python function to trace.
534    args: the positional args with which the Python function should be called;
535      ignored if a signature is provided.
536    kwargs: the keyword args with which the Python function should be called;
537      ignored if a signature is provided.
538    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
539      and dtypes of the arguments. When a signature is provided, `args` and
540      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
541      to `signature`. If `None`, the shapes and dtypes are inferred from the
542      inputs.
543    func_graph: Optional. An instance of FuncGraph. If provided, we will use
544      this graph else a new one is built and returned.
545    autograph: whether to use autograph to compile `python_func`.
546      See https://www.tensorflow.org/guide/autograph for more information.
547    autograph_options: additional knobs to control when `autograph=True`.
548      See https://www.tensorflow.org/guide/autograph for more information.
549    add_control_dependencies: If True, automatically adds control dependencies
550      to ensure program order matches execution order and stateful ops always
551      execute.
552    arg_names: Optional list of argument names, used to give input placeholders
553      recognizable names.
554    op_return_value: Optional. A Tensor. If set and `python_func` returns
555      Operations, those return values will be replaced with this value. If not
556      set, returning an Operation triggers an error.
557    collections: a dictionary of collections this FuncGraph should start
558      with. If not specified (None), the FuncGraph will read (but not write to)
559      the outer graph's collections that are not whitelisted, and both
560      read and write to the outer graph's collections that are whitelisted.
561      The current whitelisted collections are the global variables, the
562      local variables, and the trainable variables.
563      Defaults to None.
564    capture_by_value: An optional boolean. If True, the func graph will capture
565      Variables by value instead of reference. By default inherit from outer
566      graphs, and failing that will default to False.
567    override_flat_arg_shapes: An optional list of instances that are either
568      `None` or `TensorShape`.  The length must match that of
569      `nest.flatten((args, kwargs))`.  The entries containing value `None`
570      must match entries in flattened arguments containing non-tensors, while
571      entries containing a `TensorShape` must match entries in the flattened
572      arguments containing tensors.
573
574  Returns:
575    A FuncGraph.
576
577  Raises:
578    TypeError: If any of `python_func`'s return values is neither `None` nor a
579      `Tensor`.
580    ValueError: If both `signature` and `override_flat_arg_shapes` are
581      passed in.
582  """
583  if op_return_value is not None:
584    assert isinstance(op_return_value, ops.Tensor), op_return_value
585  if func_graph is None:
586    func_graph = FuncGraph(name, collections=collections,
587                           capture_by_value=capture_by_value)
588  assert isinstance(func_graph, FuncGraph)
589  if add_control_dependencies:
590    control_manager = AutomaticControlDependencies()
591  else:
592    control_manager = ops.NullContextmanager()
593  with func_graph.as_default(), control_manager as a:
594    current_scope = variable_scope.get_variable_scope()
595    default_use_recource = current_scope.use_resource
596    current_scope.set_use_resource(True)
597
598    if signature is not None and override_flat_arg_shapes is not None:
599      raise ValueError(
600          "Passed both signature and override_flat_arg_shapes: %s and %s."
601          % (signature, override_flat_arg_shapes))
602
603    if signature is not None:
604      args = signature
605      kwargs = {}
606
607    # Creates and names placeholders for all arguments.
608    if override_flat_arg_shapes is not None:
609      flat_args = nest.flatten(args)
610      arg_shapes = override_flat_arg_shapes[:len(flat_args)]
611      kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
612    else:
613      arg_shapes = None
614      kwarg_shapes = None
615    func_args = _get_defun_inputs_from_args(
616        args, arg_names, flat_shapes=arg_shapes)
617    func_kwargs = _get_defun_inputs_from_kwargs(
618        kwargs, flat_shapes=kwarg_shapes)
619
620    # Convert all Tensors into TensorSpecs before saving the structured inputs.
621    # If storing pure concrete functions that are not called through polymorphic
622    # functions, we don't have access to FunctionSpec, so we need to call the
623    # TensorSpecs by their `arg_names` for later binding.
624    func_graph.structured_input_signature = (
625        convert_structure_to_signature(func_args, arg_names),
626        convert_structure_to_signature(func_kwargs))
627
628    flat_func_args = nest.flatten(func_args)
629    flat_func_kwargs = nest.flatten(func_kwargs)
630    # Temporarily set inputs to allow graph building code to inspect
631    # them. Reassigned below.
632    func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
633                         if isinstance(arg, ops.Tensor)]
634
635    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
636    # Variables to help check whether mutation happens in calling the function
637    # Copy the recursive list, tuple and map structure, but not base objects
638    func_args_before = nest.pack_sequence_as(func_args, flat_func_args)
639    func_kwargs_before = nest.pack_sequence_as(
640        func_kwargs, flat_func_kwargs)
641
642    def convert(x):
643      """Converts a function output to a Tensor."""
644      if x is None:
645        return None
646      if op_return_value is not None and isinstance(x, ops.Operation):
647        # TODO(b/79881896): we currently can't capture external control deps, so
648        # this won't work if x needs to be captured (i.e. if python_func returns
649        # captured Operations).
650        with ops.control_dependencies([x]):
651          x = array_ops.identity(op_return_value)
652      elif not isinstance(x, tensor_array_ops.TensorArray):
653        try:
654          x = ops.convert_to_tensor_or_composite(x)
655        except (ValueError, TypeError):
656          raise TypeError(
657              "To be compatible with tf.contrib.eager.defun, Python functions "
658              "must return zero or more Tensors; in compilation of %s, found "
659              "return value of type %s, which is not a Tensor." %
660              (str(python_func), type(x)))
661      if add_control_dependencies:
662        x = a.mark_as_return(x)
663      return x
664
665    this_tape = tape.push_new_tape()
666    try:
667      if autograph:
668        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
669        _, original_func = tf_decorator.unwrap(python_func)
670
671        def wrapper(*args, **kwargs):
672          # Note: functions annotated with @tf.function should always be
673          # converted even though they would meet autograph's whitelisting
674          # criteria.
675          # If this assumption is ever broken, converted_call will need to
676          # handle the possibility of original_func still being a shim, e.g.
677          # bound to WeakrefSelf.
678          return autograph.converted_call(
679              original_func, None,
680              autograph.ConversionOptions(
681                  recursive=True,
682                  optional_features=autograph_options,
683                  force_conversion=True,
684              ), args, kwargs)
685
686        # Wrapping around a decorator allows checks like tf_inspect.getargspec
687        # to be accurate.
688        converted_func = tf_decorator.make_decorator(original_func, wrapper)
689        python_func = tf_decorator.rewrap(python_func, original_func,
690                                          converted_func)
691
692      func_outputs = python_func(*func_args, **func_kwargs)
693
694      # invariant: `func_outputs` contains only Tensors, IndexedSlices,
695      # SparseTensors, TensorArrays and `None`s.
696      func_outputs = nest.map_structure(convert, func_outputs)
697
698      check_mutation(func_args_before, func_args)
699      check_mutation(func_kwargs_before, func_kwargs)
700    finally:
701      tape.pop_tape(this_tape)
702      current_scope.set_use_resource(default_use_recource)
703
704    # Variables in `func_args`, `func_kwargs` should be explicit inputs
705    # to the function, not captured inputs.
706    tape_variables = this_tape.watched_variables()
707    arg_variables = set()
708    inputs = []
709    for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
710      if isinstance(arg, resource_variable_ops.ResourceVariable):
711        # Even if an argument variable was not used in the function, we've
712        # already manually captured the resource Tensor when creating argument
713        # placeholders.
714        resource_placeholder = func_graph.captures.pop(arg.handle, None)
715        if resource_placeholder is None:
716          continue
717        arg_variables.add(arg)
718        inputs.append(resource_placeholder)
719      elif isinstance(arg, ops.Tensor):
720        inputs.append(arg)
721    variables = [v for v in tape_variables if v not in arg_variables]
722    func_graph.inputs = inputs + list(func_graph.captures.values())
723
724    func_graph.structured_outputs = func_outputs
725    # Returning a closed-over tensor does not trigger convert_to_tensor.
726    func_graph.outputs.extend(
727        func_graph.capture(x)
728        for x in flatten(func_graph.structured_outputs)
729        if x is not None)
730
731    func_graph.variables = variables
732
733  if add_control_dependencies:
734    func_graph.control_outputs.extend(control_manager.ops_which_must_run)
735
736# Register any other functions defined in the graph.
737  with ops.init_scope():
738    if context.executing_eagerly():
739      for f in func_graph._functions.values():  # pylint: disable=protected-access
740        # TODO(ashankar): What about the gradient registry?
741        context.add_function(f._c_func.func)  # pylint: disable=protected-access
742
743  return func_graph
744
745
746def maybe_captured(tensor):
747  """If t is a captured value placeholder, returns the original captured value.
748
749  Args:
750    tensor: Tensor.
751
752  Returns:
753    A tensor, potentially from a different Graph/FuncGraph.
754  """
755  if (not isinstance(tensor, ops.EagerTensor) and
756      tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
757    for input_t, placeholder_t in tensor.op.graph.captures.items():
758      if tensor == placeholder_t:
759        return maybe_captured(input_t)
760  # pylint: enable=protected-access
761  return tensor
762
763
764def device_stack_has_callable(device_stack):
765  """Checks whether a device stack contains a callable."""
766  return any(callable(spec._device_name_or_function)  # pylint: disable=protected-access
767             for spec in device_stack.peek_objs())
768
769
770def check_mutation(n1, n2):
771  """Check if two list of arguments are exactly the same."""
772  errmsg = ("Function to be traced should not modify structure of input "
773            "arguments. Check if your function has list and dictionary "
774            "operations that alter input arguments, "
775            "such as `list.pop`, `list.append`")
776  try:
777    nest.assert_same_structure(n1, n2)
778  except ValueError:
779    raise ValueError(errmsg)
780
781  for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
782    if arg1 is not arg2:
783      raise ValueError(errmsg)
784
785
786def flatten(sequence):
787  """Like `nest.flatten` but also unpacks other Tensor-like objects.
788
789  Flattens non-tensor objects into their constituent tensors.
790
791  Args:
792    sequence: A nested structure of Tensors, CompositeTensors, and
793      TensorArrays.
794
795  Returns:
796    A list of tensors.
797  """
798  # TODO(akshayka): Support `SparseTensor` in a similar fashion.
799  flat_sequence = nest.flatten(sequence, expand_composites=True)
800  return [
801      item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
802      for item in flat_sequence]
803
804
805def pack_sequence_as(structure, flat_sequence):
806  """Like `nest.pack_sequence_as` but also packs other Tensor-like objects.
807
808  Args:
809    structure: The structure to pack into. May contain Tensors,
810      CompositeTensors, or TensorArrays.
811    flat_sequence: An iterable containing tensors.
812
813  Returns:
814    A nested structure.
815
816  Raises:
817    AssertionError if `structure` and `flat_sequence` are not compatible.
818  """
819  flat_sequence = list(flat_sequence)
820  flattened_structure = nest.flatten(structure, expand_composites=True)
821  if len(flattened_structure) != len(flat_sequence):
822    raise ValueError("Mismatch in element count")
823  for i in range(len(flat_sequence)):
824    if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
825      flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
826          old_ta=flattened_structure[i], flow=flat_sequence[i])
827  return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
828
829
830
831def _create_substitute_placeholder(value, name=None, dtype=None):
832  """Creates a placeholder for `value` and propagates shape info to it."""
833  # Note: setting ops.control_dependencies(None) ensures we always put
834  # capturing placeholders outside of any control flow context.
835  with ops.control_dependencies(None):
836    placeholder = graph_placeholder(
837        dtype=dtype or value.dtype, shape=value.shape, name=name)
838  custom_gradient.copy_handle_data(value, placeholder)
839  return placeholder
840
841
842def _get_defun_inputs_from_args(args, names, flat_shapes=None):
843  """Maps Python function positional args to graph-construction inputs."""
844  return _get_defun_inputs(
845      args, names, structure=args, flat_shapes=flat_shapes)
846
847
848def _get_defun_inputs(args, names, structure, flat_shapes=None):
849  """Maps python function args to graph-construction inputs.
850
851  Args:
852    args: A flat list of user-specified arguments.
853    names: A list of strings with user-specified argument names, same length as
854      `args`. May be `None`, in which case a generic name is used.
855    structure: The original argument list or dictionary.
856    flat_shapes: A flat list of values that are either `None` or
857      instances of `TensorShape`.  If provided, then length must match
858      that of `nest.flatten(args)`; and locations where `args` are
859      instances of `Tensor` must have a corresponding `TensorShape` in
860      `flat_shapes`.  May be `None`, in which case exact shapes are read
861      directly from the args.
862
863  Returns:
864    Placeholders with the same structure as `structure`.
865
866  Raises:
867    RuntimeError: if `flat_shapes` is provided, but
868     `len(flat_shapes) != len(nest.flatten(args))`.
869    RuntimeError: if a shape from `flat_shapes` is not None
870     for an argument that is not a `Tensor`, `TensorSpec`,
871     or `ResourceVariable`.
872  """
873  func_graph = ops.get_default_graph()
874  function_inputs = []
875  if names is None:
876    names = [None] * len(args)
877  if flat_shapes is None:
878    shapes_iter = itertools.repeat(None)
879  else:
880    len_flat_args = len(nest.flatten(args))
881    if len_flat_args != len(flat_shapes):
882      raise RuntimeError(
883          "Length of fully flat shapes (%d) must match that of "
884          "flatten(args) (%d).  args: %s, flat_shapes: %s"
885          % (len(flat_shapes),
886             len_flat_args,
887             args,
888             flat_shapes))
889    shapes_iter = iter(flat_shapes)
890  for arg_value, name in zip(args, names):
891    flattened = nest.flatten(arg_value)
892    tensor_specs = [
893        arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)
894    ]
895    specified_names = [arg.name for arg in tensor_specs if arg.name]
896    if specified_names and len(specified_names) < len(tensor_specs):
897      raise ValueError("If specifying TensorSpec names for nested structures, "
898                       "either zero or all names have to be specified.")
899
900    for arg in flattened:
901      # We have a shape entry for each arg, regadless of whether it's a real
902      # Tensor or not.  For non-tensor entries it should be None.
903      shape = next(shapes_iter)
904      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
905        if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
906          requested_name = arg.name
907        else:
908          requested_name = name
909        placeholder_shape = shape if shape is not None else arg.shape
910        try:
911          placeholder = graph_placeholder(
912              arg.dtype, placeholder_shape,
913              name=requested_name)
914        except ValueError:
915          # Sometimes parameter names are not valid op names, so fall back to
916          # unnamed placeholders.
917          placeholder = graph_placeholder(arg.dtype, placeholder_shape)
918        if name is not None:
919          # Record the requested/user-specified name in case it's different than
920          # the uniquified name, for validation when exporting signatures.
921          placeholder.op._set_attr(  # pylint: disable=protected-access
922              "_user_specified_name",
923              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
924        function_inputs.append(placeholder)
925      elif isinstance(arg, resource_variable_ops.ResourceVariable):
926        # Capture arg variables to create placeholders for them. These will be
927        # removed as captures after the function is traced (since otherwise we'd
928        # just add it back with a new placeholder when the variable was
929        # referenced).
930        placeholder = func_graph.capture(arg.handle, name=name)
931        placeholder.op._set_attr(  # pylint: disable=protected-access
932            "_user_specified_name",
933            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
934        function_inputs.append(arg)
935      else:
936        if shape is not None:
937          raise RuntimeError(
938              "Expected provided shape override to be None for arg that isn't "
939              "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s"
940              % (arg, shape, args))
941        function_inputs.append(arg)
942  return nest.pack_sequence_as(structure, function_inputs)
943
944
945def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
946  """Maps Python function keyword args to graph-construction inputs."""
947  if kwargs:
948    names, args = zip(*sorted(kwargs.items()))
949  else:
950    names = []
951    args = []
952  return _get_defun_inputs(
953      args, names, structure=kwargs, flat_shapes=flat_shapes)
954
955
956def dismantle_func_graph(func_graph):
957  """Removes reference cycles in `func_graph` FuncGraph.
958
959  Helpful for making sure the garbage collector doesn't need to run when
960  the FuncGraph goes out of scope, e.g. in tests using defun with
961  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
962
963  Args:
964    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
965      after this function.
966  """
967  # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
968  # Clearing captures using clear() leaves some cycles around.
969  while func_graph.captures:
970    func_graph.captures.popitem()
971  memory.dismantle_ordered_dict(func_graph.captures)
972  ops.dismantle_graph(func_graph)
973