• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Defun decorator for defining graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import functools
24import itertools
25import pprint
26import threading
27import types as types_lib
28import weakref
29
30import numpy as np
31import six
32from six.moves import map
33
34from tensorflow.core.framework import attr_value_pb2
35from tensorflow.core.framework import function_pb2
36from tensorflow.python import pywrap_tfe
37from tensorflow.python.client import pywrap_tf_session
38from tensorflow.python.eager import backprop
39from tensorflow.python.eager import backprop_util
40from tensorflow.python.eager import context
41from tensorflow.python.eager import execute
42from tensorflow.python.eager import forwardprop_util
43from tensorflow.python.eager import monitoring
44from tensorflow.python.eager import tape
45from tensorflow.python.eager.graph_only_ops import graph_placeholder
46from tensorflow.python.framework import c_api_util
47from tensorflow.python.framework import composite_tensor
48from tensorflow.python.framework import constant_op
49from tensorflow.python.framework import device as pydev
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import error_interpolation
52from tensorflow.python.framework import errors
53from tensorflow.python.framework import func_graph as func_graph_module
54from tensorflow.python.framework import ops
55from tensorflow.python.framework import tensor_shape
56from tensorflow.python.framework import tensor_spec
57from tensorflow.python.framework import type_spec
58from tensorflow.python.ops import array_ops
59from tensorflow.python.ops import control_flow_ops
60from tensorflow.python.ops import default_gradient
61from tensorflow.python.ops import functional_ops
62from tensorflow.python.ops import gradients_util
63from tensorflow.python.ops import handle_data_util
64from tensorflow.python.ops import resource_variable_ops
65from tensorflow.python.platform import tf_logging as logging
66from tensorflow.python.profiler import trace
67from tensorflow.python.saved_model import save_context
68from tensorflow.python.types import core
69from tensorflow.python.util import _pywrap_utils
70from tensorflow.python.util import compat
71from tensorflow.python.util import function_utils
72from tensorflow.python.util import lazy_loader
73from tensorflow.python.util import memory
74from tensorflow.python.util import nest
75from tensorflow.python.util import object_identity
76from tensorflow.python.util import tf_decorator
77from tensorflow.python.util import tf_inspect
78from tensorflow.python.util.tf_export import tf_export
79
80# Loaded lazily due to a circular dependency (roughly
81# tf.function->autograph->->dataset->tf.function).
82# TODO(b/133251390): Use a regular import.
83ag_ctx = lazy_loader.LazyLoader(
84    "ag_ctx", globals(),
85    "tensorflow.python.autograph.core.ag_ctx")
86np_arrays = lazy_loader.LazyLoader(
87    "np_arrays", globals(),
88    "tensorflow.python.ops.numpy_ops.np_arrays")
89
90
91FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
92BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
93IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
94SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
95
96_graph_building_time_counter = monitoring.Counter(
97    "/tensorflow/core/tf_function/graph_building_time_usecs",
98    "Time for tf.function to build a graph (us).")
99
100
101# TODO(b/195985838): cleanup this function.
102def _make_input_signature_hashable(elem):
103  """Rewrite input signature to be hashable.
104
105  We replace nested variables in the input signature with TensorSpec in order to
106  be hashable.
107
108  Args:
109    elem: Input signature element
110
111  Returns:
112    A hashable object for the requested input signature
113  """
114  try:
115    hash(elem)
116  except TypeError:
117    # TODO(slebedev): consider using nest.
118    if isinstance(elem, tuple):
119      return tuple(map(_make_input_signature_hashable, elem))
120
121    # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
122    # all recognized types to be hashable.
123    assert isinstance(elem, weakref.ReferenceType)
124    v = elem()
125
126    if resource_variable_ops.is_resource_variable(v):
127      # We special case variables here to use unique_id as the cache key. This
128      # ensures we have to retrace whenever a different variable is passed in.
129      # This is needed to support cases where the user may use the id of a
130      # variable in the function perhaps as a lookup in a dictionary.
131      #
132      # This choice leads to more retracing when we could have possibly used the
133      # shape and dtype instead. However, we expect the number of variables in a
134      # program to be bounded, and correspondingly the number of retraces.
135      #
136      # Note we also include the class name to avoid collisions with strings.
137      return v.__class__, v._unique_id  # pylint: disable=protected-access
138
139    if _is_ndarray(v):
140      # Numpy arrays are not hashable, but when calling functions we treat them
141      # in the same way as tf.Tensors.
142      if not hasattr(v, "shape") or not hasattr(v, "dtype"):
143        # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
144        v = _as_ndarray(v)
145      return tensor_spec.TensorSpec(v.shape, v.dtype)
146
147    raise ValueError("Arguments to a tf.function must be a nested structure of "
148                     "Tensors, Variables, NumPy arrays, or hashable Python "
149                     f"objects, got {type(v)}.")
150
151  return elem
152
153
154CacheKey = collections.namedtuple("CacheKey", [
155    "input_signature",
156    "parent_graph",
157    "device_functions",
158    "colocation_stack",
159    "in_cross_replica_context",
160    "variable_policy",
161    "xla_context_id",
162])
163
164
165def _type_spec_for(x):
166  """Returns a TypeSpec for `x`, or `None` if `x` doesn't have a TensorSpec."""
167  if isinstance(x, ops.Tensor):
168    return tensor_spec.TensorSpec.from_tensor(x)
169  elif isinstance(x, type_spec.TypeSpec):
170    return x
171  elif isinstance(x, composite_tensor.CompositeTensor):
172    return x._type_spec  # pylint: disable=protected-access
173  else:
174    return None
175
176
177def _is_type_subset(a, b):
178  """Returns true if TypeSpec `b` is a subset of type `a` (or if a is None.)"""
179  if a is None:
180    return True
181  else:
182    return a.most_specific_compatible_type(b) == a
183
184
185def _shape_relaxed_type_for_composite_tensor(x):
186  """Returns a shape-relaxed TypeSpec for x (if composite) or x (if not)."""
187  if isinstance(x, composite_tensor.CompositeTensor):
188    # pylint: disable=protected-access
189    return x._type_spec._with_tensor_ranks_only()
190  else:
191    return x
192
193
194def common_shape(x, y):
195  """Find a `TensorShape` that is compatible with both `x` and `y`."""
196  if x is None != y is None:
197    raise RuntimeError(
198        "Cannot find a common shape when LHS shape is None but RHS shape "
199        f"is not (or vice versa): {x} vs. {y}.")
200  if x is None:
201    return None  # The associated input was not a Tensor, no shape generated.
202  if not isinstance(x, tensor_shape.TensorShape):
203    raise TypeError(f"`x` must be a TensorShape, got type {type(x)}.")
204  if not isinstance(y, tensor_shape.TensorShape):
205    raise TypeError(f"`y` must be a TensorShape, got type {type(y)}.")
206  if x.rank != y.rank or x.rank is None:
207    return tensor_shape.TensorShape(None)
208  dims = []
209  for dim_x, dim_y in zip(x.dims, y.dims):
210    if (dim_x != dim_y
211        or tensor_shape.dimension_value(dim_x) is None
212        or tensor_shape.dimension_value(dim_y) is None):
213      dims.append(None)
214    else:
215      dims.append(tensor_shape.dimension_value(dim_x))
216  return tensor_shape.TensorShape(dims)
217
218
219def is_same_structure(structure1,
220                      structure2,
221                      check_values=False):
222  """Check two structures for equality, optionally of types and of values."""
223  try:
224    nest.assert_same_structure(structure1, structure2, expand_composites=True)
225  except (ValueError, TypeError):
226    return False
227  if check_values:
228    flattened1 = nest.flatten(structure1, expand_composites=True)
229    flattened2 = nest.flatten(structure2, expand_composites=True)
230    # First check the types to avoid AttributeErrors.
231    if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)):
232      return False
233    return flattened1 == flattened2
234  return True
235
236
237def _parse_func_attrs(attributes):
238  """Convert the keyword arguments into function_def attributes.
239
240  Currently only support primitive types: bool, int, float and string.
241
242  Args:
243    attributes: the dictionary of attributes.
244  Returns:
245    A dict of attributes where the key is the name of attribute and the value
246      is the AttrValue proto.
247  Raises:
248    ValueError: If the kwargs contains unallowlisted name or unsupported value
249      types.
250  """
251  attrs = {}
252  for key, value in attributes.items():
253    if isinstance(value, attr_value_pb2.AttrValue):
254      attrs[key] = value
255    # bool type check has to happen before int since bool is a subclass of int.
256    elif isinstance(value, bool):
257      attrs[key] = attr_value_pb2.AttrValue(b=value)
258    elif isinstance(value, int):
259      attrs[key] = attr_value_pb2.AttrValue(i=value)
260    elif isinstance(value, float):
261      attrs[key] = attr_value_pb2.AttrValue(f=value)
262    elif isinstance(value, (str, bytes, six.text_type)):
263      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
264    else:
265      raise ValueError(f"Attribute {key} must be bool, int, float, string, or "
266                       f"AttrValue. Got {type(value)}.")
267  return attrs
268
269
270class _InterpolateFunctionError(object):
271  """Context Manager that interpolates the exception from 'top_level_func'."""
272
273  __slots__ = ["_func"]
274
275  def __init__(self, top_level_func):
276    self._func = top_level_func
277
278  def __enter__(self):
279    pass
280
281  def __exit__(self, typ, exc, tb):
282    if not exc or not isinstance(exc, errors.OpError):
283      return False
284    message = compat.as_text(exc.message)
285    _, tags = error_interpolation.parse_message(message)
286    g = None
287    func_stack = []
288    for t in tags:
289      if t.type == "function_node":
290        # TODO(mdan): Tests should cover this.
291        if t.name == compat.as_str(self._func.name):
292          g = self._func.graph
293        elif g:
294          next_func = g._get_function(t.name)  # pylint: disable=protected-access
295          if next_func is not None and isinstance(next_func,
296                                                  _EagerDefinedFunction):
297            g = next_func.graph
298        if g:
299          func_stack.append(g.name)
300        else:
301          func_stack.append("<unknown>")
302    if g:
303      message = error_interpolation.interpolate(message, g)
304      if len(func_stack) >= 2:
305        message += "\n\nFunction call stack:\n"
306        message += " -> ".join(func_stack)
307        message += "\n"
308      exc._message = message  # pylint: disable=protected-access
309    return False
310
311
312_function_callbacks = set()
313
314
315def add_function_callback(function_callback):
316  """Add a callback function for Function creation.
317
318  The callback function has the signature:
319
320    `def function_callback(function, name, graph, inputs, outputs):`
321
322  where:
323  - `function`: _EagerDefinedFunction being created before finalizing the graph.
324      Do not modify the function directly but instead modify the graph.
325  - `name`: name of the function.
326  - `graph`: Graph of the function.
327  - `inputs`: `tuple` of tensors used as inputs to the function.
328  - `outputs`: `tuple` of tensors used as outputs from the function.
329
330  The callback is at the top of the `_EagerDefinedFunction` construction, giving
331  callback an opportunity to make the last edits to the graph. Do not make
332  changes to `graph, inputs`, and `outputs` manually, but, instead, set the
333  `graph` as the default then define ops.
334
335  Repeated registration of the same callback function is idempotent.
336  After a callback is added, it can be removed with the
337  `remove_function_callback()` method.
338
339  Args:
340    function_callback: The callback to add.
341  """
342  _function_callbacks.add(function_callback)
343
344
345def remove_function_callback(function_callback):
346  """Remove an already-added function callback.
347
348  See the doc string of `add_function_callback()` for more information.
349
350  Args:
351    function_callback: The callback to remove.
352  """
353  _function_callbacks.remove(function_callback)
354
355
356def clear_function_callbacks():
357  """Clear all function callbacks, if any have been regisered."""
358  _function_callbacks.clear()
359
360
361_FORWARD_PREFIX = "__forward_"
362_BACKWARD_PREFIX = "__backward_"
363_INFERENCE_PREFIX = "__inference_"
364
365
366def _forward_name(n):
367  """The name of a generated forward defun named n."""
368  return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
369
370
371def _backward_name(n):
372  """The name of a generated backward defun named n."""
373  return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
374
375
376def _inference_name(n):
377  """The name of a forward-but-no-gradient defun named n."""
378  return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
379
380
381def _enclosing_xla_context():
382  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
383  graph = ops.get_default_graph()
384  while graph is not None:
385    # pylint: disable=protected-access
386    context_ = graph._get_control_flow_context()
387    # pylint: enable=protected-access
388    while context_ is not None:
389      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
390        return context_
391      context_ = context_.outer_context
392    # This may be a FuncGraph due to defuns or v2 control flow. We need to
393    # find the original graph with the XLAControlFlowContext.
394    graph = getattr(graph, "outer_graph", None)
395  return None
396
397
398class _EagerDefinedFunctionDeleter(object):
399  """Unregister function from eager context."""
400
401  __slots__ = ["name"]
402
403  def __init__(self, name):
404    self.name = name
405
406  def __del__(self):
407    try:
408      context.remove_function(self.name)
409    except TypeError:
410      # Suppress some exceptions, mainly for the case when we're running on
411      # module deletion. Things that can go wrong include the context module
412      # already being unloaded, self._handle._handle_data no longer being
413      # valid, and so on. Printing warnings in these cases is silly
414      # (exceptions raised from __del__ are printed as warnings to stderr).
415      pass  # 'NoneType' object is not callable when the handle has been
416      # partially unloaded.
417    except AttributeError:
418      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
419      # been unloaded. Will catch other module unloads as well.
420
421
422class FunctionAlreadyGarbageCollectedError(Exception):
423
424  def __init__(self, function_name):
425    super(FunctionAlreadyGarbageCollectedError, self).__init__(
426        "{} has already been garbage collected and cannot be called.".format(
427            function_name))
428
429
430# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
431# so it doesn't have the definition-generating logic and is just a container for
432# an already-defined function.
433class _EagerDefinedFunction(object):
434  """Callable with the interface of `framework.function._DefinedFunction`.
435
436  `_EagerDefinedFunction` encapsulates a function definition and its properties,
437  and it provides a method for calling the encapsulated function. Some Ops
438  take functions as attributes, which have type `func`; an instance of this
439  class may be provided as the value of these `func` attributes.
440  """
441
442  def __init__(self, name, graph, inputs, outputs, attrs):
443    """Initializes an eager defined function.
444
445    Args:
446      name: str, the name for the created function.
447      graph: Graph, the graph containing the operations in the function
448      inputs: the tensors in the graph to be used as inputs to the function
449      outputs: the tensors in the graph which will be outputs from the function
450      attrs: dict mapping names of attributes to their AttrValue values
451    """
452    for function_callback in _function_callbacks:
453      function_callback(self, name, graph, tuple(inputs), tuple(outputs))
454
455    input_ops = set(arg.op for arg in inputs)
456    operations = [op for op in graph.get_operations() if op not in input_ops]
457
458    graph_output_names = graph._output_names  # pylint: disable=protected-access
459    if (graph_output_names is not None and
460        all(ops.tensor_id(t) in graph_output_names for t in outputs)):
461      output_names = [
462          compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
463      ]
464      if len(set(output_names)) != len(output_names):
465        # There are duplicate names for some reason, probably an invalid
466        # signature. Revert to auto-naming.
467        output_names = []
468    else:
469      output_names = []
470    fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
471        graph._c_graph,  # pylint: disable=protected-access
472        compat.as_str(name),
473        False,
474        [o._c_op for o in operations],  # pylint: disable=protected-access
475        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
476        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
477        output_names,
478        [o._c_op for o in graph.control_outputs],  # pylint: disable=protected-access
479        [],  # control_output_names
480        None,
481        compat.as_str(""))
482
483    for name, attr_value in attrs.items():
484      serialized = attr_value.SerializeToString()
485      # TODO(iga): this creates and deletes a new TF_Status for every attr.
486      # It might be worth creating a convenient way to re-use status.
487      pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
488                                                     serialized)
489
490    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
491    # signature, but also in general it's nice not to depend on it.
492    with c_api_util.tf_buffer() as buffer_:
493      pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_)
494      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
495    function_def = function_pb2.FunctionDef()
496    function_def.ParseFromString(compat.as_bytes(proto_data))
497    self._name = compat.as_bytes(function_def.signature.name)
498    with ops.init_scope():
499      if context.executing_eagerly():
500        context.ensure_initialized()
501        context.add_function(fn)
502        self._function_deleter = _EagerDefinedFunctionDeleter(self.name)
503        self._registered_on_context = True
504    self.definition = function_def
505    self.signature = function_def.signature
506    self._num_outputs = len(self.signature.output_arg)
507    self._output_types = [o.type for o in self.signature.output_arg]
508    self._output_shapes = [o.shape for o in outputs]
509    self._control_captures = graph.control_captures
510    # Shallow copy outputs since ConcreteFunction may mutate it.
511    self._func_graph_outputs = list(outputs)
512    self.grad_func_name = None
513    self.python_grad_func = None
514    self._c_func = c_api_util.ScopedTFFunction(fn)
515    self._grad_func = None
516    self.graph = graph
517    self._stateful_ops = tuple(op for op in operations if op._is_stateful)  # pylint: disable=protected-access
518
519  def add_to_graph(self, g=None):
520    """Add the function to the current context or a graph, if supplied.
521
522    Args:
523      g: the graph to add the function to. If not supplied, the function will
524        be added to the current context.
525    """
526    # pylint: disable=protected-access
527    if not g and context.executing_eagerly():
528      ctx = context.context()
529      if not ctx.has_function(self.name):
530        ctx.add_function_def(self.definition)
531    else:
532      if not g._is_function(self.name):
533        g._add_function(self)
534      for f in self.graph._functions.values():
535        if not g._is_function(f.name):
536          g._add_function(f)
537    # pylint: enable=protected-access
538
539  @property
540  def name(self):
541    return self._name
542
543  @property
544  def stateful_ops(self):
545    return self._stateful_ops
546
547  def call(self, ctx, args, cancellation_manager=None):
548    """Calls this function with `args` as inputs.
549
550    `ConcreteFunction` execution respects device annotations only if the
551    function won't be compiled with xla.
552
553    Args:
554      ctx: a Context object
555      args: a list of arguments to supply this function with.
556      cancellation_manager: a `CancellationManager` object that can be used to
557        cancel function execution.
558
559    Returns:
560      The outputs of the function call.
561
562    Raises:
563      ValueError: if the number of arguments is incorrect.
564      FunctionAlreadyGarbageCollectedError: if the function is no longer
565        available to be called because it has been garbage collected.
566    """
567    if len(args) != len(self.signature.input_arg):
568      raise ValueError(
569          f"Signature specifies {len(list(self.signature.input_arg))} "
570          f"arguments, got: {len(args)}.")
571
572    # If the `ScopedTFFunction` (accessed via `_c_func`) has already been
573    # cleaned up as a part of garbage collection, this `_EagerDefinedFunction`
574    # should also be garbage and is likely being called as part of a `__del__`
575    # elsewhere. In that case, there's nothing we can do, so we raise an
576    # exception for the caller to handle.
577    if self._c_func.has_been_garbage_collected:
578      raise FunctionAlreadyGarbageCollectedError(self.name)
579
580    function_call_options = ctx.function_call_options
581    if function_call_options.config_proto_serialized is None:
582      config = function_utils.get_disabled_rewriter_config()
583    else:
584      config = function_call_options.config_proto_serialized
585    executor_type = function_call_options.executor_type or ""
586
587    executing_eagerly = ctx.executing_eagerly()
588    attrs = ("executor_type", executor_type, "config_proto", config)
589    if executing_eagerly:
590      with _InterpolateFunctionError(self):
591        if cancellation_manager is None:
592          outputs = execute.execute(
593              str(self.signature.name),
594              num_outputs=self._num_outputs,
595              inputs=args,
596              attrs=attrs,
597              ctx=ctx)
598        else:
599          outputs = execute.execute_with_cancellation(
600              str(self.signature.name),
601              num_outputs=self._num_outputs,
602              inputs=args,
603              attrs=attrs,
604              ctx=ctx,
605              cancellation_manager=cancellation_manager)
606      # Replace empty list with None
607      outputs = outputs or None
608    else:
609      # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
610      # creates `PartitionedCallOp` kernels by default, or remove the previous
611      # branch if a TPU kernel is registered for `PartitionedCall`.
612      with _InterpolateFunctionError(self):
613        with ops.control_dependencies(self._control_captures):
614          # The caller must use record_operation to record this operation in the
615          # eager case, so we enforce the same requirement for the non-eager
616          # case by explicitly pausing recording. We don't have a gradient
617          # registered for PartitionedCall, so recording this operation confuses
618          # forwardprop code (GradientTape manages to ignore it).
619          with tape.stop_recording():
620            outputs = functional_ops.partitioned_call(
621                args=args,
622                f=self,
623                tout=self._output_types,
624                executing_eagerly=executing_eagerly,
625                config=config,
626                executor_type=executor_type)
627
628    for i, func_graph_output in enumerate(self._func_graph_outputs):
629      handle_data_util.copy_handle_data(func_graph_output, outputs[i])
630    if executing_eagerly:
631      return outputs
632    else:
633      # TODO(b/128924522): This additional set_shape should not be
634      # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this
635      # once that's done.
636      for i, shape in enumerate(self._output_shapes):
637        outputs[i].set_shape(shape)
638      return outputs
639
640
641def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
642  """Creates forward and backward functions from the function graphs."""
643  forward_function_name = _forward_name(forward_graph.name)
644  common_attributes = dict(attrs)
645  # NB: forward and backward function need to drop "_implements".
646  # attribute, because their signature contains all the intermediate tensors
647  # that they compute. Thus they don't have a stable signature which can
648  # be directly optimized downstream.
649  # See for more details:
650  # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
651  common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
652  backward_function_attr = _parse_func_attrs(
653      {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
654  backward_function_attr.update(common_attributes)
655  backward_function = ConcreteFunction(
656      backwards_graph, attrs=backward_function_attr)
657  forward_function_attr = _parse_func_attrs({
658      BACKWARD_FUNCTION_ATTRIBUTE_NAME:
659      backward_function.name})
660  forward_function_attr.update(common_attributes)
661  forward_function = _EagerDefinedFunction(
662      forward_function_name, forward_graph, forward_graph.inputs,
663      forward_graph.outputs, forward_function_attr)
664  return forward_function, backward_function
665
666
667class _DelayedRewriteGradientFunctions(object):
668  """Caches forward/backward functions with a delayed forward rewrite."""
669
670  def __init__(self, func_graph, attrs, func_graph_deleter):
671    """Construct an inference function and initialize caches."""
672    # A map from the number of forward function outputs with accepted gradients
673    # to forward and backward functions, used to cache non-tape backward
674    # function generation.
675    self._cached_function_pairs = {}
676    self._func_graph = func_graph
677    self._inference_function = _EagerDefinedFunction(
678        _inference_name(self._func_graph.name), self._func_graph,
679        self._func_graph.inputs, self._func_graph.outputs, attrs)
680    self._attrs = attrs
681    self._gradient_name = None
682    # Note that the FuncGraph is mutated later, so we need to inspect it now to
683    # figure out the user-specified outputs of the inference function.
684    self._num_inference_outputs = len(self._func_graph.outputs)
685    self._func_graph_deleter = func_graph_deleter
686
687  def forward_backward(self, num_doutputs=None):
688    """A possibly-cached pair of forward and backward functions."""
689    if num_doutputs is None:
690      num_doutputs = self._num_inference_outputs
691    forward_backward = self._cached_function_pairs.get(num_doutputs)
692    if forward_backward is not None:
693      return forward_backward
694    forward, backward = self._construct_forward_backward(num_doutputs)
695    self._cached_function_pairs[num_doutputs] = (forward, backward)
696    return forward, backward
697
698  def _construct_forward_backward(self, num_doutputs):
699    """Constructs a pair of forward and backward functions.
700
701    Args:
702      num_doutputs: The constructed backprop function will take output gradients
703        for the first `num_doutputs` outputs of the forward function. Defaults
704        to the number of outputs for the inference function, but when
705        higher-order gradients are computed this will increase to include side
706        outputs.
707
708    Returns:
709      A pair of (forward_function, backward_function):
710        forward_function: A re-generated inference function (an
711          _EagerDefinedFunction) to account for new side outputs, if any extra
712          were required when building the backward pass.
713        backward_function: A ConcreteFunction that Takes `num_doutputs`
714          arguments and returns gradients with respect to inputs of the forward
715          function.
716    """
717    trainable_outputs = [
718        output for output in self._func_graph.outputs[:num_doutputs]
719        if backprop_util.IsTrainable(output)]
720
721    signature = []
722    for t in trainable_outputs:
723      signature.append(
724          tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
725
726    def _backprop_function(*grad_ys):
727      with ops.device(None):
728        return gradients_util._GradientsHelper(  # pylint: disable=protected-access
729            trainable_outputs,
730            self._func_graph.inputs,
731            grad_ys=grad_ys,
732            src_graph=self._func_graph)
733
734    with self._func_graph.as_default():
735      backwards_graph = func_graph_module.FuncGraph(
736          _backward_name(self._func_graph.name))
737      func_graph_module.func_graph_from_py_func(
738          name=backwards_graph.name,
739          python_func=_backprop_function,
740          args=[], kwargs={},
741          signature=signature,
742          func_graph=backwards_graph)
743      backwards_graph_captures = backwards_graph.external_captures
744      captures_from_forward = [
745          c for c in backwards_graph_captures if
746          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
747
748      existing_outputs = object_identity.ObjectIdentitySet(
749          self._func_graph.outputs)
750      for capture in captures_from_forward:
751        if capture not in existing_outputs:
752          existing_outputs.add(capture)
753          self._func_graph.outputs.append(capture)
754
755      forward_function, backward_function = _create_forward_backward_with_graph(
756          self._attrs, self._func_graph, backwards_graph)
757      return forward_function, backward_function
758
759  def _rewrite_forward_and_call_backward(self, op, *doutputs):
760    """Add outputs to the forward call and feed them to the grad function."""
761    forward_function, backwards_function = self.forward_backward(len(doutputs))
762    if not backwards_function.outputs:
763      return backwards_function.structured_outputs
764    forward_function.add_to_graph(op.graph)
765
766    # pylint: disable=protected-access
767    # Rewrite an inference call op to be a forward call op
768    op._set_func_attr("f", forward_function.name)
769    op._set_type_list_attr("Tout", forward_function._output_types)
770    op._add_outputs(
771        forward_function._output_types[len(op.outputs):],
772        forward_function._output_shapes[len(op.outputs):])
773    for i in range(len(op.outputs)):
774      func_graph_output = forward_function._func_graph_outputs[i]
775      handle_data_util.copy_handle_data(func_graph_output, op.outputs[i])
776    # pylint: enable=protected-access
777
778    capture_mapping = dict(
779        zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
780    remapped_captures = [
781        capture_mapping.get(ops.tensor_id(capture), capture)
782        for capture in backwards_function.captured_inputs
783    ]
784
785    # Replace Nones with zeros since we're calling a graph function which
786    # expects numeric inputs.
787    cleaned_doutputs = []
788    for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
789      if backprop_util.IsTrainable(placeholder):
790        if isinstance(doutput, ops.IndexedSlices):
791          # Gradient passed to a backward ConcreteFunction must be tf.Tensor,
792          # so we convert tf.IndexedSlices to tf.Tensor.
793          cleaned_doutputs.append(ops.convert_to_tensor(doutput))
794        elif doutput is not None:
795          cleaned_doutputs.append(doutput)
796        else:
797          cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
798
799    # Compute the gradients using the side outputs
800    return backwards_function._call_flat(  # pylint: disable=protected-access
801        cleaned_doutputs, remapped_captures)
802
803  def get_gradient_function(self):
804    """Returns gradient function.
805
806    The gradient rewrites an inference call op to a forward call op, but does
807    not modify a pre-existing forward call op. It then computes the gradient
808    from the output's gradients and the side outputs of the forward op.
809    """
810    return self._rewrite_forward_and_call_backward
811
812  def forward(self, inference_args=None, input_tangents=None):
813    """A forward function with only user-specified outputs.
814
815    The call operation for the returned inference function can be rewritten into
816    a forward function. This only happens if the backward function (from the
817    `backward` method) ends up being used to compute gradients.
818
819    This approach avoids constructing unnecessary graphs, but it only works if
820    we are calling this function when not executing eagerly.
821
822    Args:
823      inference_args: A flat list of Tensors, arguments to the inference
824        function. Unused, but taken for compatibility with
825        _TapeGradientFunctions.
826      input_tangents: A flat list of Tensors, jvps associated with
827        `inference_args`. Unused; if required, tape functions must be used
828        instead.
829
830    Returns:
831      An _EagerDefinedFunction.
832    """
833    del inference_args  # unused
834    if input_tangents:
835      # This class does not support special-cased forwardprop. The arguments are
836      # here for compatibility with _TapeGradientFunctions.
837      raise errors.InternalError("unexpectedly got forwardprop information in "
838                                 "a class that does not support forwardprop.")
839    return self._inference_function
840
841  def _backward(self, outputs):
842    """Fetch a backward function for `outputs` from the forward function."""
843    def _backward_function(*args):
844      call_op = outputs[0].op
845      return self._rewrite_forward_and_call_backward(call_op, *args)
846    return _backward_function, outputs
847
848  def record(self, flat_outputs, inference_args, input_tangents):
849    """Record the function call operation.
850
851    _DelayedRewriteGradientFunctions supports only first-order backprop tape
852    gradients (and then only when graph building). It does not work with
853    higher-order tape gradients or forward autodiff, but does work with
854    higher-order symbolic gradients (tf.gradients).
855
856    Args:
857      flat_outputs: The result of running `forward`.
858      inference_args: A flat list of Tensors with inference inputs to the
859        operation.
860      input_tangents: A flat list of Tensors with input tangents consumed by the
861        operation.
862    """
863    backward_function, to_record = self._backward(flat_outputs)
864    tape.record_operation(self._inference_function.signature.name,
865                          to_record, inference_args + input_tangents,
866                          backward_function)
867
868
869# Contains information about a forward function wrapped to compute jvps.
870_ForwardWrapper = collections.namedtuple(
871    "_ForwardWrapper", (
872        # The wrapper Graph.
873        "graph",
874        # A flat list of non-tangent Tensor outputs from the wrapped forward
875        # function.
876        "outputs",
877        # Indices for output tangents, same format as
878        # forwardprop_util.pack_tangents.
879        "output_indices",
880        # A flat list of tangents for `outputs`.
881        "output_tangents"))
882
883
884class _TapeGradientFunctions(object):
885  """Caches forward and backward functions compatible with eager gradients.
886
887  In contrast to the delayed-rewrite approach in
888  `_DelayedRewriteGradientFunctions` which only works with delayed execution,
889  the forward function generated by this class has a fixed set of outputs which
890  may be preserved by a tape in order to compute gradients later.
891
892  This class is abstract; its child classes differ in how many side outputs of
893  the forward function their backward function accepts gradients for, which
894  determines whether higher-order tape gradients are possible.
895  """
896
897  def __init__(self, func_graph, attrs, func_graph_deleter,
898               forwardprop_input_indices, delayed_rewrite_functions,
899               need_gradients_for_jvps):
900    self._func_graph = func_graph
901    self._forward_graph = None
902    self._attrs = attrs
903    self._forward = None
904    self._backward = None
905    self._num_outputs = len(func_graph.outputs)
906    self._func_graph_deleter = func_graph_deleter
907    self._forwardprop_input_indices = forwardprop_input_indices
908    self._forwardprop_output_indices = None
909    self._num_forwardprop_outputs = 0
910    self._num_inference_outputs = len(func_graph.outputs)
911    self._num_trainable_inference_outputs = len(
912        [t for t in func_graph.outputs if backprop_util.IsTrainable(t)])
913    self._delayed_rewrite_functions = delayed_rewrite_functions
914    self._need_gradients_for_jvps = need_gradients_for_jvps
915
916  def _build_functions_for_outputs(
917      self, outputs, inference_args, input_tangents):
918    """Forward+backward functions where the backward function sees `outputs`."""
919    # First figure out which of `outputs` are trainable. We'll accept gradients
920    # for each of these in the backward function.
921    handles_to_variables = self._func_graph.variable_captures
922    trainable_outputs = []
923    trainable_indices = []
924    for index, output in enumerate(outputs):
925
926      if backprop_util.IsTrainable(output):
927        # Swap in the Variable object for resource handles if we can so
928        # sparse gradients work.
929        output = handles_to_variables.get(id(output), output)
930        trainable_outputs.append(output)
931        trainable_indices.append(index)
932
933    backwards_graph = func_graph_module.FuncGraph(
934        _backward_name(self._func_graph.name))
935    with backwards_graph.as_default():
936      gradients_wrt_outputs = []
937      for output in trainable_outputs:
938        gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
939            output)
940        gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
941        handle_data_util.copy_handle_data(output, gradient_placeholder)
942        gradients_wrt_outputs.append(gradient_placeholder)
943      with ops.device(None):
944        gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
945            trainable_outputs,
946            self._func_graph.inputs,
947            grad_ys=gradients_wrt_outputs,
948            src_graph=self._func_graph)
949
950      if input_tangents:
951        # Convert IndexedSlices to dense tensors (as we do elsewhere for
952        # function gradients). Our C++ bindings don't know how to handle them
953        # currently.
954        gradients_wrt_inputs = nest.map_structure(
955            lambda x: ops.convert_to_tensor(x) if x is not None else None,
956            gradients_wrt_inputs)
957      captures_from_forward = [
958          c for c in backwards_graph.external_captures
959          if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
960      ]
961      existing_outputs = object_identity.ObjectIdentitySet(
962          self._func_graph.outputs)
963      for capture in captures_from_forward:
964        if capture not in existing_outputs:
965          existing_outputs.add(capture)
966          self._func_graph.outputs.append(capture)
967
968    # The ordering of `backwards_graph.inputs` is important: inputs of
969    # `backward_function` correspond to outputs (including
970    # side outputs) of `self._tape_forward_function`.
971    backwards_graph.inputs = (
972        gradients_wrt_outputs + backwards_graph.internal_captures)
973    backwards_graph.outputs.extend(
974        grad
975        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
976        if grad is not None)
977    backwards_graph.structured_outputs = gradients_wrt_inputs
978
979    forward_function, backward_function = _create_forward_backward_with_graph(
980        self._attrs, self._func_graph, backwards_graph)
981
982    if not input_tangents:
983      # There is no need to special-case forwardprop, so we can return the
984      # forward+backward pair we've created without further wrapping.
985      return (forward_function, self._func_graph, backward_function,
986              # No forwardprop outputs.
987              None, 0)
988    forward_wrapper = self._wrap_forward_function_with_jvps(
989        forward_function, backward_function, inference_args, input_tangents)
990    (wrapped_backwards_graph,
991     forward_wrapper) = self._wrap_backward_function_with_jvp_backprop(
992         backward_function, gradients_wrt_outputs, forward_wrapper)
993    # Now that we've added new captures, we need to make sure forward outputs
994    # are in the same order the backward function expects them to be in:
995    # [inference outputs] + [jvps] + [side outputs] + [captures].
996    forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
997    (wrapped_forward_function,
998     wrapped_backward_function) = _create_forward_backward_with_graph(
999         self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
1000    if (len(inference_args) + len(input_tangents)
1001        != len(forward_wrapper.graph.inputs)):
1002      raise errors.InternalError(
1003          f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, "
1004          f"but we expected {len(inference_args) + len(input_tangents)} "
1005          f"({len(inference_args)} inference inputs and "
1006          f"{len(input_tangents)} input tangents).")
1007    return (wrapped_forward_function, forward_wrapper.graph,
1008            wrapped_backward_function, forward_wrapper.output_indices,
1009            len(forward_wrapper.output_tangents))
1010
1011  def _wrap_forward_function_with_jvps(
1012      self, forward_function, backward_function,
1013      inference_args, input_tangents):
1014    """Adds inline JVP computation to a forward function."""
1015    forward_wrapper_graph = func_graph_module.FuncGraph(
1016        _forward_name(self._func_graph.name))
1017    with forward_wrapper_graph.as_default():
1018      # Tell forward accumulators to free up space for new JVP computations,
1019      # since one may be in the process of computing a JVP (if that computation
1020      # triggered this function building).
1021      #
1022      # We'll make symbolic versions of input JVPs, run the forward function
1023      # under forward accumulators to get symbolic output JVPs, then set those
1024      # as outputs of the new wrapped forward function.
1025      with forwardprop_util.push_forwardprop_state():
1026        forward_captures = {
1027            ops.tensor_id(internal): external
1028            for external, internal in self._func_graph.captures}
1029        for input_index, real_input in enumerate(self._func_graph.inputs):
1030          # This loop is more or less equivalent to running tf.identity on each
1031          # of self._func_graph.inputs. However, doing that also captures jvps
1032          # for resource handles, which confuses the jvp capturing code below
1033          # (since primal inputs are interwoven with jvp inputs).
1034          input_placeholder = array_ops.placeholder(
1035              dtype=real_input.dtype,
1036              shape=real_input.shape)
1037          capture = forward_captures.get(ops.tensor_id(real_input))
1038          if capture is not None:
1039            forward_wrapper_graph.add_capture(capture, input_placeholder)
1040            if capture.dtype == dtypes.resource:
1041              handle_data_util.copy_handle_data(capture, input_placeholder)
1042          else:
1043            forward_wrapper_graph.inputs.append(input_placeholder)
1044        for inp, arg in zip(forward_wrapper_graph.inputs, inference_args):
1045          tape.record_operation(
1046              "captured_value", [inp], [arg],
1047              backward_function=lambda x: [x],
1048              forward_function=lambda x: [x])
1049        num_inference_inputs = len(inference_args)
1050        for tape_indices in self._forwardprop_input_indices:
1051          for input_index, jvp_index in tape_indices:
1052            input_placeholder = forward_wrapper_graph.inputs[input_index]
1053            if len(forward_wrapper_graph.inputs) != jvp_index:
1054              raise errors.InternalError(
1055                  f"Expected {jvp_index} forward graph inputs, "
1056                  f"got {len(forward_wrapper_graph.inputs)}.")
1057            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
1058                input_placeholder)
1059            jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
1060            external_jvp = input_tangents[jvp_index - num_inference_inputs]
1061            forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder)
1062            tensor_shape.TensorShape(
1063                external_jvp.shape).assert_is_compatible_with(
1064                    jvp_placeholder.shape)
1065            tape.record_operation(
1066                "captured_value",
1067                [jvp_placeholder],
1068                [external_jvp],
1069                backward_function=lambda x: [x],
1070                forward_function=lambda x: [x])
1071        forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs]
1072        gradient_function = (
1073            self._delayed_rewrite_functions._rewrite_forward_and_call_backward)  # pylint: disable=protected-access
1074        with ops.get_default_graph()._override_gradient_function(  # pylint: disable=protected-access
1075            {"PartitionedCall": gradient_function,
1076             "StatefulPartitionedCall": gradient_function}):
1077          forward_outputs = forward_function.call(context.context(),
1078                                                  forward_inputs)
1079          if isinstance(forward_outputs, ops.Operation):
1080            # _wrapped_backward_function expects a list, but if the function has
1081            # no outputs its call() returns an Operation. We need to undo that
1082            # so we don't cause problems later.
1083            forward_outputs = []
1084        py_backward, _ = self._wrap_backward_function(
1085            self._func_graph, backward_function, forward_outputs)
1086      # We will never request backward tape gradients for this operation
1087      # directly since we're wrapping the call; forwardprop will call the
1088      # backward function (and nested forward accumulators may build
1089      # higher-order gradients), but any watching GradientTapes should ignore
1090      # it.
1091      #
1092      # TODO(allenl): It might be better to explicitly stop backward recording
1093      # so we don't use the second-order tape cases unnecessarily.
1094      tape.record_operation_forwardprop_only(
1095          forward_function.signature.name,
1096          forward_outputs, forward_inputs, py_backward, None)
1097      output_indices, output_tangents = (
1098          pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
1099      output_tangents = [forward_wrapper_graph.capture(t)
1100                         for t in output_tangents]
1101    return _ForwardWrapper(
1102        graph=forward_wrapper_graph, outputs=forward_outputs,
1103        output_indices=output_indices, output_tangents=output_tangents)
1104
1105  def _wrap_backward_function_with_jvp_backprop(
1106      self, backward_function, gradients_wrt_outputs, forward_wrapper):
1107    """Wraps `backward_function` to include gradients for JVPs."""
1108    wrapped_backwards_graph = func_graph_module.FuncGraph(
1109        _backward_name(self._func_graph.name))
1110    with wrapped_backwards_graph.as_default():
1111      py_backward, recorded_outputs = self._wrap_backward_function(
1112          self._func_graph, backward_function, forward_wrapper.outputs)
1113      trainable_index = 0
1114      forward_doutputs = []
1115      doutput_args = []
1116      for output in recorded_outputs:
1117        if backprop_util.IsTrainable(output):
1118          doutput = gradients_wrt_outputs[trainable_index]
1119          doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape)
1120          doutput_args.append(doutput_placeholder)
1121          forward_doutputs.append(doutput_placeholder)
1122          trainable_index += 1
1123        else:
1124          doutput_args.append(None)
1125
1126      dinputs = py_backward(*doutput_args)
1127      existing_outputs = object_identity.ObjectIdentitySet(
1128          forward_wrapper.outputs + forward_wrapper.output_tangents)
1129      num_processed_output_tangents = 0
1130      gradients_wrt_output_tangents = []
1131      tangent_doutputs = []
1132      output_tangents = forward_wrapper.output_tangents
1133      output_indices = forward_wrapper.output_indices
1134      if self._need_gradients_for_jvps:
1135        # TODO(allenl): Consider using a throwaway graph to avoid extra gradient
1136        # evaluations; gradients for jvps may have common subgraphs.
1137        while num_processed_output_tangents != len(output_tangents):
1138          for output in output_tangents[num_processed_output_tangents:]:
1139            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
1140                output)
1141            placeholder = graph_placeholder(gradient_dtype, gradient_shape)
1142            gradients_wrt_output_tangents.append(placeholder)
1143            tangent_doutputs.append(placeholder)
1144          num_processed_output_tangents = len(output_tangents)
1145          with ops.device(None):
1146            gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
1147                output_tangents,
1148                forward_wrapper.graph.inputs,
1149                grad_ys=gradients_wrt_output_tangents,
1150                src_graph=forward_wrapper.graph)
1151          dinputs = [
1152              backprop.aggregate_indexed_slices_gradients((existing, new))
1153              for existing, new in zip(dinputs, gradients_wrt_inputs)
1154              if existing is not None or new is not None]
1155          dinputs.extend(gradients_wrt_inputs[len(dinputs):])
1156          captures_from_forward = [
1157              c for c in wrapped_backwards_graph.external_captures
1158              if (not isinstance(c, ops.EagerTensor)
1159                  and c.graph is forward_wrapper.graph)]
1160          for capture in captures_from_forward:
1161            if capture not in existing_outputs:
1162              existing_outputs.add(capture)
1163              forward_wrapper.outputs.append(capture)
1164          output_indices, output_tangents = (
1165              forwardprop_util.pack_tangents(forward_wrapper.outputs))
1166          output_tangents = [forward_wrapper.graph.capture(t)
1167                             for t in output_tangents]
1168          for t in output_tangents:
1169            existing_outputs.add(t)
1170    wrapped_backwards_graph.inputs = (
1171        forward_doutputs[:self._num_trainable_inference_outputs]
1172        + tangent_doutputs
1173        + forward_doutputs[self._num_trainable_inference_outputs:]
1174        + wrapped_backwards_graph.internal_captures)
1175    wrapped_backwards_graph.structured_outputs = dinputs
1176    wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None]
1177    return (wrapped_backwards_graph,
1178            forward_wrapper._replace(output_indices=output_indices,
1179                                     output_tangents=output_tangents))
1180
1181  def _shuffle_forward_outputs(self, forward_wrapper):
1182    """Reorders function outputs so captures are last."""
1183    def _index_map(original):
1184      if original < self._num_inference_outputs:
1185        return original
1186      if original >= len(forward_wrapper.outputs):
1187        return (original - len(forward_wrapper.outputs)
1188                + self._num_inference_outputs)
1189      return original + len(forward_wrapper.output_tangents)
1190    output_indices = nest.map_structure(
1191        _index_map, forward_wrapper.output_indices)
1192    forward_wrapper.graph.outputs = (
1193        forward_wrapper.outputs[:self._num_inference_outputs]
1194        + forward_wrapper.output_tangents
1195        + forward_wrapper.outputs[self._num_inference_outputs:])
1196    return forward_wrapper._replace(output_indices=output_indices)
1197
1198  def forward(self, inference_args, input_tangents):
1199    """Construct or fetch a forward function with side-outputs.
1200
1201    When graph building without a tape active, symbolic gradients rely on
1202    regenerating the backward function for higher-order gradients (to account
1203    for new side outputs of the rewritten forward function call). Thus there is
1204    no fixed backward function for this case. However, when a tape is active
1205    (eager or graph building), we generate fixed backward and forward functions
1206    at forward function call time.
1207
1208    This difference between the tape and non-tape cases is to avoid building
1209    unneeded backward functions while graph building (where we may or may not
1210    eventually need gradients).
1211
1212    Args:
1213      inference_args: A flat list of Tensors, arguments to the inference
1214        function.
1215      input_tangents: A flat list of Tensors, jvps associated with
1216        `inference_args`.
1217
1218    Returns:
1219      A forward _EagerDefinedFunction.
1220    """
1221    if self._forward is None:
1222      (self._forward, self._forward_graph, self._backward,
1223       self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
1224           self._forward_and_backward_functions(inference_args, input_tangents))
1225    return self._forward
1226
1227  def _wrap_backward_function(self, forward_graph, backward, outputs):
1228    """Create a backward function given `outputs` from the forward function."""
1229    capture_mapping = dict(
1230        zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
1231    captured_inputs = backward.captured_inputs
1232    remapped_captures = [
1233        capture_mapping.get(ops.tensor_id(capture), capture)
1234        for capture in captured_inputs
1235    ]
1236    if any(t.graph is forward_graph for t in remapped_captures
1237           if not isinstance(t, ops.EagerTensor)):
1238      incorrect_mapping = [t for t in remapped_captures
1239                           if (not isinstance(t, ops.EagerTensor) and
1240                               t.graph is not forward_graph)]
1241      raise errors.InternalError("Failed to map all backward graph captures to "
1242                                 "the forward graph. Incorrectly mapped: "
1243                                 f"{incorrect_mapping}.")
1244    # We may need to use zeros_like to get a zero for variant Tensors with
1245    # unconnected gradients. We do that in advance so we don't have to hold on
1246    # to the outputs themselves, which may not be needed otherwise.
1247    variant_zeros_like = {}
1248    backward_function_inputs = (len(backward.inputs) - len(captured_inputs))
1249    recorded_outputs = []
1250    trainable_recorded_outputs = 0
1251    skip_positions = []
1252    if self._num_forwardprop_outputs and not self._need_gradients_for_jvps:
1253      relevant_outputs = (
1254          outputs[:self._num_inference_outputs]
1255          + outputs[self._num_inference_outputs
1256                    + self._num_forwardprop_outputs:])
1257    else:
1258      relevant_outputs = outputs
1259    for output_index, output in enumerate(relevant_outputs):
1260      if trainable_recorded_outputs < backward_function_inputs:
1261        recorded_outputs.append(output)
1262      if backprop_util.IsTrainable(output):
1263        trainable_recorded_outputs += 1
1264      else:
1265        skip_positions.append(output_index)
1266      if output.dtype == dtypes.variant:
1267        variant_zeros_like[output_index] = default_gradient.zeros_like(output)
1268
1269    def _backward_function_wrapper(*args):
1270      """Process output gradients and call the backward function."""
1271      if not backward.outputs:
1272        return backward.structured_outputs
1273
1274      processed_args = []
1275      input_index = 0
1276      for output_index, arg in enumerate(args):
1277        # Convert IndexedSlices to dense tensors. The IndexedSlices optimization
1278        # is only really effective when doing tf.gather(variable) as the
1279        # adjoint functions for most operations are unlikely to preserve the
1280        # sparsity in IndexedSlices.
1281        if isinstance(arg, ops.IndexedSlices):
1282          arg = ops.convert_to_tensor(arg)
1283        if output_index in skip_positions:
1284          continue
1285        if arg is None:
1286          # We're calling a (non-polymorphic) ConcreteFunction, so we need to
1287          # have a Tensor value for each Tensor we thought would be trainable
1288          # based on its dtype, even if it ended up being unconnected.
1289          input_placeholder = backward.inputs[
1290              input_index]
1291          if input_placeholder.dtype == dtypes.variant:
1292            arg = variant_zeros_like[output_index]
1293          else:
1294            arg = array_ops.zeros(
1295                *default_gradient.shape_and_dtype(input_placeholder))
1296        processed_args.append(arg)
1297        input_index += 1
1298        if input_index >= backward_function_inputs:
1299          break
1300      return backward._call_flat(  # pylint: disable=protected-access
1301          processed_args, remapped_captures)
1302
1303    return _backward_function_wrapper, recorded_outputs
1304
1305  def record(self, flat_outputs, inference_args, input_tangents):
1306    """Record the function call operation.
1307
1308    For backprop, indicates the backward function to use and which new Tensors
1309    must be watched. For forwardprop from eager, the function call itself will
1310    have produced tangents which need to be recorded.
1311
1312    Args:
1313      flat_outputs: The result of running `forward`.
1314      inference_args: A flat list of Tensors with inference inputs to the
1315        operation.
1316      input_tangents: A flat list of Tensors with input tangents consumed by the
1317        operation.
1318    """
1319    backward_function, to_record = self._wrap_backward_function(
1320        self._forward_graph, self._backward, flat_outputs)
1321    if self._forwardprop_output_indices:
1322      tape.record_operation_backprop_only(
1323          self._forward.signature.name,
1324          to_record, inference_args,
1325          backward_function)
1326      tape.record_operation_forwardprop_only(
1327          self._forward.signature.name,
1328          flat_outputs, inference_args + input_tangents,
1329          backward_function,
1330          self._forwardprop_output_indices)
1331    else:
1332      tape.record_operation(self._forward.signature.name,
1333                            to_record, inference_args + input_tangents,
1334                            backward_function)
1335
1336
1337class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
1338  """Caches tape-friendly functions for first-order gradients."""
1339
1340  def __init__(self, func_graph, attrs, func_graph_deleter,
1341               forwardprop_input_indices, delayed_rewrite_functions,
1342               need_gradients_for_jvps):
1343    super(_FirstOrderTapeGradientFunctions, self).__init__(
1344        func_graph, attrs, func_graph_deleter, forwardprop_input_indices,
1345        delayed_rewrite_functions, need_gradients_for_jvps)
1346    self._func_graph_deleter = func_graph_deleter
1347    self._forwardprop_input_indices = forwardprop_input_indices
1348
1349  def _forward_and_backward_functions(self, inference_args, input_tangents):
1350    """Shortcut for when only first-order gradients are required.
1351
1352    The returned backward function does not accept gradients with respect to
1353    side output of forward_function. This is fine as long as the user can't
1354    possibly request second order tape gradients, as when they've used a single
1355    non-persistent GradientTape. Since we don't need the backward function to
1356    take gradients with respect to side outputs, we can skip some potentially
1357    slow graph building.
1358
1359    Args:
1360      inference_args: A flat list of Tensors, arguments to the inference
1361        function.
1362      input_tangents: A flat list of Tensors, jvps associated with
1363        `inference_args`.
1364
1365    Returns:
1366      A tuple of (forward_function, backward_function):
1367        forward_function: Takes the same inputs as the inference function, but
1368          returns side outputs used by backward_function in addition to the
1369          inference function's outputs.
1370        backward_function: Takes side outputs from forward_function and
1371          gradients with respect to the "real" outputs of forward_function and
1372          returns gradients with respect to the inputs.
1373    """
1374    outputs = self._func_graph.outputs[:self._num_inference_outputs]
1375    return self._build_functions_for_outputs(
1376        outputs, inference_args, input_tangents)
1377
1378
1379class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
1380  """Caches tape-friendly functions for higher-order gradients."""
1381
1382  # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
1383  # generalizing if so.
1384  def _forward_and_backward_functions(self, inference_args, input_tangents):
1385    """Forward and backward functions suitable for higher-order gradients.
1386
1387    Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
1388    this method accepts gradients for all of the outputs of the returned forward
1389    function, including side outputs.
1390
1391    Args:
1392      inference_args: A flat list of Tensors, arguments to the inference
1393        function.
1394      input_tangents: A flat list of Tensors, jvps associated with
1395        `inference_args`.
1396
1397    Returns:
1398      A tuple of (forward_function, backward_function):
1399        forward_function: Takes the same inputs as the inference function, but
1400          returns side outputs used by backward_function in addition to the
1401          inference function's outputs.
1402        backward_function: Takes side outputs from forward_function and
1403          gradients with respect to all of its outputs, real and side. Returns
1404          gradients with respect to the inputs.
1405    """
1406    outputs = []
1407    iteration_count = 0
1408    # First we need to figure out how many side outputs from the forward pass
1409    # will be required. We do this in a temporary graph to avoid actually
1410    # running multiple copies of the backward pass (one per _GradientsHelper
1411    # call).
1412    #
1413    # While computing gradients, the backward function captures Tensors from
1414    # the forward function. We add these as side outputs of the original
1415    # function. However, we then need to accept output gradients with respect
1416    # to these side outputs for higher order gradients to work. Thus we loop
1417    # until the number of outputs of the function stabilizes. Note that this
1418    # is only required for tape gradients, where we need to declare in advance
1419    # all of the forward op's outputs: symbolic gradients with tf.gradients
1420    # instead rely on regenerating backward functions when higher-order
1421    # gradients are requested.
1422    while (len(outputs) < len(self._func_graph.outputs)
1423           # It's possible for gradient generation to add new ops to the forward
1424           # pass. If all of the new outputs are non-trainable, there's no
1425           # reason to continue.
1426           and any(backprop_util.IsTrainable(output)
1427                   for output in self._func_graph.outputs[len(outputs):])):
1428      iteration_count += 1
1429      if iteration_count >= 20 and iteration_count % 5 == 0:
1430        new_op_with_trainable_output = None
1431        num_new_trainable_outputs = 0
1432        for output in self._func_graph.outputs[len(outputs):]:
1433          if backprop_util.IsTrainable(output):
1434            num_new_trainable_outputs += 1
1435            new_op_with_trainable_output = output.op
1436        logging.warning(
1437            ("Determining side outputs for the function '{}' is taking longer "
1438             "than expected ({} iterations, typically this converges in 5 or "
1439             "so). This could indicate that a gradient registration is adding "
1440             "new ops to the forward pass every time gradients are generated. "
1441             "{} new trainable output(s) were added this iteration, one from "
1442             "the following op:\n {}\nThis may indicate a TensorFlow bug, or "
1443             "an issue in a tf.custom_gradient.")
1444            .format(
1445                self._func_graph.name, iteration_count,
1446                num_new_trainable_outputs, new_op_with_trainable_output))
1447      outputs = list(self._func_graph.outputs)
1448      self._build_functions_for_outputs(
1449          outputs, inference_args, input_tangents)
1450
1451    (forward_function, forward_graph,
1452     backward_function, output_indices, num_output_tangents) = (
1453         self._build_functions_for_outputs(
1454             outputs, inference_args, input_tangents))
1455    if (len(self._func_graph.outputs) > len(outputs)
1456        and any(backprop_util.IsTrainable(output)
1457                for output in self._func_graph.outputs[len(outputs):])):
1458      raise errors.InternalError(
1459          "Unexpectedly added new outputs to the forward function when "
1460          "building the backward function: "
1461          f"{self._func_graph.outputs[len(outputs):]}.")
1462    return (forward_function, forward_graph, backward_function, output_indices,
1463            num_output_tangents)
1464
1465
1466class _ForwardBackwardCall(object):
1467  """Holds the state of a function call between execution and recording."""
1468
1469  __slots__ = [
1470      "_functions", "_inference_args", "_input_tangents", "_tape_watching"
1471  ]
1472
1473  def __init__(self, functions, inference_args, input_tangents, tape_watching):
1474    """Collects information about the function call.
1475
1476    Args:
1477      functions: An object which produces forward and backward functions, either
1478        a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object.
1479      inference_args: A flat list of Tensors, arguments to the inference
1480        function.
1481      input_tangents: A flat list of Tensors, jvps associated with
1482        `inference_args`.
1483      tape_watching: Boolean, with True indicating that recording is necessary.
1484    """
1485    self._functions = functions
1486    self._inference_args = inference_args
1487    self._input_tangents = input_tangents
1488    self._tape_watching = tape_watching
1489
1490  def forward(self):
1491    """Builds or retrieves a forward function for this call."""
1492    forward_function = self._functions.forward(
1493        self._inference_args, self._input_tangents)
1494    return forward_function, self._inference_args + self._input_tangents
1495
1496  def record(self, flat_outputs):
1497    """Given outputs from the execution of `forward`, records the operation."""
1498    if (self._tape_watching
1499        and not isinstance(flat_outputs, ops.Operation)
1500        and flat_outputs is not None):
1501      # We only record function calls which have outputs, and then only when a
1502      # tape is watching.
1503      self._functions.record(
1504          flat_outputs, self._inference_args, self._input_tangents)
1505
1506
1507# Sentinel value used by with ConcreteFunction's structured signature to
1508# indicate that a non-tensor parameter should use the value that was
1509# specified when the concrete function was created.
1510_BOUND_VALUE = object()
1511
1512
1513class ConcreteFunction(core.ConcreteFunction):
1514  """A `tf.types.experimental.ConcreteFunction` created from `tf.function`."""
1515
1516  def __init__(self,
1517               func_graph,
1518               attrs=None,
1519               shared_func_graph=True,
1520               function_spec=None):
1521    """Initialize a `ConcreteFunction`.
1522
1523    Args:
1524      func_graph: An instance of FuncGraph: the function body to wrap.
1525      attrs: (optional) dict mapping names of attributes to their AttrValue
1526        values. Attributes in `attrs` will be included in this function's
1527        definition.
1528     shared_func_graph: If False, the ConcreteFunction takes ownership of
1529       `func_graph` and will break reference cycles when it is deleted. This
1530       makes the FuncGraph inoperable.
1531     function_spec: FunctionSpec for the original function.  If not specified,
1532       then this ConcreteFunction may only be called using the flat signature.
1533
1534    Raises:
1535      ValueError: If number of input_placeholders is not equal to the number
1536        of function inputs.
1537    """
1538    # _arg_keywords and _num_positional_args define the flat signature.  They
1539    # are assigned after construction.
1540    self._arg_keywords = None
1541    self._num_positional_args = None
1542
1543    self._func_graph = func_graph
1544    self._captured_inputs = self._func_graph.external_captures
1545    self._captured_closures = self._func_graph.deferred_external_captures
1546
1547    # function_spec defines the structured signature.
1548    self._set_function_spec(function_spec)
1549
1550    if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
1551      # The alternative is to silently drop "implements" tag
1552      # but it seems likely it would lead to hard to catch bugs.
1553      # Another alternative is to make func_body to preserve the order
1554      # of arguments if variables are present. Yet another option
1555      # is to automatically replace variables as arguments to functions
1556      # to v.read_value() whenever "implements" tag is present
1557      # Anytime we annotate existing function we probably want to wrap
1558      # it with safe read_value for backward compatibility.
1559      has_resource_vars = any(inp.dtype == dtypes.resource
1560                              for inp in self.inputs)
1561
1562      assert not any(
1563          (has_resource_vars, self._captured_inputs, self._captured_closures)
1564      ), ('Function {name} has "{attr}={value}" attribute and thus can not '
1565          "depend on any tensors outside of its signature or modify variables. "
1566          "\n\nNote: variables are always captured and cause function "
1567          "re-tracing for every variable called.\n"
1568          "  inputs: {inputs}\n  captures: {captured}\n"
1569          "  closures: {closures}.\n\n"
1570          "To pass a variable to such function use  "
1571          "use variable.read_value().".format(
1572              name=func_graph.name,
1573              attr=IMPLEMENTS_ATTRIBUTE_NAME,
1574              value=attrs[IMPLEMENTS_ATTRIBUTE_NAME],
1575              inputs=self.inputs,
1576              captured=self._captured_inputs,
1577              closures=self._captured_closures))
1578    self._output_shapes = tuple(
1579        output.shape for output in self._func_graph.outputs)
1580    self._attrs = _parse_func_attrs(attrs or {})
1581
1582    if shared_func_graph:
1583      self._garbage_collector = None
1584    else:
1585      self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph)
1586
1587    # Pairs of forward and backward functions used for computing gradients.
1588    #
1589    # These each get a reference to the FuncGraph deleter since they use the
1590    # FuncGraph directly.
1591    self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
1592        func_graph, self._attrs, self._garbage_collector)
1593    self._first_order_tape_functions = {}
1594    self._higher_order_tape_functions = {}
1595    # Cache the inference function to avoid a (Python) function call when not
1596    # building gradients.
1597    self._inference_function = self._delayed_rewrite_functions.forward()
1598
1599  def _set_function_spec(self, function_spec):
1600    """Enables the structured signature by supplying a function_spec."""
1601    self._function_spec = None
1602    self._pre_initialized_function_spec = function_spec
1603
1604    # Note: when ConcreteFunctions are built by recreate_function() in
1605    # function_deserialization.py, they don't have a structured_input_signature
1606    # yet.  In that case, _initialize_function_spec() gets called by
1607    # _setup_functions_structures() in load.py.
1608    if (function_spec is not None and
1609        self.structured_input_signature is not None):
1610      self._initialize_function_spec()
1611
1612  def _initialize_function_spec(self):
1613    """Updates `self._function_spec` to include varargs and bound variables.
1614
1615    Adds new positional arguments for any varargs (i.e., for args that are
1616    in `structured_input_signature`, but not in the original fullargspec.args).
1617
1618    Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
1619    all args and kwargs in `structured_input_signature`.
1620
1621    Sets `varkw` and `varargs` to None.
1622    """
1623    if self._pre_initialized_function_spec is None:
1624      return  # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
1625    assert not self._function_spec, "already initialized"
1626    function_spec = self._pre_initialized_function_spec
1627    args = function_spec.fullargspec.args
1628    arg_specs, kwarg_specs = self.structured_input_signature
1629    vararg_indices = range(len(function_spec.arg_names), len(arg_specs))
1630    fullargspec = tf_inspect.FullArgSpec(
1631        args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices],
1632        varargs=None,
1633        varkw=None,
1634        defaults=[_BOUND_VALUE] * len(arg_specs),
1635        kwonlyargs=list(sorted(kwarg_specs)),
1636        kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
1637        annotations=function_spec.fullargspec.annotations)
1638    self._function_spec = FunctionSpec(
1639        fullargspec,
1640        function_spec.is_method,
1641        function_spec.input_signature,
1642        function_spec.is_pure,
1643        name=self._func_graph.name)
1644
1645  @property
1646  def variables(self):
1647    """Sequence of variables for this function."""
1648    return tuple(self._func_graph.variables)
1649
1650  @property
1651  def trainable_variables(self):
1652    """Sequence of trainable variables for this function."""
1653    return tuple(self._func_graph.trainable_variables)
1654
1655  def __call__(self, *args, **kwargs):
1656    """Executes the wrapped function.
1657
1658    ConcreteFunctions have two signatures:
1659
1660    * The signature of the original function wrapped by this ConcreteFunction.
1661    * A flat signature, where each argument accepts a single Tensor.
1662
1663    The original function signature is generally preferred, but the flat input
1664    signature is supported for backward compatibility.
1665
1666    ### Original Function Signature
1667
1668    When calling a ConcreteFunction with the signature of the original function,
1669    each argument must match the type or value that was used when the
1670    ConcreteFunction's graph was traced.  In particular:
1671
1672    * Tensor arguments (including CompositeTensors, such as RaggedTensor) must
1673      have matching `TypeSpec`s.
1674    * Non-Tensor arguments (such as booleans or ints) must have equal values.
1675    * Nested arguments (such as lists, tuples, or dictionaries) must have the
1676      same nesting structure; and each nested value must have a matching type
1677      or value.
1678
1679    The default value for any arguments that were traced with non-Tensor values
1680    is the value that was used in the trace.  Arguments that were traced with
1681    tensor arguments do not have a default value (even if the original function
1682    had a default value for that argument).
1683
1684    ### Flat Signature
1685
1686    When calling a ConcreteFunction with the flat signature, the arguments
1687    correspond to the flattened component tensors of the arguments that were
1688    used to construct the ConcreteFunction.  Parameter names are assigned based
1689    on `TensorSpec.name` (when specified) or the original argument names (with
1690    suffixes automatically added for nested arguments or composite tensors with
1691    multiple components).
1692
1693    Args:
1694      *args: Positional arguments to the concrete function.
1695      **kwargs: Keyword arguments to the concrete function.
1696
1697    Returns:
1698      The result of applying the TF function on the given Tensors.
1699
1700    Raises:
1701      AssertionError: If this `ConcreteFunction` was not created through
1702        `get_concrete_function`.
1703      TypeError: If the arguments do not match the function's signature.
1704    """
1705    return self._call_impl(args, kwargs)
1706
1707  def _call_impl(self, args, kwargs, cancellation_manager=None):
1708    """See `__call__` for details."""
1709    with trace.Trace(self._func_graph.name, tf_function_call="concrete"):
1710      # Construct the list of input tensors: check if the structured signature
1711      # applies first; and if not, then use the flat signature.
1712      if self._function_spec is not None:
1713        try:
1714          return self._call_with_structured_signature(args, kwargs,
1715                                                      cancellation_manager)
1716        except TypeError as structured_err:
1717          try:
1718            return self._call_with_flat_signature(args, kwargs,
1719                                                  cancellation_manager)
1720          except TypeError:
1721            raise structured_err
1722
1723      return self._call_with_flat_signature(args, kwargs, cancellation_manager)
1724
1725  def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
1726    """Executes the wrapped function with the flat signature.
1727
1728    Args:
1729      args: Positional arguments to the concrete function.
1730      kwargs: Keyword arguments to the concrete function.
1731      cancellation_manager: A `CancellationManager` that can be used to cancel
1732        function invocation.
1733
1734    Returns:
1735      The result of applying the function on the Tensors/Variables contained in
1736      `args` and `kwargs`.
1737    Raises:
1738      TypeError: if `args` and `kwargs` do not match the flat signature of this
1739        `ConcreteFunction`.
1740    """
1741    if len(args) > self._num_positional_args:
1742      raise TypeError(
1743          f"{self._flat_signature_summary()} takes {self._num_positional_args} "
1744          f"positional arguments, got {len(args)}.")
1745    args = list(args)
1746    kwargs = dict(kwargs)
1747    for keyword in self._arg_keywords[len(args):]:
1748      try:
1749        args.append(kwargs.pop(compat.as_str(keyword)))
1750      except KeyError:
1751        specified_keywords = (
1752            list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
1753        missing_required_args = sorted(
1754            set(self._arg_keywords) - set(specified_keywords))
1755        raise TypeError(f"{self._flat_signature_summary()} missing required "
1756                        f"arguments: {', '.join(missing_required_args)}.")
1757    if kwargs:
1758      positional_arg_keywords = set(self._arg_keywords[:len(args)])
1759      for unused_key in kwargs:
1760        if unused_key in positional_arg_keywords:
1761          raise TypeError(f"{self._flat_signature_summary()} got two values "
1762                          f"for '{unused_key}'.")
1763      raise TypeError(f"{self._flat_signature_summary()} got unexpected "
1764                      f"keyword arguments: {', '.join(sorted(kwargs))}.")
1765
1766    for i, arg in enumerate(args):
1767      if not isinstance(
1768          arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
1769        raise TypeError(f"{self._flat_signature_summary()}: expected argument "
1770                        f"#{i}(zero-based) to be a Tensor; "
1771                        f"got {type(arg).__name__} ({arg}).")
1772    return self._call_flat(args, self.captured_inputs, cancellation_manager)
1773
1774  def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
1775    """Executes the wrapped function with the structured signature.
1776
1777    Args:
1778      args: Positional arguments to the concrete function.
1779      kwargs: Keyword arguments to the concrete function.
1780      cancellation_manager: A `CancellationManager` that can be used to cancel
1781        function invocation.
1782
1783    Returns:
1784      The result of applying the function on the Tensors/Variables contained in
1785      `args` and `kwargs`.
1786    Raises:
1787      TypeError: if `args` and `kwargs` do not match the structured signature
1788        of this `ConcreteFunction`.
1789    """
1790    args, kwargs, _, filtered_flat_args = \
1791        self._function_spec.canonicalize_function_inputs(*args, **kwargs)
1792    self._structured_signature_check_missing_args(args, kwargs)
1793    self._structured_signature_check_unexpected_args(args, kwargs)
1794    self._structured_signature_check_arg_types(args, kwargs)
1795    return self._call_flat(
1796        filtered_flat_args,
1797        captured_inputs=self.captured_inputs,
1798        cancellation_manager=cancellation_manager)
1799
1800  def _structured_signature_check_missing_args(self, args, kwargs):
1801    """Raises a TypeError if any args are missing."""
1802    arg_specs, kwarg_specs = self.structured_input_signature
1803    missing_arguments = []
1804    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1805      if arg is _BOUND_VALUE and _contains_type_spec(spec):
1806        missing_arguments.append(self._function_spec.arg_names[i])
1807    for (name, arg) in kwargs.items():
1808      if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
1809        missing_arguments.append(name)
1810    if missing_arguments:
1811      raise TypeError(f"{self._structured_signature_summary()} missing "
1812                      "required arguments: "
1813                      f"{', '.join(sorted(missing_arguments))}.")
1814
1815  def _structured_signature_check_unexpected_args(self, args, kwargs):
1816    """Raises a TypeError if there are any extra args."""
1817    arg_specs, kwarg_specs = self.structured_input_signature
1818    if len(args) > len(arg_specs):
1819      raise TypeError(
1820          f"{self._structured_signature_summary()} takes "
1821          f"{len(self._function_spec.arg_names)} positional arguments but got "
1822          f"{len(args)}.")
1823    if len(kwargs) > len(kwarg_specs):
1824      extra_args = set(kwargs) - set(kwarg_specs)
1825      raise TypeError(f"{self._structured_signature_summary()} got unexpected "
1826                      f"keyword arguments: {', '.join(extra_args)}.")
1827
1828  def _structured_signature_check_arg_types(self, args, kwargs):
1829    """Raises a TypeError if any args have the wrong type."""
1830    # Check argument types
1831    arg_specs, kwarg_specs = self.structured_input_signature
1832    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1833      name = self._function_spec.arg_names[i]
1834      self._structured_signature_check_arg_type(arg, spec, name)
1835    for (name, arg) in kwargs.items():
1836      self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
1837
1838  def _structured_signature_check_arg_type(self, arg, spec, name):
1839    """Raise TypeError if `arg`'s type doesn't match `spec`."""
1840    if arg is _BOUND_VALUE:
1841      return
1842
1843    # Check the overall nested structure of the argument.
1844    try:
1845      nest.assert_same_structure(arg, spec, expand_composites=True)
1846    except (ValueError, TypeError):
1847      try:
1848        nest.assert_same_structure(arg, spec, expand_composites=False)
1849        expected, got = spec, arg
1850      except (ValueError, TypeError):
1851        expected, got = _structure_summary(spec), _structure_summary(arg)
1852      raise TypeError(f"{self._structured_signature_summary()}: argument "
1853                      f"{name} had incorrect type\n"
1854                      f"  expected: {expected}\n"
1855                      f"       got: {got}")
1856
1857    # Check the type for each leaf in the nested structure.
1858    arg_pieces = nest.flatten(arg, expand_composites=True)
1859    spec_pieces = nest.flatten(spec, expand_composites=True)
1860    for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
1861      # TODO(mdan): Use consistent error messages.
1862      if isinstance(spec_piece, tensor_spec.DenseSpec):
1863        # TODO(edloper): Consider calling convert_to_tensor on non-tensor
1864        # values here.  That would match the behavior of
1865        # _call_concrete_function() in function_deserialization.py.  If
1866        # we do, then we need to change the nest assert_same_structure and
1867        # flatten calls above to use shallow variants.
1868        tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
1869        if not isinstance(arg_piece, tensor_types):
1870          raise TypeError(f"{self._structured_signature_summary()} expected a "
1871                          f"Tensor in {name}, but got "
1872                          f"{type(arg_piece).__name__} value {arg_piece}.")
1873      elif arg_piece is not _BOUND_VALUE:
1874        try:
1875          arg_matches_spec = bool(arg_piece == spec_piece)
1876        except (ValueError, TypeError):
1877          logging.vlog(1, "Error matching value with spec", exc_info=True)
1878          arg_matches_spec = False
1879        if not arg_matches_spec:
1880          raise TypeError(
1881              f"ConcreteFunction {self._structured_signature_summary()} was "
1882              f"constructed with {type(spec_piece).__name__} value "
1883              f"{spec_piece} in {name}, but was called with "
1884              f"{type(arg_piece).__name__} value {arg_piece}.")
1885
1886  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
1887    """Executes the wrapped function.
1888
1889    Args:
1890      args: a list of Tensors or Variables. Arguments from the Python function
1891        should be filtered before calling this method: objects aside from
1892        Tensors, CompositeTensors, and Variables are ignored. Any
1893        CompositeTensors should be expanded before calling this method.
1894      captured_inputs: the captured inputs that are also part of the input args
1895        to the actual execution. By default, it should be self._captured_inputs.
1896      cancellation_manager: (Optional.) A `CancellationManager` that can be
1897        used to cancel function invocation.
1898
1899    Returns:
1900      The result of applying the TF function to `args`.
1901
1902    Raises:
1903      ValueError: If `args` contains anything other than Tensors or Variables.
1904    """
1905    ctx = context.context()
1906    executing_eagerly = ctx.executing_eagerly()
1907
1908    # Copy saveable status of function's graph to current FuncGraph.
1909    default_graph = ops.get_default_graph()
1910    if default_graph.building_function and not self._func_graph.saveable:
1911      default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
1912
1913    if (tape.could_possibly_record() or
1914        hasattr(default_graph, "watch_variable")):
1915      for v in self._func_graph.variables:
1916        resource_variable_ops.variable_accessed(v)
1917
1918    tensor_inputs = []
1919    variables_used = set([])
1920    for i, arg in enumerate(args):
1921      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1922        # We can pass a variable more than once, and in this case we need to
1923        # pass its handle only once.
1924        if id(arg.handle) in variables_used:
1925          continue
1926        resource_variable_ops.variable_accessed(arg)
1927        tensor_inputs.append(arg.handle)
1928        variables_used.add(id(arg.handle))
1929      elif isinstance(arg, ops.Tensor):
1930        tensor_inputs.append(arg)
1931        if not executing_eagerly:
1932          # If we're graph building, shape inference is on. We check for input
1933          # compatibility up front to avoid hard to debug incompatibilities
1934          # later.
1935          graph_input_shape = tensor_shape.TensorShape(
1936              self._func_graph.inputs[i].shape)
1937          if not graph_input_shape.is_compatible_with(arg.shape):
1938            if self._arg_keywords:
1939              arg_name = "'{}'".format(self._arg_keywords[i])
1940            else:
1941              arg_name = "with index {}".format(i)
1942            raise ValueError(
1943                f"The argument {arg_name} (value {arg}) is not compatible with "
1944                "the shape this function was traced with. Expected shape "
1945                f"{self._func_graph.inputs[i].shape}, but got shape "
1946                f"{arg.shape}.\n\nIf you called get_concrete_function, you may "
1947                "need to pass a tf.TensorSpec(..., shape=...) with a less "
1948                "specific shape, having None on axes which can vary.")
1949      else:
1950        raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got "
1951                         f"{type(arg)} when calling {self._func_graph.name}.")
1952    args = tensor_inputs + captured_inputs
1953    possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
1954    if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
1955        and executing_eagerly):
1956      # No tape is watching; skip to running the function.
1957      return self._build_call_outputs(self._inference_function.call(
1958          ctx, args, cancellation_manager=cancellation_manager))
1959    forward_backward = self._select_forward_and_backward_functions(
1960        args,
1961        possible_gradient_type,
1962        executing_eagerly)
1963    forward_function, args_with_tangents = forward_backward.forward()
1964    if executing_eagerly:
1965      flat_outputs = forward_function.call(
1966          ctx, args_with_tangents, cancellation_manager=cancellation_manager)
1967    else:
1968      with default_graph._override_gradient_function(  # pylint: disable=protected-access
1969          {"PartitionedCall": self._get_gradient_function(),
1970           "StatefulPartitionedCall": self._get_gradient_function()}):
1971        flat_outputs = forward_function.call(ctx, args_with_tangents)
1972    forward_backward.record(flat_outputs)
1973    return self._build_call_outputs(flat_outputs)
1974
1975  def _experimental_with_cancellation_manager(self, cancellation_manager):
1976    """Returns a callable that invokes a cancellable version of this function.
1977
1978    Args:
1979      cancellation_manager: A `CancellationManager` object that can be used to
1980        cancel function invocation.
1981
1982    Returns:
1983      A callable with the same signature as this concrete function.
1984    """
1985
1986    def cancellable_call(*args, **kwargs):
1987      return self._call_impl(
1988          args, kwargs, cancellation_manager=cancellation_manager)
1989
1990    return cancellable_call
1991
1992  @property
1993  def name(self):
1994    """`ConcreteFunction` name."""
1995    return self._delayed_rewrite_functions.forward().name
1996
1997  @property
1998  def graph(self):
1999    """Returns the graph from which this function was constructed."""
2000    return self._func_graph
2001
2002  @property
2003  def inputs(self):
2004    """Returns tensors in `self.graph` corresponding to arguments."""
2005    return self._func_graph.inputs
2006
2007  @property
2008  def structured_input_signature(self):
2009    """Returns structured signature for this concrete function.
2010
2011    Returns:
2012      A tuple `(args, kwargs)`, where:
2013
2014        * `args` is a tuple that specifies the expected type or value each for
2015          positional argument.
2016        * `kwargs` is a dictionary that specifies the expected type or value
2017          for each keyword-only argument.
2018
2019      The type or value for each argument is specified using one of the
2020      following:
2021
2022        * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
2023          value is expected.
2024        * A Python value, such as an integer, indicating that an equal value
2025          is expected.
2026        * A nested structure of `tf.TypeSpec`s and Python values, indicating
2027          that a corresponding nested structure is expected.
2028    """
2029    return self._func_graph.structured_input_signature
2030
2031  @property
2032  def outputs(self):
2033    """Returns tensors in `self.graph` corresponding to returned tensors."""
2034    return self._func_graph.outputs
2035
2036  @property
2037  def structured_outputs(self):
2038    """Returns outputs in `self.graph` as returned by the original function."""
2039    return self._func_graph.structured_outputs
2040
2041  @property
2042  def captured_inputs(self):
2043    """Returns external Tensors captured by this function.
2044
2045    self.__call__(*args) passes `args + self.captured_inputs` to the function.
2046    """
2047    from_closures = nest.flatten([x() for x in self._captured_closures],
2048                                 expand_composites=True)
2049    return self._captured_inputs + from_closures
2050
2051  @property
2052  def function_def(self):
2053    """Returns a `FunctionDef` object representing this function."""
2054    return self._delayed_rewrite_functions.forward().definition
2055
2056  @property
2057  def output_shapes(self):
2058    """The function's output shapes."""
2059    return nest.map_structure(
2060        lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)),
2061        composite_tensor.replace_composites_with_components(
2062            self._func_graph.structured_outputs),
2063        expand_composites=False)
2064
2065  @property
2066  def output_dtypes(self):
2067    # TODO(akshayka): Consider removing this.
2068    return nest.map_structure(
2069        lambda x: x.dtype if x is not None else None,
2070        composite_tensor.replace_composites_with_components(
2071            self._func_graph.structured_outputs),
2072        expand_composites=False)
2073
2074  def add_to_graph(self, g=None):
2075    """Registers the function, adds it to the graph g or default graph.
2076
2077    Args:
2078      g: If specified, registers the function with this graph. Defaults to the
2079        current context (either the default graph or the eager context).
2080    """
2081    # If we are not executing eagerly, adds the function to default graph if no
2082    # graph is specified.
2083    # In case of eager execution, function definition gets added to context
2084    # during construction itself.
2085
2086    if not context.executing_eagerly() and not g:
2087      g = ops.get_default_graph()
2088    self._delayed_rewrite_functions.forward().add_to_graph(g)
2089
2090  def add_gradient_functions_to_graph(self, g=None):
2091    """Add forward/backward functions to graph `g` or the current context."""
2092    if not context.executing_eagerly() and not g:
2093      g = ops.get_default_graph()
2094    self._delayed_rewrite_functions.forward().add_to_graph(g)
2095    forward_function, backward_function = (
2096        self._delayed_rewrite_functions.forward_backward())
2097    forward_function.add_to_graph(g)
2098    backward_function.add_to_graph(g)
2099
2100  def _get_gradient_function(self):
2101    """Returns gradient function. It will be lazily created at first call."""
2102    return self._delayed_rewrite_functions._rewrite_forward_and_call_backward  # pylint: disable=protected-access
2103
2104  def _select_forward_and_backward_functions(
2105      self, args, possible_gradient_type, executing_eagerly):
2106    """Selects forward and backward functions based on the calling context.
2107
2108    The forward function computes the "real" function outputs, `self._outputs`,
2109    and any extra values needed by the corresponding backward function.
2110
2111    Args:
2112      args: A flat list of Tensors with all of the inputs to the forward
2113        function (including user-specified and captured inputs).
2114      possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
2115      executing_eagerly: Boolean, the value of context.executing_eagerly().
2116
2117    Returns:
2118      An object with a `forward` method returning a tuple of (forward_function :
2119      _EagerDefinedFunction, augmented_arguments : List), and a corresponding
2120      `record` method which takes outputs from the forward function and records
2121      the operation. forward_function should be called with augmented_arguments.
2122    """
2123    if executing_eagerly:
2124      input_tangents = forwardprop_util.pack_tangents(args)
2125    else:
2126      input_tangents = forwardprop_util.TangentInfo()
2127    need_gradients_for_jvps = tape.should_record_backprop(
2128        input_tangents.tangents)
2129    # Allows re-use of forward and backward function pairs depending on the
2130    # tapes and forward accumulators watching its inputs.
2131    cache_key = (need_gradients_for_jvps, input_tangents.indices)
2132    if (possible_gradient_type
2133        == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
2134      if input_tangents.indices or executing_eagerly:
2135        # There is a single non-persistent tape active, so the user can only
2136        # request first-order gradients from a tape. We can spend less time
2137        # graph building since we know this.
2138        #
2139        # We may still end up computing higher-order gradients, but that'd be
2140        # through `tf.gradients`, which can re-write the forward pass and so
2141        # needs no preparation here.
2142        functions = self._first_order_tape_functions.get(cache_key, None)
2143        if functions is None:
2144          functions = _FirstOrderTapeGradientFunctions(
2145              self._func_graph, self._attrs, self._garbage_collector,
2146              forwardprop_input_indices=input_tangents.indices,
2147              delayed_rewrite_functions=self._delayed_rewrite_functions,
2148              need_gradients_for_jvps=need_gradients_for_jvps)
2149          self._first_order_tape_functions[cache_key] = functions
2150        return _ForwardBackwardCall(
2151            functions, args, input_tangents.tangents, tape_watching=True)
2152      else:
2153        # We can avoid computing second-order gradients in some cases by doing a
2154        # delayed rewrite when graph building. Since we know we'll only compute
2155        # first-order tape gradients, the delayed rewrite is safe: we won't need
2156        # to tell the tape about side outputs.
2157        #
2158        # TODO(allenl): This case is really dirty. It would be better if we
2159        # could temporarily pop all of the current tapes to avoid
2160        # accidentally taking second-order gradients.
2161        return _ForwardBackwardCall(
2162            self._delayed_rewrite_functions, args, input_tangents.tangents,
2163            tape_watching=True)
2164    elif (possible_gradient_type
2165          == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
2166      # Either there's a persistent tape watching, or there are multiple nested
2167      # tapes. Either way, the user may request higher-order gradients. We'll
2168      # spend a bit more time and make sure higher-order gradients are correct.
2169      functions = self._higher_order_tape_functions.get(
2170          cache_key, None)
2171      if functions is None:
2172        functions = _HigherOrderTapeGradientFunctions(
2173            self._func_graph, self._attrs, self._garbage_collector,
2174            forwardprop_input_indices=input_tangents.indices,
2175            delayed_rewrite_functions=self._delayed_rewrite_functions,
2176            need_gradients_for_jvps=need_gradients_for_jvps)
2177        self._higher_order_tape_functions[cache_key] = functions
2178      return _ForwardBackwardCall(functions, args, input_tangents.tangents,
2179                                  tape_watching=True)
2180    # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
2181    # tape is recording.
2182    return _ForwardBackwardCall(
2183        self._delayed_rewrite_functions, args, input_tangents.tangents,
2184        tape_watching=False)
2185
2186  def _build_call_outputs(self, result):
2187    """Maps the fdef output list to actual output structure.
2188
2189    Args:
2190      result: Output lists defined by FunctionDef.
2191    Returns:
2192      The actual call output.
2193    """
2194    # TODO(jlchu): call C++ version in function.cc when speed is improved
2195    if self._func_graph.structured_outputs is None:
2196      return result
2197
2198    # Replace outputs with results, skipping over any 'None' values.
2199    outputs_list = nest.flatten(
2200        self._func_graph.structured_outputs, expand_composites=True)
2201    j = 0
2202    for i, o in enumerate(outputs_list):
2203      if o is not None:
2204        handle_data_util.copy_handle_data(self.outputs[j], result[j])
2205        outputs_list[i] = result[j]
2206        j += 1
2207    ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
2208                                outputs_list, expand_composites=True)
2209    return ret
2210
2211  @property
2212  def _as_name_attr_list(self):
2213    """Returns a `NameAttrList` representing this function."""
2214    ret = attr_value_pb2.NameAttrList(name=self.name)
2215    for name, value in self._attrs.items():
2216      ret.attr[name].CopyFrom(value)
2217    return ret
2218
2219  def _structured_signature_summary(self, default_values=False):
2220    """Returns a string summarizing this function's structured signature.
2221
2222    Args:
2223      default_values: If true, then include default values in the signature.
2224
2225    Returns:
2226      A `string`.
2227    """
2228    # Note: we can't just use self._funcion_spec.signature_summary(), because
2229    # that would show "_BOUND_VALUE" as the default value for all arguments.
2230    assert self._function_spec is not None
2231    arg_specs, kwarg_specs = self.structured_input_signature
2232    arg_names = list(self._function_spec.arg_names)
2233
2234    # If an explicit input_signature is provided to @tf.function, then any
2235    # arguments with defaults that are not covered by that explicit signature
2236    # are simply dropped from the signature.
2237    # TODO(b/159639913) Look into whether dropping arguments with default values
2238    # from the signature is the right thing to do.
2239    arg_names = arg_names[:len(arg_specs)]
2240
2241    if default_values:
2242      for i in range(len(arg_names)):
2243        if not _contains_type_spec(arg_specs[i]):
2244          arg_names[i] += "={}".format(arg_specs[i])
2245    if kwarg_specs:
2246      arg_names.append("*")
2247      for name, spec in kwarg_specs.items():
2248        arg_names.append(name)
2249        if default_values and not _contains_type_spec(spec):
2250          arg_names[-1] += "={}".format(spec)
2251    signature = f"{self._func_graph.name}({', '.join(arg_names)})"
2252
2253    return signature
2254
2255  def _flat_signature_summary(self):
2256    """Returns a string summarizing this function's flat signature."""
2257    assert self._arg_keywords is not None
2258    assert self._num_positional_args is not None
2259    arg_names = self._arg_keywords
2260    if self._num_positional_args > len(arg_names):
2261      arg_names.extend(
2262          "<arg{}>".format(i + 1)
2263          for i in range(len(arg_names), self._num_positional_args))
2264    return f"{self._func_graph.name}({', '.join(arg_names)})"
2265
2266  def pretty_printed_signature(self, verbose=True):
2267    """Returns a string summarizing the signature of this concrete function."""
2268    if not verbose:
2269      return self._structured_signature_summary(default_values=True)
2270
2271    def pretty_print_spec(spec):
2272      """Returns a string describing the spec for a single argument."""
2273      if isinstance(spec, tensor_spec.TensorSpec):
2274        return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
2275      elif nest.is_sequence(spec):
2276        pieces = nest.flatten(spec, expand_composites=False)
2277        markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
2278        structure = nest.pack_sequence_as(spec, markers)
2279        # Ensure dictionaries are sorted by key (for determinism)
2280        result = pprint.pformat(structure, width=10000)
2281        for (marker, piece) in zip(markers, pieces):
2282          result += "\n      {}: {}".format(marker, pretty_print_spec(piece))
2283        return result
2284      else:
2285        return repr(spec)
2286
2287    lines = [self._structured_signature_summary(default_values=True)]
2288    arg_specs, kwarg_specs = self.structured_input_signature
2289    names = list(self._function_spec.arg_names)
2290
2291    # If an explicit input_signature is provided to @tf.function, then any
2292    # arguments with defaults that are not covered by that explicit signature
2293    # are simply dropped from the signature.
2294    # TODO(b/159639913) Look into whether dropping arguments with default values
2295    # from the signature is the right thing to do.
2296
2297    # Note: we can skip bound args, since we already displayed their bound
2298    # value in the signature summary.
2299    arg_details = []
2300    for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)):
2301      if _contains_type_spec(spec):
2302        arg_details.append("    {}: {}".format(name, pretty_print_spec(spec)))
2303
2304    if kwarg_specs:
2305      for kwarg in sorted(kwarg_specs):
2306        spec = kwarg_specs[kwarg]
2307        if _contains_type_spec(spec):
2308          arg_details.append("    {}: {}".format(
2309              kwarg, pretty_print_spec(spec)))
2310
2311    if arg_details:
2312      lines.append("  Args:")
2313      lines.extend(arg_details)
2314    lines.append("  Returns:")
2315
2316    def spec_from_value(value):
2317      # For loaded function, structured_outputs are already specs.
2318      if isinstance(value, type_spec.TypeSpec):
2319        return value
2320      return type_spec.type_spec_from_value(value)
2321
2322    lines.append("    {}".format(
2323        pretty_print_spec(
2324            nest.map_structure(spec_from_value, self.structured_outputs))))
2325
2326    return "\n".join(lines)
2327
2328  def __repr__(self):
2329    if self._function_spec is not None:
2330      return "<ConcreteFunction {} at 0x{:X}>".format(
2331          self.pretty_printed_signature(verbose=False), id(self))
2332    elif not (self._num_positional_args is None or self._arg_keywords is None):
2333      return "<ConcreteFunction {} at 0x{:X}>".format(
2334          self._flat_signature_summary(), id(self))
2335    else:
2336      return object.__repr__(self)
2337
2338  def __str__(self):
2339    if self._function_spec is not None:
2340      return "ConcreteFunction {}".format(self.pretty_printed_signature())
2341    else:
2342      return self.__repr__()
2343
2344
2345_pywrap_utils.RegisterType("Tensor", ops.Tensor)
2346_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
2347_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
2348
2349
2350def _deterministic_dict_values(dictionary):
2351  return tuple(dictionary[key] for key in sorted(dictionary))
2352
2353
2354class FunctionSpec(object):
2355  """Specification of how to bind arguments to a function."""
2356
2357  @staticmethod
2358  def from_function_and_signature(python_function,
2359                                  input_signature,
2360                                  is_pure=False,
2361                                  experimental_follow_type_hints=False,
2362                                  jit_compile=None):
2363    """Create a FunctionSpec instance given a python function and signature.
2364
2365    Args:
2366      python_function: a function to inspect
2367      input_signature: a signature of the function (None, if variable)
2368      is_pure: if True all input arguments (including variables and constants)
2369      will be converted to tensors and no variable changes allowed.
2370      experimental_follow_type_hints: see `tf.function`
2371      jit_compile: see `tf.function`
2372
2373    Returns:
2374      instance of FunctionSpec
2375    """
2376    fullargspec = tf_inspect.getfullargspec(python_function)
2377    # Checks if the `fullargspec` contains self or cls as its first argument.
2378    is_method = tf_inspect.isanytargetmethod(python_function)
2379
2380    # Treat a wrapped partial function as a special case. For all arguments that
2381    # were overridden with keywords in the partial:
2382    #   - remove the corresponding arguments,
2383    #   - remove the corresponding keywords.
2384    _, unwrapped = tf_decorator.unwrap(python_function)
2385    # TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and
2386    # fullargspec.kwonlydefaults.
2387    if isinstance(unwrapped, functools.partial):
2388      # Also consider the Python3 case with kwonlydefaults.
2389      if fullargspec.defaults or fullargspec.kwonlydefaults:
2390        new_defaults = fullargspec.defaults
2391        new_args = fullargspec.args
2392        if fullargspec.defaults:
2393          # To be able to canonicalize the function properly, we want to ignore
2394          # default values that are overridden via a partial kwarg. For example:
2395          #
2396          #   def func(a, b, c, d=5, e=7):
2397          #     return a, b, c, d, e
2398          #   p_func = tf.function(functools.partial(func, 10, e=9))
2399          #
2400          # Here we want to drop from the defaults the parameter `e`. If we
2401          # forwarded the call to the partial function with a default for `e`
2402          # we would get an error for passing two values for one parameter.
2403          #
2404          # Note that this has a limitation: we can only override parameters at
2405          # the end of the parameter list.
2406          #
2407          # In this case we want to end up with 3 arguments (b, c, d) and 1
2408          # default value (5). We do this by constructing a mask where 0 stands
2409          # for a value that was overridden by a partial kwarg. The seemingly
2410          # complicated logic below does just that - for arguments (b, c, d, e)
2411          # we would get a mask (1, 1, 1, 0).
2412          old_args = fullargspec.args
2413          old_defaults = fullargspec.defaults
2414
2415          no_default = object()
2416          num_args_without_defaults = len(old_args) - len(old_defaults)
2417          left_padding = tuple([no_default] * num_args_without_defaults)
2418
2419          args_with_defaults = zip(old_args, left_padding + old_defaults)
2420
2421          # Create a mask where 0 stands for args that had a partial kwarg
2422          # defined.
2423          non_keyword_defaults_mask = [
2424              0 if key in unwrapped.keywords else 1 for key in old_args
2425          ]
2426          # Keep only arguments and defaults that were not kwargs of partial.
2427          new_args_with_defaults = list(
2428              itertools.compress(args_with_defaults, non_keyword_defaults_mask))
2429          # Keep all args.
2430          new_args = [arg for arg, _ in new_args_with_defaults]
2431          # Keep only real default values.
2432          new_defaults = [
2433              default for _, default in new_args_with_defaults
2434              if default is not no_default
2435          ]
2436        fullargspec = tf_inspect.FullArgSpec(
2437            args=new_args,
2438            varargs=fullargspec.varargs,
2439            varkw=fullargspec.varkw,
2440            defaults=new_defaults,
2441            kwonlyargs=[],
2442            kwonlydefaults={},
2443            annotations=fullargspec.annotations)
2444
2445    # Get the function's name.  Remove functools.partial wrappers if necessary.
2446    while isinstance(python_function, functools.partial):
2447      python_function = python_function.func
2448    name = getattr(python_function, "__name__", "f")
2449
2450    return FunctionSpec(
2451        fullargspec,
2452        is_method,
2453        input_signature,
2454        is_pure=is_pure,
2455        jit_compile=jit_compile,
2456        experimental_follow_type_hints=experimental_follow_type_hints,
2457        name=name)
2458
2459  def __init__(self,
2460               fullargspec,
2461               is_method,
2462               input_signature,
2463               is_pure=False,
2464               experimental_follow_type_hints=False,
2465               name=None,
2466               jit_compile=None):
2467    """Constructs a FunctionSpec describing a python function.
2468
2469    Args:
2470      fullargspec: `tf_inspect.FullArgSpec` object describing the function.
2471      is_method: True if the function is a method.
2472      input_signature: a signature of the function (None, if variable)
2473      is_pure: if True all input arguments (including variables and constants)
2474        will be converted to tensors and no variable changes allowed.
2475      experimental_follow_type_hints: see `tf.function`.
2476      name: Name of the function
2477      jit_compile: see `tf.function`.
2478    """
2479    self._fullargspec = fullargspec
2480    self._is_method = is_method
2481    self._is_pure = is_pure
2482    self._jit_compile = jit_compile
2483    self._experimental_follow_type_hints = experimental_follow_type_hints
2484
2485    # TODO(edloper): Include name when serializing for SavedModel?
2486    self._name = name or "f"
2487
2488    if self._is_method:
2489      # Remove `self`: default arguments shouldn't be matched to it.
2490      # TODO(b/127938157): Should this error out if there is no arg to
2491      # be removed?
2492      args = fullargspec.args[1:]
2493    else:
2494      args = fullargspec.args
2495
2496    # A cache mapping from argument name to index, for canonicalizing
2497    # arguments that are called in a keyword-like fashion.
2498    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
2499    self._arg_names = args
2500
2501    # A cache mapping from arg index to default value, for canonicalization.
2502    default_values = fullargspec.defaults
2503    offset = len(args) - len(default_values or [])
2504    self._arg_indices_to_default_values = {
2505        offset + index: default
2506        for index, default in enumerate(default_values or [])
2507    }
2508    self._arg_indices_no_default_values = set(range(len(args))) - set(
2509        self._arg_indices_to_default_values)
2510    if input_signature is None:
2511      self._input_signature = None
2512    else:
2513      if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
2514        raise ValueError("Cannot define a TensorFlow function from a Python "
2515                         "function with keyword-only arguments when "
2516                         "input_signature is provided.")
2517
2518      if not isinstance(input_signature, (tuple, list)):
2519        raise TypeError(f"input_signature must be either a tuple or a "
2520                        f"list, got {type(input_signature)}.")
2521
2522      self._input_signature = tuple(input_signature)
2523      self._flat_input_signature = tuple(nest.flatten(input_signature,
2524                                                      expand_composites=True))
2525
2526  @property
2527  def fullargspec(self):
2528    return self._fullargspec
2529
2530  @property
2531  def is_method(self):
2532    return self._is_method
2533
2534  @property
2535  def args_to_indices(self):
2536    return self._args_to_indices
2537
2538  @property
2539  def kwargs_to_include(self):
2540    return self._kwargs_to_include
2541
2542  @property
2543  def input_signature(self):
2544    return self._input_signature
2545
2546  @property
2547  def flat_input_signature(self):
2548    return self._flat_input_signature
2549
2550  @property
2551  def is_pure(self):
2552    return self._is_pure
2553
2554  @property
2555  def jit_compile(self):
2556    return self._jit_compile
2557
2558  @property
2559  def arg_names(self):
2560    return self._arg_names
2561
2562  @property
2563  def vararg_name(self):
2564    return self._fullargspec.varargs
2565
2566  @property
2567  def varkw_name(self):
2568    return self._fullargspec.varkw
2569
2570  def signature_summary(self, default_values=False):
2571    """Returns a string summarizing this function's signature.
2572
2573    Args:
2574      default_values: If true, then include default values in the signature.
2575
2576    Returns:
2577      A `string`.
2578    """
2579    args = list(self._arg_names)
2580    if default_values:
2581      for (i, default) in self._arg_indices_to_default_values.items():
2582        args[i] += "={}".format(default)
2583    if self._fullargspec.kwonlyargs:
2584      args.append("*")
2585      for arg_name in self._fullargspec.kwonlyargs:
2586        args.append(arg_name)
2587        if default_values and arg_name in self._fullargspec.kwonlydefaults:
2588          args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
2589    return f"{self._name}({', '.join(args)})"
2590
2591  def _to_tensor_or_tensor_spec(self, x):
2592    return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec))
2593            else ops.convert_to_tensor(x))
2594
2595  def _convert_variables_to_tensors(self, args, kwargs):
2596    args = [self._to_tensor_or_tensor_spec(x) for x in args]
2597    kwargs = {kw: self._to_tensor_or_tensor_spec(x)
2598              for kw, x in kwargs.items()}
2599    return tuple(args), kwargs
2600
2601  def _convert_annotated_args_to_tensors(self, args, kwargs):
2602    """Attempts to autobox arguments annotated as tf.Tensor."""
2603    if self.input_signature is not None:
2604      return
2605
2606    args = list(args)
2607    for i, arg in enumerate(args):
2608      # See
2609      # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
2610      if i < len(self._fullargspec.args):
2611        annotation_key = self._fullargspec.args[i]
2612      else:
2613        annotation_key = self._fullargspec.varargs
2614      arg_annotation = self._fullargspec.annotations.get(annotation_key, None)
2615
2616      # TODO(rahulkamat): Change to TensorLike (here ans below)
2617      if arg_annotation == ops.Tensor:
2618        args[i] = self._to_tensor_or_tensor_spec(arg)
2619
2620    for kw, v in kwargs.items():
2621      if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
2622        annotation_key = kw
2623      else:
2624        annotation_key = self._fullargspec.varkw
2625      kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
2626      if kwarg_annotation == ops.Tensor:
2627        kwargs[kw] = self._to_tensor_or_tensor_spec(v)
2628    return tuple(args), kwargs
2629
2630  def _validate_inputs(self, flat_inputs):
2631    """Raises an error if inputs contain illegal values."""
2632    for inp in flat_inputs:
2633      # TODO(b/183107079): Allow these once they're handled properly.
2634      if isinstance(inp, weakref.ref):
2635        raise ValueError(
2636            f"weakref input {inp} not supported for function {self._name}")
2637
2638  def canonicalize_function_inputs(self, *args, **kwargs):
2639    """Canonicalizes `args` and `kwargs`.
2640
2641    Canonicalize the inputs to the Python function using a `FunctionSpec`
2642    instance. In particular, we parse the varargs and kwargs that the
2643    original function was called with into a tuple corresponding to the
2644    Python function's positional (named) arguments and a dictionary
2645    corresponding to its kwargs.  Missing default arguments are added.
2646
2647    If this `FunctionSpec` has an input signature, then it is used to convert
2648    arguments to tensors; otherwise, any inputs containing numpy arrays are
2649    converted to tensors.
2650
2651    Additionally, any inputs containing numpy arrays are converted to Tensors.
2652
2653    Args:
2654      *args: The varargs this object was called with.
2655      **kwargs: The keyword args this function was called with.
2656
2657    Returns:
2658      A canonicalized ordering of the inputs, as well as full and filtered
2659      (Tensors and Variables only) versions of their concatenated flattened
2660      representations, represented by a tuple in the form (args, kwargs,
2661      flat_args, filtered_flat_args). Here: `args` is a full list of bound
2662      arguments, and `kwargs` contains only true keyword arguments, as opposed
2663      to named arguments called in a keyword-like fashion.
2664
2665    Raises:
2666      ValueError: If a keyword in `kwargs` cannot be matched with a positional
2667        argument when an input signature is specified, or when the inputs
2668        do not conform to the input signature.
2669    """
2670    if self._is_pure:
2671      args, kwargs = self._convert_variables_to_tensors(args, kwargs)
2672    if self._experimental_follow_type_hints:
2673      args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
2674    # Pre-calculate to reduce overhead
2675    arglen = len(args)
2676    if self._input_signature is not None:
2677      if arglen > len(self._input_signature):
2678        raise TypeError(f"{self.signature_summary()} specifies "
2679                        f"{len(self._input_signature)} positional arguments, "
2680                        f"but got {arglen}.")
2681      for arg in six.iterkeys(kwargs):
2682        index = self._args_to_indices.get(arg, None)
2683        if index is None:
2684          raise TypeError(f"{self.signature_summary()} got unexpected keyword "
2685                          f"argument `{arg}`.")
2686        if index >= len(self._input_signature):
2687          raise TypeError(
2688              f"{self.signature_summary()} got keyword argument `{arg}` that "
2689              "was not included in input_signature.")
2690
2691    if not kwargs:
2692      inputs = args
2693      if self._arg_indices_to_default_values:
2694        try:
2695          inputs += tuple(self._arg_indices_to_default_values[i]
2696                          for i in range(arglen, len(self._arg_names)))
2697        except KeyError:
2698          missing_args = [
2699              self._arg_names[i]
2700              for i in range(arglen, len(self._arg_names))
2701              if i not in self._arg_indices_to_default_values
2702          ]
2703          raise TypeError(f"{self.signature_summary()} missing required "
2704                          f"arguments: {', '.join(missing_args)}.")
2705
2706      if self._fullargspec.kwonlydefaults:
2707        kwargs.update(self._fullargspec.kwonlydefaults)
2708    else:
2709      # Maps from index of arg to its corresponding value, according to `args`
2710      # and `kwargs`; seeded with the default values for the named args that
2711      # aren't in `args`.
2712      arg_indices_to_values = {
2713          index: default for index, default in six.iteritems(
2714              self._arg_indices_to_default_values) if index >= arglen
2715      }
2716      consumed_args = []
2717      missing_arg_indices = self._arg_indices_no_default_values - set(
2718          range(arglen))
2719      for arg, value in six.iteritems(kwargs):
2720        index = self._args_to_indices.get(arg, None)
2721        if index is not None:
2722          if index < arglen:
2723            raise TypeError(f"{self.signature_summary()} got two values for "
2724                            f"{arg!r}.")
2725          arg_indices_to_values[index] = value
2726          # These arguments in 'kwargs' might also belong to
2727          # positional arguments
2728          missing_arg_indices.discard(index)
2729          consumed_args.append(arg)
2730      for arg in consumed_args:
2731        # After this loop, `kwargs` will only contain keyword_only arguments,
2732        # and all positional_or_keyword arguments have been moved to `inputs`.
2733        kwargs.pop(arg)
2734      inputs = args + _deterministic_dict_values(arg_indices_to_values)
2735      # Exclude positional args with values
2736      if missing_arg_indices:
2737        missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
2738        if len(missing_args) == 1:
2739          raise TypeError(f"{self.signature_summary()} missing 1 required "
2740                          f"argument: {missing_args[0]}.")
2741        else:
2742          raise TypeError(f"{self.signature_summary()} missing required "
2743                          f"arguments: {', '.join(missing_args)}.")
2744
2745      if kwargs and self._input_signature is not None:
2746        raise TypeError("Keyword arguments are not supported when "
2747                        "input_signature is provided. Signature: "
2748                        f"{self.signature_summary()}.")
2749
2750      if self._fullargspec.kwonlydefaults:
2751        for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
2752          kwargs.setdefault(kwarg, default)
2753
2754    if self._input_signature is None:
2755      inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
2756      kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
2757      flat_inputs += flat_kwargs
2758      filtered_flat_inputs += filtered_flat_kwargs
2759    else:
2760      assert not kwargs
2761      inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
2762          inputs, self._input_signature, self._flat_input_signature)
2763
2764    self._validate_inputs(flat_inputs)
2765
2766    return inputs, kwargs, flat_inputs, filtered_flat_inputs
2767
2768
2769def _as_ndarray(value):
2770  """Converts value to an ndarray, assumes _is_ndarray(value)."""
2771  # TODO(tomhennigan) Support __array_interface__ too.
2772  return value.__array__()
2773
2774
2775def _is_ndarray(value):
2776  """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
2777  # TODO(tomhennigan) Support __array_interface__ too.
2778  return hasattr(value, "__array__") and not (
2779      isinstance(value, ops.Tensor)
2780      or isinstance(value, resource_variable_ops.BaseResourceVariable)
2781      or hasattr(value, "_should_act_as_resource_variable")
2782
2783      # For legacy reasons we do not automatically promote Numpy strings.
2784      or isinstance(value, np.str_)
2785      # NumPy dtypes have __array__ as unbound methods.
2786      or isinstance(value, type)
2787      # CompositeTensors should be flattened instead.
2788      or isinstance(value, composite_tensor.CompositeTensor))
2789
2790
2791def _convert_numpy_inputs(inputs):
2792  """Convert numpy array inputs to tensors."""
2793  # We assume that any CompositeTensors have already converted their components
2794  # from numpy arrays to Tensors, so we don't need to expand composites here for
2795  # the numpy array conversion. Instead, we do so because the flattened inputs
2796  # are eventually passed to ConcreteFunction()._call_flat, which requires
2797  # expanded composites.
2798  flat_inputs = nest.flatten(inputs, expand_composites=True)
2799
2800  # Check for NumPy arrays in arguments and convert them to Tensors.
2801  # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
2802  # finding a way to store them directly in the cache key (currently not
2803  # possible since ndarrays are not hashable).
2804  need_packing = False
2805  filtered_flat_inputs = []
2806  for index, value in enumerate(flat_inputs):
2807    if isinstance(value,
2808                  (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
2809      filtered_flat_inputs.append(value)
2810    elif hasattr(value, "__array__") and not (
2811        hasattr(value, "_should_act_as_resource_variable") or
2812        isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
2813      # This case is equivalent to _is_ndarray(value) == True
2814      a = _as_ndarray(value)
2815      if not isinstance(a, np.ndarray):
2816        raise TypeError(f"The output of __array__ must be an np.ndarray, "
2817                        f"got {type(a)} from {value}.")
2818      flat_inputs[index] = constant_op.constant(a)
2819      filtered_flat_inputs.append(flat_inputs[index])
2820      need_packing = True
2821  if need_packing:
2822    return (nest.pack_sequence_as(
2823        structure=inputs, flat_sequence=flat_inputs,
2824        expand_composites=True), flat_inputs, filtered_flat_inputs)
2825  else:
2826    return inputs, flat_inputs, filtered_flat_inputs
2827
2828
2829def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
2830  """Convert inputs to pass into a function with an explicit signature."""
2831
2832  def format_error_message(inputs, input_signature):
2833    return ("  inputs: (\n" + "    " + ",\n    ".join(str(i) for i in inputs) +
2834            ")\n" + "  input_signature: (\n" + "    " +
2835            ",\n    ".join(str(i) for i in input_signature) + ")")
2836
2837  try:
2838    flatten_inputs = nest.flatten_up_to(
2839        input_signature,
2840        inputs[:len(input_signature)],
2841        expand_composites=True,
2842        check_types=False)  # lists are convert to tuples for `tf.data`.
2843  except ValueError:
2844    raise ValueError("Structure of Python function inputs does not match "
2845                     "input_signature:\n"
2846                     f"{format_error_message(inputs, input_signature)}.")
2847
2848  need_packing = False
2849  for index, (value, spec) in enumerate(zip(flatten_inputs,
2850                                            flat_input_signature)):
2851    if (isinstance(spec, tensor_spec.TensorSpec) and
2852        not _pywrap_utils.IsTensor(value)):
2853      try:
2854        flatten_inputs[index] = ops.convert_to_tensor(
2855            value, dtype_hint=spec.dtype)
2856        need_packing = True
2857      except ValueError:
2858        raise ValueError("When input_signature is provided, all inputs to "
2859                         "the Python function must be convertible to "
2860                         "tensors:\n"
2861                         f"{format_error_message(inputs, input_signature)}.")
2862
2863  if any(not spec.is_compatible_with(other) for spec, other in zip(
2864      flat_input_signature,
2865      flatten_inputs)):
2866    raise ValueError("Python inputs incompatible with input_signature:\n"
2867                     f"{format_error_message(inputs, input_signature)}.")
2868
2869  if need_packing:
2870    inputs = nest.pack_sequence_as(
2871        structure=input_signature,
2872        flat_sequence=flatten_inputs,
2873        expand_composites=True)
2874
2875  flat_inputs = nest.flatten(inputs, expand_composites=True)
2876
2877  return (inputs, flat_inputs, [
2878      t for t in flat_inputs
2879      if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
2880  ])
2881
2882
2883class FunctionCache(object):
2884  """A lightweight container for cached functions.
2885  """
2886
2887  __slots__ = [
2888      "missed", "primary", "arg_relaxed_specs", "arg_relaxed",
2889      "_garbage_collectors"
2890  ]
2891
2892  def __init__(self):
2893    # The set of functions that have been missed; entries are CacheKey with
2894    # input_signature `None` (e.g. a "call context key")
2895    self.missed = set()
2896    # The primary cache, mapping a fully shaped CacheKey to a function.
2897    self.primary = collections.OrderedDict()
2898    # A cache key lookup, mapping a CacheKey generated without shape info to a
2899    # flat list of `TypeSpec`s with relaxed shapes (one for each flattened
2900    # argument). Arguments that are not Tensors or `CompositeTensor`s contain a
2901    # `None` for the corresponding relaxed spec.
2902    self.arg_relaxed_specs = collections.OrderedDict()
2903    # The secondary cache, mapping a CacheKey generated without shape info to a
2904    # function.
2905    self.arg_relaxed = collections.OrderedDict()
2906    # All OrderedDicts require manual garbage collection.
2907    self._garbage_collectors = [
2908        _FunctionGarbageCollector(self.primary),
2909        _FunctionGarbageCollector(self.arg_relaxed),
2910        _FunctionGarbageCollector(self.arg_relaxed_specs)]
2911
2912  def all_values(self):
2913    """A list of all `ConcreteFunction` instances held by this cache."""
2914    # We need to simultaneously make sure our returned concrete functions are
2915    # unique *and* make sure they are returned in a deterministic order for
2916    # serialization.
2917    #
2918    # TODO(b/174215821): It's likely that we ultimately would just prefer to
2919    # choose the most specific concrete function shape given a set of
2920    # arguments. If and when that is implemented, this logic can be revisited.
2921    primary_functions = set(self.primary.values())
2922    return list(self.primary.values()) + [
2923        v for v in self.arg_relaxed.values() if v not in primary_functions]
2924
2925
2926# TODO(mdan): Refactor this and clarify relationship with def_function.Function.
2927# Right now, def_function.Function is the higher level implementation.
2928class Function(object):
2929  """Wrapper class for the graph functions defined for a Python function.
2930
2931  See the documentation for `defun` for more information on the semantics of
2932  defined functions.
2933
2934  `Function` class is thread-compatible meaning that minimal usage of defuns
2935  (defining and calling) is thread-safe, but if users call other methods or
2936  invoke the base `python_function` themselves, external synchronization is
2937  necessary.
2938  In addition, Function is not reentrant, so recursive functions need to call
2939  the wrapped function, not the wrapper.
2940  """
2941
2942  def __init__(self,
2943               python_function,
2944               name,
2945               input_signature=None,
2946               attributes=None,
2947               autograph=True,
2948               autograph_options=None,
2949               experimental_relax_shapes=False,
2950               capture_by_value=None,
2951               jit_compile=None,
2952               experimental_follow_type_hints=False):
2953    """Initializes a `Function`.
2954
2955    Args:
2956      python_function: the function to be wrapped.
2957      name: the name given to it.
2958      input_signature: a possibly nested sequence of `TensorSpec` objects
2959        specifying the input signature of this function. If `None`, a separate
2960        function is instantiated for each inferred input signature.
2961      attributes: dict, extra keyword arguments that will be added as attribute
2962        of the function.
2963      autograph: whether to use autograph to compile
2964        `python_function`. See https://www.tensorflow.org/guide/autograph for
2965        more information.
2966      autograph_options: Experimental knobs to control behavior
2967        `when autograph=True`. See https://www.tensorflow.org/guide/autograph
2968        for more information.
2969      experimental_relax_shapes: When true, argument shapes may be relaxed to
2970        avoid unnecessary retracing.
2971      capture_by_value: Experimental. Whether to capture resource variables by
2972        value or reference. If None, will inherit from a parent context or
2973        default to False.
2974      jit_compile: Force-compile the function with XLA, cf.
2975        def_function.Function doc on jit_compile.
2976      experimental_follow_type_hints: See the documentation for `tf.function`.
2977
2978    Raises:
2979      ValueError: if `input_signature` is not None and the `python_function`'s
2980        argspec has keyword arguments.
2981    """
2982    self._python_function = python_function
2983    pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes
2984    self._function_spec = FunctionSpec.from_function_and_signature(
2985        python_function,
2986        input_signature,
2987        is_pure=pure_function,
2988        experimental_follow_type_hints=experimental_follow_type_hints)
2989    self._name = name
2990    self._autograph = autograph
2991    self._autograph_options = autograph_options
2992    self._experimental_relax_shapes = experimental_relax_shapes
2993    self._function_cache = FunctionCache()
2994    self._function_attributes = attributes or {}
2995    self._capture_by_value = capture_by_value
2996    self.tracing_count = 0
2997    if self.input_signature is not None:
2998      self._hashable_input_signature = _make_input_signature_hashable(
2999          self.flat_input_signature)
3000
3001    self._lock = threading.Lock()
3002    # _descriptor_cache is a of instance of a class to an instance-specific
3003    # `Function`, used to make sure defun-decorated methods create different
3004    # functions for each instance.
3005    self._descriptor_cache = weakref.WeakKeyDictionary()
3006    self._jit_compile = jit_compile
3007    self._experimental_follow_type_hints = experimental_follow_type_hints
3008
3009  def __call__(self, *args, **kwargs):
3010    """Calls a graph function specialized to the inputs."""
3011    with self._lock:
3012      (graph_function,
3013       filtered_flat_args) = self._maybe_define_function(args, kwargs)
3014    return graph_function._call_flat(
3015        filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
3016
3017  @property
3018  def python_function(self):
3019    """Returns the wrapped Python function."""
3020    return self._python_function  # pylint: disable=protected-access
3021
3022  @property
3023  def function_spec(self):
3024    return self._function_spec
3025
3026  @property
3027  def input_signature(self):
3028    """Returns the input signature."""
3029    return self._function_spec.input_signature
3030
3031  @property
3032  def flat_input_signature(self):
3033    """Returns the flattened input signature."""
3034    return self._function_spec.flat_input_signature
3035
3036  def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
3037    """Returns a concrete function which cleans up its graph function."""
3038    if self.input_signature:
3039      args, kwargs = None, None
3040    with self._lock:
3041      graph_function, _ = self._maybe_define_function(args, kwargs)
3042    return graph_function
3043
3044  def _get_concrete_function_internal(self, *args, **kwargs):
3045    """Bypasses error checking when getting a graph function."""
3046    graph_function = self._get_concrete_function_internal_garbage_collected(
3047        *args, **kwargs)
3048    # We're returning this concrete function to someone, and they may keep a
3049    # reference to the FuncGraph without keeping a reference to the
3050    # ConcreteFunction object. So we won't clean up the reference cycles
3051    # manually and instead will leave them to Python's garbage collector.
3052    graph_function._garbage_collector.release()  # pylint: disable=protected-access
3053    return graph_function
3054
3055  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
3056    """Returns a `ConcreteFunction` specialized to inputs and execution context.
3057
3058    Unlike `get_concrete_function(...)`, the graph will be deleted when the
3059    returned function is deleted.  It's useful to avoid creating a reference
3060    cycle when you know for sure that the graph will be no longer used without
3061    the returned function.
3062
3063    Args:
3064      *args: inputs to specialize on.
3065      **kwargs: inputs to specialize on.
3066    """
3067    if self.input_signature:
3068      if kwargs:
3069        raise ValueError("Cannot define a TensorFlow function from a Python "
3070                         "function with keyword arguments when "
3071                         "input_signature is provided, got keyword arguments "
3072                         f"({kwargs}) with input_signature "
3073                         f"({self.input_signature}).")
3074      if args:
3075        # If args are provided, they must match the input signature.
3076        if not is_same_structure(self.input_signature, args):
3077          raise ValueError("Structure of Python function inputs does not match "
3078                           f"input_signature: inputs ({args}), "
3079                           f"input_signature ({self.input_signature}).")
3080        flat_inputs = nest.flatten(args, expand_composites=True)
3081        if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec,
3082                                    resource_variable_ops.BaseResourceVariable))
3083               for arg in flat_inputs):
3084          raise ValueError("When input_signature is provided, all inputs to "
3085                           "the Python function must be Tensors, Variables, "
3086                           "tf.TensorSpec or tf.VariableSpec objects.")
3087        if any(not spec.is_compatible_with(other)
3088               for spec, other in zip(self.flat_input_signature, flat_inputs)):
3089          raise ValueError("Python inputs incompatible with input_signature: "
3090                           f"inputs ({args}), input_signature "
3091                           f"({self.input_signature}).")
3092      args, kwargs = None, None
3093    with self._lock:
3094      graph_function, _ = self._maybe_define_function(args, kwargs)
3095      seen_names = set()
3096      captured = object_identity.ObjectIdentitySet(
3097          graph_function.graph.internal_captures)
3098      # pylint: disable=protected-access
3099      graph_function._arg_keywords = []
3100      prefix_counts = {}
3101      # pylint: enable=protected-access
3102      num_positional = 0
3103      for arg in graph_function.graph.inputs:
3104        if arg in captured:
3105          break
3106        num_positional += 1
3107        user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
3108        proposal = user_arg_name
3109        while proposal in seen_names:
3110          index = prefix_counts.get(user_arg_name, 1)
3111          proposal = "{}_{}".format(user_arg_name, index)
3112          prefix_counts[user_arg_name] = index + 1
3113        seen_names.add(proposal)
3114        graph_function._arg_keywords.append(proposal)  # pylint: disable=protected-access
3115      # Anything can be a positional argument, in the same order as .inputs
3116      graph_function._num_positional_args = num_positional  # pylint: disable=protected-access
3117      return graph_function
3118
3119  def get_concrete_function(self, *args, **kwargs):
3120    """Returns a `ConcreteFunction` specialized to inputs and execution context.
3121
3122    Args:
3123      *args: inputs to specialize on. Can be concrete values (e.g. 1)
3124         or `tf.Tensor` or `tf.TensorSpec`.
3125      **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1)
3126         or `tf.Tensor` or `tf.TensorSpec`.
3127    """
3128    graph_function = self._get_concrete_function_garbage_collected(
3129        *args, **kwargs)
3130    graph_function._garbage_collector.release()  # pylint: disable=protected-access
3131    return graph_function
3132
3133  def __get__(self, instance, owner):
3134    """Makes it possible to defun instance methods."""
3135    del owner
3136    # `instance` here is the instance that this `Function` was accessed through
3137    # e.g., for
3138    #
3139    #   class Foo(object):
3140    #
3141    #     @function.defun
3142    #     def bar(self):
3143    #       ...
3144    #
3145    #   foo = Foo()
3146    #   foo.bar()  # `foo.bar` is a `Function` instance
3147    #
3148    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
3149    # new instance of `Function` here to allow different instances each
3150    # to create variables once, thereby allowing methods to be decorated with
3151    # defun. Keeps a cache to avoid retracing the function every time the
3152    # descriptor is accessed.
3153    if instance not in self._descriptor_cache:
3154      if instance is None:
3155        return self
3156      # If there is no instance-specific `Function` in the cache, we construct
3157      # an instance-specific `Function` that uses a weak reference to the
3158      # instance (so that the instance will be correctly gc'd).
3159
3160      # And finally add the wrapped function to the description cache
3161      self._descriptor_cache[instance] = class_method_to_instance_method(
3162          self, instance)
3163
3164    # Return the cached `Function` for the instance
3165    return self._descriptor_cache[instance]
3166
3167  def _cache_key(self,
3168                 args,
3169                 kwargs,
3170                 cache_key_context,
3171                 include_tensor_ranks_only=False):
3172    """Computes the cache key given inputs and execution context."""
3173    if self.input_signature is None:
3174      # We always use both args and kwargs to form input even if one is empty.
3175      # This reduces ambiguity, for example, when args contains a dict and
3176      # kwargs is empty.
3177      inputs = (args, kwargs)
3178      input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
3179                                                    include_tensor_ranks_only)
3180      hashable_input_signature = _make_input_signature_hashable(input_signature)
3181    else:
3182      del args, kwargs
3183      assert not include_tensor_ranks_only
3184      hashable_input_signature = self._hashable_input_signature
3185
3186    (parent_graph, device_functions, colocation_stack, in_cross_replica_context,
3187     variable_policy, xla_context_id) = cache_key_context
3188
3189    return CacheKey(hashable_input_signature, parent_graph, device_functions,
3190                    colocation_stack, in_cross_replica_context, variable_policy,
3191                    xla_context_id)
3192
3193  def _cache_key_context(self):
3194    """Returns execution context."""
3195    ctx = context.context()
3196
3197    # Don't need to open an init_scope if the _cache_key call is in eager mode
3198    # already.
3199    executing_eagerly = ctx.executing_eagerly()
3200    parent_graph = None
3201    xla_context_id = 0
3202    if not executing_eagerly:
3203      # We want to force function retracing for each different
3204      # XLAControlFlowContext, so add `xla_context_id` to the cache key.
3205      xla_context = _enclosing_xla_context()
3206      if xla_context is not None and \
3207            xla_context.RequiresUniqueFunctionRetracing():
3208        xla_context_id = id(xla_context)
3209
3210      with ops.init_scope():
3211        # The graph, or whether we're executing eagerly, should be a part of the
3212        # cache key so we don't improperly capture tensors such as variables.
3213        executing_eagerly = ctx.executing_eagerly()
3214        parent_graph = None if executing_eagerly else ops.get_default_graph()
3215
3216    # pylint: disable=protected-access
3217    default_graph = ops.get_default_graph()
3218    # TODO(b/117617952): The current distribution strategy will affect graph
3219    # building (e.g. accessing different variables from different devices) and
3220    # so requires retracing for each device.
3221    strategy_stack = default_graph._distribution_strategy_stack
3222    uses_distribution_strategy = (
3223        strategy_stack and
3224        strategy_stack[-1].strategy.extended._retrace_functions_for_each_device
3225    )
3226    if executing_eagerly:
3227      colocation_stack = ()
3228      if uses_distribution_strategy:
3229        device_functions = (pydev.merge_device(ctx.device_name),)
3230      else:
3231        device_functions = ()
3232    else:
3233      colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
3234      if (uses_distribution_strategy
3235          or func_graph_module.device_stack_has_callable(
3236              default_graph._device_function_stack)):
3237        # Putting the device in the cache key ensures that call-site device
3238        # annotations are respected.
3239        device_functions = tuple(default_graph._device_functions_outer_to_inner)
3240      else:
3241        device_functions = ()
3242
3243    in_cross_replica_context = False
3244    try:
3245      in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
3246    except (AttributeError, IndexError):
3247      pass
3248
3249    if save_context.in_save_context():
3250      variable_policy = (
3251          save_context.get_save_options().experimental_variable_policy)
3252    else:
3253      variable_policy = None
3254
3255    return (parent_graph, device_functions, colocation_stack,
3256            in_cross_replica_context, variable_policy, xla_context_id)
3257
3258  def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
3259    """Create a `ConcreteFunction` from `args` and `kwargs`."""
3260    self.tracing_count += 1
3261
3262    if self.input_signature is None:
3263      arglen = len(args)
3264    else:
3265      arglen = len(self.input_signature)
3266    base_arg_names = self._function_spec.arg_names[:arglen]
3267    num_missing_args = arglen - len(self._function_spec.arg_names)
3268    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
3269    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
3270    # where arg is based on the self._function_spec.vararg_name.
3271    missing_arg_names = [
3272        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
3273    ]
3274    arg_names = base_arg_names + missing_arg_names
3275    graph_function = ConcreteFunction(
3276        func_graph_module.func_graph_from_py_func(
3277            self._name,
3278            self._python_function,
3279            args,
3280            kwargs,
3281            self.input_signature,
3282            autograph=self._autograph,
3283            autograph_options=self._autograph_options,
3284            arg_names=arg_names,
3285            override_flat_arg_shapes=override_flat_arg_shapes,
3286            capture_by_value=self._capture_by_value),
3287        self._function_attributes,
3288        function_spec=self.function_spec,
3289        # Tell the ConcreteFunction to clean up its graph once it goes out of
3290        # scope. This is not the default behavior since it gets used in some
3291        # places (like Keras) where the FuncGraph lives longer than the
3292        # ConcreteFunction.
3293        shared_func_graph=False)
3294    return graph_function
3295
3296  def _define_function_with_shape_relaxation(self, args, kwargs, flat_args,
3297                                             filtered_flat_args,
3298                                             cache_key_context):
3299    """Define a function, relaxing arg shapes to avoid unnecessary retracing."""
3300    flat_no_comp = nest.flatten((args, kwargs), expand_composites=False)
3301
3302    any_composite_args = any(
3303        isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp)
3304
3305    # Build a cache key where TensorShapes include only rank information (and
3306    # not information about the size of each dimension).
3307    if not any_composite_args:
3308      rank_only_cache_key = self._cache_key(
3309          args, kwargs, cache_key_context, include_tensor_ranks_only=True)
3310    else:
3311      # For the rank-only cache key, replace any composite tensors with
3312      # shape-relaxed TypeSpecs.
3313      (cache_key_args, cache_key_kwargs) = nest.map_structure(
3314          _shape_relaxed_type_for_composite_tensor, (args, kwargs))
3315      rank_only_cache_key = self._cache_key(
3316          cache_key_args,
3317          cache_key_kwargs,
3318          cache_key_context,
3319          include_tensor_ranks_only=True)
3320
3321    arg_specs = [_type_spec_for(x) for x in flat_no_comp]
3322    relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get(
3323        rank_only_cache_key, None)
3324    relaxed_arg_function = self._function_cache.arg_relaxed.get(
3325        rank_only_cache_key, None)
3326
3327    if (relaxed_arg_function is not None
3328        and all(_is_type_subset(x, y) for (x, y) in
3329                zip(relaxed_arg_specs, arg_specs))):
3330      return relaxed_arg_function, filtered_flat_args
3331
3332    if relaxed_arg_specs is None:
3333      relaxed_arg_specs = arg_specs
3334    else:
3335      if len(arg_specs) != len(relaxed_arg_specs):
3336        raise RuntimeError("Expected arg_specs len to match relaxed_arg_specs "
3337                           f"len: {len(arg_specs):d} vs. "
3338                           f"{len(relaxed_arg_specs):d}.")
3339      relaxed_arg_specs = [
3340          x if x is None else x.most_specific_compatible_type(y)
3341          for (x, y) in zip(arg_specs, relaxed_arg_specs)]
3342    self._function_cache.arg_relaxed_specs[rank_only_cache_key] = (
3343        relaxed_arg_specs)
3344    relaxed_arg_shapes = [
3345        x if x is None else x.shape
3346        for x in nest.flatten(relaxed_arg_specs, expand_composites=True)]
3347
3348    if any_composite_args:
3349      # Rebuild composite tensors with the relaxed TypeSpecs.  For example,
3350      # if a tf.data iterator is passed as an argument, then we need to relax
3351      # the TensorShapes in its element_spec.
3352      (relaxed_arg_specs, relaxed_kwarg_specs) = nest.pack_sequence_as(
3353          (args, kwargs), relaxed_arg_specs, expand_composites=False)
3354      (args, kwargs) = nest.pack_sequence_as(
3355          (relaxed_arg_specs, relaxed_kwarg_specs),
3356          flat_args,
3357          expand_composites=True)
3358
3359    graph_function = self._create_graph_function(
3360        args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
3361    self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
3362
3363    return (graph_function, [
3364        t for t in nest.flatten((args, kwargs), expand_composites=True)
3365        if isinstance(t, (ops.Tensor,
3366                          resource_variable_ops.BaseResourceVariable))
3367    ])
3368
3369  def _maybe_define_function(self, args, kwargs):
3370    """Gets a function for these inputs, defining it if necessary.
3371
3372    `args` and `kwargs` can be None if this `Function` was created with an
3373    `input_signature`.
3374
3375    Caller must hold self._lock.
3376
3377    Args:
3378      args: The varargs for the Python function.
3379      kwargs: The keyword args for the Python function.
3380
3381    Returns:
3382      A graph function corresponding to the input signature implied by args and
3383      kwargs, as well as filtered flattened inputs (only Tensors and Variables)
3384      that the object should be called with.
3385
3386    Raises:
3387      ValueError: If inputs are incompatible with the input signature.
3388      TypeError: If the function inputs include non-hashable objects
3389      RuntimeError: If there's an internal bug (inconsistency) in handling
3390        shape relaxation retracing.
3391    """
3392    if self.input_signature is None or args is not None or kwargs is not None:
3393      args, kwargs, flat_args, filtered_flat_args = \
3394          self._function_spec.canonicalize_function_inputs(*args, **kwargs)
3395    else:
3396      flat_args, filtered_flat_args = [None], []
3397
3398    cache_key_context = self._cache_key_context()
3399    cache_key = self._cache_key(args, kwargs, cache_key_context)
3400
3401    try:
3402      hash(cache_key)
3403    except TypeError as e:
3404      raise TypeError(
3405          "Arguments supplied to `defun`-generated functions must be "
3406          f"hashable.  Original error: {e}.")
3407
3408    graph_function = self._function_cache.primary.get(cache_key, None)
3409    if graph_function is not None:
3410      return graph_function, filtered_flat_args
3411
3412    with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()):
3413      with trace.Trace("tf.function-graph_building"):
3414        logging.vlog(1,
3415                     "Creating new FuncGraph for Python function %r (key: %r)",
3416                     self._python_function, cache_key)
3417        logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
3418                     args, kwargs)
3419
3420        # pylint: disable=protected-access
3421        call_context_key = cache_key._replace(input_signature=None)
3422        # pylint: disable=protected-access
3423
3424        ag_status = (
3425            ag_ctx.Status.ENABLED
3426            if self._autograph else ag_ctx.Status.DISABLED)
3427        with ag_ctx.ControlStatusCtx(
3428            status=ag_status, options=self._autograph_options):
3429
3430          # Build a function with shape relaxation retracing if:
3431          # 1. shape relaxation is explicitly enabled
3432          # and 2. there's no provided input signature
3433          # and 3. there's been a cache miss for this calling context
3434          if (self._experimental_relax_shapes and
3435              self.input_signature is None and
3436              call_context_key in self._function_cache.missed):
3437            return self._define_function_with_shape_relaxation(
3438                args, kwargs, flat_args, filtered_flat_args, cache_key_context)
3439
3440          self._function_cache.missed.add(call_context_key)
3441          graph_function = self._create_graph_function(args, kwargs)
3442          self._function_cache.primary[cache_key] = graph_function
3443
3444          return graph_function, filtered_flat_args
3445
3446
3447def register(func, *args, **kwargs):
3448  """Register a specialization of a `Function` into the graph.
3449
3450  This won't actually call the function with the inputs, and only put the
3451  function definition into graph. Register function with different input param
3452  will result into multiple version of functions registered in graph.
3453
3454  Args:
3455    func: the `Function` instance that generated by a @defun
3456    *args: input arguments for the Python function.
3457    **kwargs: input keyword arguments for the Python function.
3458
3459  Returns:
3460    a `ConcreteFunction` object specialized to inputs and execution context.
3461
3462  Raises:
3463    ValueError: When the input function is not a defun wrapped python function.
3464  """
3465  if not isinstance(func, Function):
3466    raise ValueError("Only defun function is allowed to be registered. "
3467                     f"Got {func} with type {type(func)}.")
3468  concrete_func = func.get_concrete_function(*args, **kwargs)
3469  concrete_func.add_to_graph()
3470  concrete_func.add_gradient_functions_to_graph()
3471  return concrete_func
3472
3473
3474def validate_signature(signature):
3475  if any(not isinstance(arg, tensor_spec.DenseSpec)
3476         for arg in nest.flatten(signature, expand_composites=True)):
3477    bad_args = [arg for arg in nest.flatten(signature, expand_composites=True)
3478                if not isinstance(arg, tensor_spec.DenseSpec)]
3479    raise TypeError("input_signature must be a possibly nested sequence of "
3480                    f"TensorSpec objects, got invalid args {bad_args} with "
3481                    f"types {list(map(type, bad_args))}.")
3482
3483
3484def validate_python_function(python_function):
3485  if not callable(python_function):
3486    raise TypeError(f"{python_function} is not a callable object.")
3487
3488
3489def defun(func=None,
3490          input_signature=None,
3491          autograph=True,
3492          experimental_autograph_options=None,
3493          experimental_relax_shapes=False):
3494  """Compiles a Python function into a callable TensorFlow graph.
3495
3496  `defun` (short for "define function") compiles a Python function
3497  composed of TensorFlow operations into a callable that executes a `tf.Graph`
3498  containing those operations. The callable produced by `defun` contains only
3499  the subgraph of TensorFlow operations that were executed when the Python
3500  function was called with a particular input signature, defined as a list
3501  of the shapes and dtypes of the Python function's Tensor-valued arguments and
3502  the values of its non-Tensor Python objects.
3503
3504  When eager execution is enabled, the ability to create graphs from Python
3505  functions makes it possible to incrementally trade off debuggability and
3506  interactivity for performance.  Functions compiled with `defun` cannot be
3507  inspected with `pdb`; however, executing a graph
3508  generated by `defun` sometimes takes less time and memory than eagerly
3509  executing the corresponding Python function, since specifying computations as
3510  graphs allows for optimizations like automatic buffer reuse and
3511  parallelization among ops. Note that executing a `defun`-compiled function
3512  incurs a small constant overhead, so eagerly executing sufficiently small
3513  Python functions might take less time than executing their corresponding
3514  `defun`-generated graphs.
3515
3516  For a Python function to be compatible with `defun`, all of its arguments must
3517  be hashable Python objects or lists thereof. The function itself may not
3518  modify the list/map structure of its arguments. Additionally, it must return
3519  zero or more `tf.Tensor` objects. If the Python function returns
3520  a `tf.Variable`, its compiled version will return the value of that variable
3521  as a `tf.Tensor`.
3522
3523  Executing a graph generated by `defun` respects device annotations (i.e.,
3524  all `with tf.device` directives present in a Python function will also be
3525  present in its corresponding graph), but it is not yet possible to execute the
3526  generated graphs across multiple machines.
3527
3528  _Example Usage_
3529
3530  ```python
3531  import tensorflow as tf
3532
3533  tf.compat.v1.enable_eager_execution()
3534
3535  # A simple example.
3536  def f(x, y):
3537    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
3538
3539  g = tf.contrib.eager.defun(f)
3540
3541  x = tf.constant([[2.0, 3.0]])
3542  y = tf.constant([[3.0, -2.0]])
3543
3544  # `f` and `g` will return the same value, but `g` will be executed as a
3545  # TensorFlow graph.
3546  assert f(x, y).numpy() == g(x, y).numpy()
3547
3548  # `defun` is capable of compiling Python functions that close over Python
3549  # objects, including Tensors and Variables.
3550  @tf.contrib.eager.defun
3551  def h():
3552    return f(x, y)
3553
3554  assert (h().numpy() == f(x, y).numpy()).all()
3555
3556  # `defun` automatically lifts variables out of the graphs it creates,
3557  # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
3558  # `tf.keras.Model` objects.
3559  class MyModel(tf.keras.Model):
3560
3561    def __init__(self, keep_probability=0.2):
3562      super(MyModel, self).__init__()
3563      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
3564      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
3565      self.keep_probability = keep_probability
3566
3567    @tf.contrib.eager.defun
3568    def call(self, inputs, training=True):
3569      x = self.dense2(self.dense1(inputs))
3570      if training:
3571        return tf.nn.dropout(x, self.keep_probability)
3572      else:
3573        return x
3574
3575  model = MyModel()
3576  model(x, training=True)  # executes a graph, with dropout
3577  model(x, training=False) # executes a graph, without dropout
3578
3579  # `defun`-compiled functions are differentiable.
3580  optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
3581  with tf.GradientTape() as tape:
3582    outputs = model(x)
3583  gradient = tape.gradient(outputs, model.trainable_variables)
3584  optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
3585                            model.trainable_variables))
3586  ```
3587
3588  When using `defun`, there are subtleties regarding inputs, Python control
3589  flow, and variable creation that one should be aware of. For concreteness, let
3590  `f` be a Python function that returns zero or more `tf.Tensor` objects and
3591  let `F = defun(f)`. `F` builds a graph for each unique input signature it
3592  sees, Python control flow is baked into graphs, and operations related to
3593  variable initialization are automatically lifted out of the graphs that `F`
3594  generates and placed in the eager context if executing eagerly or into an
3595  outer graph otherwise.
3596
3597  _Input Signatures_
3598
3599  By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
3600  for every unique sequence of the shapes and dtypes of Tensor arguments and
3601  the values of Python objects it is invoked with. For example, calling
3602  `F(tf.random.uniform([2])` will execute a different graph than
3603  `F(tf.random.uniform([3])` because the two inputs have different shapes.
3604  The first time that `F(*args, **kwargs)` is called with a particular sequence
3605  of Tensor shapes and dtypes and Python values, it constructs a graph by
3606  tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
3607  input signature inferred from `(*args, **kwargs)` and cached for future reuse.
3608
3609  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
3610  before being passed to `f`, and are treated as Tensors for caching. This
3611  allows a function to be called multiple times with NumPy arrays having
3612  different values but the same shape and dtype without re-tracing each time.
3613
3614  `tf.contrib.eager.defun` caches graphs for your convenience, letting you
3615  define TensorFlow functions without explicitly specifying their signatures.
3616  However, this policy is conservative and potentially expensive; for example,
3617  when different invocations of your function have differently-shaped Tensor
3618  inputs, this policy might generate more graph functions than necessary. To
3619  eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
3620  optional `input_signature` argument specifying the shapes and dtypes of the
3621  inputs. In particular, the shapes may be partially unspecified, with `None`s
3622  in the unknown dimensions.  When an input signature is provided,
3623  `tf.contrib.eager.defun` will only instantiate a single graph for the
3624  decorated Python function. The following is an example:
3625
3626  ```python
3627  import tensorflow as tf
3628
3629  # The first `TensorSpec` below describes the shape and dtype of `words`,
3630  # and the second describes the shape and dtype of `another_tensor`. Note that
3631  # the last dimension of the `words` `TensorSpec` is left unspecified.
3632  @tf.contrib.eager.defun(input_signature=[
3633    tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
3634    tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
3635  ])
3636  def my_sequence_model(words, another_tensor):
3637    ...
3638
3639  # Note how the third dimension of the first input can vary freely.
3640  words = tf.random.uniform(([50, 300, 10])
3641  second_input = tf.random.uniform([300, 100])
3642  my_sequence_model(words, second_input)
3643
3644  words = tf.random.uniform(([50, 300, 20])
3645  my_sequence_model(words, second_input)
3646
3647  # Passing an input with an incompatible shape will raise an error.
3648  words = tf.random.uniform(([50, 100, 20])
3649  my_sequence_model(words, second_input)  # <---- This will raise an error.
3650
3651  ```
3652
3653  Python functions that are compiled with an `input_signature` must only accept
3654  Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
3655
3656  _Tracing_
3657
3658  Be aware that because `F` only logs TensorFlow operations, all the other
3659  Python code that `f` executes will only shape the _construction_ of the graphs
3660  that `F` executes: the Python code won't be executed when the graphs
3661  themselves are executed, though it will be executed every time the Python
3662  function is traced (and a given Python function might be traced multiple
3663  times, once for each input signature it is invoked with). For example, whereas
3664  the Python function
3665
3666  ```python
3667  import tensorflow as tf
3668  import numpy as np
3669
3670  tf.compat.v1.enable_eager_execution()
3671
3672  def add_noise():
3673    return tf.eye(5) + np.random.randn(5, 5)
3674  ```
3675
3676  will return a different output everytime it is invoked, the compiled function
3677  `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
3678  every time it is called, since a particular random offset generated by NumPy
3679  will be inserted into the graph as a TensorFlow constant. The solution is to
3680  replace the call to `np.random.randn` with `tf.random.normal((5, 5))`.
3681
3682  _Python Side-Effects_
3683
3684  A corollary of the previous discussion on tracing is the following: If a
3685  Python function `f` has Python side-effects, then executing `f` multiple times
3686  will not necessarily be semantically equivalent to executing `F =
3687  tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
3688  that `defun` only captures the subgraph of TensorFlow operations that is
3689  constructed when `f` is called in a graph-building context.
3690
3691  _Python Control Flow_
3692
3693  The structure of many machine learning computations depend upon whether one is
3694  training or validating, and it is common to nest specialized logic under `if
3695  training:` blocks. By mapping each input signature to a unique graph, `defun`
3696  lets users transparently compile such code, as the following code snippet
3697  demonstrates:
3698
3699  ```python
3700  import tensorflow as tf
3701
3702  tf.compat.v1.enable_eager_execution()
3703
3704  @tf.contrib.eager.defun
3705  def lossy_matmul(W, x, training=True):
3706    outputs = tf.matmul(W, x)
3707    if training:
3708      outputs = tf.nn.dropout(outputs, keep_probability=0.2)
3709    return outputs
3710
3711  W = tf.random.normal((3, 5))
3712  x = tf.random.normal((5, 1))
3713
3714  # Executes a graph that applies dropout.
3715  lossy_outputs = lossy_matmul(W, x, training=True)
3716
3717  # Executes a graph that does not apply dropout.
3718  exact_outputs = lossy_matmul(W, x, training=False)
3719  ```
3720
3721  _TensorFlow Control Flow_
3722
3723  When `autograph` is `True`, data-dependent control flow is allowed as well.
3724  Control flow statements that depend on `Tensor` values are staged into
3725  corresponding TensorFlow ops. For example, the following code will work as
3726  expected:
3727
3728  ```python
3729  @tf.contrib.eager.defun
3730  def dynamic_rnn_loop(cell, seq):
3731    state, output = cell.zero_state()
3732    for input in seq:
3733      state, output = cell(input, state)
3734    return output
3735  ```
3736
3737  For more information see `tf.autograph`.
3738
3739  _Variables_
3740
3741  TensorFlow operations related to variable creation and initialization are
3742  automatically lifted out of the graphs generated by `defun`. In practice, this
3743  implies that variable creation and initialization only happen the first time
3744  `F` is called, and that variables are reused every time thereafter. Many
3745  TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
3746  first time they are called and reuse them thereafter. Automatic variable
3747  lifting makes it possible to compile these APIs without extra effort, at the
3748  cost of introducing a discrepancy between the semantics of executing Python
3749  functions and their corresponding compiled functions. For example:
3750
3751  ```python
3752  import tensorflow as tf
3753
3754  tf.compat.v1.enable_eager_execution()
3755
3756  def fn():
3757    x = tf.Variable(0.0)
3758    x.assign_add(1.0)
3759    return x.read_value()
3760
3761  # `fn` is a Python function, so x is created, initialized, and destroyed upon
3762  # every invocation
3763  assert fn().numpy() == fn().numpy() == 1.0
3764
3765  compiled = tf.contrib.eager.defun(fn)
3766
3767  # Compiling `fn` with `defun` hoists all variables outside of the generated
3768  # graph, so initialization happens exactly once.
3769  assert compiled().numpy() == 1.0
3770  assert compiled().numpy() == 2.0
3771  ```
3772
3773  Finally, because each input signature is bound to a unique graph, if your
3774  Python function constructs `tf.Variable` objects, then each graph constructed
3775  for that Python function will reference a unique set of variables. To
3776  circumvent this problem, we recommend against compiling Python functions that
3777  create `tf.Variable` objects. Instead, Python functions should either
3778  lexically close over `tf.Variable` objects or accept them as arguments,
3779  preferably encapsulated in an object-oriented container. If you must create
3780  variables inside your Python function and you want each graph generated for it
3781  to reference the same set of variables, add logic to your Python function that
3782  ensures that variables are only created the first time it is called and are
3783  reused for every subsequent invocation; note that this is precisely what
3784  `tf.keras.layers.Layer` objects do, so we recommend using them to represent
3785  variable-bearing computations whenever possible.
3786
3787  Args:
3788    func: function to be compiled. If `func` is None, returns a
3789      decorator that can be invoked with a single argument - `func`. The
3790      end result is equivalent to providing all the arguments up front.
3791      In other words, defun(input_signature=...)(func) is equivalent to
3792      defun(func, input_signature=...). The former allows
3793      the following use case:
3794        @tf.contrib.eager.defun(input_signature=...)
3795        def foo(...):
3796          ...
3797
3798    input_signature: A possibly nested sequence of
3799      `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
3800      the Tensors that will be supplied to this function. If `None`, a separate
3801      function is instantiated for each inferred input signature.  If a
3802      signature is specified, every input to `func` must be a `Tensor`, and
3803      `func` cannot accept `**kwargs`.
3804    autograph: Whether `func` should be compiled before
3805      constructing the graph. See https://www.tensorflow.org/guide/autograph
3806      for more information.
3807    experimental_autograph_options: Experimental knobs (in the form of a tuple
3808      of tensorflow.autograph.Feature values) to control behavior when
3809      autograph=True.
3810    experimental_relax_shapes: When true, argument shapes may be relaxed to
3811      avoid unnecessary retracing.
3812
3813  Returns:
3814     If `func` is not None, returns a callable that will execute the compiled
3815     function (and return zero or more `tf.Tensor` objects).
3816     If `func` is None, returns a decorator that, when invoked with a single
3817     `func` argument, returns a callable equivalent to the case above.
3818
3819  Raises:
3820    TypeError: If `input_signature` is neither `None` nor a sequence of
3821      `tf.contrib.eager.TensorSpec` objects.
3822  """
3823  return defun_with_attributes(
3824      func=func,
3825      input_signature=input_signature,
3826      autograph=autograph,
3827      experimental_autograph_options=experimental_autograph_options,
3828      experimental_relax_shapes=experimental_relax_shapes)
3829
3830
3831@tf_export("__internal__.function.defun_with_attributes", v1=[])
3832def defun_with_attributes(func=None,
3833                          input_signature=None,
3834                          attributes=None,
3835                          autograph=True,
3836                          experimental_autograph_options=None,
3837                          jit_compile=None,
3838                          experimental_relax_shapes=False,
3839                          experimental_follow_type_hints=False):
3840  """Compiles a Python function into a callable TensorFlow graph.
3841
3842  This function supports adding extra function attributes. See detailed
3843  documentation in defun(). Currently this is not exposed in public API since we
3844  don't expect user to directly use attributes, and attribute won't work by
3845  itself. This assumption might change in future.
3846
3847  Args:
3848    func: function to be compiled.
3849    input_signature: same as defun()'s input_signature.
3850    attributes: A dictionary of arguments which will be added to function def as
3851      attributes. Currently only support primitive types as value, and only
3852      allowlisted attribute name is allowed. Unallowlisted attribute name or
3853      unsupported value will result into ValueError. `func_name` is also one of
3854      the allowlisted argument which is a python string, and sets the name for
3855      this `ConcreteFunction` in the graph.
3856    autograph: same as defun()'s autograph.
3857    experimental_autograph_options: same as defun()'s
3858      experimental_autograph_options.
3859    jit_compile: same as defun()'s jit_compile.
3860    experimental_relax_shapes: same as defun()'s experimental_relax_shapes
3861    experimental_follow_type_hints: see `tf.function`.
3862
3863  Returns:
3864    Same as the return value of defun, with attributes added to the function in
3865    graph.
3866  """
3867  if input_signature is not None:
3868    validate_signature(input_signature)
3869
3870  # TODO(apassos): deal with captured global state. Deal with control flow.
3871  def decorated(function):
3872    try:
3873      if attributes:
3874        name = attributes.pop("func_name", function.__name__)
3875      else:
3876        name = function.__name__
3877    except AttributeError:
3878      name = "function"
3879    return tf_decorator.make_decorator(
3880        function,
3881        Function(
3882            function,
3883            name,
3884            input_signature=input_signature,
3885            attributes=attributes,
3886            autograph=autograph,
3887            autograph_options=experimental_autograph_options,
3888            jit_compile=jit_compile,
3889            experimental_relax_shapes=experimental_relax_shapes,
3890            experimental_follow_type_hints=experimental_follow_type_hints))
3891
3892  # This code path is for the `foo = tfe.defun(foo, ...)` use case
3893  if func is not None:
3894    return decorated(func)
3895
3896  # This code path is for the
3897  #
3898  # @tfe.defun(...)
3899  # def foo(...):
3900  #    ...
3901  #
3902  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
3903  return decorated
3904
3905
3906# When a method is bound to objects of this type, it allows AutoGraph to
3907# recover a weak reference the original method's self pointer, so that it can
3908# execute it consistent with class_method_to_instance_method's
3909# bound_method_wrapper.
3910# TODO(b/119246461): This is not pretty. Use a descriptor instead?
3911class TfMethodTarget(object):
3912  """Binding target for methods replaced by function and defun."""
3913
3914  __slots__ = ("weakrefself_target__", "weakrefself_func__")
3915
3916  def __init__(self, target, original_python_function):
3917    self.weakrefself_target__ = target
3918    self.weakrefself_func__ = weakref.ref(original_python_function)
3919
3920  @property
3921  def target(self):
3922    return self.weakrefself_target__()
3923
3924  @property
3925  def target_class(self):
3926    true_self = self.weakrefself_target__()
3927    if tf_inspect.isclass(true_self):
3928      # Class method
3929      return true_self
3930    else:
3931      return true_self.__class__
3932
3933  def call(self, args, kwargs):
3934    wrapped_fn = self.weakrefself_func__()
3935    if tf_inspect.ismethod(wrapped_fn):
3936      wrapped_fn = six.get_unbound_function(wrapped_fn)
3937    return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)
3938
3939
3940def class_method_to_instance_method(original_function, instance):
3941  """Constructs a new `Function` with `self` bound."""
3942  weak_instance = weakref.ref(instance)
3943
3944  # Note: while we could bind to a weakref proxy instead, that causes the
3945  # bound method to be unhashable.
3946  bound_method = types_lib.MethodType(
3947      original_function.python_function,
3948      TfMethodTarget(weak_instance, original_function.python_function))
3949
3950  # original_function is expected to be of one of the two `Function` types
3951  # (defined either in function.py or def_function.py).
3952  assert hasattr(original_function, "_name")
3953  assert hasattr(original_function, "_autograph")
3954  assert hasattr(original_function, "_function_spec")
3955  assert hasattr(original_function, "python_function")
3956
3957  weak_bound_method_wrapper = None
3958  def bound_method_wrapper(*args, **kwargs):
3959    """Wraps either a dummy MethodType or a converted AutoGraph function."""
3960    # __wrapped__ allows AutoGraph to swap in a converted function.
3961    strong_bound_method_wrapper = weak_bound_method_wrapper()
3962    wrapped_fn = strong_bound_method_wrapper.__wrapped__
3963
3964    if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
3965      # If __wrapped__ was not replaced, then call original_function.
3966      # TODO(mdan): For better consistency, use the wrapper's call().
3967      wrapped_fn = original_function.python_function
3968      if tf_inspect.ismethod(wrapped_fn):
3969        wrapped_fn = six.get_unbound_function(wrapped_fn)
3970      return wrapped_fn(weak_instance(), *args, **kwargs)
3971
3972    # If __wrapped__ was replaced, then it is always an unbound function.
3973    # However, the replacer is still responsible for attaching self properly.
3974    # TODO(mdan): Is it possible to do it here instead?
3975    return wrapped_fn(*args, **kwargs)
3976  weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
3977
3978  # pylint: disable=protected-access
3979  # We make a dummy MethodType object to generate the correct bound method
3980  # signature. The actual call is to a function with a weak reference to
3981  # `instance`.
3982  instance_func = type(original_function)(
3983      tf_decorator.make_decorator(bound_method, bound_method_wrapper),
3984      name=original_function._name,
3985      autograph=original_function._autograph,
3986      input_signature=original_function.input_signature,
3987      experimental_relax_shapes=original_function._experimental_relax_shapes,
3988      jit_compile=original_function._jit_compile)
3989  # pylint: enable=protected-access
3990
3991  # We wrap the the bound method with tf_decorator so inspection works correctly
3992  wrapped_instance_func = tf_decorator.make_decorator(bound_method,
3993                                                      instance_func)
3994  return wrapped_instance_func
3995
3996
3997class _FunctionGarbageCollector(object):
3998  """Cleans up cycles when a defun goes out of scope."""
3999
4000  __slots__ = ["_cache"]
4001
4002  def __init__(self, cache):
4003    self._cache = cache
4004
4005  def __del__(self):
4006    if func_graph_module is None or memory is None:
4007      return
4008    try:
4009      while self._cache:
4010        self._cache.popitem()
4011      memory.dismantle_ordered_dict(self._cache)
4012    except:  # pylint: disable=bare-except
4013      pass
4014
4015
4016class ConcreteFunctionGarbageCollector(object):
4017  """Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
4018
4019  __slots__ = ["_func_graph"]
4020
4021  def __init__(self, func_graph):
4022    self._func_graph = func_graph
4023
4024  def release(self):
4025    """Call off the FuncGraph deletion."""
4026    self._func_graph = None
4027
4028  def __del__(self):
4029    if func_graph_module is None or memory is None or self._func_graph is None:
4030      return
4031    try:
4032      func_graph_module.dismantle_func_graph(self._func_graph)
4033    except:  # pylint: disable=bare-except
4034      pass
4035
4036
4037class _Marker(object):
4038  """Markers used to pretty-print nested args in function signatures."""
4039
4040  __slots__ = ["_s"]
4041
4042  def __init__(self, s):
4043    self._s = s
4044
4045  def __repr__(self):
4046    return str(self._s)
4047
4048
4049def _structure_summary(structure):
4050  """Displays a summary of the nesting structure of the given value."""
4051
4052  def type_name(x):
4053    if isinstance(x, type_spec.TypeSpec):
4054      return x.value_type.__name__
4055    else:
4056      return type(x).__name__
4057
4058  markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
4059  return str(nest.pack_sequence_as(structure, markers))
4060
4061
4062def _contains_type_spec(value):
4063  return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))
4064