• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
16
17NOTE: At this time, functions are experimental and subject to change!. Proceed
18with caution.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import hashlib
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import function_pb2
30from tensorflow.python.client import pywrap_tf_session as c_api
31from tensorflow.python.eager import context
32from tensorflow.python.framework import c_api_util
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import graph_to_function_def
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope as vs
39from tensorflow.python.util import compat
40from tensorflow.python.util import function_utils
41from tensorflow.python.util import tf_contextlib
42from tensorflow.python.util import tf_inspect
43
44
45class Defun(object):
46  """Decorator used to define TensorFlow functions.
47
48  Use this decorator to make a Python function usable directly as a TensorFlow
49  function.
50
51  The decorated function must add ops to the default graph and return zero or
52  more `Tensor` objects.  Call the decorator with named arguments, one for each
53  argument of the function to decorate, with the expected type of the argument
54  as value.
55
56  For example if the function to decorate accepts two `tf.float32` arguments
57  named `x` and `y`, call the decorator with:
58
59      @Defun(tf.float32, tf.float32)
60      def foo(x, y):
61        ...
62
63  When you call the decorated function, it adds the `call` ops to the
64  default graph. In addition, it adds the definition of the function into the
65  default graph. Because the addition of the function into the graph
66  is deferred, the decorator can be used anywhere in the program.
67
68  Any variables created inside of the function are hoisted into the outer graph.
69  Note that the variables are created in the variable scope that was active
70  during the first call to the function. Subsequent function calls will refer to
71  the same set of variables.
72
73  Definitions of functions in a graph are frozen as soon as the graph is used to
74  create a session. However, new functions and new calls to existing functions
75  may be added to the graph, with the new functions themselves becoming
76  immediately frozen.
77
78  Example, but also see the [How To on functions](link_needed).
79
80  ```python
81  # Defining the function.
82  @tf.Defun(tf.float32, tf.float32)
83  def MyFunc(x, y):
84    return x + y, x - y
85
86  # Building the graph.
87  a = tf.constant([1.0])
88  b = tf.constant([2.0])
89  c, d = MyFunc(a, b, name='mycall')
90  ```
91  """
92
93  def __init__(self, *input_types, **kwargs):
94    """Create a `Defun` decorator.
95
96    Args:
97      *input_types: A list of `tf.DType`
98      **kwargs: Optional keyword arguments, including
99         func_name - (optional).  A python string, the name to use to
100           declare this `Function` in the graph.
101
102         grad_func - (optional).  A function implementing the gradient
103           of the function-to-register.  This is must be a
104           `_DefinedFunction` object. The gradient
105           function must satisfy the criterion defined in
106           function.proto:GradientDef.
107
108         python_grad_func - (optional).  A function implementing the
109           gradient of the function python-side. This function must
110           take the current op and the gradients w.r.t. its outputs,
111           and return the gradients w.r.t. the inputs. That is it must
112           implement the interface expected by `tf.RegisterGradient`).
113           This will be called by tf.gradients to add the gradient ops
114           to the graph. At most one of grad_func and python_grad_func
115           can be specified.
116
117         out_names = (optional). A list of strings, one per output
118           tensor.
119
120         shape_func - (optional). A function taking the op and returning a list
121           of static shapes to set for the function's outputs.
122    """
123    self._input_types = input_types
124    self._func_name = kwargs.pop("func_name", None)
125    self._grad_func = kwargs.pop("grad_func", None)
126    self._python_grad_func = kwargs.pop("python_grad_func", None)
127    self._out_names = kwargs.pop("out_names", None)
128    self._extra_kwargs = kwargs
129
130  def __call__(self, func):
131    # Various sanity checks on the callable func.
132    if not callable(func):
133      raise ValueError(f"Function {func} must be a callable.")
134
135    # Func should not use kwargs and defaults.
136    argspec = tf_inspect.getargspec(func)
137    if argspec.keywords or argspec.defaults:
138      raise ValueError(
139          "Functions with argument defaults or keywords arguments are not "
140          f"supported. {func} has defaults {argspec.defaults} and keywords "
141          f"{argspec.keywords}.")
142
143    # Computes how many arguments 'func' has.
144    min_args = len(argspec.args)
145    max_args = min_args
146    if argspec.varargs:
147      max_args = 1000000
148    argnames = argspec.args
149    if tf_inspect.ismethod(func):
150      # 1st argument is the "class" type.
151      min_args -= 1
152      argnames = argnames[1:]
153
154    if self._input_types:
155      # If Defun is given a list of types for the inputs, the number
156      # of input types should be compatible with 'func'.
157      num = len(self._input_types)
158      if num < min_args or num > max_args:
159        raise ValueError(
160            "The number of tf.function input types is not compatible with the "
161            f"allowed arguments of {func}. The tf.function have {num} input "
162            f"types, while the python function allows minimum {min_args} and "
163            f"maximum {max_args} arguments.")
164      return _DefinedFunction(
165          func,
166          argnames,
167          self._input_types,
168          self._func_name,
169          self._grad_func,
170          self._python_grad_func,
171          out_names=self._out_names,
172          **self._extra_kwargs)
173
174    # 'func' expects no arguments and input types is an empty list.
175    if min_args == 0 and max_args == 0:
176      return _DefinedFunction(
177          func, [], [],
178          self._func_name,
179          self._grad_func,
180          self._python_grad_func,
181          out_names=self._out_names,
182          **self._extra_kwargs)
183
184    # Input types are unknown. It's an overloaded function and hence
185    # its definition needs to be deferred until it's called.
186    return _OverloadedFunction(
187        func,
188        argnames,
189        self._func_name,
190        self._grad_func,
191        self._python_grad_func,
192        out_names=self._out_names,
193        **self._extra_kwargs)
194
195
196class _DefinedFunctionDeleter(object):
197  """Unregister function from eager context."""
198
199  __slots__ = ["name"]
200
201  def __init__(self, name):
202    self.name = name
203
204  def __del__(self):
205    try:
206      context.remove_function(self.name)
207    except TypeError:
208      # Suppress some exceptions, mainly for the case when we're running on
209      # module deletion. Things that can go wrong include the context module
210      # already being unloaded, self._handle._handle_data no longer being
211      # valid, and so on. Printing warnings in these cases is silly
212      # (exceptions raised from __del__ are printed as warnings to stderr).
213      pass  # 'NoneType' object is not callable when the handle has been
214      # partially unloaded.
215    except AttributeError:
216      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
217      # been unloaded. Will catch other module unloads as well.
218
219
220class _DefinedFunction(object):
221  """_DefinedFunction encapsulates a function definition and its properties.
222
223  Attributes:
224    name: The function name.
225    definition: The definition of this function. A FunctionDef proto.
226    grad_func_name: If not None, the name of this function's gradient function.
227    python_grad_func: A python callable implementing the gradient of
228      the function python-side.
229  """
230
231  def __init__(self,
232               func,
233               argnames,
234               input_types,
235               func_name=None,
236               grad_func=None,
237               python_grad_func=None,
238               out_names=None,
239               shape_func=None,
240               capture_by_value=False,
241               allowlisted_stateful_ops=None,
242               capture_resource_var_by_value=True,
243               **kwargs):
244    """Creates _DefinedFunction.
245
246    Args:
247      func:  A python callable which constructs a tf function body.
248      argnames: A list of strings for function argument names.
249      input_types: The function's argument types. Can be a tuple, list of
250        tf data types.
251      func_name: The function name. Defaults to None, in which derives from
252        'func'.
253      grad_func: This function's gradient function, if not None. Defaults
254        to None.
255      python_grad_func: A python callable implementing the gradient of
256        the function python-side.
257      out_names: An optional list of strings for the function return value
258        names.
259      shape_func: An optional function mapping an op to a list of static
260        output shapes.
261      capture_by_value: Boolean (defaults to False). If True, captured values
262        will be copied into the function body.
263      allowlisted_stateful_ops: A set of ops that if stateful we ignore and
264        copy into the function body, when `capture_by_value` is True.
265      capture_resource_var_by_value: Boolean (defaults to True). If False,
266        captured resource variable returns the handle instead of value.
267      **kwargs: The keyword arguments. **kwargs is passed to every call
268        site of this function.
269
270    Raises:
271      ValueError: The function definition is invalid.
272
273    """
274    self._func = func
275    self._input_types = input_types
276    self._func_name = func_name
277    self._grad_func = grad_func
278    self._python_grad_func = python_grad_func
279    self._out_names = out_names
280    self._shape_func = shape_func
281    self._capture_by_value = capture_by_value
282    self._allowlisted_stateful_ops = allowlisted_stateful_ops
283    if self._allowlisted_stateful_ops is None:
284      self._allowlisted_stateful_ops = set()
285    self._capture_resource_var_by_value = capture_resource_var_by_value
286    self._extra_kwargs = kwargs
287    # Constructed only when C API is disabled, lazily
288    self._definition = None
289    # Constructed only when C API is enabled, lazily
290    self._c_func = None
291    self._function_deleter = None
292    self._sub_functions = {}  # Constructed with _definition or _c_func
293    # pylint: disable=protected-access
294    device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
295    # pylint: enable=protected-access
296
297    # Get the innermost device if possible.
298    self._caller_device = device_funcs[-1] if device_funcs else None
299
300    # Cached OpDef for this function. When C API is enabled, this is
301    # the only part of FunctionDef that we cache in Python. When C API
302    # is disabled the whole _definition is available and this is simply
303    # another reference to _definition.signature
304    self._op_def = None
305
306    assert isinstance(input_types, (list, tuple))
307    self._arg_types = input_types
308    self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
309                       for i in range(len(input_types))]
310
311  @property
312  def name(self):
313    """Function name."""
314    self._create_definition_if_needed()
315    return self._func_name
316
317  @property
318  def definition(self):
319    """Function definition proto."""
320    self._create_definition_if_needed()
321    if self._c_func:
322      with c_api_util.tf_buffer() as buf:
323        c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
324        fdef = function_pb2.FunctionDef()
325        proto_data = c_api.TF_GetBuffer(buf)
326        fdef.ParseFromString(compat.as_bytes(proto_data))
327        with ops.init_scope():
328          if context.executing_eagerly():
329            context.add_function(self._c_func.func)
330            self._function_deleter = _DefinedFunctionDeleter(
331                fdef.signature.name)
332      return fdef
333    return self._definition
334
335  @property
336  def _signature(self):
337    self._create_definition_if_needed()
338    return self._op_def
339
340  def set_grad_func(self, grad_func):
341    """Specifies the gradient function of this function."""
342    assert not self._grad_func
343    assert isinstance(grad_func, _DefinedFunction)
344    self._grad_func = grad_func
345
346  @property
347  def grad_func_name(self):
348    """Returns the name of the gradient function."""
349    return self._grad_func.name if self._grad_func else None
350
351  @property
352  def python_grad_func(self):
353    """Python gradient function callable."""
354    return self._python_grad_func
355
356  @property
357  def declared_input_types(self):
358    """Returns the list of data types of explicit declared inputs."""
359    return self._input_types
360
361  @property
362  def captured_inputs(self):
363    """Returns the list of implicitly captured inputs."""
364    self._create_definition_if_needed()
365    return self._extra_inputs
366
367  @property
368  def stateful_ops(self):
369    """Returns the list of stateful ops in function definition.
370
371    Returns:
372      A list of (op.name, op.type) pairs.
373    """
374    self._create_definition_if_needed()
375    return self._stateful_ops
376
377  def _create_definition_if_needed(self):
378    """Creates the function definition if it's not created yet."""
379    with context.graph_mode():
380      self._create_definition_if_needed_impl()
381
382  def _create_definition_if_needed_impl(self):
383    """This is not what you want, see _create_definition_if_needed."""
384    if self._definition is not None or self._c_func is not None:
385      return
386
387    # Copy variable collections (by reference) from the parent graph such that
388    # name based variable sharing (e.g. via tf.make_template) works between the
389    # func graph and parent graph.
390    variable_keys = []
391    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
392    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access
393
394    parent_graph = ops.get_default_graph()
395    collections_ref = {
396        key: parent_graph.get_collection_ref(key) for key in variable_keys}
397
398    temp_graph = func_graph_from_py_func(
399        self._func,
400        self._arg_names,
401        self._arg_types,
402        self._func_name,
403        self._capture_by_value,
404        self._caller_device,
405        collections_ref=collections_ref,
406        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
407        capture_resource_var_by_value=self._capture_resource_var_by_value)
408
409    self._extra_inputs = temp_graph.extra_inputs
410    # pylint: disable=protected-access
411    self._sub_functions = temp_graph._functions
412    # pylint: enable=protected-access
413
414    # Extra kwargs are treated as attrs on the function def.
415    if self._func_name:
416      base_func_name = self._func_name
417    else:
418      base_func_name = function_utils.get_func_name(self._func)
419      if self._grad_func:
420        base_func_name += ("_%s" % self._grad_func.name)
421    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
422
423    if not temp_graph._c_graph:  # pylint: disable=protected-access
424      # Build the FunctionDef
425      self._definition = graph_to_function_def.graph_to_function_def(
426          temp_graph,
427          temp_graph.get_operations(),
428          temp_graph.inputs,
429          temp_graph.outputs,
430          out_names=self._out_names)
431
432      for k in kwargs_attr:
433        self._definition.attr[k].CopyFrom(kwargs_attr[k])
434
435      # Hash the definition and its dependencies.
436      self._hash_str = self._create_hash_str(
437          self._definition.signature.input_arg,
438          self._definition.signature.output_arg, self._definition.node_def)
439
440      # Finally, we decide the function name to use.  If not specified,
441      # make up something which is almost certainly unique (but deterministic).
442      if not self._func_name:
443        self._func_name = "_".join([base_func_name, self._hash_str])
444      self._definition.signature.name = self._func_name
445      if self._func.__doc__:
446        self._definition.signature.description = self._func.__doc__
447
448      self._op_def = self._definition.signature
449    else:  # C API is enabled
450      output_names = ([compat.as_bytes(x) for x in self._out_names]
451                      if self._out_names else [])
452      description = self._func.__doc__ or None
453      # pylint: disable=protected-access
454      c_func = c_api.TF_GraphToFunction_wrapper(
455          temp_graph._c_graph,
456          base_func_name,
457          self._func_name is None,  # append_hash_to_fn_name
458          None,  # opers
459          [t._as_tf_output() for t in temp_graph.inputs],
460          [t._as_tf_output() for t in temp_graph.outputs],
461          output_names,
462          [], # control_outputs
463          [], # control_output_names
464          None,  # opts
465          description)
466      self._c_func = c_api_util.ScopedTFFunction(c_func)
467      # pylint: enable=protected-access
468      self._set_c_attrs(kwargs_attr)
469
470      # Set cached fields: _op_def and _func_name (if not already set)
471      self._op_def = self.definition.signature
472      if self._func_name:
473        assert self._func_name == self._op_def.name
474      else:
475        self._func_name = compat.as_str(self._op_def.name)
476
477    self._stateful_ops = [(op.name, op.type)
478                          for op in temp_graph.get_operations()
479                          if op._is_stateful]  # pylint: disable=protected-access
480
481  def _set_c_attrs(self, attrs):
482    """Sets `attrs` as attributes of self._c_func.
483
484    Requires that self._c_func is not None.
485
486    Args:
487      attrs: a dictionary from attribute name to attribute proto value
488    """
489    for name, attr_value in attrs.items():
490      serialized = attr_value.SerializeToString()
491      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
492      # It might be worth creating a convenient way to re-use the same status.
493      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
494                                         serialized)
495
496  def _create_hash_str(self, input_arg, output_arg, node_def):
497    """Creates an 8-character string unique to this input.
498
499    Args:
500      input_arg: the input_arg field of an OpDef
501                 (e.g. self._definition.signature.input_arg)
502      output_arg: the output_arg field of an OpDef
503                 (e.g. self._definition.signature.output_arg)
504      node_def: the node_def field of a FunctionDef
505                (e.g. self._definition.node_def)
506
507    Returns:
508      The unique string for this input
509    """
510    hasher = hashlib.sha1()
511
512    def update_num(n):
513      hasher.update(compat.as_bytes("%x" % n))
514
515    def update_str(s):
516      update_num(len(s))
517      hasher.update(compat.as_bytes(s))
518
519    def update_strs(slist):
520      update_num(len(slist))
521      for s in slist:
522        update_str(s)
523
524    for adef in input_arg:
525      update_str(adef.SerializeToString())
526
527    for adef in output_arg:
528      update_str(adef.SerializeToString())
529
530    for n in sorted(node_def, key=lambda n: n.name):
531      update_str(n.name)
532      update_str(n.op)
533      update_strs(n.input)
534      update_num(len(n.attr))
535      # NOTE: protobuf map serialization does not guarantee ordering.
536      for k in sorted(n.attr):
537        update_str(k)
538        update_str(n.attr[k].SerializeToString())
539
540    return hasher.hexdigest()[:8]
541
542  def add_to_graph(self, g):
543    """Adds this function into the graph g."""
544    self._create_definition_if_needed()
545
546    # Adds this function into 'g'.
547    # pylint: disable=protected-access
548    if context.executing_eagerly():
549      context.context().add_function_def(self.definition)
550    else:
551      g._add_function(self)
552    # pylint: enable=protected-access
553
554    # Ensures related sub-routines are defined in 'g', too.
555    for f in self._sub_functions.values():
556      f.add_to_graph(g)
557
558    # Adds its gradient function, too.
559    if self._grad_func:
560      self._grad_func.add_to_graph(g)
561
562  def __call__(self, *args, **kwargs):
563    self.add_to_graph(ops.get_default_graph())
564    args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
565    ret, op = _call(self._signature, *args, **kwargs)
566
567    # Set a hidden attr in 'op' so that gradients_impl can refer back
568    # to this _DefinedFunction instance to access python_grad_func.
569    assert isinstance(op, ops.Operation)
570    setattr(op, "__defun", self)
571
572    if self._shape_func is not None:
573      shapes = self._shape_func(op)
574      if len(shapes) != len(op.outputs):
575        raise ValueError(f"shape_func {self._shape_func} produced "
576                         f"{len(shapes):d} shapes, which does not match "
577                         f"{len(op.outputs)} outputs.")
578      for (t, shape) in zip(op.outputs, shapes):
579        t.set_shape(shape)
580    return ret
581
582
583class _OverloadedFunction(object):
584  """_OverloadedFunction encapsulates an overloaded function.
585
586  _OverloadedFunction maintains a mapping from input types to
587  instantiated _DefinedFunction in self._overload.
588
589  """
590
591  def __init__(self,
592               func,
593               argnames,
594               func_name=None,
595               grad_func=None,
596               python_grad_func=None,
597               out_names=None,
598               **kwargs):
599    """Creates _DefinedFunction.
600
601    Args:
602      func:  A python callable which constructs a tf function body.
603      argnames: A list of strings for function argument names.
604      func_name: The function name. Defaults to None, in which derives from
605        'func'.
606      grad_func: This function's gradient function, if not None. Defaults
607        to None.
608      python_grad_func: A python callable implementing the gradient of
609        the function python-side.
610      out_names: A list of strings for the function return value names.
611      **kwargs: The keyword arguments. **kwargs is passed to every call
612        site of this function.
613
614    Raises:
615      ValueError: The function definition is invalid.
616
617    """
618    self._func = func
619    self._argnames = argnames
620    self._func_name = func_name
621    assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
622    self._grad_func = grad_func
623    self._python_grad_func = python_grad_func
624    self._out_names = out_names
625    self._extra_kwargs = kwargs
626    self._overload = {}
627
628  def instantiate(self, input_types):
629    """Instantiate this function given input argument types.
630
631    Args:
632      input_types: A list of data types for the inputs.
633
634    Returns:
635      _DefinedFunction for the given input types.
636
637    """
638    # Stringify the type list.
639    key = _type_list_to_str(input_types)
640    defined = self._overload.get(key)
641    if not defined:
642      # If not defined yet, define the function given the input types.
643      name = self._func_name
644      if name is not None:
645        name = "_".join([name, key])
646      defined = _DefinedFunction(
647          self._func,
648          self._argnames,
649          input_types,
650          name,
651          None,
652          self._python_grad_func,
653          out_names=self._out_names,
654          **self._extra_kwargs)
655      _ = defined.name  # Fully instantiate the function definition.
656      if self._grad_func:
657        # If _grad_func is given, it is another
658        # _OverloadedFunction. We need to instantiate it with the
659        # right input types.
660        output_types = [
661            dtypes.DType(_.type) for _ in defined._signature.output_arg  # pylint: disable=protected-access
662        ]
663        # pylint: disable=protected-access
664        defined._grad_func = self._grad_func.instantiate(input_types +
665                                                         output_types)
666        # pylint: enable=protected-access
667      self._overload[key] = defined
668    return defined
669
670  def __call__(self, *args, **kwargs):
671    input_types = []
672    args = list(args)
673    for (i, x) in enumerate(args):
674      x = ops.convert_to_tensor(x)
675      if not isinstance(x, ops.Tensor):
676        raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.")
677      input_types.append(x.dtype)
678      args[i] = x
679    return self.instantiate(input_types)(*args, **kwargs)
680
681
682class _FuncGraph(ops.Graph):
683  """A helper for constructing a function.
684
685  _FuncGraph overrides ops.Graph's create_op() so that we can keep
686  track of all inputs into every op created inside the function.  If
687  any input is from other graphs, we keep track of it in self.capture
688  and substitute the input with a place holder.
689
690  Each captured input's corresponding place holder is converted into a
691  function argument and the caller passes in the captured tensor.
692  """
693
694  def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
695               capture_resource_var_by_value, *args, **kwargs):
696    super(_FuncGraph, self).__init__(*args, **kwargs)
697    self._capture_by_value = capture_by_value
698    self._allowlisted_stateful_ops = allowlisted_stateful_ops
699    self._capture_resource_var_by_value = capture_resource_var_by_value
700    self._building_function = True
701    self._outer_graph = ops.get_default_graph()
702    self._vscope = vs.get_variable_scope()
703    self._old_custom_getter = self._vscope.custom_getter
704
705    # The name of the function.
706    self.name = name
707    # Placeholder tensors representing the inputs to this function. The tensors
708    # are in this _FuncGraph.
709    self.inputs = []
710    # Tensors that will be returned this function. The tensors are in this
711    # _FuncGraph.
712    self.outputs = []
713    # Maps external tensor -> internal tensor (e.g. input placeholder).
714    self._captured = {}
715    # The external tensors that have been captured as inputs and must be passed
716    # to this function (empty if capturing by value, otherwise these are the
717    # keys of _captured).
718    self.extra_inputs = []
719    # Input placeholders that been added for captured values (empty if capturing
720    # by value).
721    self.extra_args = []
722    # Captured variables.
723    # TODO(skyewm): is this needed?
724    self.extra_vars = []
725
726  # pylint: disable=g-doc-return-or-yield
727
728  @property
729  def outer_graph(self):
730    """The graph active when this _FuncGraph was created."""
731    return self._outer_graph
732
733  @tf_contextlib.contextmanager
734  def container(self, container_name):
735    """Returns a context manager that specifies the resource container to use.
736
737    Overridden from `tf.Graph` to update both the init_scope container
738    and the present inner container. This is necessary to make sure setting
739    containers applies correctly both to created variables and to stateful
740    ops.
741
742    Args:
743      container_name: container name string.
744
745    Returns:
746      A context manager for defining resource containers for stateful ops,
747        yields the container name.
748    """
749    original_container = self._container
750    # pylint: disable=protected-access
751    with ops.init_scope():
752      original_init_container = ops.get_default_graph()._container
753    try:
754      self._container = container_name
755      with ops.init_scope():
756        ops.get_default_graph()._container = container_name
757      yield self._container
758    finally:
759      self._container = original_container
760      with ops.init_scope():
761        ops.get_default_graph()._container = original_init_container
762    # pylint: enable=protected-access
763
764  # pylint: enable=g-doc-return-or-yield
765
766  def getvar(
767      self,
768      getter,
769      name,
770      shape=None,
771      dtype=None,
772      initializer=None,
773      reuse=None,
774      trainable=True,
775      collections=None,  # pylint: disable=redefined-outer-name
776      use_resource=None,
777      **kwargs):
778    """A custom variable getter."""
779    # Here, we switch the default graph to the outer graph and ask the
780    # variable scope in which the function is defined to give us the
781    # variable. The variable is stashed in extra_vars and returned to
782    # the caller.
783    #
784    # We capture these variables so that the variable definition is
785    # hoisted upward to the outer most graph.
786    with self._outer_graph.as_default():
787      # pylint: disable=protected-access
788      var = self._vscope.get_variable(
789          vs._get_default_variable_store(),
790          name,
791          shape=shape,
792          dtype=dtype,
793          initializer=initializer,
794          reuse=reuse,
795          trainable=trainable,
796          collections=collections,
797          use_resource=use_resource)
798      self.extra_vars.append(var)
799      if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
800          self._capture_resource_var_by_value):
801        # For resource-based variables read the variable outside the function
802        # and pass in the value. This ensures that the function is pure and
803        # differentiable. TODO(apassos) this may have performance problems if
804        # the function will only do embedding lookups on the variable.
805        return var.value()
806      return var
807
808  def _create_op_internal(
809      self,
810      op_type,
811      inputs,
812      dtypes=None,  # pylint: disable=redefined-outer-name
813      input_types=None,
814      name=None,
815      attrs=None,
816      op_def=None,
817      compute_device=True):
818    for i, x in enumerate(inputs):
819      if isinstance(x, ops.EagerTensor) or x.graph is not self:
820        inputs[i] = self.capture(x)
821    return super(_FuncGraph, self)._create_op_internal(
822        op_type,
823        inputs,
824        dtypes=dtypes,
825        input_types=input_types,
826        name=name,
827        attrs=attrs,
828        op_def=op_def,
829        compute_device=compute_device)
830
831  def capture(self, tensor, name=None):
832    """Adds the given tensor to this graph and returns the captured tensor."""
833    if tensor.ref() in self._captured:
834      # Captured already.
835      return self._captured[tensor.ref()]
836    elif self._capture_by_value:
837      return self._add_tensor_and_parents(tensor)
838    else:
839      return self._capture_tensor_as_extra_input(tensor, name)
840
841  @property
842  def captures(self):
843    """Pairs of tensors and captured tensor."""
844    return [(k.deref(), v) for k, v in self._captured.items()]
845
846  def _capture_tensor_as_extra_input(self, tensor, name=None):
847    # Substitute with a placeholder.
848    self.extra_inputs.append(tensor)
849    # Hoist the new input placeholder out of any control flow context
850    # we're currently in.
851    with ops.control_dependencies(None):
852      ph = array_ops.placeholder(
853          tensor.dtype, shape=tensor.get_shape(), name=name)
854    # pylint: disable=protected-access
855    if isinstance(tensor, ops.EagerTensor):
856      handle_data = tensor._handle_data
857      if handle_data:
858        handle_data = handle_data.SerializeToString()
859    else:
860      handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
861                                                tensor._as_tf_output())
862
863    if handle_data:
864      c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
865                                  compat.as_bytes(handle_data))
866    # pylint: enable=protected-access
867    self.inputs.append(ph)
868    self._captured[tensor.ref()] = ph
869    self.extra_args.append(ph)
870    if _is_guaranteed_const(tensor):
871      with ops.control_dependencies(None):
872        return array_ops.guarantee_const(ph)
873    else:
874      return ph
875
876  def _add_tensor_and_parents(self, tensor):
877    op = self._add_op_and_parents(tensor.op)
878    return op.outputs[tensor.value_index]
879
880  def _add_op_and_parents(self, op):
881    # pylint: disable=protected-access
882    op_def = graph_to_function_def._get_op_def(op)
883    if op._is_stateful and op not in self._allowlisted_stateful_ops:
884      raise ValueError(f"Cannot capture a stateful node (name:{op.name}, "
885                       f"type:{op.type}) by value.")
886    elif op.type in ("Placeholder", "PlaceholderV2"):
887      raise ValueError(f"Cannot capture a placeholder (name:{op.name}, "
888                       f"type:{op.type}) by value.")
889    # pylint: enable=protected-access
890
891    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
892
893    captured_op = self._create_op_internal(
894        op.type,
895        captured_inputs, [o.dtype for o in op.outputs],
896        name=op.name,
897        attrs=op.node_def.attr,
898        op_def=op_def)
899
900    for t, captured_t in zip(op.outputs, captured_op.outputs):
901      self._captured[t.ref()] = captured_t
902
903    return captured_op
904
905
906def func_graph_from_py_func(func,
907                            arg_names,
908                            arg_types,
909                            name=None,
910                            capture_by_value=False,
911                            device=None,
912                            colocation_stack=None,
913                            container=None,
914                            collections_ref=None,
915                            arg_shapes=None,
916                            allowlisted_stateful_ops=None,
917                            capture_resource_var_by_value=True):
918  """Returns a _FuncGraph generated from `func`.
919
920  Args:
921    func: A Python callable which constructs a TF function body. The arguments
922      must correspond to `arg_types`. Returns a value or list/tuple of values.
923      No returned value can be None.
924    arg_names: A sequence of strings for the function argument names.
925    arg_types: A sequence of the function's argument types.
926    name: The function name. If None, the name is derived from `func`.
927    capture_by_value: boolean. If True, captured values will be copied into the
928      function body.
929    device: device name or function.
930    colocation_stack: A colocation stack (list) the _FuncGraph should use.
931    container: A container name the _FuncGraph should start with.
932    collections_ref: A reference to a collections dict the _FuncGraph should
933      use internally.
934    arg_shapes: A sequence of the function's argument shapes.
935    allowlisted_stateful_ops: A set of ops that if stateful we ignore and
936      re-create.
937    capture_resource_var_by_value: Boolean (defaults to True). If False,
938      captured resource variable returns the handle instead of value.
939
940  Returns:
941    A _FuncGraph.
942
943  Raises:
944    ValueError: if func returns None.
945  """
946  if not name:
947    name = function_utils.get_func_name(func)
948  func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
949                          capture_resource_var_by_value)
950
951  with func_graph.as_default(), ops.device(device):
952    # pylint: disable=protected-access
953    if collections_ref is not None:
954      func_graph._collections = collections_ref
955    if container is not None:
956      func_graph._container = container
957    if colocation_stack is not None:
958      func_graph._colocation_stack = colocation_stack
959    # pylint: enable=protected-access
960
961    if arg_shapes is None:
962      arg_shapes = [None] * len(arg_types)
963
964    # Create placeholders for the function arguments.
965    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
966      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
967      func_graph.inputs.append(argholder)
968    # Call func and gather the output tensors.
969    with vs.variable_scope("", custom_getter=func_graph.getvar):
970      outputs = func(*func_graph.inputs)
971
972    # There is no way of distinguishing between a function not returning
973    # anything and a function returning None in Python.
974    # We need to allow the former and ideally want to forbid the latter as
975    # it is most likely user error.
976    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
977    # allow users to explicitly mark the function as not returning anything.
978    # For now, we allow a single None return and interpret it as a function
979    # with no output.
980    if outputs is None:
981      outputs = []
982    else:
983      # If func only returned one value, make it a tuple.
984      if not isinstance(outputs, (list, tuple)):
985        outputs = (outputs,)
986      if any(_ is None for _ in outputs):
987        raise ValueError(f"Function {name} can not return None.")
988    # Ensures each output is a Tensor in the function graph.
989    outputs = [ops.convert_to_tensor(t) for t in outputs]
990    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
991               for t in outputs]
992    func_graph.outputs = outputs
993  return func_graph
994
995
996def _is_guaranteed_const(tensor):
997  """Determines whether `tensor` is guaranteed to be a constant.
998
999  A tensor is guaranteed to be a constant if either it was produced by
1000  a `GuaranteeConst` op or if all of its children are guaranteed to be
1001  constants.
1002
1003  Args:
1004    tensor: The tensor for which to determine const-ness.
1005
1006  Returns:
1007    True if `tensor` is guaranteed to be a constant, False otherwise.
1008  """
1009
1010  if isinstance(tensor, ops.EagerTensor):
1011    return False
1012
1013  class Work(object):
1014
1015    def __init__(self, op, leaving):
1016      self.op = op
1017      self.leaving = leaving
1018
1019  is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
1020  constants = set([])
1021  def all_inputs_const(op):
1022    # If all inputs of an op are guaranteed constants, then we can infer that
1023    # the op produces a constant as well.
1024    return op.inputs and all(inp.op in constants for inp in op.inputs)
1025
1026  visited = set([])
1027  stack = [Work(tensor.op, leaving=False)]
1028  while stack:
1029    work = stack.pop()
1030    if work.leaving:
1031      if all_inputs_const(work.op):
1032        constants.add(work.op)
1033      continue
1034    visited.add(work.op)
1035    if is_guaranteed_const(work.op):
1036      constants.add(work.op)
1037      continue
1038
1039    # This op will be revisited after all its inputs are checked for const-ness.
1040    stack.append(Work(work.op, leaving=True))
1041    for inp in work.op.inputs:
1042      if inp.op not in visited:
1043        stack.append(Work(inp.op, leaving=False))
1044  return tensor.op in constants
1045
1046
1047def _call(sig, *inputs, **kwargs):
1048  """Adds a node calling a function.
1049
1050  This adds a `call` op to the default graph that calls the function
1051  of signature `sig`, passing the tensors in `inputs` as arguments.
1052  It returns the outputs of the call, which are one or more tensors.
1053
1054  `sig` is OpDefArg.a `_DefinedFunction` object.
1055
1056  You can pass an optional keyword parameter `name=string` to name the
1057  added operation.
1058
1059  You can pass an optional keyword parameter `noinline=True|False` to
1060  instruct the runtime not to inline the function body into the call
1061  site.
1062
1063  Args:
1064    sig: OpDefArg. The signature of the function.
1065    *inputs: arguments to the function.
1066    **kwargs: Optional keyword arguments.  Can only contain 'name' or
1067        'noinline'.
1068
1069  Returns:
1070     A 2-element tuple. First element: a Tensor if the function returns a single
1071     value; a list of Tensors if the function returns multiple value; the
1072     Operation if the function returns no values. Second element: the Operation.
1073
1074  Raises:
1075    ValueError: if the arguments are invalid.
1076  """
1077  if len(inputs) != len(sig.input_arg):
1078    raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got "
1079                     f"{len(inputs):d}.")
1080  name = kwargs.pop("name", None)
1081  g = ops.get_default_graph()
1082  func_name = sig.name
1083  if name is None:
1084    name = func_name
1085  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
1086  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
1087  op = g._create_op_internal(  # pylint: disable=protected-access
1088      func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig)
1089  if op.outputs:
1090    if len(op.outputs) == 1:
1091      ret = op.outputs[0]
1092    else:
1093      ret = tuple(op.outputs)
1094  else:
1095    ret = op
1096  return ret, op
1097
1098
1099def _from_definition(fdef, grad_func=None):
1100  """Creates a _DefinedFunction initialized from a FunctionDef proto.
1101
1102  Args:
1103    fdef: a FunctionDef
1104    grad_func: a _DefinedFunction or None
1105
1106  Returns:
1107    A _DefinedFunction representing fdef
1108  """
1109  # TODO(iga): This method does major surgery on _DefinedFunction.
1110  # Make it a named constructor using @classmethod of _DefinedFunction.
1111
1112  # The Python callable is only needed to create a FunctionDef. Since we have
1113  # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
1114  # have access to such a callable here).
1115  func = None
1116  argnames = [arg.name for arg in fdef.signature.input_arg]
1117  input_types = tuple(
1118      dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
1119  func_name = fdef.signature.name
1120  # Note: FunctionDefs do not include python gradient functions, so if the
1121  # original _DefinedFunction included one it will not be reflected here.
1122  python_grad_func = None
1123  out_names = [arg.name for arg in fdef.signature.output_arg]
1124  result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
1125                            python_grad_func, out_names)
1126  # pylint: disable=protected-access
1127  serialized = fdef.SerializeToString()
1128  c_func = c_api.TF_FunctionImportFunctionDef(serialized)
1129  result._c_func = c_api_util.ScopedTFFunction(c_func)
1130  result._extra_inputs = []
1131  result._op_def = fdef.signature
1132  # pylint: enable=protected-access
1133
1134  return result
1135
1136
1137def from_library(lib):
1138  """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
1139
1140  This method handles assigning the correct gradient functions to each
1141  function.
1142
1143  Args:
1144    lib: a FunctionDefLibrary
1145
1146  Returns:
1147    A list of _DefinedFunctions
1148
1149  Raises:
1150    ValueError: `lib` is invalid
1151  """
1152  if not lib.function and not lib.gradient:
1153    return []
1154
1155  # function name -> FunctionDef proto
1156  funcs = {fdef.signature.name: fdef for fdef in lib.function}
1157
1158  # Validate that all references function names have function defs
1159  for g in lib.gradient:
1160    if g.function_name not in funcs:
1161      raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' "
1162                       f"FunctionDef\n{lib}")
1163    if g.gradient_func not in funcs:
1164      raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' "
1165                       f"FunctionDef\n{lib}")
1166
1167  # function name -> gradient function name
1168  func_to_grad = collections.defaultdict(lambda: None)
1169  # gradient function name -> names of functions having that grad function
1170  grad_to_funcs = collections.defaultdict(list)
1171
1172  for gdef in lib.gradient:
1173    func_to_grad[gdef.function_name] = gdef.gradient_func
1174    grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
1175
1176  # Start with functions without gradients
1177  ready = [
1178      fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
1179  ]
1180  if not ready:
1181    raise ValueError(
1182        f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}")
1183  # function name -> _DefinedFunction
1184  initialized = {}
1185
1186  while ready:
1187    fdef = ready.pop()
1188    name = fdef.signature.name
1189
1190    grad = initialized.get(func_to_grad[name])
1191    if func_to_grad[name]:
1192      assert grad
1193    defined_func = _from_definition(fdef, grad_func=grad)
1194    initialized[name] = defined_func
1195
1196    ready.extend(funcs[f] for f in grad_to_funcs[name])
1197
1198  return initialized.values()
1199
1200
1201def _get_experimental_kwarg_as_attr(attr_name, value):
1202  """Creates an AttrValue for a python object."""
1203  if isinstance(value, bool):
1204    return attr_value_pb2.AttrValue(b=value)
1205  elif isinstance(value, int):
1206    return attr_value_pb2.AttrValue(i=value)
1207  elif isinstance(value, float):
1208    return attr_value_pb2.AttrValue(f=value)
1209  elif isinstance(value, str):
1210    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1211  else:
1212    raise ValueError(f"Attribute {attr_name} must be bool, int, float, or "
1213                     f"str. Got {type(value)}.")
1214
1215
1216def _get_kwarg_as_str_attr(attr_name, value):
1217  """Creates an AttrValue for a python object."""
1218  if isinstance(value, str):
1219    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1220  else:
1221    raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.")
1222
1223
1224def _parse_kwargs_as_attrs(func_name, **kwargs):
1225  """Parses **kwargs into a node's attributes."""
1226  attrs = {}
1227
1228  noinline = kwargs.pop("noinline", None)
1229  if noinline is not None:
1230    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
1231
1232  # For compatibility with previous behavior, Defun does not perform shape
1233  # inference through its function call operations.
1234  attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True)
1235
1236  compiled = kwargs.pop("compiled", None)
1237  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
1238  if compiled is not None:
1239    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
1240    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
1241        b=bool(separate_compiled_gradients))
1242    # Forward _XlaScope from enclosing context (if set), otherwise create new.
1243    # pylint: disable=protected-access
1244    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
1245      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
1246    else:
1247      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
1248          s=("function_%s" % func_name).encode())
1249    # pylint: enable=protected-access
1250
1251  kwargs_keys = list(kwargs.keys())
1252  for key in kwargs_keys:
1253    if key.startswith("experimental_"):
1254      attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
1255      del kwargs[key]
1256    # Support for https://github.com/tensorflow/community/pull/113/files.
1257    elif key == "_implements" or key == "_reference":
1258      attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key])
1259      del kwargs[key]
1260  if kwargs:
1261    raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.")
1262  return attrs
1263
1264
1265def get_extra_vars():
1266  """Returns the captured variables by the function.
1267
1268  Returns:
1269    If the default graph is being used to define a function, the
1270    returned list of variables are those created inside the function
1271    body so far. Otherwise, returns an empty list.
1272  """
1273  g = ops.get_default_graph()
1274  if isinstance(g, _FuncGraph):
1275    return g.extra_vars
1276  else:
1277    return []
1278
1279
1280def get_extra_inputs():
1281  """Returns the captured input tensors by the function.
1282
1283  Returns:
1284    If the default graph is being used to define a function, the
1285    returned list of tensors are those accessed inside the function body
1286    but defined outside the function body so far. Otherwise, returns an
1287    empty list.
1288  """
1289  g = ops.get_default_graph()
1290  if isinstance(g, _FuncGraph):
1291    return g.extra_inputs
1292  else:
1293    return []
1294
1295
1296def get_extra_args():
1297  """Returns the corresponding function arguments for the captured inputs.
1298
1299  Returns:
1300    If the default graph is being used to define a function, the
1301    returned list of place holders are those used inside the function
1302    body corresponding those returned by get_extra_inputs(). Otherwise,
1303    returns an empty list.
1304  """
1305  g = ops.get_default_graph()
1306  if isinstance(g, _FuncGraph):
1307    return g.extra_args
1308  else:
1309    return []
1310
1311
1312def _type_list_to_str(types):
1313  if any(_ not in _DTYPE_TO_STR for _ in types):
1314    unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR]
1315    raise ValueError(f"Unsupported dtypes {unsupported_types} in "
1316                     "`types`. Supported dtypes are "
1317                     f"{_DTYPE_TO_STR.keys()}.")
1318  return "".join(_DTYPE_TO_STR[_] for _ in types)
1319
1320
1321# NOTE: The list needs to be extended when more data types are added.
1322_DTYPE_TO_STR = {
1323    dtypes.float16: "f16",
1324    dtypes.float32: "f32",
1325    dtypes.float64: "f64",
1326    dtypes.int32: "i32",
1327    dtypes.uint8: "i8",
1328    dtypes.uint16: "u16",
1329    dtypes.uint32: "u32",
1330    dtypes.uint64: "u64",
1331    dtypes.int16: "i16",
1332    dtypes.int8: "i8",
1333    dtypes.string: "s",
1334    dtypes.complex64: "c64",
1335    dtypes.complex128: "c128",
1336    dtypes.int64: "i64",
1337    dtypes.bool: "b",
1338    dtypes.qint8: "qi8",
1339    dtypes.quint8: "qu8",
1340    dtypes.qint16: "qi16",
1341    dtypes.quint16: "qu16",
1342    dtypes.qint32: "qi32",
1343    dtypes.bfloat16: "b16"
1344}
1345
1346
1347def function_def_from_tf_function(c_func):
1348  """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
1349  with c_api_util.tf_buffer() as buf:
1350    c_api.TF_FunctionToFunctionDef(c_func, buf)
1351    data = c_api.TF_GetBuffer(buf)
1352  fdef = function_pb2.FunctionDef()
1353  fdef.ParseFromString(compat.as_bytes(data))
1354  return fdef
1355