• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Defun decorator for defining graph-mode functions."""
17
18import collections
19import pprint
20import threading
21import types as types_lib
22from typing import List
23import weakref
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import function_pb2
27from tensorflow.core.function.polymorphism import function_cache
28from tensorflow.python import pywrap_tfe
29from tensorflow.python.client import pywrap_tf_session
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import backprop_util
32from tensorflow.python.eager import context
33from tensorflow.python.eager import execute
34from tensorflow.python.eager import forwardprop_util
35from tensorflow.python.eager import function_context
36from tensorflow.python.eager import function_saved_model_utils
37from tensorflow.python.eager import function_spec
38from tensorflow.python.eager import monitoring
39from tensorflow.python.eager import tape
40from tensorflow.python.eager.graph_only_ops import graph_placeholder
41from tensorflow.python.framework import c_api_util
42from tensorflow.python.framework import composite_tensor
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import error_interpolation
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import func_graph as func_graph_module
47from tensorflow.python.framework import indexed_slices
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.framework import tensor_spec
51from tensorflow.python.framework import type_spec
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import default_gradient
54from tensorflow.python.ops import functional_ops
55from tensorflow.python.ops import gradients_util
56from tensorflow.python.ops import handle_data_util
57from tensorflow.python.ops import resource_variable_ops
58from tensorflow.python.platform import tf_logging as logging
59from tensorflow.python.profiler import trace
60from tensorflow.python.trackable import base as trackable
61from tensorflow.python.types import core
62from tensorflow.python.util import _pywrap_utils
63from tensorflow.python.util import compat
64from tensorflow.python.util import function_utils
65from tensorflow.python.util import lazy_loader
66from tensorflow.python.util import memory
67from tensorflow.python.util import nest
68from tensorflow.python.util import object_identity
69from tensorflow.python.util import tf_decorator
70from tensorflow.python.util import tf_inspect
71from tensorflow.python.util.tf_export import tf_export
72
73# Loaded lazily due to a circular dependency (roughly
74# tf.function->autograph->->dataset->tf.function).
75# TODO(b/133251390): Use a regular import.
76ag_ctx = lazy_loader.LazyLoader(
77    "ag_ctx", globals(),
78    "tensorflow.python.autograph.core.ag_ctx")
79np_arrays = lazy_loader.LazyLoader(
80    "np_arrays", globals(),
81    "tensorflow.python.ops.numpy_ops.np_arrays")
82
83
84FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
85BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
86IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
87SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
88
89_graph_building_time_counter = monitoring.Counter(
90    "/tensorflow/core/tf_function/graph_building_time_usecs",
91    "Time for tf.function to build a graph (us).")
92
93
94def _type_spec_for(x):
95  """Returns a TypeSpec for `x`, or `x` if `x` doesn't have a TensorSpec."""
96  if isinstance(x, ops.Tensor):
97    # We intentionally leave out the name of x from the TensorSpec here,
98    # because the name of a TensorSpec will override arg_name
99    # in the '_get_defun_inputs' method in func_graph.py.
100    return tensor_spec.TensorSpec(x.shape, x.dtype)
101  elif isinstance(x, type_spec.TypeSpec):
102    return x
103  elif isinstance(x, composite_tensor.CompositeTensor):
104    return x._type_spec  # pylint: disable=protected-access
105  else:
106    return x
107
108
109def _is_type_subset(a, b):
110  """Returns true if `b` is a subset of type `a` (or if a is not a TypeSpec.)"""
111  if isinstance(a, type_spec.TypeSpec):
112    return a.most_specific_compatible_type(b) == a
113  return True
114
115
116def common_shape(x, y):
117  """Find a `TensorShape` that is compatible with both `x` and `y`."""
118  if x is None != y is None:
119    raise RuntimeError(
120        "Cannot find a common shape when LHS shape is None but RHS shape "
121        f"is not (or vice versa): {x} vs. {y}.")
122  if x is None:
123    return None  # The associated input was not a Tensor, no shape generated.
124  if not isinstance(x, tensor_shape.TensorShape):
125    raise TypeError(f"`x` must be a TensorShape, got type {type(x)}.")
126  if not isinstance(y, tensor_shape.TensorShape):
127    raise TypeError(f"`y` must be a TensorShape, got type {type(y)}.")
128  if x.rank != y.rank or x.rank is None:
129    return tensor_shape.TensorShape(None)
130  dims = []
131  for dim_x, dim_y in zip(x.dims, y.dims):
132    if (dim_x != dim_y
133        or tensor_shape.dimension_value(dim_x) is None
134        or tensor_shape.dimension_value(dim_y) is None):
135      dims.append(None)
136    else:
137      dims.append(tensor_shape.dimension_value(dim_x))
138  return tensor_shape.TensorShape(dims)
139
140
141def _parse_func_attrs(attributes):
142  """Convert the keyword arguments into function_def attributes.
143
144  Currently only support primitive types: bool, int, float and string.
145
146  Args:
147    attributes: the dictionary of attributes.
148  Returns:
149    A dict of attributes where the key is the name of attribute and the value
150      is the AttrValue proto.
151  Raises:
152    ValueError: If the kwargs contains unallowlisted name or unsupported value
153      types.
154  """
155  attrs = {}
156  for key, value in attributes.items():
157    if isinstance(value, attr_value_pb2.AttrValue):
158      attrs[key] = value
159    # bool type check has to happen before int since bool is a subclass of int.
160    elif isinstance(value, bool):
161      attrs[key] = attr_value_pb2.AttrValue(b=value)
162    elif isinstance(value, int):
163      attrs[key] = attr_value_pb2.AttrValue(i=value)
164    elif isinstance(value, float):
165      attrs[key] = attr_value_pb2.AttrValue(f=value)
166    elif isinstance(value, (str, bytes)):
167      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
168    else:
169      raise ValueError(f"Attribute {key} must be bool, int, float, string, or "
170                       f"AttrValue. Got {type(value)}.")
171  return attrs
172
173
174class _InterpolateFunctionError(object):
175  """Context Manager that interpolates the exception from 'top_level_func'."""
176
177  __slots__ = ["_func"]
178
179  def __init__(self, top_level_func):
180    self._func = top_level_func
181
182  def __enter__(self):
183    pass
184
185  def __exit__(self, typ, exc, tb):
186    if not exc or not isinstance(exc, errors.OpError):
187      return False
188    message = compat.as_text(exc.message)
189    _, func_tags, _ = error_interpolation.parse_message(message)
190    g = None
191    for func_tag in func_tags:
192      # TODO(mdan): Tests should cover this.
193      if func_tag.name == compat.as_str(self._func.name):
194        g = self._func.graph
195      elif g:
196        next_func = g._get_function(func_tag.name)  # pylint: disable=protected-access
197        if next_func is not None and isinstance(next_func,
198                                                _EagerDefinedFunction):
199          g = next_func.graph
200    if g:
201      exc._message = error_interpolation.interpolate(message, g)  # pylint: disable=protected-access
202    return False
203
204
205_function_callbacks = set()
206
207
208def add_function_callback(function_callback):
209  """Add a callback function for Function creation.
210
211  The callback function has the signature:
212
213    `def function_callback(function, name, graph, inputs, outputs):`
214
215  where:
216  - `function`: _EagerDefinedFunction being created before finalizing the graph.
217      Do not modify the function directly but instead modify the graph.
218  - `name`: name of the function.
219  - `graph`: Graph of the function.
220  - `inputs`: `tuple` of tensors used as inputs to the function.
221  - `outputs`: `tuple` of tensors used as outputs from the function.
222
223  The callback is at the top of the `_EagerDefinedFunction` construction, giving
224  callback an opportunity to make the last edits to the graph. Do not make
225  changes to `graph, inputs`, and `outputs` manually, but, instead, set the
226  `graph` as the default then define ops.
227
228  Repeated registration of the same callback function is idempotent.
229  After a callback is added, it can be removed with the
230  `remove_function_callback()` method.
231
232  Args:
233    function_callback: The callback to add.
234  """
235  _function_callbacks.add(function_callback)
236
237
238def remove_function_callback(function_callback):
239  """Remove an already-added function callback.
240
241  See the doc string of `add_function_callback()` for more information.
242
243  Args:
244    function_callback: The callback to remove.
245  """
246  _function_callbacks.remove(function_callback)
247
248
249def clear_function_callbacks():
250  """Clear all function callbacks, if any have been regisered."""
251  _function_callbacks.clear()
252
253
254_FORWARD_PREFIX = "__forward_"
255_BACKWARD_PREFIX = "__backward_"
256_INFERENCE_PREFIX = "__inference_"
257
258
259def _forward_name(n):
260  """The name of a generated forward defun named n."""
261  return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
262
263
264def _backward_name(n):
265  """The name of a generated backward defun named n."""
266  return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
267
268
269def _inference_name(n):
270  """The name of a forward-but-no-gradient defun named n."""
271  return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
272
273
274class _EagerDefinedFunctionDeleter(object):
275  """Unregister function from eager context."""
276
277  __slots__ = ["name"]
278
279  def __init__(self, name):
280    self.name = name
281
282  def __del__(self):
283    try:
284      context.remove_function(self.name)
285    except TypeError:
286      # Suppress some exceptions, mainly for the case when we're running on
287      # module deletion. Things that can go wrong include the context module
288      # already being unloaded, self._handle._handle_data no longer being
289      # valid, and so on. Printing warnings in these cases is silly
290      # (exceptions raised from __del__ are printed as warnings to stderr).
291      pass  # 'NoneType' object is not callable when the handle has been
292      # partially unloaded.
293    except AttributeError:
294      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
295      # been unloaded. Will catch other module unloads as well.
296
297
298# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
299# so it doesn't have the definition-generating logic and is just a container for
300# an already-defined function.
301class _EagerDefinedFunction(object):
302  """Callable with the interface of `framework.function._DefinedFunction`.
303
304  `_EagerDefinedFunction` encapsulates a function definition and its properties,
305  and it provides a method for calling the encapsulated function. Some Ops
306  take functions as attributes, which have type `func`; an instance of this
307  class may be provided as the value of these `func` attributes.
308  """
309
310  def __init__(self, name, graph, inputs, outputs, attrs):
311    """Initializes an eager defined function.
312
313    Args:
314      name: str, the name for the created function.
315      graph: Graph, the graph containing the operations in the function
316      inputs: the tensors in the graph to be used as inputs to the function
317      outputs: the tensors in the graph which will be outputs from the function
318      attrs: dict mapping names of attributes to their AttrValue values
319    """
320    for function_callback in _function_callbacks:
321      function_callback(self, name, graph, tuple(inputs), tuple(outputs))
322
323    input_ops = set(arg.op for arg in inputs)
324    operations = [op for op in graph.get_operations() if op not in input_ops]
325
326    graph_output_names = graph._output_names  # pylint: disable=protected-access
327    if (graph_output_names is not None and
328        all(ops.tensor_id(t) in graph_output_names for t in outputs)):
329      output_names = [
330          compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
331      ]
332      if len(set(output_names)) != len(output_names):
333        # There are duplicate names for some reason, probably an invalid
334        # signature. Revert to auto-naming.
335        output_names = []
336    else:
337      output_names = []
338    with graph._c_graph.get() as c_graph:  # pylint: disable=protected-access
339      fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
340          c_graph,
341          compat.as_str(name),
342          False,
343          [o._c_op for o in operations],  # pylint: disable=protected-access
344          [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
345          [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
346          output_names,
347          [o._c_op for o in graph.control_outputs],  # pylint: disable=protected-access
348          [],  # control_output_names
349          None,
350          compat.as_str(""))
351
352    self._c_func = c_api_util.ScopedTFFunction(fn, name)
353
354    for name, attr_value in attrs.items():
355      serialized = attr_value.SerializeToString()
356      # TODO(iga): this creates and deletes a new TF_Status for every attr.
357      # It might be worth creating a convenient way to re-use status.
358      pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
359                                                     serialized)
360
361    # NOTE(feyu): Do not cache signature and definition at initialization to
362    # save memory usage of concrete functions never called through Python. We
363    # cache them on the first call of .definition and .signature.
364    signature = self._get_definition().signature
365
366    self._name = compat.as_bytes(signature.name)
367    with ops.init_scope():
368      if context.executing_eagerly():
369        context.ensure_initialized()
370        context.add_function(fn)
371        self._function_deleter = _EagerDefinedFunctionDeleter(self.name)
372        self._registered_on_context = True
373
374    self._num_outputs = len(signature.output_arg)
375    self._output_types = [o.type for o in signature.output_arg]
376    self._output_shapes = [o.shape for o in outputs]
377    self._control_captures = graph.control_captures
378    # Shallow copy outputs since ConcreteFunction may mutate it.
379    self._func_graph_outputs = list(outputs)
380    self.grad_func_name = None
381    self.python_grad_func = None
382    self._grad_func = None
383    self.graph = graph
384    self._stateful_ops = tuple(op for op in operations if op._is_stateful)  # pylint: disable=protected-access
385
386  @property
387  def signature(self):
388    try:
389      return self._signature
390    except AttributeError:
391      self._signature = self.definition.signature
392    return self._signature
393
394  @property
395  def definition(self):
396    try:
397      return self._definition
398    except AttributeError:
399      self._definition = self._get_definition()
400    return self._definition
401
402  def _get_definition(self):
403    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
404    # signature, but also in general it's nice not to depend on it.
405    with c_api_util.tf_buffer() as buffer_:
406      with self._c_func.get() as func:
407        pywrap_tf_session.TF_FunctionToFunctionDef(func, buffer_)
408      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
409    function_def = function_pb2.FunctionDef()
410    function_def.ParseFromString(compat.as_bytes(proto_data))
411    return function_def
412
413  def add_to_graph(self, g=None):
414    """Add the function to the current context or a graph, if supplied.
415
416    Args:
417      g: the graph to add the function to. If not supplied, the function will
418        be added to the current context.
419    """
420    # pylint: disable=protected-access
421    if not g and context.executing_eagerly():
422      ctx = context.context()
423      if not ctx.has_function(self.name):
424        ctx.add_function_def(self.definition)
425    else:
426      if not g._is_function(self.name):
427        g._add_function(self)
428      for f in self.graph._functions.values():
429        if not g._is_function(f.name):
430          g._add_function(f)
431    # pylint: enable=protected-access
432
433  @property
434  def name(self):
435    return self._name
436
437  @property
438  def stateful_ops(self):
439    return self._stateful_ops
440
441  def call(self, ctx, args, cancellation_manager=None):
442    """Calls this function with `args` as inputs.
443
444    `ConcreteFunction` execution respects device annotations only if the
445    function won't be compiled with xla.
446
447    Args:
448      ctx: a Context object
449      args: a list of arguments to supply this function with.
450      cancellation_manager: a `CancellationManager` object that can be used to
451        cancel function execution.
452
453    Returns:
454      The outputs of the function call.
455
456    Raises:
457      ValueError: if the number of arguments is incorrect.
458      FunctionAlreadyGarbageCollectedError: if the function is no longer
459        available to be called because it has been garbage collected.
460    """
461    if len(args) != len(self.signature.input_arg):
462      raise ValueError(
463          f"Signature specifies {len(list(self.signature.input_arg))} "
464          f"arguments, got: {len(args)}.")
465
466    function_call_options = ctx.function_call_options
467    if function_call_options.config_proto_serialized is None:
468      config = function_utils.get_disabled_rewriter_config()
469    else:
470      config = function_call_options.config_proto_serialized
471    executor_type = function_call_options.executor_type or ""
472
473    executing_eagerly = ctx.executing_eagerly()
474    attrs = ("executor_type", executor_type, "config_proto", config)
475    if executing_eagerly:
476      with _InterpolateFunctionError(self):
477        if cancellation_manager is None:
478          outputs = execute.execute(
479              str(self.signature.name),
480              num_outputs=self._num_outputs,
481              inputs=args,
482              attrs=attrs,
483              ctx=ctx)
484        else:
485          outputs = execute.execute_with_cancellation(
486              str(self.signature.name),
487              num_outputs=self._num_outputs,
488              inputs=args,
489              attrs=attrs,
490              ctx=ctx,
491              cancellation_manager=cancellation_manager)
492      # Replace empty list with None
493      outputs = outputs or None
494    else:
495      # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
496      # creates `PartitionedCallOp` kernels by default, or remove the previous
497      # branch if a TPU kernel is registered for `PartitionedCall`.
498      with _InterpolateFunctionError(self):
499        with ops.control_dependencies(self._control_captures):
500          # The caller must use record_operation to record this operation in the
501          # eager case, so we enforce the same requirement for the non-eager
502          # case by explicitly pausing recording. We don't have a gradient
503          # registered for PartitionedCall, so recording this operation confuses
504          # forwardprop code (GradientTape manages to ignore it).
505          with tape.stop_recording():
506            outputs = functional_ops.partitioned_call(
507                args=args,
508                f=self,
509                tout=self._output_types,
510                executing_eagerly=executing_eagerly,
511                config=config,
512                executor_type=executor_type)
513
514    for i, func_graph_output in enumerate(self._func_graph_outputs):
515      handle_data_util.copy_handle_data(func_graph_output, outputs[i])
516    if executing_eagerly:
517      return outputs
518    else:
519      # TODO(b/128924522): This additional set_shape should not be
520      # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this
521      # once that's done.
522      for i, shape in enumerate(self._output_shapes):
523        outputs[i].set_shape(shape)
524      return outputs
525
526
527def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
528  """Creates forward and backward functions from the function graphs."""
529  forward_function_name = _forward_name(forward_graph.name)
530  common_attributes = dict(attrs)
531  # NB: forward and backward function need to drop "_implements".
532  # attribute, because their signature contains all the intermediate tensors
533  # that they compute. Thus they don't have a stable signature which can
534  # be directly optimized downstream.
535  # See for more details:
536  # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
537  common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
538  backward_function_attr = _parse_func_attrs(
539      {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
540  backward_function_attr.update(common_attributes)
541  backward_function = ConcreteFunction(
542      backwards_graph, attrs=backward_function_attr)
543  forward_function_attr = _parse_func_attrs({
544      BACKWARD_FUNCTION_ATTRIBUTE_NAME:
545      backward_function.name})
546  forward_function_attr.update(common_attributes)
547  forward_function = _EagerDefinedFunction(
548      forward_function_name, forward_graph, forward_graph.inputs,
549      forward_graph.outputs, forward_function_attr)
550  return forward_function, backward_function
551
552
553class _DelayedRewriteGradientFunctions(object):
554  """Caches forward/backward functions with a delayed forward rewrite."""
555
556  def __init__(self, func_graph, attrs, func_graph_deleter):
557    """Construct an inference function and initialize caches."""
558    # A map from the number of forward function outputs with accepted gradients
559    # to forward and backward functions, used to cache non-tape backward
560    # function generation.
561    self._cached_function_pairs = {}
562    self._func_graph = func_graph
563    self._inference_function = _EagerDefinedFunction(
564        _inference_name(self._func_graph.name), self._func_graph,
565        self._func_graph.inputs, self._func_graph.outputs, attrs)
566    self._attrs = attrs
567    self._gradient_name = None
568    # Note that the FuncGraph is mutated later, so we need to inspect it now to
569    # figure out the user-specified outputs of the inference function.
570    self._num_inference_outputs = len(self._func_graph.outputs)
571    self._func_graph_deleter = func_graph_deleter
572
573  def forward_backward(self, num_doutputs=None):
574    """A possibly-cached pair of forward and backward functions."""
575    if num_doutputs is None:
576      num_doutputs = self._num_inference_outputs
577    forward_backward = self._cached_function_pairs.get(num_doutputs)
578    if forward_backward is not None:
579      return forward_backward
580    forward, backward = self._construct_forward_backward(num_doutputs)
581    self._cached_function_pairs[num_doutputs] = (forward, backward)
582    return forward, backward
583
584  def _construct_forward_backward(self, num_doutputs):
585    """Constructs a pair of forward and backward functions.
586
587    Args:
588      num_doutputs: The constructed backprop function will take output gradients
589        for the first `num_doutputs` outputs of the forward function. Defaults
590        to the number of outputs for the inference function, but when
591        higher-order gradients are computed this will increase to include side
592        outputs.
593
594    Returns:
595      A pair of (forward_function, backward_function):
596        forward_function: A re-generated inference function (an
597          _EagerDefinedFunction) to account for new side outputs, if any extra
598          were required when building the backward pass.
599        backward_function: A ConcreteFunction that Takes `num_doutputs`
600          arguments and returns gradients with respect to inputs of the forward
601          function.
602    """
603    trainable_outputs = [
604        output for output in self._func_graph.outputs[:num_doutputs]
605        if backprop_util.IsTrainable(output)]
606
607    signature = []
608    for t in trainable_outputs:
609      signature.append(
610          tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
611
612    def _backprop_function(*grad_ys):
613      with ops.device(None):
614        return gradients_util._GradientsHelper(  # pylint: disable=protected-access
615            trainable_outputs,
616            self._func_graph.inputs,
617            grad_ys=grad_ys,
618            src_graph=self._func_graph)
619
620    with self._func_graph.as_default():
621      backwards_graph = func_graph_module.FuncGraph(
622          _backward_name(self._func_graph.name))
623      func_graph_module.func_graph_from_py_func(
624          name=backwards_graph.name,
625          python_func=_backprop_function,
626          args=[], kwargs={},
627          signature=signature,
628          func_graph=backwards_graph)
629      backwards_graph_captures = backwards_graph.external_captures
630      captures_from_forward = [
631          c for c in backwards_graph_captures if
632          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
633
634      existing_outputs = object_identity.ObjectIdentitySet(
635          self._func_graph.outputs)
636      for capture in captures_from_forward:
637        if capture not in existing_outputs:
638          existing_outputs.add(capture)
639          self._func_graph.outputs.append(capture)
640
641      forward_function, backward_function = _create_forward_backward_with_graph(
642          self._attrs, self._func_graph, backwards_graph)
643      return forward_function, backward_function
644
645  def _rewrite_forward_and_call_backward(self, op, *doutputs):
646    """Add outputs to the forward call and feed them to the grad function."""
647    forward_function, backwards_function = self.forward_backward(len(doutputs))
648    if not backwards_function.outputs:
649      return backwards_function.structured_outputs
650    forward_function.add_to_graph(op.graph)
651
652    # pylint: disable=protected-access
653    # Rewrite an inference call op to be a forward call op
654    op._set_func_attr("f", forward_function.name)
655    op._set_type_list_attr("Tout", forward_function._output_types)
656    op._add_outputs(
657        forward_function._output_types[len(op.outputs):],
658        forward_function._output_shapes[len(op.outputs):])
659    for i in range(len(op.outputs)):
660      func_graph_output = forward_function._func_graph_outputs[i]
661      handle_data_util.copy_handle_data(func_graph_output, op.outputs[i])
662    # pylint: enable=protected-access
663
664    capture_mapping = dict(
665        zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
666    remapped_captures = [
667        capture_mapping.get(ops.tensor_id(capture), capture)
668        for capture in backwards_function.captured_inputs
669    ]
670
671    # Replace Nones with zeros since we're calling a graph function which
672    # expects numeric inputs.
673    cleaned_doutputs = []
674    for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
675      if backprop_util.IsTrainable(placeholder):
676        if isinstance(doutput, indexed_slices.IndexedSlices):
677          # Gradient passed to a backward ConcreteFunction must be tf.Tensor,
678          # so we convert tf.IndexedSlices to tf.Tensor.
679          cleaned_doutputs.append(ops.convert_to_tensor(doutput))
680        elif doutput is not None:
681          cleaned_doutputs.append(doutput)
682        else:
683          cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
684
685    # Compute the gradients using the side outputs
686    return backwards_function._call_flat(  # pylint: disable=protected-access
687        cleaned_doutputs, remapped_captures)
688
689  def get_gradient_function(self):
690    """Returns gradient function.
691
692    The gradient rewrites an inference call op to a forward call op, but does
693    not modify a pre-existing forward call op. It then computes the gradient
694    from the output's gradients and the side outputs of the forward op.
695    """
696    return self._rewrite_forward_and_call_backward
697
698  def forward(self, inference_args=None, input_tangents=None):
699    """A forward function with only user-specified outputs.
700
701    The call operation for the returned inference function can be rewritten into
702    a forward function. This only happens if the backward function (from the
703    `backward` method) ends up being used to compute gradients.
704
705    This approach avoids constructing unnecessary graphs, but it only works if
706    we are calling this function when not executing eagerly.
707
708    Args:
709      inference_args: A flat list of Tensors, arguments to the inference
710        function. Unused, but taken for compatibility with
711        _TapeGradientFunctions.
712      input_tangents: A flat list of Tensors, jvps associated with
713        `inference_args`. Unused; if required, tape functions must be used
714        instead.
715
716    Returns:
717      An _EagerDefinedFunction.
718    """
719    del inference_args  # unused
720    if input_tangents:
721      # This class does not support special-cased forwardprop. The arguments are
722      # here for compatibility with _TapeGradientFunctions.
723      raise errors.InternalError("unexpectedly got forwardprop information in "
724                                 "a class that does not support forwardprop.")
725    return self._inference_function
726
727  def _backward(self, outputs):
728    """Fetch a backward function for `outputs` from the forward function."""
729    def _backward_function(*args):
730      call_op = outputs[0].op
731      return self._rewrite_forward_and_call_backward(call_op, *args)
732    return _backward_function, outputs
733
734  def record(self, flat_outputs, inference_args, input_tangents):
735    """Record the function call operation.
736
737    _DelayedRewriteGradientFunctions supports only first-order backprop tape
738    gradients (and then only when graph building). It does not work with
739    higher-order tape gradients or forward autodiff, but does work with
740    higher-order symbolic gradients (tf.gradients).
741
742    Args:
743      flat_outputs: The result of running `forward`.
744      inference_args: A flat list of Tensors with inference inputs to the
745        operation.
746      input_tangents: A flat list of Tensors with input tangents consumed by the
747        operation.
748    """
749    backward_function, to_record = self._backward(flat_outputs)
750    tape.record_operation(self._inference_function.signature.name,
751                          to_record, inference_args + input_tangents,
752                          backward_function)
753
754
755# Contains information about a forward function wrapped to compute jvps.
756_ForwardWrapper = collections.namedtuple(
757    "_ForwardWrapper", (
758        # The wrapper Graph.
759        "graph",
760        # A flat list of non-tangent Tensor outputs from the wrapped forward
761        # function.
762        "outputs",
763        # Indices for output tangents, same format as
764        # forwardprop_util.pack_tangents.
765        "output_indices",
766        # A flat list of tangents for `outputs`.
767        "output_tangents"))
768
769
770class _TapeGradientFunctions(object):
771  """Caches forward and backward functions compatible with eager gradients.
772
773  In contrast to the delayed-rewrite approach in
774  `_DelayedRewriteGradientFunctions` which only works with delayed execution,
775  the forward function generated by this class has a fixed set of outputs which
776  may be preserved by a tape in order to compute gradients later.
777
778  This class is abstract; its child classes differ in how many side outputs of
779  the forward function their backward function accepts gradients for, which
780  determines whether higher-order tape gradients are possible.
781  """
782
783  def __init__(self, func_graph, attrs, func_graph_deleter,
784               forwardprop_input_indices, delayed_rewrite_functions,
785               need_gradients_for_jvps):
786    self._func_graph = func_graph
787    self._forward_graph = None
788    self._attrs = attrs
789    self._forward = None
790    self._backward = None
791    self._num_outputs = len(func_graph.outputs)
792    self._func_graph_deleter = func_graph_deleter
793    self._forwardprop_input_indices = forwardprop_input_indices
794    self._forwardprop_output_indices = None
795    self._num_forwardprop_outputs = 0
796    self._num_inference_outputs = len(func_graph.outputs)
797    self._num_trainable_inference_outputs = len(
798        [t for t in func_graph.outputs if backprop_util.IsTrainable(t)])
799    self._delayed_rewrite_functions = delayed_rewrite_functions
800    self._need_gradients_for_jvps = need_gradients_for_jvps
801
802  def _build_functions_for_outputs(
803      self, outputs, inference_args, input_tangents):
804    """Forward+backward functions where the backward function sees `outputs`."""
805    # First figure out which of `outputs` are trainable. We'll accept gradients
806    # for each of these in the backward function.
807    handles_to_variables = self._func_graph.variable_captures
808    trainable_outputs = []
809    trainable_indices = []
810    for index, output in enumerate(outputs):
811
812      if backprop_util.IsTrainable(output):
813        # Swap in the Variable object for resource handles if we can so
814        # sparse gradients work.
815        output = handles_to_variables.get(id(output), output)
816        trainable_outputs.append(output)
817        trainable_indices.append(index)
818
819    backwards_graph = func_graph_module.FuncGraph(
820        _backward_name(self._func_graph.name))
821    with backwards_graph.as_default():
822      gradients_wrt_outputs = []
823      for output in trainable_outputs:
824        gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
825            output)
826        gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
827        handle_data_util.copy_handle_data(output, gradient_placeholder)
828        gradients_wrt_outputs.append(gradient_placeholder)
829      with ops.device(None):
830        gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
831            trainable_outputs,
832            self._func_graph.inputs,
833            grad_ys=gradients_wrt_outputs,
834            src_graph=self._func_graph)
835
836      if input_tangents:
837        # Convert IndexedSlices to dense tensors (as we do elsewhere for
838        # function gradients). Our C++ bindings don't know how to handle them
839        # currently.
840        gradients_wrt_inputs = nest.map_structure(
841            lambda x: ops.convert_to_tensor(x) if x is not None else None,
842            gradients_wrt_inputs)
843      captures_from_forward = [
844          c for c in backwards_graph.external_captures
845          if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
846      ]
847      existing_outputs = object_identity.ObjectIdentitySet(
848          self._func_graph.outputs)
849      for capture in captures_from_forward:
850        if capture not in existing_outputs:
851          existing_outputs.add(capture)
852          self._func_graph.outputs.append(capture)
853
854    # The ordering of `backwards_graph.inputs` is important: inputs of
855    # `backward_function` correspond to outputs (including
856    # side outputs) of `self._tape_forward_function`.
857    backwards_graph.inputs = (
858        gradients_wrt_outputs + backwards_graph.internal_captures)
859    backwards_graph.outputs.extend(
860        grad
861        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
862        if grad is not None)
863    backwards_graph.structured_outputs = gradients_wrt_inputs
864
865    forward_function, backward_function = _create_forward_backward_with_graph(
866        self._attrs, self._func_graph, backwards_graph)
867
868    if not input_tangents:
869      # There is no need to special-case forwardprop, so we can return the
870      # forward+backward pair we've created without further wrapping.
871      return (forward_function, self._func_graph, backward_function,
872              # No forwardprop outputs.
873              None, 0)
874    forward_wrapper = self._wrap_forward_function_with_jvps(
875        forward_function, backward_function, inference_args, input_tangents)
876    (wrapped_backwards_graph,
877     forward_wrapper) = self._wrap_backward_function_with_jvp_backprop(
878         backward_function, gradients_wrt_outputs, forward_wrapper)
879    # Now that we've added new captures, we need to make sure forward outputs
880    # are in the same order the backward function expects them to be in:
881    # [inference outputs] + [jvps] + [side outputs] + [captures].
882    forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
883    (wrapped_forward_function,
884     wrapped_backward_function) = _create_forward_backward_with_graph(
885         self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
886    if (len(inference_args) + len(input_tangents)
887        != len(forward_wrapper.graph.inputs)):
888      raise errors.InternalError(
889          f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, "
890          f"but we expected {len(inference_args) + len(input_tangents)} "
891          f"({len(inference_args)} inference inputs and "
892          f"{len(input_tangents)} input tangents).")
893    return (wrapped_forward_function, forward_wrapper.graph,
894            wrapped_backward_function, forward_wrapper.output_indices,
895            len(forward_wrapper.output_tangents))
896
897  def _wrap_forward_function_with_jvps(
898      self, forward_function, backward_function,
899      inference_args, input_tangents):
900    """Adds inline JVP computation to a forward function."""
901    forward_wrapper_graph = func_graph_module.FuncGraph(
902        _forward_name(self._func_graph.name))
903    with forward_wrapper_graph.as_default():
904      # Tell forward accumulators to free up space for new JVP computations,
905      # since one may be in the process of computing a JVP (if that computation
906      # triggered this function building).
907      #
908      # We'll make symbolic versions of input JVPs, run the forward function
909      # under forward accumulators to get symbolic output JVPs, then set those
910      # as outputs of the new wrapped forward function.
911      with forwardprop_util.push_forwardprop_state():
912        forward_captures = {
913            ops.tensor_id(internal): external
914            for external, internal in self._func_graph.captures}
915        for input_index, real_input in enumerate(self._func_graph.inputs):
916          # This loop is more or less equivalent to running tf.identity on each
917          # of self._func_graph.inputs. However, doing that also captures jvps
918          # for resource handles, which confuses the jvp capturing code below
919          # (since primal inputs are interwoven with jvp inputs).
920          input_placeholder = array_ops.placeholder(
921              dtype=real_input.dtype,
922              shape=real_input.shape)
923          capture = forward_captures.get(ops.tensor_id(real_input))
924          if capture is not None:
925            forward_wrapper_graph.add_capture(capture, input_placeholder)
926            if capture.dtype == dtypes.resource:
927              handle_data_util.copy_handle_data(capture, input_placeholder)
928          else:
929            forward_wrapper_graph.inputs.append(input_placeholder)
930        for inp, arg in zip(forward_wrapper_graph.inputs, inference_args):
931          tape.record_operation(
932              "captured_value", [inp], [arg],
933              backward_function=lambda x: [x],
934              forward_function=lambda x: [x])
935        num_inference_inputs = len(inference_args)
936        for tape_indices in self._forwardprop_input_indices:
937          for input_index, jvp_index in tape_indices:
938            input_placeholder = forward_wrapper_graph.inputs[input_index]
939            if len(forward_wrapper_graph.inputs) != jvp_index:
940              raise errors.InternalError(
941                  f"Expected {jvp_index} forward graph inputs, "
942                  f"got {len(forward_wrapper_graph.inputs)}.")
943            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
944                input_placeholder)
945            jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
946            external_jvp = input_tangents[jvp_index - num_inference_inputs]
947            forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder)
948            tensor_shape.TensorShape(
949                external_jvp.shape).assert_is_compatible_with(
950                    jvp_placeholder.shape)
951            tape.record_operation(
952                "captured_value",
953                [jvp_placeholder],
954                [external_jvp],
955                backward_function=lambda x: [x],
956                forward_function=lambda x: [x])
957        forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs]
958        gradient_function = (
959            self._delayed_rewrite_functions._rewrite_forward_and_call_backward)  # pylint: disable=protected-access
960        with ops.get_default_graph()._override_gradient_function(  # pylint: disable=protected-access
961            {"PartitionedCall": gradient_function,
962             "StatefulPartitionedCall": gradient_function}):
963          forward_outputs = forward_function.call(context.context(),
964                                                  forward_inputs)
965          if isinstance(forward_outputs, ops.Operation):
966            # _wrapped_backward_function expects a list, but if the function has
967            # no outputs its call() returns an Operation. We need to undo that
968            # so we don't cause problems later.
969            forward_outputs = []
970        py_backward, _ = self._wrap_backward_function(
971            self._func_graph, backward_function, forward_outputs)
972      # We will never request backward tape gradients for this operation
973      # directly since we're wrapping the call; forwardprop will call the
974      # backward function (and nested forward accumulators may build
975      # higher-order gradients), but any watching GradientTapes should ignore
976      # it.
977      #
978      # TODO(allenl): It might be better to explicitly stop backward recording
979      # so we don't use the second-order tape cases unnecessarily.
980      tape.record_operation_forwardprop_only(
981          forward_function.signature.name,
982          forward_outputs, forward_inputs, py_backward, None)
983      output_indices, output_tangents = (
984          pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
985      output_tangents = [forward_wrapper_graph.capture(t)
986                         for t in output_tangents]
987    return _ForwardWrapper(
988        graph=forward_wrapper_graph, outputs=forward_outputs,
989        output_indices=output_indices, output_tangents=output_tangents)
990
991  def _wrap_backward_function_with_jvp_backprop(
992      self, backward_function, gradients_wrt_outputs, forward_wrapper):
993    """Wraps `backward_function` to include gradients for JVPs."""
994    wrapped_backwards_graph = func_graph_module.FuncGraph(
995        _backward_name(self._func_graph.name))
996    with wrapped_backwards_graph.as_default():
997      py_backward, recorded_outputs = self._wrap_backward_function(
998          self._func_graph, backward_function, forward_wrapper.outputs)
999      trainable_index = 0
1000      forward_doutputs = []
1001      doutput_args = []
1002      for output in recorded_outputs:
1003        if backprop_util.IsTrainable(output):
1004          doutput = gradients_wrt_outputs[trainable_index]
1005          doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape)
1006          doutput_args.append(doutput_placeholder)
1007          forward_doutputs.append(doutput_placeholder)
1008          trainable_index += 1
1009        else:
1010          doutput_args.append(None)
1011
1012      dinputs = py_backward(*doutput_args)
1013      existing_outputs = object_identity.ObjectIdentitySet(
1014          forward_wrapper.outputs + forward_wrapper.output_tangents)
1015      num_processed_output_tangents = 0
1016      gradients_wrt_output_tangents = []
1017      tangent_doutputs = []
1018      output_tangents = forward_wrapper.output_tangents
1019      output_indices = forward_wrapper.output_indices
1020      if self._need_gradients_for_jvps:
1021        # TODO(allenl): Consider using a throwaway graph to avoid extra gradient
1022        # evaluations; gradients for jvps may have common subgraphs.
1023        while num_processed_output_tangents != len(output_tangents):
1024          for output in output_tangents[num_processed_output_tangents:]:
1025            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
1026                output)
1027            placeholder = graph_placeholder(gradient_dtype, gradient_shape)
1028            gradients_wrt_output_tangents.append(placeholder)
1029            tangent_doutputs.append(placeholder)
1030          num_processed_output_tangents = len(output_tangents)
1031          with ops.device(None):
1032            gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
1033                output_tangents,
1034                forward_wrapper.graph.inputs,
1035                grad_ys=gradients_wrt_output_tangents,
1036                src_graph=forward_wrapper.graph)
1037          dinputs = [
1038              backprop.aggregate_indexed_slices_gradients((existing, new))
1039              for existing, new in zip(dinputs, gradients_wrt_inputs)
1040              if existing is not None or new is not None]
1041          dinputs.extend(gradients_wrt_inputs[len(dinputs):])
1042          captures_from_forward = [
1043              c for c in wrapped_backwards_graph.external_captures
1044              if (not isinstance(c, ops.EagerTensor)
1045                  and c.graph is forward_wrapper.graph)]
1046          for capture in captures_from_forward:
1047            if capture not in existing_outputs:
1048              existing_outputs.add(capture)
1049              forward_wrapper.outputs.append(capture)
1050          output_indices, output_tangents = (
1051              forwardprop_util.pack_tangents(forward_wrapper.outputs))
1052          output_tangents = [forward_wrapper.graph.capture(t)
1053                             for t in output_tangents]
1054          for t in output_tangents:
1055            existing_outputs.add(t)
1056    wrapped_backwards_graph.inputs = (
1057        forward_doutputs[:self._num_trainable_inference_outputs]
1058        + tangent_doutputs
1059        + forward_doutputs[self._num_trainable_inference_outputs:]
1060        + wrapped_backwards_graph.internal_captures)
1061    wrapped_backwards_graph.structured_outputs = dinputs
1062    wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None]
1063    return (wrapped_backwards_graph,
1064            forward_wrapper._replace(output_indices=output_indices,
1065                                     output_tangents=output_tangents))
1066
1067  def _shuffle_forward_outputs(self, forward_wrapper):
1068    """Reorders function outputs so captures are last."""
1069    def _index_map(original):
1070      if original < self._num_inference_outputs:
1071        return original
1072      if original >= len(forward_wrapper.outputs):
1073        return (original - len(forward_wrapper.outputs)
1074                + self._num_inference_outputs)
1075      return original + len(forward_wrapper.output_tangents)
1076    output_indices = nest.map_structure(
1077        _index_map, forward_wrapper.output_indices)
1078    forward_wrapper.graph.outputs = (
1079        forward_wrapper.outputs[:self._num_inference_outputs]
1080        + forward_wrapper.output_tangents
1081        + forward_wrapper.outputs[self._num_inference_outputs:])
1082    return forward_wrapper._replace(output_indices=output_indices)
1083
1084  def forward(self, inference_args, input_tangents):
1085    """Construct or fetch a forward function with side-outputs.
1086
1087    When graph building without a tape active, symbolic gradients rely on
1088    regenerating the backward function for higher-order gradients (to account
1089    for new side outputs of the rewritten forward function call). Thus there is
1090    no fixed backward function for this case. However, when a tape is active
1091    (eager or graph building), we generate fixed backward and forward functions
1092    at forward function call time.
1093
1094    This difference between the tape and non-tape cases is to avoid building
1095    unneeded backward functions while graph building (where we may or may not
1096    eventually need gradients).
1097
1098    Args:
1099      inference_args: A flat list of Tensors, arguments to the inference
1100        function.
1101      input_tangents: A flat list of Tensors, jvps associated with
1102        `inference_args`.
1103
1104    Returns:
1105      A forward _EagerDefinedFunction.
1106    """
1107    if self._forward is None:
1108      (self._forward, self._forward_graph, self._backward,
1109       self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
1110           self._forward_and_backward_functions(inference_args, input_tangents))
1111    return self._forward
1112
1113  def _wrap_backward_function(self, forward_graph, backward, outputs):
1114    """Create a backward function given `outputs` from the forward function."""
1115    capture_mapping = dict(
1116        zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
1117    captured_inputs = backward.captured_inputs
1118    remapped_captures = [
1119        capture_mapping.get(ops.tensor_id(capture), capture)
1120        for capture in captured_inputs
1121    ]
1122    if any(t.graph is forward_graph for t in remapped_captures
1123           if not isinstance(t, ops.EagerTensor)):
1124      incorrect_mapping = [t for t in remapped_captures
1125                           if (not isinstance(t, ops.EagerTensor) and
1126                               t.graph is not forward_graph)]
1127      raise errors.InternalError("Failed to map all backward graph captures to "
1128                                 "the forward graph. Incorrectly mapped: "
1129                                 f"{incorrect_mapping}.")
1130    # We may need to use zeros_like to get a zero for variant Tensors with
1131    # unconnected gradients. We do that in advance so we don't have to hold on
1132    # to the outputs themselves, which may not be needed otherwise.
1133    variant_zeros_like = {}
1134    backward_function_inputs = (len(backward.inputs) - len(captured_inputs))
1135    recorded_outputs = []
1136    trainable_recorded_outputs = 0
1137    skip_positions = []
1138    if self._num_forwardprop_outputs and not self._need_gradients_for_jvps:
1139      relevant_outputs = (
1140          outputs[:self._num_inference_outputs]
1141          + outputs[self._num_inference_outputs
1142                    + self._num_forwardprop_outputs:])
1143    else:
1144      relevant_outputs = outputs
1145    for output_index, output in enumerate(relevant_outputs):
1146      if trainable_recorded_outputs < backward_function_inputs:
1147        recorded_outputs.append(output)
1148      if backprop_util.IsTrainable(output):
1149        trainable_recorded_outputs += 1
1150      else:
1151        skip_positions.append(output_index)
1152      if output.dtype == dtypes.variant:
1153        variant_zeros_like[output_index] = default_gradient.zeros_like(output)
1154
1155    def _backward_function_wrapper(*args):
1156      """Process output gradients and call the backward function."""
1157      if not backward.outputs:
1158        return backward.structured_outputs
1159
1160      processed_args = []
1161      input_index = 0
1162      for output_index, arg in enumerate(args):
1163        # Convert IndexedSlices to dense tensors. The IndexedSlices optimization
1164        # is only really effective when doing tf.gather(variable) as the
1165        # adjoint functions for most operations are unlikely to preserve the
1166        # sparsity in IndexedSlices.
1167        if isinstance(arg, indexed_slices.IndexedSlices):
1168          arg = ops.convert_to_tensor(arg)
1169        if output_index in skip_positions:
1170          continue
1171        if arg is None:
1172          # We're calling a (non-polymorphic) ConcreteFunction, so we need to
1173          # have a Tensor value for each Tensor we thought would be trainable
1174          # based on its dtype, even if it ended up being unconnected.
1175          input_placeholder = backward.inputs[
1176              input_index]
1177          if input_placeholder.dtype == dtypes.variant:
1178            arg = variant_zeros_like[output_index]
1179          else:
1180            arg = array_ops.zeros(
1181                *default_gradient.shape_and_dtype(input_placeholder))
1182        processed_args.append(arg)
1183        input_index += 1
1184        if input_index >= backward_function_inputs:
1185          break
1186      return backward._call_flat(  # pylint: disable=protected-access
1187          processed_args, remapped_captures)
1188
1189    return _backward_function_wrapper, recorded_outputs
1190
1191  def record(self, flat_outputs, inference_args, input_tangents):
1192    """Record the function call operation.
1193
1194    For backprop, indicates the backward function to use and which new Tensors
1195    must be watched. For forwardprop from eager, the function call itself will
1196    have produced tangents which need to be recorded.
1197
1198    Args:
1199      flat_outputs: The result of running `forward`.
1200      inference_args: A flat list of Tensors with inference inputs to the
1201        operation.
1202      input_tangents: A flat list of Tensors with input tangents consumed by the
1203        operation.
1204    """
1205    backward_function, to_record = self._wrap_backward_function(
1206        self._forward_graph, self._backward, flat_outputs)
1207    if self._forwardprop_output_indices:
1208      tape.record_operation_backprop_only(
1209          self._forward.signature.name,
1210          to_record, inference_args,
1211          backward_function)
1212      tape.record_operation_forwardprop_only(
1213          self._forward.signature.name,
1214          flat_outputs, inference_args + input_tangents,
1215          backward_function,
1216          self._forwardprop_output_indices)
1217    else:
1218      tape.record_operation(self._forward.signature.name,
1219                            to_record, inference_args + input_tangents,
1220                            backward_function)
1221
1222
1223class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
1224  """Caches tape-friendly functions for first-order gradients."""
1225
1226  def __init__(self, func_graph, attrs, func_graph_deleter,
1227               forwardprop_input_indices, delayed_rewrite_functions,
1228               need_gradients_for_jvps):
1229    super().__init__(func_graph, attrs, func_graph_deleter,
1230                     forwardprop_input_indices, delayed_rewrite_functions,
1231                     need_gradients_for_jvps)
1232    self._func_graph_deleter = func_graph_deleter
1233    self._forwardprop_input_indices = forwardprop_input_indices
1234
1235  def _forward_and_backward_functions(self, inference_args, input_tangents):
1236    """Shortcut for when only first-order gradients are required.
1237
1238    The returned backward function does not accept gradients with respect to
1239    side output of forward_function. This is fine as long as the user can't
1240    possibly request second order tape gradients, as when they've used a single
1241    non-persistent GradientTape. Since we don't need the backward function to
1242    take gradients with respect to side outputs, we can skip some potentially
1243    slow graph building.
1244
1245    Args:
1246      inference_args: A flat list of Tensors, arguments to the inference
1247        function.
1248      input_tangents: A flat list of Tensors, jvps associated with
1249        `inference_args`.
1250
1251    Returns:
1252      A tuple of (forward_function, backward_function):
1253        forward_function: Takes the same inputs as the inference function, but
1254          returns side outputs used by backward_function in addition to the
1255          inference function's outputs.
1256        backward_function: Takes side outputs from forward_function and
1257          gradients with respect to the "real" outputs of forward_function and
1258          returns gradients with respect to the inputs.
1259    """
1260    outputs = self._func_graph.outputs[:self._num_inference_outputs]
1261    return self._build_functions_for_outputs(
1262        outputs, inference_args, input_tangents)
1263
1264
1265class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
1266  """Caches tape-friendly functions for higher-order gradients."""
1267
1268  # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
1269  # generalizing if so.
1270  def _forward_and_backward_functions(self, inference_args, input_tangents):
1271    """Forward and backward functions suitable for higher-order gradients.
1272
1273    Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
1274    this method accepts gradients for all of the outputs of the returned forward
1275    function, including side outputs.
1276
1277    Args:
1278      inference_args: A flat list of Tensors, arguments to the inference
1279        function.
1280      input_tangents: A flat list of Tensors, jvps associated with
1281        `inference_args`.
1282
1283    Returns:
1284      A tuple of (forward_function, backward_function):
1285        forward_function: Takes the same inputs as the inference function, but
1286          returns side outputs used by backward_function in addition to the
1287          inference function's outputs.
1288        backward_function: Takes side outputs from forward_function and
1289          gradients with respect to all of its outputs, real and side. Returns
1290          gradients with respect to the inputs.
1291    """
1292    outputs = []
1293    iteration_count = 0
1294    # First we need to figure out how many side outputs from the forward pass
1295    # will be required. We do this in a temporary graph to avoid actually
1296    # running multiple copies of the backward pass (one per _GradientsHelper
1297    # call).
1298    #
1299    # While computing gradients, the backward function captures Tensors from
1300    # the forward function. We add these as side outputs of the original
1301    # function. However, we then need to accept output gradients with respect
1302    # to these side outputs for higher order gradients to work. Thus we loop
1303    # until the number of outputs of the function stabilizes. Note that this
1304    # is only required for tape gradients, where we need to declare in advance
1305    # all of the forward op's outputs: symbolic gradients with tf.gradients
1306    # instead rely on regenerating backward functions when higher-order
1307    # gradients are requested.
1308    while (len(outputs) < len(self._func_graph.outputs)
1309           # It's possible for gradient generation to add new ops to the forward
1310           # pass. If all of the new outputs are non-trainable, there's no
1311           # reason to continue.
1312           and any(backprop_util.IsTrainable(output)
1313                   for output in self._func_graph.outputs[len(outputs):])):
1314      iteration_count += 1
1315      if iteration_count >= 20 and iteration_count % 5 == 0:
1316        new_op_with_trainable_output = None
1317        num_new_trainable_outputs = 0
1318        for output in self._func_graph.outputs[len(outputs):]:
1319          if backprop_util.IsTrainable(output):
1320            num_new_trainable_outputs += 1
1321            new_op_with_trainable_output = output.op
1322        logging.warning(
1323            ("Determining side outputs for the function '{}' is taking longer "
1324             "than expected ({} iterations, typically this converges in 5 or "
1325             "so). This could indicate that a gradient registration is adding "
1326             "new ops to the forward pass every time gradients are generated. "
1327             "{} new trainable output(s) were added this iteration, one from "
1328             "the following op:\n {}\nThis may indicate a TensorFlow bug, or "
1329             "an issue in a tf.custom_gradient.")
1330            .format(
1331                self._func_graph.name, iteration_count,
1332                num_new_trainable_outputs, new_op_with_trainable_output))
1333      outputs = list(self._func_graph.outputs)
1334      self._build_functions_for_outputs(
1335          outputs, inference_args, input_tangents)
1336
1337    (forward_function, forward_graph,
1338     backward_function, output_indices, num_output_tangents) = (
1339         self._build_functions_for_outputs(
1340             outputs, inference_args, input_tangents))
1341    if (len(self._func_graph.outputs) > len(outputs)
1342        and any(backprop_util.IsTrainable(output)
1343                for output in self._func_graph.outputs[len(outputs):])):
1344      raise errors.InternalError(
1345          "Unexpectedly added new outputs to the forward function when "
1346          "building the backward function: "
1347          f"{self._func_graph.outputs[len(outputs):]}.")
1348    return (forward_function, forward_graph, backward_function, output_indices,
1349            num_output_tangents)
1350
1351
1352class _ForwardBackwardCall(object):
1353  """Holds the state of a function call between execution and recording."""
1354
1355  __slots__ = [
1356      "_functions", "_inference_args", "_input_tangents", "_tape_watching"
1357  ]
1358
1359  def __init__(self, functions, inference_args, input_tangents, tape_watching):
1360    """Collects information about the function call.
1361
1362    Args:
1363      functions: An object which produces forward and backward functions, either
1364        a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object.
1365      inference_args: A flat list of Tensors, arguments to the inference
1366        function.
1367      input_tangents: A flat list of Tensors, jvps associated with
1368        `inference_args`.
1369      tape_watching: Boolean, with True indicating that recording is necessary.
1370    """
1371    self._functions = functions
1372    self._inference_args = inference_args
1373    self._input_tangents = input_tangents
1374    self._tape_watching = tape_watching
1375
1376  def forward(self):
1377    """Builds or retrieves a forward function for this call."""
1378    forward_function = self._functions.forward(
1379        self._inference_args, self._input_tangents)
1380    return forward_function, self._inference_args + self._input_tangents
1381
1382  def record(self, flat_outputs):
1383    """Given outputs from the execution of `forward`, records the operation."""
1384    if (self._tape_watching
1385        and not isinstance(flat_outputs, ops.Operation)
1386        and flat_outputs is not None):
1387      # We only record function calls which have outputs, and then only when a
1388      # tape is watching.
1389      self._functions.record(
1390          flat_outputs, self._inference_args, self._input_tangents)
1391
1392
1393class ConcreteFunction(core.ConcreteFunction, trackable.Trackable):
1394  """A `tf.types.experimental.ConcreteFunction` created from `tf.function`."""
1395
1396  def __init__(self, func_graph, attrs=None, shared_func_graph=True, spec=None):
1397    """Initialize a `ConcreteFunction`.
1398
1399    Args:
1400      func_graph: An instance of FuncGraph: the function body to wrap.
1401      attrs: (optional) dict mapping names of attributes to their AttrValue
1402        values. Attributes in `attrs` will be included in this function's
1403        definition.
1404     shared_func_graph: If False, the ConcreteFunction takes ownership of
1405       `func_graph` and will break reference cycles when it is deleted. This
1406       makes the FuncGraph inoperable.
1407     spec: FunctionSpec for the original function.  If not specified, then this
1408       ConcreteFunction may only be called using the flat signature.
1409
1410    Raises:
1411      ValueError: If number of input_placeholders is not equal to the number
1412        of function inputs.
1413    """
1414    # _arg_keywords and _num_positional_args define the flat signature.  They
1415    # are assigned after construction.
1416    self._arg_keywords = None
1417    self._num_positional_args = None
1418
1419    self._func_graph = func_graph
1420    self._captured_inputs = self._func_graph.external_captures + self._func_graph.deferred_external_captures
1421
1422    # spec defines the structured signature.
1423    self._set_function_spec(spec)
1424
1425    if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
1426      # The alternative is to silently drop "implements" tag
1427      # but it seems likely it would lead to hard to catch bugs.
1428      # Another alternative is to make func_body to preserve the order
1429      # of arguments if variables are present. Yet another option
1430      # is to automatically replace variables as arguments to functions
1431      # to v.read_value() whenever "implements" tag is present
1432      # Anytime we annotate existing function we probably want to wrap
1433      # it with safe read_value for backward compatibility.
1434      has_resource_vars = any(
1435          inp.dtype == dtypes.resource for inp in self.inputs)
1436
1437      assert not any((has_resource_vars, self._captured_inputs)), (
1438          'Function {name} has "{attr}={value}" attribute and thus can not '
1439          "depend on any tensors outside of its signature or modify variables. "
1440          "\n\nNote: variables are always captured and cause function "
1441          "re-tracing for every variable called.\n"
1442          "  inputs: {inputs}\n  captures: {captured}\n\n"
1443          "To pass a variable to such function use  "
1444          "use variable.read_value().".format(
1445              name=func_graph.name,
1446              attr=IMPLEMENTS_ATTRIBUTE_NAME,
1447              value=attrs[IMPLEMENTS_ATTRIBUTE_NAME],
1448              inputs=self.inputs,
1449              captured=self._captured_inputs))
1450    self._output_shapes = tuple(
1451        output.shape for output in self._func_graph.outputs)
1452    self._attrs = _parse_func_attrs(attrs or {})
1453
1454    if shared_func_graph:
1455      self._garbage_collector = None
1456    else:
1457      self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph)
1458
1459    # Pairs of forward and backward functions used for computing gradients.
1460    #
1461    # These each get a reference to the FuncGraph deleter since they use the
1462    # FuncGraph directly.
1463    self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
1464        func_graph, self._attrs, self._garbage_collector)
1465    self._first_order_tape_functions = {}
1466    self._higher_order_tape_functions = {}
1467    # Cache the inference function to avoid a (Python) function call when not
1468    # building gradients.
1469    self._inference_function = self._delayed_rewrite_functions.forward()
1470
1471  def _set_function_spec(self, spec):
1472    """Enables the structured signature by supplying a spec."""
1473    self._function_spec = None
1474    self._pre_initialized_function_spec = spec
1475    self._initialize_function_spec()
1476
1477  def _initialize_function_spec(self):
1478    """Updates `self._function_spec` to include varargs and bound variables.
1479
1480    Adds new positional arguments for any varargs (i.e., for args that are
1481    in `structured_input_signature`, but not in the original fullargspec.args).
1482
1483    Replaces `defaults` and `kwonlydefaults` with the `BOUND_VALUE`, for
1484    all args and kwargs in `structured_input_signature`.
1485
1486    Sets `varkw` and `varargs` to None.
1487    """
1488    if self._pre_initialized_function_spec is None:
1489      return  # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
1490    assert not self._function_spec, "already initialized"
1491    spec = self._pre_initialized_function_spec
1492    args = spec.fullargspec.args
1493    arg_specs, kwarg_specs = self.structured_input_signature
1494    vararg_indices = range(len(spec.arg_names), len(arg_specs))
1495    fullargspec = tf_inspect.FullArgSpec(
1496        args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices],
1497        varargs=None,
1498        varkw=None,
1499        defaults=[function_spec.BOUND_VALUE] * len(arg_specs),
1500        kwonlyargs=list(sorted(kwarg_specs)),
1501        kwonlydefaults=dict(
1502            (k, function_spec.BOUND_VALUE) for k in kwarg_specs),
1503        annotations=spec.fullargspec.annotations)
1504    self._function_spec = function_spec.FunctionSpec(
1505        fullargspec,
1506        spec.is_method,
1507        spec.input_signature,
1508        spec.is_pure,
1509        name=self._func_graph.name)
1510
1511  @property
1512  def variables(self):
1513    """Sequence of variables for this function."""
1514    return tuple(self._func_graph.variables)
1515
1516  def set_variables(self, variables):
1517    self._func_graph.variables = variables
1518
1519  @property
1520  def trainable_variables(self):
1521    """Sequence of trainable variables for this function."""
1522    return tuple(self._func_graph.trainable_variables)
1523
1524  def __call__(self, *args, **kwargs):
1525    """Executes the wrapped function.
1526
1527    ConcreteFunctions have two signatures:
1528
1529    * The signature of the original function wrapped by this ConcreteFunction.
1530    * A flat signature, where each argument accepts a single Tensor.
1531
1532    The original function signature is generally preferred, but the flat input
1533    signature is supported for backward compatibility.
1534
1535    ### Original Function Signature
1536
1537    When calling a ConcreteFunction with the signature of the original function,
1538    each argument must match the type or value that was used when the
1539    ConcreteFunction's graph was traced.  In particular:
1540
1541    * Tensor arguments (including CompositeTensors, such as RaggedTensor) must
1542      have matching `TypeSpec`s.
1543    * Non-Tensor arguments (such as booleans or ints) must have equal values.
1544    * Nested arguments (such as lists, tuples, or dictionaries) must have the
1545      same nesting structure; and each nested value must have a matching type
1546      or value.
1547
1548    The default value for any arguments that were traced with non-Tensor values
1549    is the value that was used in the trace.  Arguments that were traced with
1550    tensor arguments do not have a default value (even if the original function
1551    had a default value for that argument).
1552
1553    ### Flat Signature
1554
1555    When calling a ConcreteFunction with the flat signature, the arguments
1556    correspond to the flattened component tensors of the arguments that were
1557    used to construct the ConcreteFunction.  Parameter names are assigned based
1558    on `TensorSpec.name` (when specified) or the original argument names (with
1559    suffixes automatically added for nested arguments or composite tensors with
1560    multiple components).
1561
1562    Args:
1563      *args: Positional arguments to the concrete function.
1564      **kwargs: Keyword arguments to the concrete function.
1565
1566    Returns:
1567      The result of applying the TF function on the given Tensors.
1568
1569    Raises:
1570      AssertionError: If this `ConcreteFunction` was not created through
1571        `get_concrete_function`.
1572      TypeError: If the arguments do not match the function's signature.
1573    """
1574    return self._call_impl(args, kwargs)
1575
1576  def _call_impl(self, args, kwargs, cancellation_manager=None):
1577    """See `__call__` for details."""
1578    with trace.Trace(self._func_graph.name, tf_function_call="concrete"):
1579      # Construct the list of input tensors: check if the structured signature
1580      # applies first; and if not, then use the flat signature.
1581      if self._function_spec is not None:
1582        try:
1583          return self._call_with_structured_signature(args, kwargs,
1584                                                      cancellation_manager)
1585        except TypeError as structured_err:
1586          try:
1587            return self._call_with_flat_signature(args, kwargs,
1588                                                  cancellation_manager)
1589          except TypeError:
1590            raise structured_err
1591
1592      return self._call_with_flat_signature(args, kwargs, cancellation_manager)
1593
1594  def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
1595    """Executes the wrapped function with the flat signature.
1596
1597    Args:
1598      args: Positional arguments to the concrete function.
1599      kwargs: Keyword arguments to the concrete function.
1600      cancellation_manager: A `CancellationManager` that can be used to cancel
1601        function invocation.
1602
1603    Returns:
1604      The result of applying the function on the Tensors/Variables contained in
1605      `args` and `kwargs`.
1606    Raises:
1607      TypeError: if `args` and `kwargs` do not match the flat signature of this
1608        `ConcreteFunction`.
1609    """
1610    if len(args) > self._num_positional_args:
1611      raise TypeError(
1612          f"{self._flat_signature_summary()} takes {self._num_positional_args} "
1613          f"positional arguments, got {len(args)}.")
1614    args = list(args)
1615    kwargs = dict(kwargs)
1616    for keyword in self._arg_keywords[len(args):]:
1617      try:
1618        args.append(kwargs.pop(compat.as_str(keyword)))
1619      except KeyError:
1620        specified_keywords = (
1621            list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
1622        missing_required_args = sorted(
1623            set(self._arg_keywords) - set(specified_keywords))
1624        raise TypeError(f"{self._flat_signature_summary()} missing required "
1625                        f"arguments: {', '.join(missing_required_args)}.")
1626    if kwargs:
1627      positional_arg_keywords = set(self._arg_keywords[:len(args)])
1628      for unused_key in kwargs:
1629        if unused_key in positional_arg_keywords:
1630          raise TypeError(f"{self._flat_signature_summary()} got two values "
1631                          f"for '{unused_key}'.")
1632      raise TypeError(f"{self._flat_signature_summary()} got unexpected "
1633                      f"keyword arguments: {', '.join(sorted(kwargs))}.")
1634
1635    for i, arg in enumerate(args):
1636      if not isinstance(
1637          arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
1638        raise TypeError(f"{self._flat_signature_summary()}: expected argument "
1639                        f"#{i}(zero-based) to be a Tensor; "
1640                        f"got {type(arg).__name__} ({arg}).")
1641    return self._call_flat(args, self.captured_inputs, cancellation_manager)
1642
1643  def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
1644    """Executes the wrapped function with the structured signature.
1645
1646    Args:
1647      args: Positional arguments to the concrete function.
1648      kwargs: Keyword arguments to the concrete function.
1649      cancellation_manager: A `CancellationManager` that can be used to cancel
1650        function invocation.
1651
1652    Returns:
1653      The result of applying the function on the Tensors/Variables contained in
1654      `args` and `kwargs`.
1655    Raises:
1656      TypeError: if `args` and `kwargs` do not match the structured signature
1657        of this `ConcreteFunction`.
1658    """
1659    args, kwargs, filtered_flat_args = (
1660        self._function_spec.canonicalize_function_inputs(args, kwargs))
1661    self._structured_signature_check_missing_args(args, kwargs)
1662    self._structured_signature_check_unexpected_args(args, kwargs)
1663    self._structured_signature_check_arg_types(args, kwargs)
1664    return self._call_flat(
1665        filtered_flat_args,
1666        captured_inputs=self.captured_inputs,
1667        cancellation_manager=cancellation_manager)
1668
1669  def _structured_signature_check_missing_args(self, args, kwargs):
1670    """Raises a TypeError if any args are missing."""
1671    arg_specs, kwarg_specs = self.structured_input_signature
1672    missing_arguments = []
1673    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1674      if arg is function_spec.BOUND_VALUE and _contains_type_spec(spec):
1675        missing_arguments.append(self._function_spec.arg_names[i])
1676    for (name, arg) in kwargs.items():
1677      if arg is function_spec.BOUND_VALUE and _contains_type_spec(
1678          kwarg_specs[name]):
1679        missing_arguments.append(name)
1680    if missing_arguments:
1681      raise TypeError(f"{self._structured_signature_summary()} missing "
1682                      "required arguments: "
1683                      f"{', '.join(sorted(missing_arguments))}.")
1684
1685  def _structured_signature_check_unexpected_args(self, args, kwargs):
1686    """Raises a TypeError if there are any extra args."""
1687    arg_specs, kwarg_specs = self.structured_input_signature
1688    if len(args) > len(arg_specs):
1689      raise TypeError(
1690          f"{self._structured_signature_summary()} takes "
1691          f"{len(self._function_spec.arg_names)} positional arguments but got "
1692          f"{len(args)}.")
1693    if len(kwargs) > len(kwarg_specs):
1694      extra_args = set(kwargs) - set(kwarg_specs)
1695      raise TypeError(f"{self._structured_signature_summary()} got unexpected "
1696                      f"keyword arguments: {', '.join(extra_args)}.")
1697
1698  def _structured_signature_check_arg_types(self, args, kwargs):
1699    """Raises a TypeError if any args have the wrong type."""
1700    # Check argument types
1701    arg_specs, kwarg_specs = self.structured_input_signature
1702    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1703      name = self._function_spec.arg_names[i]
1704      self._structured_signature_check_arg_type(arg, spec, name)
1705    for (name, arg) in kwargs.items():
1706      self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
1707
1708  def _structured_signature_check_arg_type(self, arg, spec, name):
1709    """Raise TypeError if `arg`'s type doesn't match `spec`."""
1710    if arg is function_spec.BOUND_VALUE:
1711      return
1712
1713    # Check the overall nested structure of the argument.
1714    try:
1715      nest.assert_same_structure(arg, spec, expand_composites=True)
1716    except (ValueError, TypeError):
1717      try:
1718        nest.assert_same_structure(arg, spec, expand_composites=False)
1719        expected, got = spec, arg
1720      except (ValueError, TypeError):
1721        expected, got = _structure_summary(spec), _structure_summary(arg)
1722      raise TypeError(f"{self._structured_signature_summary()}: argument "
1723                      f"{name} had incorrect type\n"
1724                      f"  expected: {expected}\n"
1725                      f"       got: {got}")
1726
1727    # Check the type for each leaf in the nested structure.
1728    arg_pieces = nest.flatten(arg, expand_composites=True)
1729    spec_pieces = nest.flatten(spec, expand_composites=True)
1730    for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
1731      # TODO(mdan): Use consistent error messages.
1732      if isinstance(spec_piece, tensor_spec.DenseSpec):
1733        # TODO(edloper): Consider calling convert_to_tensor on non-tensor
1734        # values here.  That would match the behavior of
1735        # _call_concrete_function() in function_deserialization.py.  If
1736        # we do, then we need to change the nest assert_same_structure and
1737        # flatten calls above to use shallow variants.
1738        tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
1739        if not isinstance(arg_piece, tensor_types):
1740          raise TypeError(f"{self._structured_signature_summary()} expected a "
1741                          f"Tensor in {name}, but got "
1742                          f"{type(arg_piece).__name__} value {arg_piece}.")
1743      elif arg_piece is not function_spec.BOUND_VALUE:
1744        try:
1745          arg_matches_spec = bool(arg_piece == spec_piece)
1746        except (ValueError, TypeError):
1747          logging.vlog(1, "Error matching value with spec", exc_info=True)
1748          arg_matches_spec = False
1749        if not arg_matches_spec:
1750          raise TypeError(
1751              f"ConcreteFunction {self._structured_signature_summary()} was "
1752              f"constructed with {type(spec_piece).__name__} value "
1753              f"{spec_piece} in {name}, but was called with "
1754              f"{type(arg_piece).__name__} value {arg_piece}.")
1755
1756  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
1757    """Executes the wrapped function.
1758
1759    Args:
1760      args: a list of Tensors or Variables. Arguments from the Python function
1761        should be filtered before calling this method: objects aside from
1762        Tensors, CompositeTensors, and Variables are ignored. Any
1763        CompositeTensors should be expanded before calling this method.
1764      captured_inputs: the captured inputs that are also part of the input args
1765        to the actual execution. By default, it should be self._captured_inputs.
1766      cancellation_manager: (Optional.) A `CancellationManager` that can be
1767        used to cancel function invocation.
1768
1769    Returns:
1770      The result of applying the TF function to `args`.
1771
1772    Raises:
1773      ValueError: If `args` contains anything other than Tensors or Variables.
1774    """
1775    ctx = context.context()
1776    executing_eagerly = ctx.executing_eagerly()
1777
1778    # Copy saveable status of function's graph to current FuncGraph.
1779    default_graph = ops.get_default_graph()
1780    if default_graph.building_function and not self._func_graph.saveable:
1781      default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
1782
1783    if (tape.could_possibly_record() or
1784        hasattr(default_graph, "watch_variable")):
1785      for v in self._func_graph.variables:
1786        resource_variable_ops.variable_accessed(v)
1787
1788    tensor_inputs = []
1789    variables_used = set([])
1790    for i, arg in enumerate(args):
1791      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1792        # We can pass a variable more than once, and in this case we need to
1793        # pass its handle only once.
1794        if id(arg.handle) in variables_used:
1795          continue
1796        resource_variable_ops.variable_accessed(arg)
1797        tensor_inputs.append(arg.handle)
1798        variables_used.add(id(arg.handle))
1799      elif isinstance(arg, ops.Tensor):
1800        tensor_inputs.append(arg)
1801      else:
1802        raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got "
1803                         f"{type(arg)} when calling {self._func_graph.name}.")
1804
1805    if not executing_eagerly:
1806      for i, tensor_input in enumerate(tensor_inputs):
1807        # Can not compare shapes in these cases
1808        # TODO(b/216506654): Consider moving this check elsewhere and making it
1809        # work for all types (e.g. by including shape for Variables).
1810        if (tensor_input.dtype == dtypes.resource or
1811            tensor_input.dtype == dtypes.variant):
1812          continue
1813
1814        # If we're graph building, shape inference is on. We check for input
1815        # compatibility up front to avoid hard to debug incompatibilities
1816        # later.
1817        graph_input_shape = tensor_shape.TensorShape(
1818            self._func_graph.inputs[i].shape)
1819        if not graph_input_shape.is_compatible_with(tensor_input.shape):
1820          raise ValueError(
1821              f"Tensor {tensor_input} is not compatible with the shape this "
1822              f"function was traced with. Expected shape "
1823              f"{self._func_graph.inputs[i].shape}, but got shape "
1824              f"{tensor_input.shape}.\n\nIf you called get_concrete_function, "
1825              f"you may need to pass a tf.TensorSpec(..., shape=...) with a "
1826              f"less specific shape, having None on axes which can vary.")
1827
1828    args = tensor_inputs + captured_inputs
1829    possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
1830    if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
1831        and executing_eagerly):
1832      # No tape is watching; skip to running the function.
1833      return self._build_call_outputs(self._inference_function.call(
1834          ctx, args, cancellation_manager=cancellation_manager))
1835    forward_backward = self._select_forward_and_backward_functions(
1836        args,
1837        possible_gradient_type,
1838        executing_eagerly)
1839    forward_function, args_with_tangents = forward_backward.forward()
1840    if executing_eagerly:
1841      flat_outputs = forward_function.call(
1842          ctx, args_with_tangents, cancellation_manager=cancellation_manager)
1843    else:
1844      with default_graph._override_gradient_function(  # pylint: disable=protected-access
1845          {"PartitionedCall": self._get_gradient_function(),
1846           "StatefulPartitionedCall": self._get_gradient_function()}):
1847        flat_outputs = forward_function.call(ctx, args_with_tangents)
1848    forward_backward.record(flat_outputs)
1849    return self._build_call_outputs(flat_outputs)
1850
1851  def _experimental_with_cancellation_manager(self, cancellation_manager):
1852    """Returns a callable that invokes a cancellable version of this function.
1853
1854    Args:
1855      cancellation_manager: A `CancellationManager` object that can be used to
1856        cancel function invocation.
1857
1858    Returns:
1859      A callable with the same signature as this concrete function.
1860    """
1861
1862    def cancellable_call(*args, **kwargs):
1863      return self._call_impl(
1864          args, kwargs, cancellation_manager=cancellation_manager)
1865
1866    return cancellable_call
1867
1868  @property
1869  def name(self):
1870    """`ConcreteFunction` name."""
1871    return self._delayed_rewrite_functions.forward().name
1872
1873  @property
1874  def graph(self):
1875    """Returns the graph from which this function was constructed."""
1876    return self._func_graph
1877
1878  @property
1879  def inputs(self):
1880    """Returns tensors in `self.graph` corresponding to arguments."""
1881    return self._func_graph.inputs
1882
1883  @property
1884  def structured_input_signature(self):
1885    """Returns structured signature for this concrete function.
1886
1887    Returns:
1888      A tuple `(args, kwargs)`, where:
1889
1890        * `args` is a tuple that specifies the expected type or value each for
1891          positional argument.
1892        * `kwargs` is a dictionary that specifies the expected type or value
1893          for each keyword-only argument.
1894
1895      The type or value for each argument is specified using one of the
1896      following:
1897
1898        * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
1899          value is expected.
1900        * A Python value, such as an integer, indicating that an equal value
1901          is expected.
1902        * A nested structure of `tf.TypeSpec`s and Python values, indicating
1903          that a corresponding nested structure is expected.
1904    """
1905    return self._func_graph.structured_input_signature
1906
1907  @property
1908  def outputs(self):
1909    """Returns tensors in `self.graph` corresponding to returned tensors."""
1910    return self._func_graph.outputs
1911
1912  @property
1913  def structured_outputs(self):
1914    """Returns outputs in `self.graph` as returned by the original function."""
1915    return self._func_graph.structured_outputs
1916
1917  def set_external_captures(self, captures):
1918    """Updates the function capture values.
1919
1920    The new values must have tensor types and shapes consistent with the
1921    original captures of the concrete function, but it is allowed to change a
1922    value captured with a deferred one and vice-versa.
1923
1924    Args:
1925      captures: A list of tensors or closures. Tensors are value captures, and
1926        closures are call-time (deferred captures).
1927    """
1928    # TODO(wxinyi): 1. verify that the new captures' type spec is compatible
1929    # with the original's. However, doing so requires MirroredVariable captures
1930    # initialized. 2. replace the original/new captures/deferred
1931    # captures in the wrapped graph. Doing such for a capture-to-deferred
1932    # capture replacement requires more arguments than the deferred capture
1933    # itself, e.g. default value, spec.
1934    self._captured_inputs = captures
1935
1936  def replace_capture_with_deferred_capture(self,
1937                                            tensor,
1938                                            closure,
1939                                            spec,
1940                                            placeholder=None,
1941                                            default_value=None):
1942    """Replaces existing capture `tensor` with a deferred capture `closure`.
1943
1944    This API replaces the capture `tensor` from the concrete function's captured
1945    inputs list, and places the deferred capture `closure` in
1946    its spot so the order of captured inputs is preserved. This is important
1947    because the old `tensor` and the new `closure` will have the same internal
1948    placeholder, which can be passed through the `placeholder` argument, or
1949    skipped, in which case we find the placeholder from internal inputs by
1950    indexing `tensor` in the external captured inputs list. Thus, it is
1951    important that the new deferred capture has output spec (specified by the
1952    `spec` argument) compatible with the internal placeholder (`placeholder`)
1953    and the original capture (`tensor`).
1954
1955    For example,
1956
1957    ```python
1958    bool_captured_tensor = tf.constant(True)
1959    float_captured_tensor = tf.constant([3.], dtype=tf.float32)
1960    value = tf.constant([2.], dtype=tf.float32)
1961
1962    @tf.function
1963    def fn():
1964      deferred_tensor = ops.get_default_graph().capture_call_time_value(
1965          lambda: value,
1966          tf.TensorSpec(shape=(1,), dtype=tf.float32))
1967      if bool_captured_tensor:
1968        return deferred_tensor
1969      else:
1970        return deferred_tensor + float_captured_tensor
1971
1972    concrete_fn = fn.get_concrete_function()
1973    print(concrete_fn())  # tf.Tensor([2.], shape=(1,), dtype=float32)
1974
1975    new_bool_captured_tensor = constant_op.constant(False)
1976    def bool_closure():
1977      return new_bool_captured_tensor
1978
1979    concrete_fn.replace_capture_with_deferred_capture(
1980        bool_captured_tensor,
1981        bool_closure,
1982        spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool))
1983
1984    print(concrete_fn())  # tf.Tensor([5.], shape=(1,), dtype=float32)
1985    ```
1986
1987    Args:
1988      tensor: Tensor already captured. This `tensor` should be listed in
1989        concrete_function.captured_inputs except when it's empty such as when
1990        the concrete function is restored from SavedModel.
1991      closure: function which takes no arguments, to be evaluated at function
1992        call time, returning a nest of tensors compatible with `spec`.
1993      spec: nest of TypeSpec for the value to capture.
1994      placeholder: optional. The internal placeholder corresponding to the
1995        captured `tensor` and the new `closure`.
1996      default_value: optional value to use in environments that cannot safely
1997        evaluate closure.
1998    """
1999    capture_index = None
2000    for i, capture in enumerate(self._captured_inputs):
2001      if id(tensor) == id(capture):
2002        capture_index = i
2003        break
2004
2005    if placeholder is None:
2006      if capture_index is None:
2007        raise ValueError(
2008            f"Did not find `tensor` argument {tensor} in the ConcreteFunction's"
2009            " captured inputs list, and did not receive a placeholder argument."
2010            " Thus we're unable to infer the internal placeholder. ")
2011
2012      placeholder = self.inputs[-len(self._captured_inputs) + capture_index]
2013
2014    if not (spec.is_compatible_with(tensor) or
2015            spec.is_compatible_with(placeholder)):
2016      raise ValueError(
2017          f"Attempting to substitute closure with spec {spec} that's "
2018          f"incompatible with the original capture {tensor} or the internal "
2019          f"placeholder {placeholder}.")
2020
2021    self._func_graph.replace_capture_with_deferred_capture(
2022        tensor=tensor,
2023        closure=closure,
2024        spec=spec,
2025        placeholder=placeholder,
2026        default_value=default_value)
2027
2028    if capture_index is not None:
2029      self._captured_inputs[capture_index] = closure
2030
2031  @property
2032  def captured_inputs(self):
2033    """Returns external Tensors captured by this function.
2034
2035    self.__call__(*args) passes `args + self.captured_inputs` to the function.
2036    """
2037    return nest.flatten(
2038        [x() if callable(x) else x for x in self._captured_inputs],
2039        expand_composites=True)
2040
2041  @property
2042  def function_def(self):
2043    """Returns a `FunctionDef` object representing this function."""
2044    return self._delayed_rewrite_functions.forward().definition
2045
2046  @property
2047  def output_shapes(self):
2048    """The function's output shapes."""
2049    return nest.map_structure(
2050        lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)),
2051        composite_tensor.replace_composites_with_components(
2052            self._func_graph.structured_outputs),
2053        expand_composites=False)
2054
2055  @property
2056  def output_dtypes(self):
2057    # TODO(akshayka): Consider removing this.
2058    return nest.map_structure(
2059        lambda x: x.dtype if x is not None else None,
2060        composite_tensor.replace_composites_with_components(
2061            self._func_graph.structured_outputs),
2062        expand_composites=False)
2063
2064  def add_to_graph(self, g=None):
2065    """Registers the function, adds it to the graph g or default graph.
2066
2067    Args:
2068      g: If specified, registers the function with this graph. Defaults to the
2069        current context (either the default graph or the eager context).
2070    """
2071    # If we are not executing eagerly, adds the function to default graph if no
2072    # graph is specified.
2073    # In case of eager execution, function definition gets added to context
2074    # during construction itself.
2075
2076    if not context.executing_eagerly() and not g:
2077      g = ops.get_default_graph()
2078    self._delayed_rewrite_functions.forward().add_to_graph(g)
2079
2080  def add_gradient_functions_to_graph(self, g=None):
2081    """Add forward/backward functions to graph `g` or the current context."""
2082    if not context.executing_eagerly() and not g:
2083      g = ops.get_default_graph()
2084    self._delayed_rewrite_functions.forward().add_to_graph(g)
2085    forward_function, backward_function = (
2086        self._delayed_rewrite_functions.forward_backward())
2087    forward_function.add_to_graph(g)
2088    backward_function.add_to_graph(g)
2089
2090  def _get_gradient_function(self):
2091    """Returns gradient function. It will be lazily created at first call."""
2092    return self._delayed_rewrite_functions._rewrite_forward_and_call_backward  # pylint: disable=protected-access
2093
2094  def _select_forward_and_backward_functions(
2095      self, args, possible_gradient_type, executing_eagerly):
2096    """Selects forward and backward functions based on the calling context.
2097
2098    The forward function computes the "real" function outputs, `self._outputs`,
2099    and any extra values needed by the corresponding backward function.
2100
2101    Args:
2102      args: A flat list of Tensors with all of the inputs to the forward
2103        function (including user-specified and captured inputs).
2104      possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
2105      executing_eagerly: Boolean, the value of context.executing_eagerly().
2106
2107    Returns:
2108      An object with a `forward` method returning a tuple of (forward_function :
2109      _EagerDefinedFunction, augmented_arguments : List), and a corresponding
2110      `record` method which takes outputs from the forward function and records
2111      the operation. forward_function should be called with augmented_arguments.
2112    """
2113    if executing_eagerly:
2114      input_tangents = forwardprop_util.pack_tangents(args)
2115    else:
2116      input_tangents = forwardprop_util.TangentInfo()
2117    need_gradients_for_jvps = tape.should_record_backprop(
2118        input_tangents.tangents)
2119    # Allows re-use of forward and backward function pairs depending on the
2120    # tapes and forward accumulators watching its inputs.
2121    cache_key = (need_gradients_for_jvps, input_tangents.indices)
2122    if (possible_gradient_type
2123        == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
2124      if input_tangents.indices or executing_eagerly:
2125        # There is a single non-persistent tape active, so the user can only
2126        # request first-order gradients from a tape. We can spend less time
2127        # graph building since we know this.
2128        #
2129        # We may still end up computing higher-order gradients, but that'd be
2130        # through `tf.gradients`, which can re-write the forward pass and so
2131        # needs no preparation here.
2132        functions = self._first_order_tape_functions.get(cache_key, None)
2133        if functions is None:
2134          functions = _FirstOrderTapeGradientFunctions(
2135              self._func_graph, self._attrs, self._garbage_collector,
2136              forwardprop_input_indices=input_tangents.indices,
2137              delayed_rewrite_functions=self._delayed_rewrite_functions,
2138              need_gradients_for_jvps=need_gradients_for_jvps)
2139          self._first_order_tape_functions[cache_key] = functions
2140        return _ForwardBackwardCall(
2141            functions, args, input_tangents.tangents, tape_watching=True)
2142      else:
2143        # We can avoid computing second-order gradients in some cases by doing a
2144        # delayed rewrite when graph building. Since we know we'll only compute
2145        # first-order tape gradients, the delayed rewrite is safe: we won't need
2146        # to tell the tape about side outputs.
2147        #
2148        # TODO(allenl): This case is really dirty. It would be better if we
2149        # could temporarily pop all of the current tapes to avoid
2150        # accidentally taking second-order gradients.
2151        return _ForwardBackwardCall(
2152            self._delayed_rewrite_functions, args, input_tangents.tangents,
2153            tape_watching=True)
2154    elif (possible_gradient_type
2155          == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
2156      # Either there's a persistent tape watching, or there are multiple nested
2157      # tapes. Either way, the user may request higher-order gradients. We'll
2158      # spend a bit more time and make sure higher-order gradients are correct.
2159      functions = self._higher_order_tape_functions.get(
2160          cache_key, None)
2161      if functions is None:
2162        functions = _HigherOrderTapeGradientFunctions(
2163            self._func_graph, self._attrs, self._garbage_collector,
2164            forwardprop_input_indices=input_tangents.indices,
2165            delayed_rewrite_functions=self._delayed_rewrite_functions,
2166            need_gradients_for_jvps=need_gradients_for_jvps)
2167        self._higher_order_tape_functions[cache_key] = functions
2168      return _ForwardBackwardCall(functions, args, input_tangents.tangents,
2169                                  tape_watching=True)
2170    # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
2171    # tape is recording.
2172    return _ForwardBackwardCall(
2173        self._delayed_rewrite_functions, args, input_tangents.tangents,
2174        tape_watching=False)
2175
2176  def _build_call_outputs(self, result):
2177    """Maps the fdef output list to actual output structure.
2178
2179    Args:
2180      result: Output lists defined by FunctionDef.
2181    Returns:
2182      The actual call output.
2183    """
2184    # TODO(jlchu): call C++ version in function.cc when speed is improved
2185    if self._func_graph.structured_outputs is None:
2186      return result
2187
2188    # Replace outputs with results, skipping over any 'None' values.
2189    outputs_list = nest.flatten(
2190        self._func_graph.structured_outputs, expand_composites=True)
2191    j = 0
2192    for i, o in enumerate(outputs_list):
2193      if o is not None:
2194        handle_data_util.copy_handle_data(self.outputs[j], result[j])
2195        outputs_list[i] = result[j]
2196        j += 1
2197    ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
2198                                outputs_list, expand_composites=True)
2199    return ret
2200
2201  @property
2202  def _as_name_attr_list(self):
2203    """Returns a `NameAttrList` representing this function."""
2204    ret = attr_value_pb2.NameAttrList(name=self.name)
2205    for name, value in self._attrs.items():
2206      ret.attr[name].CopyFrom(value)
2207    return ret
2208
2209  def _structured_signature_summary(self, default_values=False):
2210    """Returns a string summarizing this function's structured signature.
2211
2212    Args:
2213      default_values: If true, then include default values in the signature.
2214
2215    Returns:
2216      A `string`.
2217    """
2218    # Note: we can't just use self._funcion_spec.signature_summary(), because
2219    # that would show "BOUND_VALUE" as the default value for all arguments.
2220    assert self._function_spec is not None
2221    arg_specs, kwarg_specs = self.structured_input_signature
2222    arg_names = list(self._function_spec.arg_names)
2223
2224    # If an explicit input_signature is provided to @tf.function, then any
2225    # arguments with defaults that are not covered by that explicit signature
2226    # are simply dropped from the signature.
2227    # TODO(b/159639913) Look into whether dropping arguments with default values
2228    # from the signature is the right thing to do.
2229    arg_names = arg_names[:len(arg_specs)]
2230
2231    if default_values:
2232      for i in range(len(arg_names)):
2233        if not _contains_type_spec(arg_specs[i]):
2234          arg_names[i] += "={}".format(arg_specs[i])
2235    if kwarg_specs:
2236      arg_names.append("*")
2237      for name, spec in kwarg_specs.items():
2238        arg_names.append(name)
2239        if default_values and not _contains_type_spec(spec):
2240          arg_names[-1] += "={}".format(spec)
2241    signature = f"{self._func_graph.name}({', '.join(arg_names)})"
2242
2243    return signature
2244
2245  def _flat_signature_summary(self):
2246    """Returns a string summarizing this function's flat signature."""
2247    assert self._arg_keywords is not None
2248    assert self._num_positional_args is not None
2249    arg_names = self._arg_keywords
2250    if self._num_positional_args > len(arg_names):
2251      arg_names.extend(
2252          "<arg{}>".format(i + 1)
2253          for i in range(len(arg_names), self._num_positional_args))
2254    return f"{self._func_graph.name}({', '.join(arg_names)})"
2255
2256  def pretty_printed_signature(self, verbose=True):
2257    """Returns a string summarizing the signature of this concrete function."""
2258    if not verbose:
2259      return self._structured_signature_summary(default_values=True)
2260
2261    def pretty_print_spec(spec):
2262      """Returns a string describing the spec for a single argument."""
2263      if isinstance(spec, tensor_spec.TensorSpec):
2264        return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
2265      elif nest.is_nested(spec):
2266        pieces = nest.flatten(spec, expand_composites=False)
2267        markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
2268        structure = nest.pack_sequence_as(spec, markers)
2269        # Ensure dictionaries are sorted by key (for determinism)
2270        result = pprint.pformat(structure, width=10000)
2271        for (marker, piece) in zip(markers, pieces):
2272          result += "\n      {}: {}".format(marker, pretty_print_spec(piece))
2273        return result
2274      else:
2275        return repr(spec)
2276
2277    lines = [self._structured_signature_summary(default_values=True)]
2278    arg_specs, kwarg_specs = self.structured_input_signature
2279    names = list(self._function_spec.arg_names)
2280
2281    # If an explicit input_signature is provided to @tf.function, then any
2282    # arguments with defaults that are not covered by that explicit signature
2283    # are simply dropped from the signature.
2284    # TODO(b/159639913) Look into whether dropping arguments with default values
2285    # from the signature is the right thing to do.
2286
2287    # Note: we can skip bound args, since we already displayed their bound
2288    # value in the signature summary.
2289    arg_details = []
2290    for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)):
2291      if _contains_type_spec(spec):
2292        arg_details.append("    {}: {}".format(name, pretty_print_spec(spec)))
2293
2294    if kwarg_specs:
2295      for kwarg in sorted(kwarg_specs):
2296        spec = kwarg_specs[kwarg]
2297        if _contains_type_spec(spec):
2298          arg_details.append("    {}: {}".format(
2299              kwarg, pretty_print_spec(spec)))
2300
2301    if arg_details:
2302      lines.append("  Args:")
2303      lines.extend(arg_details)
2304    lines.append("  Returns:")
2305
2306    def spec_from_value(value):
2307      # For loaded function, structured_outputs are already specs.
2308      if isinstance(value, type_spec.TypeSpec):
2309        return value
2310      return type_spec.type_spec_from_value(value)
2311
2312    lines.append("    {}".format(
2313        pretty_print_spec(
2314            nest.map_structure(spec_from_value, self.structured_outputs))))
2315
2316    return "\n".join(lines)
2317
2318  def __repr__(self):
2319    if self._function_spec is not None:
2320      return "<ConcreteFunction {} at 0x{:X}>".format(
2321          self.pretty_printed_signature(verbose=False), id(self))
2322    elif not (self._num_positional_args is None or self._arg_keywords is None):
2323      return "<ConcreteFunction {} at 0x{:X}>".format(
2324          self._flat_signature_summary(), id(self))
2325    else:
2326      return object.__repr__(self)
2327
2328  def __str__(self):
2329    if self._function_spec is not None:
2330      return "ConcreteFunction {}".format(self.pretty_printed_signature())
2331    else:
2332      return self.__repr__()
2333
2334  def _trackable_children(self, save_type="checkpoint", **kwargs):
2335    """Implements `Trackable`."""
2336    if save_type == "checkpoint":
2337      # Checkpoint dependencies do not include functions at all. Users
2338      # expect the checkpointed variables to be saved using the model
2339      # architecture, e.g. `model.layers[1].kernel` or `model.variables`.
2340      return {}
2341
2342    captured_trackables = {}
2343    for n, (capture, _) in enumerate(self.graph.captures):
2344      if (capture.dtype not in (dtypes.variant, dtypes.resource) and
2345          not resource_variable_ops.is_resource_variable(capture)):
2346        # Variant/resource type tensors are skipped since we have no way of
2347        # getting the `Trackable` wrapper for these tensors. The wrappers are
2348        # expected to be elsewhere in the saved object graph.
2349        # TODO(b/223866972): Directly encode/decode tensor captures.
2350
2351        # Resource variable captures are also skipped at this time, to maintain
2352        # existing behavior.
2353        # TODO(b/217979389): Return the non-constant captures as children.
2354
2355        captured_trackables[f"capture_{n}"] = capture
2356
2357    return captured_trackables
2358
2359  def _deserialization_dependencies(self, children):
2360    return children
2361
2362  def _export_to_saved_model_graph(self, object_map, tensor_map,
2363                                   **unused_kwargs):
2364    if not self.graph.saveable:
2365      raise ValueError(
2366          (f"Unable to save function {self.name} for the following reason(s):\n"
2367           + "\n".join(self.graph.saving_errors)))
2368    self.add_to_graph()
2369    object_map[self] = function_saved_model_utils.ExportedConcreteFunction(
2370        self, tensor_map)
2371    return []
2372
2373
2374_pywrap_utils.RegisterType("Tensor", ops.Tensor)
2375_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
2376_pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices)
2377
2378
2379# TODO(mdan): Refactor this and clarify relationship with def_function.Function.
2380# Right now, def_function.Function is the higher level implementation.
2381class Function:
2382  """Wrapper class for the graph functions defined for a Python function.
2383
2384  See the documentation for `defun` for more information on the semantics of
2385  defined functions.
2386
2387  `Function` class is thread-compatible meaning that minimal usage of defuns
2388  (defining and calling) is thread-safe, but if users call other methods or
2389  invoke the base `python_function` themselves, external synchronization is
2390  necessary.
2391  In addition, Function is not reentrant, so recursive functions need to call
2392  the wrapped function, not the wrapper.
2393  """
2394
2395  def __init__(self,
2396               python_function,
2397               name,
2398               input_signature=None,
2399               attributes=None,
2400               autograph=True,
2401               autograph_options=None,
2402               reduce_retracing=False,
2403               capture_by_value=None,
2404               jit_compile=None,
2405               experimental_follow_type_hints=False):
2406    """Initializes a `Function`.
2407
2408    Args:
2409      python_function: the function to be wrapped.
2410      name: the name given to it.
2411      input_signature: a possibly nested sequence of `TensorSpec` objects
2412        specifying the input signature of this function. If `None`, a separate
2413        function is instantiated for each inferred input signature.
2414      attributes: dict, extra keyword arguments that will be added as attribute
2415        of the function.
2416      autograph: whether to use autograph to compile
2417        `python_function`. See https://www.tensorflow.org/guide/autograph for
2418        more information.
2419      autograph_options: Experimental knobs to control behavior
2420        `when autograph=True`. See https://www.tensorflow.org/guide/autograph
2421        for more information.
2422      reduce_retracing: When True, `tf.function` uses
2423        `tf.types.experimental.TraceType` to trace supertypes of arguments to
2424        reduce the number of traces.
2425      capture_by_value: Experimental. Whether to capture resource variables by
2426        value or reference. If None, will inherit from a parent context or
2427        default to False.
2428      jit_compile: Force-compile the function with XLA, cf.
2429        def_function.Function doc on jit_compile.
2430      experimental_follow_type_hints: See the documentation for `tf.function`.
2431
2432    Raises:
2433      ValueError: if `input_signature` is not None and the `python_function`'s
2434        argspec has keyword arguments.
2435    """
2436    self._python_function = python_function
2437    pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes
2438    self._function_spec = function_spec.FunctionSpec.from_function_and_signature(
2439        python_function,
2440        input_signature,
2441        is_pure=pure_function,
2442        experimental_follow_type_hints=experimental_follow_type_hints)
2443    self._name = name
2444    self._autograph = autograph
2445    self._autograph_options = autograph_options
2446    self._reduce_retracing = reduce_retracing
2447    self._function_cache = function_cache.FunctionCache()
2448    self._function_attributes = attributes or {}
2449    self._capture_by_value = capture_by_value
2450    self.tracing_count = 0
2451    # Maintein a dict of all captures: identifier -> lambda function. It's used
2452    # to get runtime values for all captures during ConcreteFunction dispatch,
2453    self._captures_container = func_graph_module.CapturesContainer()
2454    self._lock = threading.RLock()
2455    # _descriptor_cache is a of instance of a class to an instance-specific
2456    # `Function`, used to make sure defun-decorated methods create different
2457    # functions for each instance.
2458    self._descriptor_cache = weakref.WeakKeyDictionary()
2459    self._jit_compile = jit_compile
2460    self._experimental_follow_type_hints = experimental_follow_type_hints
2461
2462  def __call__(self, *args, **kwargs):
2463    """Calls a graph function specialized to the inputs."""
2464    with self._lock:
2465      (graph_function,
2466       filtered_flat_args) = self._maybe_define_function(args, kwargs)
2467    return graph_function._call_flat(
2468        filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
2469
2470  @property
2471  def python_function(self):
2472    """Returns the wrapped Python function."""
2473    return self._python_function  # pylint: disable=protected-access
2474
2475  @property
2476  def function_spec(self):
2477    return self._function_spec
2478
2479  @property
2480  def input_signature(self):
2481    """Returns the input signature."""
2482    return self._function_spec.input_signature
2483
2484  def _maybe_define_concrete_function(self, args, kwargs):
2485    if self.input_signature and not args and not kwargs:
2486      # TODO(b/215596825): Throw error here if multiple entries are defined.
2487      args = self.input_signature
2488      kwargs = {}
2489
2490    return self._maybe_define_function(args, kwargs)
2491
2492  def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
2493    """Returns a concrete function which cleans up its graph function."""
2494    with self._lock:
2495      graph_function, _ = self._maybe_define_concrete_function(args, kwargs)
2496    return graph_function
2497
2498  def _get_concrete_function_internal(self, *args, **kwargs):
2499    """Bypasses error checking when getting a graph function."""
2500    graph_function = self._get_concrete_function_internal_garbage_collected(
2501        *args, **kwargs)
2502    # We're returning this concrete function to someone, and they may keep a
2503    # reference to the FuncGraph without keeping a reference to the
2504    # ConcreteFunction object. So we won't clean up the reference cycles
2505    # manually and instead will leave them to Python's garbage collector.
2506    graph_function._garbage_collector.release()  # pylint: disable=protected-access
2507    return graph_function
2508
2509  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
2510    """Returns a `ConcreteFunction` specialized to inputs and execution context.
2511
2512    Unlike `get_concrete_function(...)`, the graph will be deleted when the
2513    returned function is deleted.  It's useful to avoid creating a reference
2514    cycle when you know for sure that the graph will be no longer used without
2515    the returned function.
2516
2517    Args:
2518      *args: inputs to specialize on.
2519      **kwargs: inputs to specialize on.
2520    """
2521    if self.input_signature:
2522      self._function_spec.validate_inputs_with_signature(args, kwargs)
2523
2524    with self._lock:
2525      graph_function, _ = self._maybe_define_concrete_function(args, kwargs)
2526      seen_names = set()
2527      captured = object_identity.ObjectIdentitySet(
2528          graph_function.graph.internal_captures)
2529      # pylint: disable=protected-access
2530      graph_function._arg_keywords = []
2531      prefix_counts = {}
2532      # pylint: enable=protected-access
2533      num_positional = 0
2534      for arg in graph_function.graph.inputs:
2535        if arg in captured:
2536          break
2537        num_positional += 1
2538        user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
2539        proposal = user_arg_name
2540        while proposal in seen_names:
2541          index = prefix_counts.get(user_arg_name, 1)
2542          proposal = "{}_{}".format(user_arg_name, index)
2543          prefix_counts[user_arg_name] = index + 1
2544        seen_names.add(proposal)
2545        graph_function._arg_keywords.append(proposal)  # pylint: disable=protected-access
2546      # Anything can be a positional argument, in the same order as .inputs
2547      graph_function._num_positional_args = num_positional  # pylint: disable=protected-access
2548      return graph_function
2549
2550  def get_concrete_function(self, *args, **kwargs):
2551    """Returns a `ConcreteFunction` specialized to inputs and execution context.
2552
2553    Args:
2554      *args: inputs to specialize on. Can be concrete values (e.g. 1) or
2555        `tf.Tensor` or `tf.TensorSpec`.
2556      **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) or
2557        `tf.Tensor` or `tf.TensorSpec`.
2558    """
2559    graph_function = self._get_concrete_function_garbage_collected(
2560        *args, **kwargs)
2561    graph_function._garbage_collector.release()  # pylint: disable=protected-access
2562    return graph_function
2563
2564  def _list_all_concrete_functions(self) -> List[ConcreteFunction]:
2565    return self._function_cache.values()
2566
2567  def __get__(self, instance, owner):
2568    """Makes it possible to defun instance methods."""
2569    del owner
2570    # `instance` here is the instance that this `Function` was accessed through
2571    # e.g., for
2572    #
2573    #   class Foo:
2574    #
2575    #     @function.defun
2576    #     def bar(self):
2577    #       ...
2578    #
2579    #   foo = Foo()
2580    #   foo.bar()  # `foo.bar` is a `Function` instance
2581    #
2582    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
2583    # new instance of `Function` here to allow different instances each
2584    # to create variables once, thereby allowing methods to be decorated with
2585    # defun. Keeps a cache to avoid retracing the function every time the
2586    # descriptor is accessed.
2587    if instance not in self._descriptor_cache:
2588      if instance is None:
2589        return self
2590      # If there is no instance-specific `Function` in the cache, we construct
2591      # an instance-specific `Function` that uses a weak reference to the
2592      # instance (so that the instance will be correctly gc'd).
2593
2594      # And finally add the wrapped function to the description cache
2595      self._descriptor_cache[instance] = class_method_to_instance_method(
2596          self, instance)
2597
2598    # Return the cached `Function` for the instance
2599    return self._descriptor_cache[instance]
2600
2601  def _create_graph_function(self, args, kwargs):
2602    """Create a `ConcreteFunction` from `args` and `kwargs`."""
2603    self.tracing_count += 1
2604
2605    arglen = len(args)
2606    base_arg_names = self._function_spec.arg_names[:arglen]
2607    num_missing_args = arglen - len(self._function_spec.arg_names)
2608    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
2609    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
2610    # where arg is based on the self._function_spec.vararg_name.
2611    missing_arg_names = [
2612        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
2613    ]
2614    arg_names = base_arg_names + missing_arg_names
2615    graph_function = ConcreteFunction(
2616        func_graph_module.func_graph_from_py_func(
2617            self._name,
2618            self._python_function,
2619            args,
2620            kwargs,
2621            None,
2622            autograph=self._autograph,
2623            autograph_options=self._autograph_options,
2624            arg_names=arg_names,
2625            capture_by_value=self._capture_by_value),
2626        self._function_attributes,
2627        spec=self.function_spec,
2628        # Tell the ConcreteFunction to clean up its graph once it goes out of
2629        # scope. This is not the default behavior since it gets used in some
2630        # places (like Keras) where the FuncGraph lives longer than the
2631        # ConcreteFunction.
2632        shared_func_graph=False)
2633    return graph_function
2634
2635  def _maybe_define_function(self, args, kwargs):
2636    """Gets a function for these inputs, defining it if necessary.
2637
2638    Caller must hold self._lock.
2639
2640    Args:
2641      args: The varargs for the Python function.
2642      kwargs: The keyword args for the Python function.
2643
2644    Returns:
2645      A graph function corresponding to the input signature implied by args and
2646      kwargs, as well as filtered flattened inputs (only Tensors and Variables)
2647      that the object should be called with.
2648
2649    Raises:
2650      ValueError: If inputs are incompatible with the input signature.
2651      TypeError: If the function inputs include non-hashable objects
2652      RuntimeError: If there's an internal bug (inconsistency) in handling
2653        shape relaxation retracing.
2654    """
2655    args, kwargs, filtered_flat_args = (
2656        self._function_spec.canonicalize_function_inputs(args, kwargs))
2657
2658    if self.input_signature is not None:
2659      args = self.input_signature
2660      kwargs = {}
2661
2662    # Get runtime values of captures
2663    captures = self._captures_container.get_snapshot()
2664
2665    # cache_key_deletion_observer is useless here. It's based on all captures.
2666    # A new cache key will be built later when saving ConcreteFunction because
2667    # only active captures should be saved.
2668    lookup_func_key, _ = function_context.make_cache_key((args, kwargs),
2669                                                         captures)
2670    graph_function = self._function_cache.lookup(lookup_func_key, True)
2671    if graph_function is not None:
2672      return graph_function, filtered_flat_args
2673
2674    with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()):
2675      with trace.Trace("tf.function-graph_building"):
2676        logging.vlog(1,
2677                     "Creating new FuncGraph for Python function %r (key: %r)",
2678                     self._python_function, lookup_func_key)
2679        logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
2680                     args, kwargs)
2681        ag_status = (
2682            ag_ctx.Status.ENABLED
2683            if self._autograph else ag_ctx.Status.DISABLED)
2684        with ag_ctx.ControlStatusCtx(
2685            status=ag_status, options=self._autograph_options):
2686          if self.input_signature is None and self._reduce_retracing:
2687            generalized_func_key = self._function_cache.generalize(
2688                lookup_func_key)
2689            # Only get placeholders for arguments, not captures
2690            args, kwargs = generalized_func_key._placeholder_value()["args"]  # pylint: disable=protected-access
2691
2692          graph_function = self._create_graph_function(args, kwargs)
2693
2694          graph_capture_container = graph_function.graph._capture_func_lib  # pylint: disable=protected-access
2695          # Maintain the list of all captures
2696          self._captures_container.update(graph_capture_container)
2697          # Get current active captures snapshot
2698          captures = graph_capture_container.get_snapshot()
2699
2700          # Create a cache_key with args and captures
2701          traced_func_key, traced_func_deletion_observer = (
2702              function_context.make_cache_key((args, kwargs), captures))
2703
2704          self._function_cache.add(traced_func_key,
2705                                   traced_func_deletion_observer,
2706                                   graph_function)
2707
2708          return graph_function, filtered_flat_args
2709
2710
2711def register(func, *args, **kwargs):
2712  """Register a specialization of a `Function` into the graph.
2713
2714  This won't actually call the function with the inputs, and only put the
2715  function definition into graph. Register function with different input param
2716  will result into multiple version of functions registered in graph.
2717
2718  Args:
2719    func: the `Function` instance that generated by a @defun
2720    *args: input arguments for the Python function.
2721    **kwargs: input keyword arguments for the Python function.
2722
2723  Returns:
2724    a `ConcreteFunction` object specialized to inputs and execution context.
2725
2726  Raises:
2727    ValueError: When the input function is not a defun wrapped python function.
2728  """
2729  if not isinstance(func, Function):
2730    raise ValueError("Only defun function is allowed to be registered. "
2731                     f"Got {func} with type {type(func)}.")
2732  concrete_func = func.get_concrete_function(*args, **kwargs)
2733  concrete_func.add_to_graph()
2734  concrete_func.add_gradient_functions_to_graph()
2735  return concrete_func
2736
2737
2738def defun(func=None,
2739          input_signature=None,
2740          autograph=True,
2741          experimental_autograph_options=None,
2742          reduce_retracing=False):
2743  """Compiles a Python function into a callable TensorFlow graph.
2744
2745  `defun` (short for "define function") compiles a Python function
2746  composed of TensorFlow operations into a callable that executes a `tf.Graph`
2747  containing those operations. The callable produced by `defun` contains only
2748  the subgraph of TensorFlow operations that were executed when the Python
2749  function was called with a particular input signature, defined as a list
2750  of the shapes and dtypes of the Python function's Tensor-valued arguments and
2751  the values of its non-Tensor Python objects.
2752
2753  When eager execution is enabled, the ability to create graphs from Python
2754  functions makes it possible to incrementally trade off debuggability and
2755  interactivity for performance.  Functions compiled with `defun` cannot be
2756  inspected with `pdb`; however, executing a graph
2757  generated by `defun` sometimes takes less time and memory than eagerly
2758  executing the corresponding Python function, since specifying computations as
2759  graphs allows for optimizations like automatic buffer reuse and
2760  parallelization among ops. Note that executing a `defun`-compiled function
2761  incurs a small constant overhead, so eagerly executing sufficiently small
2762  Python functions might take less time than executing their corresponding
2763  `defun`-generated graphs.
2764
2765  For a Python function to be compatible with `defun`, all of its arguments must
2766  be hashable Python objects or lists thereof. The function itself may not
2767  modify the list/map structure of its arguments. Additionally, it must return
2768  zero or more `tf.Tensor` objects. If the Python function returns
2769  a `tf.Variable`, its compiled version will return the value of that variable
2770  as a `tf.Tensor`.
2771
2772  Executing a graph generated by `defun` respects device annotations (i.e.,
2773  all `with tf.device` directives present in a Python function will also be
2774  present in its corresponding graph), but it is not yet possible to execute the
2775  generated graphs across multiple machines.
2776
2777  _Example Usage_
2778
2779  ```python
2780  import tensorflow as tf
2781
2782  tf.compat.v1.enable_eager_execution()
2783
2784  # A simple example.
2785  def f(x, y):
2786    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
2787
2788  g = tf.contrib.eager.defun(f)
2789
2790  x = tf.constant([[2.0, 3.0]])
2791  y = tf.constant([[3.0, -2.0]])
2792
2793  # `f` and `g` will return the same value, but `g` will be executed as a
2794  # TensorFlow graph.
2795  assert f(x, y).numpy() == g(x, y).numpy()
2796
2797  # `defun` is capable of compiling Python functions that close over Python
2798  # objects, including Tensors and Variables.
2799  @tf.contrib.eager.defun
2800  def h():
2801    return f(x, y)
2802
2803  assert (h().numpy() == f(x, y).numpy()).all()
2804
2805  # `defun` automatically lifts variables out of the graphs it creates,
2806  # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
2807  # `tf.keras.Model` objects.
2808  class MyModel(tf.keras.Model):
2809
2810    def __init__(self, keep_probability=0.2):
2811      super().__init__()
2812      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
2813      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
2814      self.keep_probability = keep_probability
2815
2816    @tf.contrib.eager.defun
2817    def call(self, inputs, training=True):
2818      x = self.dense2(self.dense1(inputs))
2819      if training:
2820        return tf.nn.dropout(x, self.keep_probability)
2821      else:
2822        return x
2823
2824  model = MyModel()
2825  model(x, training=True)  # executes a graph, with dropout
2826  model(x, training=False) # executes a graph, without dropout
2827
2828  # `defun`-compiled functions are differentiable.
2829  optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
2830  with tf.GradientTape() as tape:
2831    outputs = model(x)
2832  gradient = tape.gradient(outputs, model.trainable_variables)
2833  optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
2834                            model.trainable_variables))
2835  ```
2836
2837  When using `defun`, there are subtleties regarding inputs, Python control
2838  flow, and variable creation that one should be aware of. For concreteness, let
2839  `f` be a Python function that returns zero or more `tf.Tensor` objects and
2840  let `F = defun(f)`. `F` builds a graph for each unique input signature it
2841  sees, Python control flow is baked into graphs, and operations related to
2842  variable initialization are automatically lifted out of the graphs that `F`
2843  generates and placed in the eager context if executing eagerly or into an
2844  outer graph otherwise.
2845
2846  _Input Signatures_
2847
2848  By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
2849  for every unique sequence of the shapes and dtypes of Tensor arguments and
2850  the values of Python objects it is invoked with. For example, calling
2851  `F(tf.random.uniform([2])` will execute a different graph than
2852  `F(tf.random.uniform([3])` because the two inputs have different shapes.
2853  The first time that `F(*args, **kwargs)` is called with a particular sequence
2854  of Tensor shapes and dtypes and Python values, it constructs a graph by
2855  tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
2856  input signature inferred from `(*args, **kwargs)` and cached for future reuse.
2857
2858  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
2859  before being passed to `f`, and are treated as Tensors for caching. This
2860  allows a function to be called multiple times with NumPy arrays having
2861  different values but the same shape and dtype without re-tracing each time.
2862
2863  `tf.contrib.eager.defun` caches graphs for your convenience, letting you
2864  define TensorFlow functions without explicitly specifying their signatures.
2865  However, this policy is conservative and potentially expensive; for example,
2866  when different invocations of your function have differently-shaped Tensor
2867  inputs, this policy might generate more graph functions than necessary. To
2868  eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
2869  optional `input_signature` argument specifying the shapes and dtypes of the
2870  inputs. In particular, the shapes may be partially unspecified, with `None`s
2871  in the unknown dimensions.  When an input signature is provided,
2872  `tf.contrib.eager.defun` will only instantiate a single graph for the
2873  decorated Python function. The following is an example:
2874
2875  ```python
2876  import tensorflow as tf
2877
2878  # The first `TensorSpec` below describes the shape and dtype of `words`,
2879  # and the second describes the shape and dtype of `another_tensor`. Note that
2880  # the last dimension of the `words` `TensorSpec` is left unspecified.
2881  @tf.contrib.eager.defun(input_signature=[
2882    tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
2883    tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
2884  ])
2885  def my_sequence_model(words, another_tensor):
2886    ...
2887
2888  # Note how the third dimension of the first input can vary freely.
2889  words = tf.random.uniform(([50, 300, 10])
2890  second_input = tf.random.uniform([300, 100])
2891  my_sequence_model(words, second_input)
2892
2893  words = tf.random.uniform(([50, 300, 20])
2894  my_sequence_model(words, second_input)
2895
2896  # Passing an input with an incompatible shape will raise an error.
2897  words = tf.random.uniform(([50, 100, 20])
2898  my_sequence_model(words, second_input)  # <---- This will raise an error.
2899
2900  ```
2901
2902  Python functions that are compiled with an `input_signature` must only accept
2903  Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
2904
2905  _Tracing_
2906
2907  Be aware that because `F` only logs TensorFlow operations, all the other
2908  Python code that `f` executes will only shape the _construction_ of the graphs
2909  that `F` executes: the Python code won't be executed when the graphs
2910  themselves are executed, though it will be executed every time the Python
2911  function is traced (and a given Python function might be traced multiple
2912  times, once for each input signature it is invoked with). For example, whereas
2913  the Python function
2914
2915  ```python
2916  import tensorflow as tf
2917  import numpy as np
2918
2919  tf.compat.v1.enable_eager_execution()
2920
2921  def add_noise():
2922    return tf.eye(5) + np.random.randn(5, 5)
2923  ```
2924
2925  will return a different output everytime it is invoked, the compiled function
2926  `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
2927  every time it is called, since a particular random offset generated by NumPy
2928  will be inserted into the graph as a TensorFlow constant. The solution is to
2929  replace the call to `np.random.randn` with `tf.random.normal((5, 5))`.
2930
2931  _Python Side-Effects_
2932
2933  A corollary of the previous discussion on tracing is the following: If a
2934  Python function `f` has Python side-effects, then executing `f` multiple times
2935  will not necessarily be semantically equivalent to executing `F =
2936  tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
2937  that `defun` only captures the subgraph of TensorFlow operations that is
2938  constructed when `f` is called in a graph-building context.
2939
2940  _Python Control Flow_
2941
2942  The structure of many machine learning computations depend upon whether one is
2943  training or validating, and it is common to nest specialized logic under `if
2944  training:` blocks. By mapping each input signature to a unique graph, `defun`
2945  lets users transparently compile such code, as the following code snippet
2946  demonstrates:
2947
2948  ```python
2949  import tensorflow as tf
2950
2951  tf.compat.v1.enable_eager_execution()
2952
2953  @tf.contrib.eager.defun
2954  def lossy_matmul(W, x, training=True):
2955    outputs = tf.matmul(W, x)
2956    if training:
2957      outputs = tf.nn.dropout(outputs, keep_probability=0.2)
2958    return outputs
2959
2960  W = tf.random.normal((3, 5))
2961  x = tf.random.normal((5, 1))
2962
2963  # Executes a graph that applies dropout.
2964  lossy_outputs = lossy_matmul(W, x, training=True)
2965
2966  # Executes a graph that does not apply dropout.
2967  exact_outputs = lossy_matmul(W, x, training=False)
2968  ```
2969
2970  _TensorFlow Control Flow_
2971
2972  When `autograph` is `True`, data-dependent control flow is allowed as well.
2973  Control flow statements that depend on `Tensor` values are staged into
2974  corresponding TensorFlow ops. For example, the following code will work as
2975  expected:
2976
2977  ```python
2978  @tf.contrib.eager.defun
2979  def dynamic_rnn_loop(cell, seq):
2980    state, output = cell.zero_state()
2981    for input in seq:
2982      state, output = cell(input, state)
2983    return output
2984  ```
2985
2986  For more information see `tf.autograph`.
2987
2988  _Variables_
2989
2990  TensorFlow operations related to variable creation and initialization are
2991  automatically lifted out of the graphs generated by `defun`. In practice, this
2992  implies that variable creation and initialization only happen the first time
2993  `F` is called, and that variables are reused every time thereafter. Many
2994  TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
2995  first time they are called and reuse them thereafter. Automatic variable
2996  lifting makes it possible to compile these APIs without extra effort, at the
2997  cost of introducing a discrepancy between the semantics of executing Python
2998  functions and their corresponding compiled functions. For example:
2999
3000  ```python
3001  import tensorflow as tf
3002
3003  tf.compat.v1.enable_eager_execution()
3004
3005  def fn():
3006    x = tf.Variable(0.0)
3007    x.assign_add(1.0)
3008    return x.read_value()
3009
3010  # `fn` is a Python function, so x is created, initialized, and destroyed upon
3011  # every invocation
3012  assert fn().numpy() == fn().numpy() == 1.0
3013
3014  compiled = tf.contrib.eager.defun(fn)
3015
3016  # Compiling `fn` with `defun` hoists all variables outside of the generated
3017  # graph, so initialization happens exactly once.
3018  assert compiled().numpy() == 1.0
3019  assert compiled().numpy() == 2.0
3020  ```
3021
3022  Finally, because each input signature is bound to a unique graph, if your
3023  Python function constructs `tf.Variable` objects, then each graph constructed
3024  for that Python function will reference a unique set of variables. To
3025  circumvent this problem, we recommend against compiling Python functions that
3026  create `tf.Variable` objects. Instead, Python functions should either
3027  lexically close over `tf.Variable` objects or accept them as arguments,
3028  preferably encapsulated in an object-oriented container. If you must create
3029  variables inside your Python function and you want each graph generated for it
3030  to reference the same set of variables, add logic to your Python function that
3031  ensures that variables are only created the first time it is called and are
3032  reused for every subsequent invocation; note that this is precisely what
3033  `tf.keras.layers.Layer` objects do, so we recommend using them to represent
3034  variable-bearing computations whenever possible.
3035
3036  Args:
3037    func: function to be compiled. If `func` is None, returns a
3038      decorator that can be invoked with a single argument - `func`. The
3039      end result is equivalent to providing all the arguments up front.
3040      In other words, defun(input_signature=...)(func) is equivalent to
3041      defun(func, input_signature=...). The former allows
3042      the following use case:
3043        @tf.contrib.eager.defun(input_signature=...)
3044        def foo(...):
3045          ...
3046
3047    input_signature: A possibly nested sequence of
3048      `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
3049      the Tensors that will be supplied to this function. If `None`, a separate
3050      function is instantiated for each inferred input signature.  If a
3051      signature is specified, every input to `func` must be a `Tensor`, and
3052      `func` cannot accept `**kwargs`.
3053    autograph: Whether `func` should be compiled before
3054      constructing the graph. See https://www.tensorflow.org/guide/autograph
3055      for more information.
3056    experimental_autograph_options: Experimental knobs (in the form of a tuple
3057      of tensorflow.autograph.Feature values) to control behavior when
3058      autograph=True.
3059    reduce_retracing: When True, `tf.function` uses
3060      `tf.types.experimental.TraceType` to trace supertypes of arguments to
3061      reduce the number of traces.
3062
3063  Returns:
3064     If `func` is not None, returns a callable that will execute the compiled
3065     function (and return zero or more `tf.Tensor` objects).
3066     If `func` is None, returns a decorator that, when invoked with a single
3067     `func` argument, returns a callable equivalent to the case above.
3068
3069  Raises:
3070    TypeError: If `input_signature` is neither `None` nor a sequence of
3071      `tf.contrib.eager.TensorSpec` objects.
3072  """
3073  return defun_with_attributes(
3074      func=func,
3075      input_signature=input_signature,
3076      autograph=autograph,
3077      experimental_autograph_options=experimental_autograph_options,
3078      reduce_retracing=reduce_retracing)
3079
3080
3081@tf_export("__internal__.function.defun_with_attributes", v1=[])
3082def defun_with_attributes(func=None,
3083                          input_signature=None,
3084                          attributes=None,
3085                          autograph=True,
3086                          experimental_autograph_options=None,
3087                          jit_compile=None,
3088                          reduce_retracing=False,
3089                          experimental_follow_type_hints=False):
3090  """Compiles a Python function into a callable TensorFlow graph.
3091
3092  This function supports adding extra function attributes. See detailed
3093  documentation in defun(). Currently this is not exposed in public API since we
3094  don't expect user to directly use attributes, and attribute won't work by
3095  itself. This assumption might change in future.
3096
3097  Args:
3098    func: function to be compiled.
3099    input_signature: same as defun()'s input_signature.
3100    attributes: A dictionary of arguments which will be added to function def as
3101      attributes. Currently only support primitive types as value, and only
3102      allowlisted attribute name is allowed. Unallowlisted attribute name or
3103      unsupported value will result into ValueError. `func_name` is also one of
3104      the allowlisted argument which is a python string, and sets the name for
3105      this `ConcreteFunction` in the graph.
3106    autograph: same as defun()'s autograph.
3107    experimental_autograph_options: same as defun()'s
3108      experimental_autograph_options.
3109    jit_compile: same as defun()'s jit_compile.
3110    reduce_retracing: same as defun()'s reduce_retracing
3111    experimental_follow_type_hints: see `tf.function`.
3112
3113  Returns:
3114    Same as the return value of defun, with attributes added to the function in
3115    graph.
3116  """
3117
3118  # TODO(apassos): deal with captured global state. Deal with control flow.
3119  def decorated(function):
3120    try:
3121      if attributes:
3122        name = attributes.pop("func_name", function.__name__)
3123      else:
3124        name = function.__name__
3125    except AttributeError:
3126      name = "function"
3127    return tf_decorator.make_decorator(
3128        function,
3129        Function(
3130            function,
3131            name,
3132            input_signature=input_signature,
3133            attributes=attributes,
3134            autograph=autograph,
3135            autograph_options=experimental_autograph_options,
3136            jit_compile=jit_compile,
3137            reduce_retracing=reduce_retracing,
3138            experimental_follow_type_hints=experimental_follow_type_hints))
3139
3140  # This code path is for the `foo = tfe.defun(foo, ...)` use case
3141  if func is not None:
3142    return decorated(func)
3143
3144  # This code path is for the
3145  #
3146  # @tfe.defun(...)
3147  # def foo(...):
3148  #    ...
3149  #
3150  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
3151  return decorated
3152
3153
3154# When a method is bound to objects of this type, it allows AutoGraph to
3155# recover a weak reference the original method's self pointer, so that it can
3156# execute it consistent with class_method_to_instance_method's
3157# bound_method_wrapper.
3158# TODO(b/119246461): This is not pretty. Use a descriptor instead?
3159class TfMethodTarget:
3160  """Binding target for methods replaced by function and defun."""
3161
3162  __slots__ = ("weakrefself_target__", "weakrefself_func__")
3163
3164  def __init__(self, target, original_python_function):
3165    self.weakrefself_target__ = target
3166    self.weakrefself_func__ = weakref.ref(original_python_function)
3167
3168  @property
3169  def target(self):
3170    return self.weakrefself_target__()
3171
3172  @property
3173  def target_class(self):
3174    true_self = self.weakrefself_target__()
3175    if tf_inspect.isclass(true_self):
3176      # Class method
3177      return true_self
3178    else:
3179      return true_self.__class__
3180
3181  def call(self, args, kwargs):
3182    wrapped_fn = self.weakrefself_func__()
3183    return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)
3184
3185
3186def class_method_to_instance_method(original_function, instance):
3187  """Constructs a new `Function` with `self` bound."""
3188  weak_instance = weakref.ref(instance)
3189
3190  # Note: while we could bind to a weakref proxy instead, that causes the
3191  # bound method to be unhashable.
3192  bound_method = types_lib.MethodType(
3193      original_function.python_function,
3194      TfMethodTarget(weak_instance, original_function.python_function))
3195
3196  # original_function is expected to be of one of the two `Function` types
3197  # (defined either in function.py or def_function.py).
3198  assert hasattr(original_function, "_name")
3199  assert hasattr(original_function, "_autograph")
3200  assert hasattr(original_function, "_function_spec")
3201  assert hasattr(original_function, "python_function")
3202
3203  weak_bound_method_wrapper = None
3204  def bound_method_wrapper(*args, **kwargs):
3205    """Wraps either a dummy MethodType or a converted AutoGraph function."""
3206    # __wrapped__ allows AutoGraph to swap in a converted function.
3207    strong_bound_method_wrapper = weak_bound_method_wrapper()
3208    wrapped_fn = strong_bound_method_wrapper.__wrapped__
3209
3210    if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
3211      # If __wrapped__ was not replaced, then call original_function.
3212      # TODO(mdan): For better consistency, use the wrapper's call().
3213      wrapped_fn = original_function.python_function
3214      return wrapped_fn(weak_instance(), *args, **kwargs)
3215
3216    # If __wrapped__ was replaced, then it is always an unbound function.
3217    # However, the replacer is still responsible for attaching self properly.
3218    # TODO(mdan): Is it possible to do it here instead?
3219    return wrapped_fn(*args, **kwargs)
3220  weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
3221
3222  # pylint: disable=protected-access
3223  # We make a dummy MethodType object to generate the correct bound method
3224  # signature. The actual call is to a function with a weak reference to
3225  # `instance`.
3226  instance_func = type(original_function)(
3227      tf_decorator.make_decorator(bound_method, bound_method_wrapper),
3228      name=original_function._name,
3229      autograph=original_function._autograph,
3230      input_signature=original_function.input_signature,
3231      reduce_retracing=original_function._reduce_retracing,
3232      jit_compile=original_function._jit_compile)
3233  # pylint: enable=protected-access
3234
3235  # We wrap the bound method with tf_decorator so inspection works correctly
3236  wrapped_instance_func = tf_decorator.make_decorator(bound_method,
3237                                                      instance_func)
3238  return wrapped_instance_func
3239
3240
3241class ConcreteFunctionGarbageCollector:
3242  """Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
3243
3244  __slots__ = ["_func_graph"]
3245
3246  def __init__(self, func_graph):
3247    self._func_graph = func_graph
3248
3249  def release(self):
3250    """Call off the FuncGraph deletion."""
3251    self._func_graph = None
3252
3253  def __del__(self):
3254    if func_graph_module is None or memory is None or self._func_graph is None:
3255      return
3256    try:
3257      func_graph_module.dismantle_func_graph(self._func_graph)
3258    except:  # pylint: disable=bare-except
3259      pass
3260
3261
3262class _Marker(object):
3263  """Markers used to pretty-print nested args in function signatures."""
3264
3265  __slots__ = ["_s"]
3266
3267  def __init__(self, s):
3268    self._s = s
3269
3270  def __repr__(self):
3271    return str(self._s)
3272
3273
3274def _structure_summary(structure):
3275  """Displays a summary of the nesting structure of the given value."""
3276
3277  def type_name(x):
3278    if isinstance(x, type_spec.TypeSpec):
3279      return x.value_type.__name__
3280    else:
3281      return type(x).__name__
3282
3283  markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
3284  return str(nest.pack_sequence_as(structure, markers))
3285
3286
3287def _contains_type_spec(value):
3288  return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))
3289