• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Defun decorator for defining graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import functools
24import threading
25import types as types_lib
26import weakref
27
28import numpy as np
29import six
30
31from tensorflow.core.framework import attr_value_pb2
32from tensorflow.core.framework import function_pb2
33from tensorflow.python import pywrap_tensorflow
34from tensorflow.python.eager import context
35from tensorflow.python.eager import execute
36from tensorflow.python.eager import tape
37from tensorflow.python.eager.graph_only_ops import graph_placeholder
38from tensorflow.python.framework import c_api_util
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as pydev
41from tensorflow.python.framework import error_interpolation
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import func_graph as func_graph_module
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_shape
46from tensorflow.python.framework import tensor_spec
47from tensorflow.python.ops import custom_gradient
48from tensorflow.python.ops import functional_ops
49from tensorflow.python.ops import gradients_util
50from tensorflow.python.ops import resource_variable_ops
51from tensorflow.python.platform import tf_logging as logging
52from tensorflow.python.util import compat
53from tensorflow.python.util import function_utils
54from tensorflow.python.util import memory
55from tensorflow.python.util import nest
56from tensorflow.python.util import tf_decorator
57from tensorflow.python.util import tf_inspect
58
59
60FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
61BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
62
63
64CacheKey = collections.namedtuple("CacheKey", [
65    "input_signature", "parent_graph", "device_functions",
66    "colocation_stack"])
67
68CacheKey.replace = CacheKey._replace  # pylint: disable=protected-access
69
70
71def _flat_shape_list(*params):
72  """Return a flat list of TensorShapes, one for each tensor[spec] in `*params`.
73
74  Args:
75    *params: Set of nested entries containing Tensors, TensorSpec, and
76      non-tensors.
77
78  Returns:
79    A list of entries containing either `None` or `TensorShape`.
80  """
81  return [tensor_shape.TensorShape(x.shape)
82          if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None
83          for x in nest.flatten(params)]
84
85
86def _shape_less_specific_than(relaxed, to_check):
87  """Checks if `relaxed` is less specific than `to_check`.
88
89  This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If
90  `to_check` has a dimension with an undefined shape, `relaxed` must also have
91  an undefined shape for that dimension.
92
93  Args:
94    relaxed: A `TensorShape` to check against.
95    to_check: A second `TensorShape`.
96
97  Returns:
98    True if `to_check` represents a set of shapes which is a subset of
99    `relaxed`'s shapes and False otherwise.
100  """
101  if to_check.dims is not None and relaxed.dims is not None:
102    if to_check.rank != relaxed.rank:
103      return False
104    for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims):
105      if check_dim.value is None and relaxed_dim.value is not None:
106        return False
107      if not relaxed_dim.is_compatible_with(check_dim):
108        return False
109  return True
110
111
112def _compatible_shapes(flat_relaxed, flat_to_check):
113  """Check if lists of TensorShapes contain compatible shapes.
114
115  Checks that each `flat_relaxed` shape covers a superset of the shapes of the
116  corresponding `flat_to_check` shape.
117
118  Args:
119    flat_relaxed: List of TensorShape or None.
120    flat_to_check: List of TensorShape or None.
121
122  Returns:
123    A python bool.
124
125  Raises:
126    RuntimeError:
127      if `len(flat_relaxed) != len(flat_to_check)`.
128    RuntimeError:
129      if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`.
130  """
131
132  if len(flat_relaxed) != len(flat_to_check):
133    raise RuntimeError("Expected shape lists of identical lengths, but saw: "
134                       "%s and %s" % (flat_relaxed, flat_to_check))
135  def is_compatible(relaxed, to_check):
136    """Internal help function.
137
138    Args:
139      relaxed: TensorShape or None.
140      to_check: TensorShape or None.
141
142    Returns:
143      Python bool.
144
145    Raises:
146      RuntimeError: If `relaxed is None != to_check is None`.
147    """
148    # If both x and y are None, there is no shape to compare.  Otherwise check
149    # if they are compatible with each other.  Either way, both input signatures
150    # must have have Tensors in the same entries.  If not, raise an assertion
151    # error.
152    if relaxed is None != to_check is None:
153      raise RuntimeError(
154          "Expected signature type matches between flattened input shapes "
155          "%s and %s; but saw that (%s is None) != (%s is None)"
156          % (flat_relaxed, flat_to_check, relaxed, to_check))
157    return relaxed is None or _shape_less_specific_than(relaxed, to_check)
158  return all(is_compatible(relaxed, to_check)
159             for relaxed, to_check in zip(flat_relaxed, flat_to_check))
160
161
162def _common_shape(x, y):
163  """Find a `TensorShape` that is compatible with both `x` and `y`."""
164  if x is None != y is None:
165    raise RuntimeError(
166        "Cannot find a common shape when LHS shape is None but RHS shape "
167        "is not (or vice versa): %s vs. %s" % (x, y))
168  if x is None:
169    return None  # The associated input was not a Tensor, no shape generated.
170  if not isinstance(x, tensor_shape.TensorShape):
171    raise TypeError("Expected x to be a TensorShape but saw %s" % (x,))
172  if not isinstance(y, tensor_shape.TensorShape):
173    raise TypeError("Expected y to be a TensorShape but saw %s" % (y,))
174  if x.rank != y.rank or x.rank is None:
175    return tensor_shape.TensorShape(None)
176  dims = []
177  for dim_x, dim_y in zip(x.dims, y.dims):
178    if (dim_x != dim_y
179        or tensor_shape.dimension_value(dim_x) is None
180        or tensor_shape.dimension_value(dim_y) is None):
181      dims.append(None)
182    else:
183      dims.append(tensor_shape.dimension_value(dim_x))
184  return tensor_shape.TensorShape(dims)
185
186
187def is_same_structure(structure1,
188                      structure2,
189                      check_values=False):
190  """Check two structures for equality, optionally of types and of values."""
191  try:
192    nest.assert_same_structure(structure1, structure2)
193  except (ValueError, TypeError):
194    return False
195  if check_values:
196    flattened1 = nest.flatten(structure1)
197    flattened2 = nest.flatten(structure2)
198    # First check the types to avoid AttributeErrors.
199    if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)):
200      return False
201    return flattened1 == flattened2
202  return True
203
204
205def _parse_func_attrs(attributes):
206  """Convert the keyword arguments into function_def attributes.
207
208  Currently only support primitive types: bool, int, float and string.
209
210  Args:
211    attributes: the dictionary of attributes.
212  Returns:
213    A dict of attributes where the key is the name of attribute and the value
214      is the AttrValue proto.
215  Raises:
216    ValueError: If the kwargs contains unwhitelisted name or unsupported value
217      types.
218  """
219  attrs = {}
220  for key, value in attributes.items():
221    if isinstance(value, attr_value_pb2.AttrValue):
222      attrs[key] = value
223    # bool type check has to happen before int since bool is a subclass of int.
224    elif isinstance(value, bool):
225      attrs[key] = attr_value_pb2.AttrValue(b=value)
226    elif isinstance(value, int):
227      attrs[key] = attr_value_pb2.AttrValue(i=value)
228    elif isinstance(value, float):
229      attrs[key] = attr_value_pb2.AttrValue(f=value)
230    elif isinstance(value, (str, bytes, six.text_type)):
231      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
232    else:
233      raise ValueError("Unsupported attribute type for %s with type %s" %
234                       (key, type(value)))
235  return attrs
236
237
238class _InterpolateFunctionError(object):
239  """Context Manager that interpolates the exception from 'top_level_func'."""
240
241  def __init__(self, top_level_func):
242    self._func = top_level_func
243
244  def __enter__(self):
245    pass
246
247  def __exit__(self, typ, exc, tb):
248    if not exc or not isinstance(exc, errors.OpError):
249      return False
250    message = compat.as_text(exc.message)
251    _, tags = error_interpolation.parse_message(message)
252    g = None
253    func_stack = []
254    # pylint: disable=protected-access
255    for t in tags:
256      if t.type == "function_node":
257        if t.name == compat.as_str(self._func.name):
258          g = self._func._graph
259        elif g:
260          next_func = g._get_function(t.name)
261          if next_func is not None and isinstance(next_func,
262                                                  _EagerDefinedFunction):
263            g = next_func._graph
264        if g:
265          func_stack.append(g.name)
266        else:
267          func_stack.append("<unknown>")
268    # pylint: enable=protected-access
269    if g:
270      message = error_interpolation.interpolate(message, g)
271      message += "\n\nFunction call stack:\n"
272      message += " -> ".join(func_stack)
273      message += "\n"
274      exc._message = message  # pylint: disable=protected-access
275    return False
276
277
278def _forward_name(n):
279  """The name of a generated forward defun named n."""
280  return "__forward_%s_%s" % (n, ops.uid())
281
282
283def _backward_name(n):
284  """The name of a generated backward defun named n."""
285  return "__backward_%s_%s" % (n, ops.uid())
286
287
288def _inference_name(n):
289  """The name of a forward-but-no-gradient defun named n."""
290  return "__inference_%s_%s" % (n, ops.uid())
291
292
293# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
294# so it doesn't have the definition-generating logic and is just a container for
295# an already-defined function.
296class _EagerDefinedFunction(object):
297  """Callable with the interface of `framework.function._DefinedFunction`.
298
299  `_EagerDefinedFunction` encapsulates a function definition and its properties,
300  and it provides a method for calling the encapsulated function. Some Ops
301  take functions as attributes, which have type `func`; an instance of this
302  class may be provided as the value of these `func` attributes.
303  """
304
305  def __init__(self, name, graph, inputs, outputs, attrs):
306    """Initializes an eager defined function.
307
308    Args:
309      name: str, the name for the created function.
310      graph: Graph, the graph containing the operations in the function
311      inputs: the tensors in the graph to be used as inputs to the function
312      outputs: the tensors in the graph which will be outputs to the function
313      attrs: dict mapping names of attributes to their AttrValue values
314    """
315    input_ops = set(arg.op for arg in inputs)
316    operations = [op for op in graph.get_operations() if op not in input_ops]
317
318    fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
319        graph._c_graph,  # pylint: disable=protected-access
320        compat.as_str(name),
321        False,
322        [o._c_op for o in operations],  # pylint: disable=protected-access
323        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
324        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
325        [],
326        [o._c_op for o in graph.control_outputs],  # pylint: disable=protected-access
327        [],  # control_output_names
328        None,
329        compat.as_str(""))
330
331    for name, attr_value in attrs.items():
332      serialized = attr_value.SerializeToString()
333      # TODO(iga): this creates and deletes a new TF_Status for every attr.
334      # It might be worth creating a convenient way to re-use status.
335      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
336          fn, compat.as_str(name), serialized)
337
338    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
339    # signature, but also in general it's nice not to depend on it.
340    with c_api_util.tf_buffer() as buffer_:
341      pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
342      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
343    function_def = function_pb2.FunctionDef()
344    function_def.ParseFromString(compat.as_bytes(proto_data))
345    with ops.init_scope():
346      if context.executing_eagerly():
347        context.add_function(fn)
348    self.definition = function_def
349    self.name = compat.as_bytes(function_def.signature.name)
350    self.signature = function_def.signature
351    self._num_outputs = len(self.signature.output_arg)
352    self._output_types = [o.type for o in self.signature.output_arg]
353    self._output_shapes = [o.shape for o in outputs]
354    self._control_captures = graph.control_captures
355    self._func_graph_outputs = outputs
356    self.grad_func_name = None
357    self.python_grad_func = None
358    self._c_func = c_api_util.ScopedTFFunction(fn)
359    self._grad_func = None
360    self._graph = graph
361    self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful)
362
363  def add_to_graph(self, g=None):
364    # pylint: disable=protected-access
365    if not g and context.executing_eagerly():
366      context.context().add_function_def(self.definition)
367    else:
368      if self.name not in g._functions:
369        g._add_function(self)
370      for f in self._graph._functions.values():
371        if f.name not in g._functions:
372          g._add_function(f)
373    # pylint: enable=protected-access
374
375  @property
376  def stateful_ops(self):
377    return self._stateful_ops
378
379  def call(self, ctx, args):
380    """Calls this function with `args` as inputs.
381
382    `ConcreteFunction` execution respects device annotations only if the
383    function won't be compiled with xla.
384
385    Args:
386      ctx: a Context object
387      args: a list of arguments to supply this function with.
388
389    Returns:
390      The outputs of the function call.
391
392    Raises:
393      ValueError: if the number of arguments is incorrect.
394    """
395    if len(args) != len(self.signature.input_arg):
396      raise ValueError(
397          "Arguments and signature arguments do not match: %s %s " %
398          (len(args), len(list(self.signature.input_arg))))
399
400    function_call_options = ctx.function_call_options
401    if function_call_options.config_proto_serialized is None:
402      config = function_utils.get_disabled_rewriter_config()
403    else:
404      config = function_call_options.config_proto_serialized
405    executor_type = function_call_options.executor_type or ""
406
407    executing_eagerly = ctx.executing_eagerly()
408    if executing_eagerly:
409      with _InterpolateFunctionError(self):
410        outputs = execute.execute(
411            str(self.signature.name),
412            num_outputs=self._num_outputs,
413            inputs=args,
414            attrs=("executor_type", executor_type,
415                   "config_proto", config),
416            ctx=ctx)
417      # Replace empty list with None
418      outputs = outputs or None
419    else:
420      # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
421      # creates `PartitionedCallOp` kernels by default, or remove the previous
422      # branch if a TPU kernel is registered for `PartitionedCall`.
423      with _InterpolateFunctionError(self):
424        with ops.control_dependencies(self._control_captures):
425          outputs = functional_ops.partitioned_call(
426              args=args,
427              f=self,
428              tout=self._output_types,
429              executing_eagerly=executing_eagerly,
430              config=config,
431              executor_type=executor_type)
432
433    if executing_eagerly:
434      return outputs
435    else:
436      for i, shape in enumerate(self._output_shapes):
437        outputs[i].set_shape(shape)
438      for i, func_graph_output in enumerate(self._func_graph_outputs):
439        custom_gradient.copy_handle_data(func_graph_output, outputs[i])
440      return outputs
441
442
443class ConcreteFunction(object):
444  """Callable object encapsulating a function definition and its gradient.
445
446  `ConcreteFunction` is a callable that encapsulates a function definition and
447  is differentiable under `tf.GradientTape` objects.
448  """
449
450  def __init__(self, func_graph, attrs=None, signature=None):
451    """Initialize a `ConcreteFunction`.
452
453    Args:
454      func_graph: An instance of FuncGraph: the function body to wrap.
455      attrs: (optional) dict mapping names of attributes to their AttrValue
456        values. Attributes in `attrs` will be included in this function's
457        definition.
458     signature: a nested sequence of `TensorSpec` objects specifying the input
459       signature of this function.
460
461    Raises:
462      ValueError: If number of input_placeholders is not equal to the number
463        of function inputs.
464    """
465    self._arg_keywords = None
466    self._num_positional_args = None
467    self._func_graph = func_graph
468    self._captured_inputs = list(self._func_graph.captures.keys())
469    self._num_outputs = len(self._func_graph.outputs)
470    self._output_shapes = tuple(
471        output.shape for output in self._func_graph.outputs)
472    self._attrs = _parse_func_attrs(attrs or {})
473
474    self._inference_function = _EagerDefinedFunction(
475        _inference_name(self._func_graph.name), self._func_graph,
476        self._func_graph.inputs, self._func_graph.outputs, self._attrs)
477    self._backward_graph_function = None
478    self._signature = signature
479    self._gradient_name = None
480
481  def __call__(self, *args, **kwargs):
482    """Executes the wrapped function.
483
484    Args:
485      *args: Tensors or Variables. Positional arguments are only accepted when
486        they correspond one-to-one with arguments of the traced Python function.
487      **kwargs: Tensors or Variables specified by name. When
488        `get_concrete_function` was called to create this `ConcreteFunction`,
489        each Tensor input was given a name, defaulting to the name of the Python
490        function's argument but possibly overridden by the `name=` argument to
491        `tf.TensorSpec`. These names become the argument names for the concrete
492        function.
493
494    Returns:
495      The result of applying the TF function on the given Tensors.
496
497    Raises:
498      AssertionError: If this `ConcreteFunction` was not created through
499        `get_concrete_function`.
500      ValueError: If arguments contains anything other than Tensors or
501        Variables.
502      TypeError: For invalid positional/keyword argument combinations.
503    """
504    if self._arg_keywords is None or self._num_positional_args is None:
505      if self._signature is not None:
506        if kwargs:
507          raise NotImplementedError(
508              "Keyword arguments not supported when calling a "
509              "wrap_function-decorated function.")
510        return self._call_flat(args)
511      raise AssertionError(
512          "Tried to call a concrete function obtained from an internal API "
513          "through the public interface. Use get_concrete_function instead.")
514    if len(args) > self._num_positional_args:
515      raise TypeError(
516          ("Expected at most {} positional arguments (and the rest keywords, "
517           "of {}), got {}. When calling a concrete function, positional "
518           "arguments may not be bound to Tensors within nested structures."
519          ).format(self._num_positional_args, self._arg_keywords, args))
520    args = list(args)
521    for keyword in self._arg_keywords[len(args):]:
522      try:
523        args.append(kwargs.pop(compat.as_str(keyword)))
524      except KeyError:
525        specified_keywords = (list(self._arg_keywords[:len(args)])
526                              + list(kwargs.keys()))
527        raise TypeError(
528            "Expected argument names {} but got values for {}. Missing: {}."
529            .format(
530                list(self._arg_keywords),
531                specified_keywords,
532                list(set(self._arg_keywords) - set(specified_keywords))))
533    if kwargs:
534      positional_arg_keywords = set(self._arg_keywords[:len(args)])
535      for unused_key in kwargs:
536        if unused_key in positional_arg_keywords:
537          raise TypeError("Got two values for keyword '{}'.".format(unused_key))
538      raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
539          list(kwargs.keys()), list(self._arg_keywords)))
540    return self._call_flat(args)
541
542  def _filtered_call(self, args, kwargs):
543    """Executes the function, filtering arguments from the Python function.
544
545    Objects aside from Tensors and Variables are ignored.
546
547    Args:
548      args: Canonicalized positional arguments of the Python function.
549      kwargs: Canonicalized keyword arguments of the Python function.
550
551    Returns:
552      The result of applying the function on the Tensors/Variables contained in
553      `args` and `kwargs`.
554    """
555    return self._call_flat(
556        (t for t in nest.flatten((args, kwargs))
557         if isinstance(t, (ops.Tensor,
558                           resource_variable_ops.ResourceVariable))))
559
560  def _call_flat(self, args):
561    """Executes the wrapped function.
562
563    Args:
564      args: a list of Tensors or Variables.
565
566    Returns:
567      The result of applying the TF function to `args`.
568
569    Raises:
570      ValueError: If `args` contains anything other than Tensors or Variables.
571    """
572    ctx = context.context()
573
574    tape.variables_accessed(self._func_graph.variables)
575
576    tensor_inputs = []
577    variables_used = set([])
578    for i, arg in enumerate(args):
579      if isinstance(arg, resource_variable_ops.ResourceVariable):
580        # We can pass a variable more than once, and in this case we need to
581        # pass its handle only once.
582        if arg.handle in variables_used:
583          continue
584        if arg.trainable:
585          tape.variable_accessed(arg)
586        tensor_inputs.append(arg.handle)
587        variables_used.add(arg.handle)
588      elif isinstance(arg, ops.Tensor):
589        tensor_inputs.append(arg)
590      elif (self._signature is not None and
591            isinstance(self._signature[i], tensor_spec.TensorSpec)):
592        tensor_inputs.append(
593            ops.convert_to_tensor(arg, self._signature[i].dtype))
594      else:
595        raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; "
596                         "on invocation of %s, the %d-th input (%s) was not a "
597                         "Tensor." % (self._func_graph.name, i, str(arg)))
598    args = tensor_inputs + self._captured_inputs
599
600    if (tape.should_record(tensor_inputs) or
601        tape.should_record(self._captured_inputs)):
602      if context.executing_eagerly():
603        return self._eager_backprop_call(args)
604      else:
605        return self._backprop_call_with_delayed_rewrite(args)
606
607    # Only need to override the gradient in graph mode and when we have outputs.
608    if context.executing_eagerly() or not self.outputs:
609      outputs = self._inference_function.call(ctx, args)
610    else:
611      self._register_gradient()
612      with ops.get_default_graph().gradient_override_map(
613          {"PartitionedCall": self._gradient_name,
614           "StatefulPartitionedCall": self._gradient_name}):
615        outputs = self._inference_function.call(ctx, args)
616    return self._build_call_outputs(outputs)
617
618  def _register_gradient(self):
619    """Registers the gradient for this `ConcreteFunction`.
620
621    The gradient rewrites an inference call op to a forward call op, but does
622    not modify a pre-existing forward call op. It then computes the gradient
623    from the output's gradients and the side outputs of the forward op.
624    """
625    if self._gradient_name:
626      return
627    self._gradient_name = "PartitionedCall-%s" % ops.uid()
628
629    @ops.RegisterGradient(self._gradient_name)
630    def _registered_grad_fn(op, *doutputs):  # pylint: disable=unused-variable
631      return self._grad_fn(op, *doutputs)
632
633  def _grad_fn(self, op, *doutputs):
634    """Gradients of this function."""
635    if self._backward_graph_function is None:
636      self._construct_backprop_function()
637
638    # pylint: disable=protected-access
639    self._forward_function.add_to_graph(op.graph)
640    num_inference_outputs = self._inference_function._num_outputs
641
642    # Rewrite an inference call op to be a forward call op
643    if op.get_attr("f").name.encode() == self._inference_function.name:
644      op._set_func_attr("f", self._forward_function.name)
645      op._set_type_list_attr("Tout", self._forward_function._output_types)
646      op._add_outputs(
647          self._forward_function._output_types[num_inference_outputs:],
648          self._forward_function._output_shapes[num_inference_outputs:])
649      for i in range(num_inference_outputs, len(op.outputs)):
650        func_graph_output = self._forward_function._func_graph_outputs[i]
651        custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
652    # pylint: enable=protected-access
653    # Compute the gradients using the side outputs
654    side_outputs = op.outputs[num_inference_outputs:]
655    args = list(doutputs[:num_inference_outputs]) + list(side_outputs)
656    return self._backward_graph_function._call_flat(  # pylint: disable=protected-access
657        (a for a in args if a is not None))
658
659  @property
660  def name(self):
661    """`ConcreteFunction` name."""
662    return self._inference_function.name
663
664  @property
665  def graph(self):
666    """Returns the graph from which this function was constructed."""
667    return self._func_graph
668
669  @property
670  def inputs(self):
671    """Returns tensors in `self.graph` corresponding to arguments."""
672    return self._func_graph.inputs
673
674  @property
675  def structured_input_signature(self):
676    """Returns structured signature of the original function."""
677    return self._func_graph.structured_input_signature
678
679  @property
680  def outputs(self):
681    """Returns tensors in `self.graph` corresponding to returned tensors."""
682    return self._func_graph.outputs
683
684  @property
685  def structured_outputs(self):
686    """Returns outputs in `self.graph` as returned by the original function."""
687    return self._func_graph.structured_outputs
688
689  @property
690  def captured_inputs(self):
691    """Returns external Tensors captured by this function.
692
693    self.__call__(*args) passes `args + self.captured_inputs` to the function.
694    """
695    return self._captured_inputs
696
697  @property
698  def function_def(self):
699    """Returns a `FunctionDef` object representing this function."""
700    return self._inference_function.definition
701
702  @property
703  def output_shapes(self):
704    """The function's output shapes."""
705    # TODO(ebrevdo): Should we only keep the output shapes associated
706    # with len(self._python_returns) outputs?
707    # TODO(akshayka): Consider removing this.
708    outputs_list = nest.flatten(self._func_graph.structured_outputs)
709    j = 0
710    for i, o in enumerate(outputs_list):
711      if o is not None:
712        if isinstance(o, ops.IndexedSlices):
713          # Extract the shape of the `IndexedSlices` object's `values` field.
714          outputs_list[i] = self._output_shapes[j]  # the `values` shape
715          if o.dense_shape is not None:
716            j += 3  # skip over shapes for `values`, `indices`, `dense_shape`
717          else:
718            j += 2  # skip over shapes for `values`, `indices`
719        else:
720          outputs_list[i] = self._output_shapes[j]
721          j += 1
722    return nest.pack_sequence_as(self._func_graph.structured_outputs,
723                                 outputs_list)
724
725  @property
726  def output_dtypes(self):
727    # TODO(akshayka): Consider removing this.
728    return nest.map_structure(lambda x: x.dtype if x is not None else None,
729                              self._func_graph.structured_outputs)
730
731  def add_to_graph(self, g=None, register_gradient_functions=False):
732    """Registers the function, adds it to the graph g or default graph."""
733    # If we are not executing eagerly, adds the function to default graph if no
734    # graph is specified.
735    # In case of eager execution, function definition gets added to context
736    # during construction itself.
737
738    # TODO(allenl/shivaniagrawal): rename this to register to reflect the
739    # method's functionality better. Remove register_gradient_functions argument
740    # and figure out if these needs to be registered.
741
742    if not context.executing_eagerly() and not g:
743      g = ops.get_default_graph()
744    self._inference_function.add_to_graph(g)  # pylint: disable=protected-access
745
746    # pylint: disable=protected-access
747    if register_gradient_functions:
748      # There are two situations for the actual call of a defun:
749      # 1. If none of the input args are resource variables or watch by any
750      #   tape, and it will run the _inference_function of concrete_func for
751      #   forward pass, the gradient will be generated by standard mechanism.
752      # 2. Otherwise, defun will create two functions, one for forward pass,
753      #   and the backward pass will be created via tape.
754      #   When registering the function, we register both cases.
755      if self._backward_graph_function is None:
756        self._construct_backprop_function()
757      forward_function = self._forward_function
758      backward_function = self._backward_graph_function._inference_function
759      # pylint: enable=protected-access
760      forward_function.add_to_graph(g)
761      backward_function.add_to_graph(g)
762
763  def _construct_backprop_function(self):
764    """Constructs the backprop function object for this function."""
765    backwards_graph = func_graph_module.FuncGraph(
766        _backward_name(self._func_graph.name))
767    forward_function_name = _forward_name(self._func_graph.name)
768    outputs = [x for x in self._func_graph.outputs
769               if gradients_util.IsTrainable(x)]
770    with backwards_graph.as_default():
771      gradients_wrt_outputs = [
772          graph_placeholder(x.dtype, x.shape) for x in outputs
773      ]
774      gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
775          outputs,
776          self._func_graph.inputs,
777          grad_ys=gradients_wrt_outputs,
778          src_graph=self._func_graph)
779
780    backwards_graph_captures = list(backwards_graph.captures.keys())
781
782    backward_function_attr = _parse_func_attrs(
783        {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
784    backward_function_attr.update(self._attrs)
785
786    # The ordering of `backwards_graph.inputs` is important: inputs of
787    # `self._backward_graph_function` correspond to outputs of
788    # `self._forward_function`.
789    backwards_graph.inputs = gradients_wrt_outputs + list(
790        backwards_graph.captures.values())
791    # Clear captures, since we pass them in as inputs.
792    backwards_graph.captures = {}
793    backwards_graph.outputs.extend(
794        grad
795        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
796        if grad is not None)
797    backwards_graph.structured_outputs = gradients_wrt_inputs
798    self._backward_graph_function = ConcreteFunction(
799        backwards_graph, attrs=backward_function_attr)
800
801    forward_function_attr = _parse_func_attrs({
802        BACKWARD_FUNCTION_ATTRIBUTE_NAME:
803            self._backward_graph_function._inference_function.name})  # pylint: disable=protected-access
804    forward_function_attr.update(self._attrs)
805    self._forward_function = _EagerDefinedFunction(
806        forward_function_name, self._func_graph, self._func_graph.inputs,
807        self._func_graph.outputs + backwards_graph_captures,
808        forward_function_attr)
809
810  def _eager_backprop_call(self, args):
811    """Calls the forward function and records the result on a tape.
812
813    This method fully constructs the forward and backward functions before
814    calling the function and recording them on the tape.
815
816    (Only records results on a tape if the function has outputs).
817
818    Args:
819      args: All inputs to the function, including resolved captured inputs
820
821    Returns:
822      The call output.
823    """
824    if self._backward_graph_function is None:
825      self._construct_backprop_function()
826
827    ctx = context.context()
828
829    self._register_gradient()
830    with ops.get_default_graph().gradient_override_map(
831        {"PartitionedCall": self._gradient_name,
832         "StatefulPartitionedCall": self._gradient_name}):
833      outputs = self._forward_function.call(ctx, args)
834
835    if isinstance(outputs, ops.Operation) or outputs is None:
836      return outputs
837
838    # `real_outputs` are the actual outputs of the inference graph function;
839    # `side_outputs` are the intermediate Tensors that were added as outputs to
840    # the forward graph function so that we can compute its gradient.
841    real_outputs = outputs[:self._num_outputs]
842    skip_positions = [i for i, t in enumerate(real_outputs)
843                      if not gradients_util.IsTrainable(t)]
844    side_outputs = outputs[self._num_outputs:]
845
846    def backward_function(*args):
847      args = [a for i, a in enumerate(args)
848              if a is not None and i not in skip_positions]
849      return self._backward_graph_function._call_flat(  # pylint: disable=protected-access
850          list(args) + side_outputs)
851
852    tape.record_operation(self._forward_function.signature.name, real_outputs,
853                          args, backward_function)
854    return self._build_call_outputs(real_outputs)
855
856  def _backprop_call_with_delayed_rewrite(self, args):
857    """Calls the inference function and records the result on a tape.
858
859    The recorded backwards function will construct the backwards graph and
860    rewrite the inference function to the forward function. This only happens
861    if the recorded backwards function ends up being used to compute gradients.
862
863    This approach avoids constructing unnecessary graphs, but it only works if
864    we are calling this function when not executing eagerly.
865
866    (Only records results on a tape if the function has outputs)
867
868    Args:
869      args: All inputs to the function, including resolved captured inputs
870
871    Returns:
872      The call output.
873    """
874    ctx = context.context()
875
876    self._register_gradient()
877    with ops.get_default_graph().gradient_override_map(
878        {"PartitionedCall": self._gradient_name,
879         "StatefulPartitionedCall": self._gradient_name}):
880      outputs = self._inference_function.call(ctx, args)
881
882    if isinstance(outputs, ops.Operation) or outputs is None:
883      return outputs
884
885    call_op = outputs[0].op
886
887    def backward_function(*args):
888      return self._grad_fn(call_op, *args)
889
890    tape.record_operation(self._inference_function.signature.name, outputs,
891                          args, backward_function)
892    return self._build_call_outputs(outputs)
893
894  def _build_call_outputs(self, result):
895    """Maps the fdef output list to actual output structure.
896
897    Args:
898      result: Output lists defined by FunctionDef.
899    Returns:
900      The actual call output.
901    """
902    if self._func_graph.structured_outputs is None:
903      return result
904
905    # Use `nest.flatten` instead of `func_graph_module.flatten` in order to
906    # preserve any IndexedSlices in `self._func_graph.structured_outputs`.
907    outputs_list = nest.flatten(self._func_graph.structured_outputs)
908    j = 0
909    for i, o in enumerate(outputs_list):
910      if o is not None:
911        if isinstance(o, ops.IndexedSlices):
912          # Repack Tensors for IndexedSlices.
913          if o.dense_shape is not None:
914            outputs_list[i] = ops.IndexedSlices(
915                values=result[j],
916                indices=result[j + 1],
917                dense_shape=result[j + 2])
918            j += 3
919          else:
920            outputs_list[i] = ops.IndexedSlices(
921                values=result[j], indices=result[j + 1])
922            j += 2
923        else:
924          outputs_list[i] = result[j]
925          j += 1
926    ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
927                                outputs_list)
928    return ret
929
930
931pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
932pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
933
934
935def _deterministic_dict_values(dictionary):
936  return tuple(dictionary[key] for key in sorted(dictionary))
937
938
939class FunctionSpec(object):
940  """Specification of how to bind arguments to a function."""
941
942  @staticmethod
943  def from_function_and_signature(python_function, input_signature):
944    """Create a FunctionSpec instance given a python function and signature."""
945    if isinstance(python_function, functools.partial):
946      python_function_to_inspect = python_function.func
947      args_to_prepend = python_function.args or tuple()
948      kwargs_to_include = python_function.keywords or {}
949      if input_signature is not None:
950        # TODO(b/124441704): Add support for input_signature + partial.
951        raise NotImplementedError(
952            "Missing support for input_signature when using partial functions.")
953    else:
954      python_function_to_inspect = python_function
955      args_to_prepend = tuple()
956      kwargs_to_include = {}
957
958    fullargspec = tf_inspect.getfullargspec(python_function_to_inspect)
959    is_method = tf_inspect.ismethod(python_function_to_inspect)
960
961    return FunctionSpec(fullargspec, is_method, args_to_prepend,
962                        kwargs_to_include, input_signature)
963
964  def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include,
965               input_signature):
966    self._fullargspec = fullargspec
967    self._is_method = is_method
968    self._args_to_prepend = args_to_prepend
969    self._kwargs_to_include = kwargs_to_include
970    self._default_values = fullargspec.defaults
971
972    if self._is_method:
973      # Remove `self`: default arguments shouldn't be matched to it.
974      args = fullargspec.args[1:]
975    else:
976      args = fullargspec.args
977
978    # A cache mapping from argument name to index, for canonicalizing
979    # arguments that are called in a keyword-like fashion.
980    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
981    self.arg_names = args
982    self.vararg_name = fullargspec.varargs
983
984    # A cache mapping from arg index to default value, for canonicalization.
985    offset = len(args) - len(fullargspec.defaults or [])
986    self._arg_indices_to_default_values = {
987        offset + index: default
988        for index, default in enumerate(fullargspec.defaults or [])
989    }
990    self._default_values_start_index = offset
991    if input_signature is None:
992      self._input_signature = None
993    else:
994      if fullargspec.varkw is not None or fullargspec.kwonlyargs:
995        raise ValueError("Cannot define a TensorFlow function from a Python "
996                         "function with keyword arguments when "
997                         "input_signature is provided.")
998
999      if not isinstance(input_signature, (tuple, list)):
1000        raise TypeError("input_signature must be either a tuple or a "
1001                        "list, received " + str(type(input_signature)))
1002
1003      self._input_signature = tuple(input_signature)
1004      self._flat_input_signature = tuple(nest.flatten(input_signature))
1005
1006  @property
1007  def fullargspec(self):
1008    return self._fullargspec
1009
1010  @property
1011  def is_method(self):
1012    return self._is_method
1013
1014  @property
1015  def args_to_prepend(self):
1016    return self._args_to_prepend
1017
1018  @property
1019  def kwargs_to_include(self):
1020    return self._kwargs_to_include
1021
1022  @property
1023  def input_signature(self):
1024    return self._input_signature
1025
1026  @property
1027  def flat_input_signature(self):
1028    return self._flat_input_signature
1029
1030  def canonicalize_function_inputs(self, *args, **kwargs):
1031    """Canonicalizes `args` and `kwargs`.
1032
1033    Canonicalize the inputs to the Python function using a `FunctionSpec`
1034    instance. In particular, we parse the varags and kwargs that the
1035    original function was called with into a tuple corresponding to the
1036    Python function's positional (named) arguments and a dictionary
1037    corresponding to its kwargs.
1038
1039    Args:
1040      *args: The varargs this object was called with.
1041      **kwargs: The keyword args this function was called with.
1042
1043    Returns:
1044      A canonicalized ordering of the inputs representened by a tuple in the
1045      form (args, kwargs). Here: `args` is a full list of bound arguments, and
1046      `kwargs` contains only true keyword arguments, as opposed to named
1047      arguments called in a keyword-like fashion.
1048
1049    Raises:
1050      ValueError: If a keyword in `kwargs` cannot be matched with a positional
1051        argument when an input signature is specified, or when the inputs
1052        do not conform to the input signature.
1053    """
1054    if self._input_signature is not None:
1055      if len(args) > len(self._input_signature):
1056        raise TypeError(
1057            "When input_signature is provided, only pass arguments "
1058            "covered by it. Received %d argument(s)." % len(args))
1059      for arg in six.iterkeys(kwargs):
1060        index = self._args_to_indices.get(arg, None)
1061        if index is None:
1062          raise TypeError(
1063              "Function got an unexpected keyword argument %s" % arg)
1064        if index >= len(self._input_signature):
1065          raise TypeError(
1066              "When input_signature is provided, only pass arguments "
1067              "covered by it. Received argument %s." % arg)
1068
1069    args = self._args_to_prepend + args
1070    kwargs = dict(kwargs, **self._kwargs_to_include)
1071    if not kwargs:
1072      if self._default_values:
1073        inputs = args + self._default_values[
1074            len(args) - self._default_values_start_index:]
1075      else:
1076        inputs = args
1077    else:
1078      # Maps from index of arg to its corresponding value, according to `args`
1079      # and `kwargs`; seeded with the default values for the named args that
1080      # aren't in `args`.
1081      arg_indices_to_values = {
1082          index: default for index, default in six.iteritems(
1083              self._arg_indices_to_default_values) if index >= len(args)
1084      }
1085      consumed_args = []
1086      for arg, value in six.iteritems(kwargs):
1087        index = self._args_to_indices.get(arg, None)
1088        if index is not None:
1089          arg_indices_to_values[index] = value
1090          consumed_args.append(arg)
1091        elif self._input_signature is not None:
1092          raise ValueError("Cannot define a TensorFlow function from a Python "
1093                           "function with keyword arguments when "
1094                           "input_signature is provided.")
1095      for arg in consumed_args:
1096        # After this loop, `kwargs` will only contain true keyword arguments, as
1097        # opposed to named arguments called in a keyword-like fashion.
1098        kwargs.pop(arg)
1099      inputs = args + _deterministic_dict_values(arg_indices_to_values)
1100
1101    if self._input_signature is None:
1102      inputs = _convert_numpy_inputs(inputs)
1103      return inputs, kwargs
1104    else:
1105      assert not kwargs
1106      inputs = _convert_inputs_to_signature(
1107          inputs,
1108          self._input_signature,
1109          self._flat_input_signature)
1110      return inputs, {}
1111
1112
1113def _convert_numpy_inputs(inputs):
1114  """Convert numpy array inputs to tensors."""
1115  flat_inputs = nest.flatten(inputs)
1116
1117  # Check for NumPy arrays in arguments and convert them to Tensors.
1118  # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
1119  # finding a way to store them directly in the cache key (currently not
1120  # possible since ndarrays are not hashable).
1121  need_packing = False
1122  for index, value in enumerate(flat_inputs):
1123    if type(value) == np.ndarray:
1124      flat_inputs[index] = constant_op.constant(value)
1125      need_packing = True
1126  if need_packing:
1127    return nest.pack_sequence_as(
1128        structure=inputs, flat_sequence=flat_inputs)
1129  else:
1130    return inputs
1131
1132
1133def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
1134  """Convert inputs to pass into a function with an explicit signature."""
1135  try:
1136    # TODO(b/124370185): Use all elements as inputs to throw an error if there
1137    # are ignored arguments. Calling with arguments that are not part of the
1138    # signature should throw an error.
1139    flatten_inputs = nest.flatten_up_to(
1140        input_signature,
1141        inputs[:len(input_signature)])
1142  except ValueError:
1143    raise ValueError("Structure of Python function inputs does not match "
1144                     "input_signature. Inputs (%s), input_signature(%s)." %
1145                     (str(inputs), str(input_signature)))
1146
1147  need_packing = False
1148  for index, (value, spec) in enumerate(zip(flatten_inputs,
1149                                            flat_input_signature)):
1150    if not pywrap_tensorflow.IsTensor(value):
1151      try:
1152        flatten_inputs[index] = ops.convert_to_tensor(
1153            value, dtype_hint=spec.dtype)
1154        need_packing = True
1155      except ValueError:
1156        raise ValueError("When input_signature is provided, all inputs to "
1157                         "the Python function must be convertible to tensors."
1158                         "Inputs (%s), input_signature(%s)." %
1159                         (str(inputs), str(input_signature)))
1160
1161  if any(not spec.is_compatible_with(other) for spec, other in zip(
1162      flat_input_signature,
1163      flatten_inputs)):
1164    raise ValueError("Python inputs incompatible with input_signature: "
1165                     "inputs (%s), input_signature (%s)" %
1166                     (str(inputs), str(input_signature)))
1167
1168  if need_packing:
1169    inputs = nest.pack_sequence_as(
1170        structure=input_signature,
1171        flat_sequence=flatten_inputs)
1172
1173  return inputs
1174
1175
1176class FunctionCache(object):
1177  """A lightweight container for cached functions.
1178  """
1179
1180  def __init__(self):
1181    # The set of functions that have been missed; entries are CacheKey with
1182    # input_signature `None` (e.g. a "call context key")
1183    self.missed = set()
1184    # The primary cache, mapping a fully shaped CacheKey to a function.
1185    self.primary = collections.OrderedDict()
1186    # A cache key lookup, mapping a CacheKey generated without shape info to a
1187    # flat list of relaxed shapes (one for each argument).  Arguments that are
1188    # not Tensors contain a `None` for the corresponding relaxed shape.
1189    self.arg_relaxed_shapes = collections.OrderedDict()
1190    # The secondary cache, mapping a CacheKey generated without shape info to a
1191    # function.
1192    self.arg_relaxed = collections.OrderedDict()
1193    # All OrderedDicts require manual garbage collection.
1194    self._garbage_collectors = [
1195        _FunctionGarbageCollector(self.primary),
1196        _FunctionGarbageCollector(self.arg_relaxed),
1197        _FunctionGarbageCollector(self.arg_relaxed_shapes)]
1198
1199  def all_values(self):
1200    """A set of all `ConcreteFunction` instances held by this cache."""
1201    return set(self.primary.values()) | set(self.arg_relaxed.values())
1202
1203
1204class Function(object):
1205  """Wrapper class for the graph functions defined for a Python function.
1206
1207  See the documentation for `defun` for more information on the semantics of
1208  defined functions.
1209
1210  `Function` class is thread-compatible meaning that minimal usage of defuns
1211  (defining and calling) is thread-safe, but if users call other methods or
1212  invoke the base `python_function` themselves, external synchronization is
1213  necessary.
1214  """
1215
1216  def __init__(self,
1217               python_function,
1218               name,
1219               input_signature=None,
1220               attributes=None,
1221               autograph=True,
1222               autograph_options=None,
1223               capture_by_value=None):
1224    """Initializes a `Function`.
1225
1226    Args:
1227      python_function: the function to be wrapped.
1228      name: the name given to it.
1229      input_signature: a possibly nested sequence of `TensorSpec` objects
1230        specifying the input signature of this function. If `None`, a separate
1231        function is instantiated for each inferred input signature.
1232      attributes: dict, extra keyword arguments that will be added as attribute
1233        of the function.
1234      autograph: whether to use autograph to compile
1235        `python_function`. See https://www.tensorflow.org/guide/autograph for
1236        more information.
1237      autograph_options: Experimental knobs to control behavior
1238        `when autograph=True`. See https://www.tensorflow.org/guide/autograph
1239        for more information.
1240      capture_by_value: Experimental. Whether to capture resource variables by
1241        value or reference. If None, will inherit from a parent context or
1242        default to False.
1243
1244    Raises:
1245      ValueError: if `input_signature` is not None and the `python_function`'s
1246        argspec has keyword arguments.
1247    """
1248    if isinstance(python_function, functools.partial):
1249      self._python_function = python_function.func
1250    else:
1251      self._python_function = python_function
1252    self._function_spec = FunctionSpec.from_function_and_signature(
1253        python_function, input_signature)
1254    self._name = name
1255    self._autograph = autograph
1256    self._autograph_options = autograph_options
1257    self._function_cache = FunctionCache()
1258    self._function_attributes = attributes or {}
1259    self._capture_by_value = capture_by_value
1260
1261    self._lock = threading.Lock()
1262    # _descriptor_cache is a of instance of a class to an instance-specific
1263    # `Function`, used to make sure defun-decorated methods create different
1264    # functions for each instance.
1265    self._descriptor_cache = weakref.WeakKeyDictionary()
1266
1267  def __call__(self, *args, **kwargs):
1268    """Calls a graph function specialized to the inputs."""
1269    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
1270    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
1271
1272  @property
1273  def python_function(self):
1274    """Returns the wrapped Python function."""
1275    return self._python_function  # pylint: disable=protected-access
1276
1277  @property
1278  def function_spec(self):
1279    return self._function_spec
1280
1281  @property
1282  def input_signature(self):
1283    """Returns the input signature."""
1284    return self._function_spec.input_signature
1285
1286  @property
1287  def flat_input_signature(self):
1288    """Returns the flattened input signature."""
1289    return self._function_spec.flat_input_signature
1290
1291  def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
1292    """Returns a concrete function which cleans up its graph function."""
1293    if self.input_signature:
1294      args, kwargs = None, None
1295    graph_function, _, _ = self._maybe_define_function(args, kwargs)
1296    return graph_function
1297
1298  def _get_concrete_function_internal(self, *args, **kwargs):
1299    """Bypasses error checking when getting a graph function."""
1300    graph_function = self._get_concrete_function_internal_garbage_collected(
1301        *args, **kwargs)
1302    # We're returning this concrete function to someone, and they may keep a
1303    # reference to the FuncGraph without keeping a reference to the
1304    # ConcreteFunction object. So we won't clean up the reference cycles
1305    # manually and instead will leave them to Python's garbage collector.
1306    graph_function._garbage_collector.release()  # pylint: disable=protected-access
1307    return graph_function
1308
1309  def get_concrete_function(self, *args, **kwargs):
1310    """Returns a `ConcreteFunction` specialized to inputs and execution context.
1311
1312    Args:
1313      *args: inputs to specialize on.
1314      **kwargs: inputs to specialize on.
1315    """
1316    if self.input_signature:
1317      if kwargs:
1318        raise ValueError("Cannot define a TensorFlow function from a Python "
1319                         "function with keyword arguments when "
1320                         "input_signature is provided.")
1321      if args:
1322        # If args are provided, they must match the input signature.
1323        if not is_same_structure(self.input_signature, args):
1324          raise ValueError("Structure of Python function inputs does not match "
1325                           "input_signature.")
1326        flat_inputs = nest.flatten(args)
1327        if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
1328               for arg in flat_inputs):
1329          raise ValueError("When input_signature is provided, all inputs to "
1330                           "the Python function must be Tensors or "
1331                           "tf.TensorSpec objects.")
1332        if any(not spec.is_compatible_with(other)
1333               for spec, other in zip(self.flat_input_signature, flat_inputs)):
1334          raise ValueError("Python inputs incompatible with input_signature: "
1335                           "inputs (%s), input_signature (%s)" %
1336                           (str(args), str(self.input_signature)))
1337      args, kwargs = None, None
1338    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
1339    if self.input_signature:
1340      args = self.input_signature
1341      kwargs = {}
1342    seen_names = set()
1343    captured = frozenset(graph_function.graph.internal_captures)
1344    allowed_positional = 0
1345    if args:
1346      for outer_arg in args:
1347        # TODO(allenl): Consider allowing arguments with defaults in the Python
1348        # function's signature to be passed as positional arguments to the
1349        # concrete function.
1350        if not isinstance(
1351            outer_arg,
1352            (ops.Tensor, resource_variable_ops.ResourceVariable,
1353             tensor_spec.TensorSpec)):
1354          break
1355        allowed_positional += 1
1356    # pylint: disable=protected-access
1357    graph_function._num_positional_args = allowed_positional
1358    graph_function._arg_keywords = []
1359    # pylint: enable=protected-access
1360    for arg in graph_function.graph.inputs:
1361      if arg in captured:
1362        break
1363      user_arg_name = arg.op.get_attr("_user_specified_name")
1364      if user_arg_name in seen_names:
1365        raise ValueError(
1366            ("Unable to construct a concrete function for {} since some "
1367             "arguments do not have unique names. Got two arguments named "
1368             "'{}'. When constructing a concrete TensorFlow function from a "
1369             "Python function which takes nested structures or variadic "
1370             "positional arguments, pass unique names to tf.TensorSpec objects "
1371             "used to identify these Tensor inputs. These names may then be "
1372             "used as keyword arguments to the concrete function.")
1373            .format(
1374                self._python_function,
1375                compat.as_str(arg.op.get_attr("_user_specified_name"))))
1376      seen_names.add(user_arg_name)
1377      graph_function._arg_keywords.append(user_arg_name)  # pylint: disable=protected-access
1378    return graph_function
1379
1380  def __get__(self, instance, owner):
1381    """Makes it possible to defun instance methods."""
1382    del owner
1383    # `instance` here is the instance that this `Function` was accessed through
1384    # e.g., for
1385    #
1386    #   class Foo(object):
1387    #
1388    #     @function.defun
1389    #     def bar(self):
1390    #       ...
1391    #
1392    #   foo = Foo()
1393    #   foo.bar()  # `foo.bar` is a `Function` instance
1394    #
1395    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
1396    # new instance of `Function` here to allow different instances each
1397    # to create variables once, thereby allowing methods to be decorated with
1398    # defun. Keeps a cache to avoid retracing the function every time the
1399    # descriptor is accessed.
1400    if instance not in self._descriptor_cache:
1401      if instance is None:
1402        return self
1403      # If there is no instance-specific `Function` in the cache, we construct
1404      # an instance-specific `Function` that uses a weak reference to the
1405      # instance (so that the instance will be correctly gc'd).
1406
1407      # And finally add the wrapped function to the description cache
1408      self._descriptor_cache[instance] = class_method_to_instance_method(
1409          self, instance)
1410
1411    # Return the cached `Function` for the instance
1412    return self._descriptor_cache[instance]
1413
1414  def _cache_key(self, args, kwargs, include_tensor_ranks_only=False):
1415    """Computes the cache key given inputs and execution context."""
1416    if self.input_signature is None:
1417      inputs = (args, kwargs) if kwargs else args
1418      input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(
1419          inputs, include_tensor_ranks_only)
1420    else:
1421      del args, kwargs
1422      assert not include_tensor_ranks_only
1423      input_signature = self.flat_input_signature
1424
1425    ctx = context.context()
1426
1427    # Don't need to open an init_scope if the _cache_key call is in eager mode
1428    # already.
1429    executing_eagerly = ctx.executing_eagerly()
1430    parent_graph = None
1431    if not executing_eagerly:
1432      with ops.init_scope():
1433        # The graph, or whether we're executing eagerly, should be a part of the
1434        # cache key so we don't improperly capture tensors such as variables.
1435        executing_eagerly = ctx.executing_eagerly()
1436        parent_graph = None if executing_eagerly else ops.get_default_graph()
1437
1438    # pylint: disable=protected-access
1439    default_graph = ops.get_default_graph()
1440    # TODO(b/117617952): The current distribution strategy will affect graph
1441    # building (e.g. accessing different variables from different devices) and
1442    # so requires retracing for each device.
1443    uses_distribution_strategy = bool(
1444        default_graph._distribution_strategy_stack)
1445    if executing_eagerly:
1446      colocation_stack = ()
1447      if uses_distribution_strategy:
1448        device_functions = (pydev.merge_device(ctx.device_name),)
1449      else:
1450        device_functions = ()
1451    else:
1452      colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
1453      if (uses_distribution_strategy
1454          or func_graph_module.device_stack_has_callable(
1455              default_graph._device_function_stack)):
1456        # Putting the device in the cache key ensures that call-site device
1457        # annotations are respected.
1458        device_functions = tuple(default_graph._device_functions_outer_to_inner)
1459      else:
1460        device_functions = ()
1461    # pylint: enable=protected-access
1462    return CacheKey(input_signature, parent_graph, device_functions,
1463                    colocation_stack)
1464
1465  def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
1466    """Create a `ConcreteFunction` from `args` and `kwargs`."""
1467    if self.input_signature is None:
1468      arglen = len(args)
1469    else:
1470      arglen = len(self.input_signature)
1471    base_arg_names = self._function_spec.arg_names[:arglen]
1472    num_missing_args = arglen - len(self._function_spec.arg_names)
1473    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
1474    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
1475    # where arg is based on the self._function_spec.vararg_name.
1476    missing_arg_names = [
1477        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
1478    ]
1479    arg_names = base_arg_names + missing_arg_names
1480    graph_function = ConcreteFunction(
1481        func_graph_module.func_graph_from_py_func(
1482            self._name,
1483            self._python_function,
1484            args,
1485            kwargs,
1486            self.input_signature,
1487            autograph=self._autograph,
1488            autograph_options=self._autograph_options,
1489            arg_names=arg_names,
1490            override_flat_arg_shapes=override_flat_arg_shapes,
1491            capture_by_value=self._capture_by_value),
1492        self._function_attributes)
1493
1494    # pylint: disable=protected-access
1495    # Tell the ConcreteFunction to clean up its graph once it goes out of
1496    # scope. ConcreteFunction does not do this in its constructor since it
1497    # gets used in some places (like Keras) where the FuncGraph lives
1498    # longer than the ConcreteFunction.
1499    graph_function._garbage_collector = ConcreteFunctionGarbageCollector(
1500        graph_function.graph)
1501    # pylint: enable=protected-access
1502
1503    return graph_function
1504
1505  def _maybe_define_function(self, args, kwargs):
1506    """Gets a function for these inputs, defining it if necessary.
1507
1508    `args` and `kwargs` can be None if this `Function` was created with an
1509    `input_signature`.
1510
1511    Args:
1512      args: The varargs for the Python function.
1513      kwargs: The keyword args for the Python function.
1514
1515    Returns:
1516      A graph function corresponding to the input signature implied by args and
1517      kwargs, as well as the inputs that the object should be called with.
1518
1519    Raises:
1520      ValueError: If inputs are incompatible with the input signature.
1521      TypeError: If the function inputs include non-hashable objects
1522      RuntimeError: If there's an internal bug (inconsistency) in handling
1523        shape relaxation retracing.
1524    """
1525    if self.input_signature is None or args is not None or kwargs is not None:
1526      args, kwargs = self._function_spec.canonicalize_function_inputs(
1527          *args, **kwargs)
1528    cache_key = self._cache_key(args, kwargs)
1529
1530    try:
1531      hash(cache_key)
1532    except TypeError as e:
1533      raise TypeError(
1534          "Arguments supplied to `defun`-generated functions must be"
1535          " hashable.  Original error: %s" % e)
1536
1537    with self._lock:
1538      graph_function = self._function_cache.primary.get(cache_key, None)
1539      if graph_function is not None:
1540        return graph_function, args, kwargs
1541
1542      logging.vlog(1,
1543                   "Creating new FuncGraph for Python function %r (key: %r)",
1544                   self._python_function, cache_key)
1545      logging.vlog(2,
1546                   "Python function signature [args: %s] [kwargs: %s]",
1547                   args,
1548                   kwargs)
1549
1550      call_context_key = cache_key.replace(input_signature=None)
1551
1552      # If there's a provided input signature, or
1553      # there's no cache miss for this calling context so far, go ahead and
1554      # build the function and bypass shape relaxation retracing.
1555      if (self.input_signature is not None
1556          or call_context_key not in self._function_cache.missed):
1557        self._function_cache.missed.add(call_context_key)
1558        graph_function = self._create_graph_function(args, kwargs)
1559        self._function_cache.primary[cache_key] = graph_function
1560        return graph_function, args, kwargs
1561
1562      rank_only_cache_key = self._cache_key(
1563          args, kwargs, include_tensor_ranks_only=True)
1564
1565      arg_shapes = _flat_shape_list(args, kwargs)
1566      relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get(
1567          rank_only_cache_key, None)
1568      relaxed_arg_function = self._function_cache.arg_relaxed.get(
1569          rank_only_cache_key, None)
1570
1571      if (relaxed_arg_function is not None
1572          and _compatible_shapes(flat_relaxed=relaxed_arg_shapes,
1573                                 flat_to_check=arg_shapes)):
1574        return relaxed_arg_function, args, kwargs
1575
1576      if relaxed_arg_shapes is None:
1577        relaxed_arg_shapes = arg_shapes
1578      else:
1579        if len(arg_shapes) != len(relaxed_arg_shapes):
1580          raise RuntimeError("Expected arg_shapes len to match "
1581                             "relaxed_arg_shapes len: %d vs. %d"
1582                             % (len(arg_shapes), len(relaxed_arg_shapes)))
1583        relaxed_arg_shapes = [
1584            _common_shape(x, y) for (x, y) in zip(
1585                arg_shapes, relaxed_arg_shapes)]
1586      self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
1587          relaxed_arg_shapes)
1588      graph_function = self._create_graph_function(
1589          args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
1590      self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
1591
1592      return graph_function, args, kwargs
1593
1594
1595def register(func, *args, **kwargs):
1596  """Register a specialization of a `Function` into the graph.
1597
1598  This won't actually call the function with the inputs, and only put the
1599  function definition into graph. Register function with different input param
1600  will result into multiple version of functions registered in graph.
1601
1602  Args:
1603    func: the `Function` instance that generated by a @defun
1604    *args: input arguments for the Python function.
1605    **kwargs: input keyword arguments for the Python function.
1606
1607  Returns:
1608    a `ConcreteFunction` object specialized to inputs and execution context.
1609
1610  Raises:
1611    ValueError: When the input function is not a defun wrapped python function.
1612  """
1613  if not isinstance(func, Function):
1614    raise ValueError("Only defun function is allowed to be registered. "
1615                     "Got type: %s" % type(func))
1616  concrete_func = func.get_concrete_function(*args, **kwargs)
1617  concrete_func.add_to_graph(register_gradient_functions=True)
1618  return concrete_func
1619
1620
1621def validate_signature(signature):
1622  if any(not isinstance(arg, tensor_spec.TensorSpec)
1623         for arg in nest.flatten(signature)):
1624    raise TypeError("Invalid input_signature %s; input_signature must be "
1625                    "a possibly nested sequence of TensorSpec objects.")
1626
1627
1628def defun(func=None,
1629          input_signature=None,
1630          autograph=True,
1631          experimental_autograph_options=None):
1632  """Compiles a Python function into a callable TensorFlow graph.
1633
1634  `defun` (short for "define function") compiles a Python function
1635  composed of TensorFlow operations into a callable that executes a `tf.Graph`
1636  containing those operations. The callable produced by `defun` contains only
1637  the subgraph of TensorFlow operations that were executed when the Python
1638  function was called with a particular input signature, defined as a list
1639  of the shapes and dtypes of the Python function's Tensor-valued arguments and
1640  the values of its non-Tensor Python objects.
1641
1642  When eager execution is enabled, the ability to create graphs from Python
1643  functions makes it possible to incrementally trade off debugability and
1644  interactivity for performance.  Functions compiled with `defun` cannot be
1645  inspected with `pdb`; however, executing a graph
1646  generated by `defun` sometimes takes less time and memory than eagerly
1647  executing the corresponding Python function, since specifying computations as
1648  graphs allows for optimizations like automatic buffer reuse and
1649  parallelization among ops. Note that executing a `defun`-compiled function
1650  incurs a small constant overhead, so eagerly executing sufficiently small
1651  Python functions might take less time than executing their corresponding
1652  `defun`-generated graphs.
1653
1654  For a Python function to be compatible with `defun`, all of its arguments must
1655  be hashable Python objects or lists thereof. The function itself may not
1656  modify the list/map structure of its arguments. Additionally, it must return
1657  zero or more `tf.Tensor` objects. If the Python function returns
1658  a `tf.Variable`, its compiled version will return the value of that variable
1659  as a `tf.Tensor`.
1660
1661  Executing a graph generated by `defun` respects device annotations (i.e.,
1662  all `with tf.device` directives present in a Python function will also be
1663  present in its corresponding graph), but it is not yet possible to execute the
1664  generated graphs across multiple machines.
1665
1666  _Example Usage_
1667
1668  ```python
1669  import tensorflow as tf
1670
1671  tf.enable_eager_execution()
1672
1673  # A simple example.
1674  def f(x, y):
1675    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
1676
1677  g = tf.contrib.eager.defun(f)
1678
1679  x = tf.constant([[2.0, 3.0]])
1680  y = tf.constant([[3.0, -2.0]])
1681
1682  # `f` and `g` will return the same value, but `g` will be executed as a
1683  # TensorFlow graph.
1684  assert f(x, y).numpy() == g(x, y).numpy()
1685
1686  # `defun` is capable of compiling Python functions that close over Python
1687  # objects, including Tensors and Variables.
1688  @tf.contrib.eager.defun
1689  def h():
1690    return f(x, y)
1691
1692  assert (h().numpy() == f(x, y).numpy()).all()
1693
1694  # `defun` automatically lifts variables out of the graphs it creates,
1695  # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
1696  # `tf.keras.Model` objects.
1697  class MyModel(tf.keras.Model):
1698
1699    def __init__(self, keep_probability=0.2):
1700      super(MyModel, self).__init__()
1701      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
1702      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
1703      self.keep_probability = keep_probability
1704
1705    @tf.contrib.eager.defun
1706    def call(self, inputs, training=True):
1707      x = self.dense2(self.dense1(inputs))
1708      if training:
1709        return tf.nn.dropout(x, self.keep_probability)
1710      else:
1711        return x
1712
1713  model = MyModel()
1714  model(x, training=True)  # executes a graph, with dropout
1715  model(x, training=False) # executes a graph, without dropout
1716
1717  # `defun`-compiled functions are differentiable.
1718  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
1719  with tf.GradientTape() as tape:
1720    outputs = model(x)
1721  gradient = tape.gradient(outputs, model.trainable_variables)
1722  optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
1723                            model.trainable_variables))
1724  ```
1725
1726  When using `defun`, there are subtleties regarding inputs, Python control
1727  flow, and variable creation that one should be aware of. For concreteness, let
1728  `f` be a Python function that returns zero or more `tf.Tensor` objects and
1729  let `F = defun(f)`. `F` builds a graph for each unique input signature it
1730  sees, Python control flow is baked into graphs, and operations related to
1731  variable initialization are automatically lifted out of the graphs that `F`
1732  generates and placed in the eager context if executing eagerly or into an
1733  outer graph otherwise.
1734
1735  _Input Signatures_
1736
1737  By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
1738  for every unique sequence of the shapes and dtypes of Tensor arguments and
1739  the values of Python objects it is invoked with. For example, calling
1740  `F(tf.random_uniform([2])` will execute a different graph than
1741  `F(tf.random_uniform([3])` because the two inputs have different shapes.
1742  The first time that `F(*args, **kwargs)` is called with a particular sequence
1743  of Tensor shapes and dtypes and Python values, it constructs a graph by
1744  tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
1745  input signature inferred from `(*args, **kwargs)` and cached for future reuse.
1746
1747  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
1748  before being passed to `f`, and are treated as Tensors for caching. This
1749  allows a function to be called multiple times with NumPy arrays having
1750  different values but the same shape and dtype without re-tracing each time.
1751
1752  `tf.contrib.eager.defun` caches graphs for your convenience, letting you
1753  define TensorFlow functions without explicitly specifying their signatures.
1754  However, this policy is conservative and potentially expensive; for example,
1755  when different invocations of your function have differently-shaped Tensor
1756  inputs, this policy might generate more graph functions than necessary. To
1757  eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
1758  optional `input_signature` argument specifying the shapes and dtypes of the
1759  inputs. In particular, the shapes may be partially unspecified, with `None`s
1760  in the unknown dimensions.  When an input signature is provided,
1761  `tf.contrib.eager.defun` will only instantiate a single graph for the
1762  decorated Python function. The following is an example:
1763
1764  ```python
1765  import tensorflow as tf
1766
1767  # The first `TensorSpec` below describes the shape and dtype of `words`,
1768  # and the second describes the shape and dtype of `another_tensor`. Note that
1769  # the last dimension of the `words` `TensorSpec` is left unspecified.
1770  @tf.contrib.eager.defun(input_signature=[
1771    tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
1772    tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
1773  ])
1774  def my_sequence_model(words, another_tensor):
1775    ...
1776
1777  # Note how the third dimension of the first input can vary freely.
1778  words = tf.random_uniform(([50, 300, 10])
1779  second_input = tf.random_uniform([300, 100])
1780  my_sequence_model(words, second_input)
1781
1782  words = tf.random_uniform(([50, 300, 20])
1783  my_sequence_model(words, second_input)
1784
1785  # Passing an input with an incompatible shape will raise an error.
1786  words = tf.random_uniform(([50, 100, 20])
1787  my_sequence_model(words, second_input)  # <---- This will raise an error.
1788
1789  ```
1790
1791  Python functions that are compiled with an `input_signature` must only accept
1792  Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
1793
1794  _Tracing_
1795
1796  Be aware that because `F` only logs TensorFlow operations, all the other
1797  Python code that `f` executes will only shape the _construction_ of the graphs
1798  that `F` executes: the Python code won't be executed when the graphs
1799  themselves are executed, though it will be executed every time the Python
1800  function is traced (and a given Python function might be traced multiple
1801  times, once for each input signature it is invoked with). For example, whereas
1802  the Python function
1803
1804  ```python
1805  import tensorflow as tf
1806  import numpy as np
1807
1808  tf.enable_eager_execution()
1809
1810  def add_noise():
1811    return tf.eye(5) + np.random.randn(5, 5)
1812  ```
1813
1814  will return a different output everytime it is invoked, the compiled function
1815  `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
1816  every time it is called, since a particular random offset generated by NumPy
1817  will be inserted into the graph as a TensorFlow constant. The solution is to
1818  replace the call to `np.random.randn` with `tf.random_normal((5, 5))`.
1819
1820  _Python Side-Effects_
1821
1822  A corollary of the previous discussion on tracing is the following: If a
1823  Python function `f` has Python side-effects, then executing `f` multiple times
1824  will not necessarily be semantically equivalent to executing `F =
1825  tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
1826  that `defun` only captures the subgraph of TensorFlow operations that is
1827  constructed when `f` is called in a graph-building context.
1828
1829  _Python Control Flow_
1830
1831  The structure of many machine learning computations depend upon whether one is
1832  training or validating, and it is common to nest specialized logic under `if
1833  training:` blocks. By mapping each input signature to a unique graph, `defun`
1834  lets users transparently compile such code, as the following code snippet
1835  demonstrates:
1836
1837  ```python
1838  import tensorflow as tf
1839
1840  tf.enable_eager_execution()
1841
1842  @tf.contrib.eager.defun
1843  def lossy_matmul(W, x, training=True):
1844    outputs = tf.matmul(W, x)
1845    if training:
1846      outputs = tf.nn.dropout(outputs, keep_probability=0.2)
1847    return outputs
1848
1849  W = tf.random_normal((3, 5))
1850  x = tf.random_normal((5, 1))
1851
1852  # Executes a graph that applies dropout.
1853  lossy_outputs = lossy_matmul(W, x, training=True)
1854
1855  # Executes a graph that does not apply dropout.
1856  exact_outputs = lossy_matmul(W, x, training=False)
1857  ```
1858
1859  _TensorFlow Control Flow_
1860
1861  When `autograph` is `True`, data-dependent control flow is allowed as well.
1862  Control flow statements that depend on `Tensor` values are staged into
1863  corresponding TensorFlow ops. For example, the following code will work as
1864  expected:
1865
1866  ```python
1867  @tf.contrib.eager.defun
1868  def dynamic_rnn_loop(cell, seq):
1869    state, output = cell.zero_state()
1870    for input in seq:
1871      state, output = cell(input, state)
1872    return output
1873  ```
1874
1875  For more information see `tf.autograph`.
1876
1877  _Variables_
1878
1879  TensorFlow operations related to variable creation and initialization are
1880  automatically lifted out of the graphs generated by `defun`. In practice, this
1881  implies that variable creation and initialization only happen the first time
1882  `F` is called, and that variables are reused every time thereafter. Many
1883  TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
1884  first time they are called and reuse them thereafter. Automatic variable
1885  lifting makes it possible to compile these APIs without extra effort, at the
1886  cost of introducing a discrepancy between the semantics of executing Python
1887  functions and their corresponding compiled functions. For example:
1888
1889  ```python
1890  import tensorflow as tf
1891
1892  tf.enable_eager_execution()
1893
1894  def fn():
1895    x = tf.Variable(0.0)
1896    x.assign_add(1.0)
1897    return x.read_value()
1898
1899  # `fn` is a Python function, so x is created, initialized, and destroyed upon
1900  # every invocation
1901  assert fn().numpy() == fn().numpy() == 1.0
1902
1903  compiled = tf.contrib.eager.defun(fn)
1904
1905  # Compiling `fn` with `defun` hoists all variables outside of the generated
1906  # graph, so initialization happens exactly once.
1907  assert compiled().numpy() == 1.0
1908  assert compiled().numpy() == 2.0
1909  ```
1910
1911  Finally, because each input signature is bound to a unique graph, if your
1912  Python function constructs `tf.Variable` objects, then each graph constructed
1913  for that Python function will reference a unique set of variables. To
1914  circumvent this problem, we recommend against compiling Python functions that
1915  create `tf.Variable` objects. Instead, Python functions should either
1916  lexically close over `tf.Variable` objects or accept them as arguments,
1917  preferably encapsulated in an object-oriented container. If you must create
1918  variables inside your Python function and you want each graph generated for it
1919  to reference the same set of variables, add logic to your Python function that
1920  ensures that variables are only created the first time it is called and are
1921  reused for every subsequent invocation; note that this is precisely what
1922  `tf.keras.layers.Layer` objects do, so we recommend using them to represent
1923  variable-bearing computations whenever possible.
1924
1925  Args:
1926    func: function to be compiled. If `func` is None, returns a
1927      decorator that can be invoked with a single argument - `func`. The
1928      end result is equivalent to providing all the arguments up front.
1929      In other words, defun(input_signature=...)(func) is equivalent to
1930      defun(func, input_signature=...). The former allows
1931      the following use case:
1932        @tf.contrib.eager.defun(input_signature=...)
1933        def foo(...):
1934          ...
1935
1936    input_signature: A possibly nested sequence of
1937      `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
1938      the Tensors that will be supplied to this function. If `None`, a separate
1939      function is instantiated for each inferred input signature.  If a
1940      signature is specified, every input to `func` must be a `Tensor`, and
1941      `func` cannot accept `**kwargs`.
1942    autograph: Whether `func` should be compiled before
1943      constructing the graph. See https://www.tensorflow.org/guide/autograph
1944      for more information.
1945    experimental_autograph_options: Experimental knobs (in the form of a tuple
1946      of tensorflow.autograph.Feature values) to control behavior when
1947      autograph=True.
1948
1949
1950  Returns:
1951     If `func` is not None, returns a callable that will execute the compiled
1952     function (and return zero or more `tf.Tensor` objects).
1953     If `func` is None, returns a decorator that, when invoked with a single
1954     `func` argument, returns a callable equivalent to the case above.
1955
1956  Raises:
1957    TypeError: If `input_signature` is neither `None` nor a sequence of
1958      `tf.contrib.eager.TensorSpec` objects.
1959  """
1960  return defun_with_attributes(
1961      func=func,
1962      input_signature=input_signature,
1963      autograph=autograph,
1964      experimental_autograph_options=experimental_autograph_options)
1965
1966
1967def defun_with_attributes(func=None,
1968                          input_signature=None,
1969                          attributes=None,
1970                          autograph=True,
1971                          experimental_autograph_options=None):
1972  """Compiles a Python function into a callable TensorFlow graph.
1973
1974  This function supports adding extra function attributes. See detailed
1975  documentation in defun(). Currently this is not exposed in public API since we
1976  don't expect user to directly use attributes, and attribute won't work by
1977  itself. This assumption might change in future.
1978
1979  Args:
1980    func: function to be compiled.
1981    input_signature: same as defun()'s input_signature.
1982    attributes: A dictionary of arguments which will be added to function def as
1983      attributes. Currently only support primitive types as value, and only
1984      whitelisted attribute name is allowed. Unwhitelisted attribute name or
1985      unsupported value will result into ValueError. `func_name` is also one of
1986      the whitelisted argument which is a python string, and sets the name for
1987      this `ConcreteFunction` in the graph.
1988    autograph: same as defun()'s autograph.
1989    experimental_autograph_options: same as defun()'s
1990      experimental_autograph_options.
1991
1992  Returns:
1993    Same as the return value of defun, with attributes added to the function in
1994    graph.
1995  """
1996  if input_signature is not None:
1997    validate_signature(input_signature)
1998
1999  # TODO(apassos): deal with captured global state. Deal with control flow.
2000  def decorated(function):
2001    try:
2002      if attributes:
2003        name = attributes.pop("func_name", function.__name__)
2004      else:
2005        name = function.__name__
2006    except AttributeError:
2007      name = "function"
2008    return tf_decorator.make_decorator(
2009        function,
2010        Function(
2011            function,
2012            name,
2013            input_signature=input_signature,
2014            attributes=attributes,
2015            autograph=autograph,
2016            autograph_options=experimental_autograph_options))
2017
2018  # This code path is for the `foo = tfe.defun(foo, ...)` use case
2019  if func is not None:
2020    return decorated(func)
2021
2022  # This code path is for the
2023  #
2024  # @tfe.defun(...)
2025  # def foo(...):
2026  #    ...
2027  #
2028  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
2029  return decorated
2030
2031
2032# When a method is bound to objects of this type, it allows AutoGraph to
2033# recover a weak reference the original method's self pointer, so that it can
2034# execute it consistent with class_method_to_instance_method's
2035# bound_method_wrapper.
2036# TODO(b/119246461): This is not pretty. Use a descriptor instead?
2037class TfMethodTarget(object):
2038  """Binding target for methods replaced by function and defun."""
2039
2040  def __init__(self, target, original_python_function):
2041    self.weakrefself_target__ = target
2042    self.weakrefself_func__ = weakref.ref(original_python_function)
2043
2044  @property
2045  def target(self):
2046    return self.weakrefself_target__()
2047
2048  def call(self, args, kwargs):
2049    wrapped_fn = self.weakrefself_func__()
2050    if tf_inspect.ismethod(wrapped_fn):
2051      wrapped_fn = six.get_unbound_function(wrapped_fn)
2052    return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)
2053
2054
2055def class_method_to_instance_method(original_function, instance):
2056  """Constructs a new `Function` with `self` bound."""
2057  weak_instance = weakref.ref(instance)
2058
2059  # Note: while we could bind to a weakref proxy instead, that causes the
2060  # bound method to be unhashable.
2061  bound_method = types_lib.MethodType(
2062      original_function.python_function,
2063      TfMethodTarget(weak_instance, original_function.python_function))
2064
2065  # original_function is expected to be of one of the two `Function` types
2066  # (defined either in function.py or def_function.py).
2067  assert hasattr(original_function, "_name")
2068  assert hasattr(original_function, "_autograph")
2069  assert hasattr(original_function, "_function_spec")
2070  assert hasattr(original_function, "python_function")
2071
2072  weak_bound_method_wrapper = None
2073  def bound_method_wrapper(*args, **kwargs):
2074    """Wraps either a dummy MethodType or a converted AutoGraph function."""
2075    # __wrapped__ allows AutoGraph to swap in a converted function.
2076    strong_bound_method_wrapper = weak_bound_method_wrapper()
2077    wrapped_fn = strong_bound_method_wrapper.__wrapped__
2078
2079    if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
2080      # If __wrapped__ was not replaced, then call original_function.
2081      # TODO(mdan): For better consistency, use the wrapper's call().
2082      wrapped_fn = original_function.python_function
2083      if tf_inspect.ismethod(wrapped_fn):
2084        wrapped_fn = six.get_unbound_function(wrapped_fn)
2085      return wrapped_fn(weak_instance(), *args, **kwargs)
2086
2087    # If __wrapped__ was replaced, then it is always an unbound function.
2088    # However, the replacer is still responsible for attaching self properly.
2089    # TODO(mdan): Is it possible to do it here instead?
2090    return wrapped_fn(*args, **kwargs)
2091  weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
2092
2093  # pylint: disable=protected-access
2094  # We make a dummy MethodType object to generate the correct bound method
2095  # signature. The actual call is to a function with a weak reference to
2096  # `instance`.
2097  instance_func = type(original_function)(
2098      tf_decorator.make_decorator(bound_method, bound_method_wrapper),
2099      name=original_function._name,
2100      autograph=original_function._autograph,
2101      input_signature=original_function.input_signature)
2102  # pylint: enable=protected-access
2103
2104  # And we wrap the function with tf_decorator so inspection works correctly
2105  wrapped_instance_func = tf_decorator.make_decorator(
2106      original_function.python_function, instance_func)
2107  return wrapped_instance_func
2108
2109
2110class _FunctionGarbageCollector(object):
2111  """Cleans up cycles when a defun goes out of scope."""
2112
2113  def __init__(self, cache):
2114    self._cache = cache
2115
2116  def __del__(self):
2117    if func_graph_module is None or memory is None:
2118      return
2119    try:
2120      while self._cache:
2121        self._cache.popitem()
2122      memory.dismantle_ordered_dict(self._cache)
2123    except:  # pylint: disable=bare-except
2124      pass
2125
2126
2127class ConcreteFunctionGarbageCollector(object):
2128  """Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
2129
2130  def __init__(self, func_graph):
2131    self._func_graph = func_graph
2132
2133  def release(self):
2134    """Call off the FuncGraph deletion."""
2135    self._func_graph = None
2136
2137  def __del__(self):
2138    if func_graph_module is None or memory is None or self._func_graph is None:
2139      return
2140    try:
2141      func_graph_module.dismantle_func_graph(self._func_graph)
2142    except:  # pylint: disable=bare-except
2143      pass
2144