• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
16
17NOTE: At this time, functions are experimental and subject to change!. Proceed
18with caution.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import hashlib
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import function_pb2
30from tensorflow.python import pywrap_tensorflow as c_api
31from tensorflow.python.eager import context
32from tensorflow.python.framework import c_api_util
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import graph_to_function_def
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope as vs
39from tensorflow.python.util import compat
40from tensorflow.python.util import function_utils
41from tensorflow.python.util import tf_contextlib
42from tensorflow.python.util import tf_inspect
43
44
45class Defun(object):
46  """Decorator used to define TensorFlow functions.
47
48  Use this decorator to make a Python function usable directly as a TensorFlow
49  function.
50
51  The decorated function must add ops to the default graph and return zero or
52  more `Tensor` objects.  Call the decorator with named arguments, one for each
53  argument of the function to decorate, with the expected type of the argument
54  as value.
55
56  For example if the function to decorate accepts two `tf.float32` arguments
57  named `x` and `y`, call the decorator with:
58
59      @Defun(tf.float32, tf.float32)
60      def foo(x, y):
61        ...
62
63  When you call the decorated function, it adds the `call` ops to the
64  default graph. In addition, it adds the definition of the function into the
65  default graph. Because the addition of the function into the graph
66  is deferred, the decorator can be used anywhere in the program.
67
68  Any variables created inside of the function are hoisted into the outer graph.
69  Note that the variables are created in the variable scope that was active
70  during the first call to the function. Subsequent function calls will refer to
71  the same set of variables.
72
73  Definitions of functions in a graph are frozen as soon as the graph is used to
74  create a session. However, new functions and new calls to existing functions
75  may be added to the graph, with the new functions themselves becoming
76  immediately frozen.
77
78  Example, but also see the [How To on functions](link_needed).
79
80  ```python
81  # Defining the function.
82  @tf.Defun(tf.float32, tf.float32)
83  def MyFunc(x, y):
84    return x + y, x - y
85
86  # Building the graph.
87  a = tf.constant([1.0])
88  b = tf.constant([2.0])
89  c, d = MyFunc(a, b, name='mycall')
90  ```
91  """
92
93  def __init__(self, *input_types, **kwargs):
94    """Create a `Defun` decorator.
95
96    Args:
97      *input_types: A list of `tf.DType`
98      **kwargs: Optional keyword arguments, including
99         func_name - (optional).  A python string, the name to use to
100           declare this `Function` in the graph.
101
102         grad_func - (optional).  A function implementing the gradient
103           of the function-to-register.  This is must be a
104           `_DefinedFunction` object. The gradient
105           function must satisfy the criterion defined in
106           function.proto:GradientDef.
107
108         python_grad_func - (optional).  A function implementing the
109           gradient of the function python-side. This function must
110           take the current op and the gradients w.r.t. its outputs,
111           and return the gradients w.r.t. the inputs. That is it must
112           implement the interface expected by `tf.RegisterGradient`).
113           This will be called by tf.gradients to add the gradient ops
114           to the graph. At most one of grad_func and python_grad_func
115           can be specified.
116
117         out_names = (optional). A list of strings, one per output
118           tensor.
119
120         shape_func - (optional). A function taking the op and returning a list
121           of static shapes to set for the function's outputs.
122    """
123    self._input_types = input_types
124    self._func_name = kwargs.pop("func_name", None)
125    self._grad_func = kwargs.pop("grad_func", None)
126    self._python_grad_func = kwargs.pop("python_grad_func", None)
127    self._out_names = kwargs.pop("out_names", None)
128    self._extra_kwargs = kwargs
129
130  def __call__(self, func):
131    # Various sanity checks on the callable func.
132    if not callable(func):
133      raise ValueError("function %s must be callable" % func)
134
135    # Func should not use kwargs and defaults.
136    argspec = tf_inspect.getargspec(func)
137    if argspec.keywords or argspec.defaults:
138      raise ValueError(
139          "function with argument defaults or keywords arguments are not"
140          " supported. {} has defaults {} and keywords {}.".format(
141              func, argspec.defaults, argspec.keywords))
142
143    # Computes how many arguments 'func' has.
144    min_args = len(argspec.args)
145    max_args = min_args
146    if argspec.varargs:
147      max_args = 1000000
148    argnames = argspec.args
149    if tf_inspect.ismethod(func):
150      # 1st argument is the "class" type.
151      min_args -= 1
152      argnames = argnames[1:]
153
154    if self._input_types:
155      # If Defun is given a list of types for the inputs, the number
156      # of input types should be compatible with 'func'.
157      num = len(self._input_types)
158      if num < min_args or num > max_args:
159        raise ValueError(
160            "The function has fewer arguments than the number of specified "
161            "input types.")
162      return _DefinedFunction(
163          func,
164          argnames,
165          self._input_types,
166          self._func_name,
167          self._grad_func,
168          self._python_grad_func,
169          out_names=self._out_names,
170          **self._extra_kwargs)
171
172    # 'func' expects no arguments and input types is an empty list.
173    if min_args == 0 and max_args == 0:
174      return _DefinedFunction(
175          func, [], [],
176          self._func_name,
177          self._grad_func,
178          self._python_grad_func,
179          out_names=self._out_names,
180          **self._extra_kwargs)
181
182    # Input types are unknown. It's an overloaded function and hence
183    # its definition needs to be deferred until it's called.
184    return _OverloadedFunction(
185        func,
186        argnames,
187        self._func_name,
188        self._grad_func,
189        self._python_grad_func,
190        out_names=self._out_names,
191        **self._extra_kwargs)
192
193
194class _DefinedFunction(object):
195  """_DefinedFunction encapsulates a function definition and its properties.
196
197  Attributes:
198    name: The function name.
199    definition: The definition of this function. A FunctionDef proto.
200    grad_func_name: If not None, the name of this function's gradient function.
201    python_grad_func: A python callable implementing the gradient of
202      the function python-side.
203  """
204
205  def __init__(self,
206               func,
207               argnames,
208               input_types,
209               func_name=None,
210               grad_func=None,
211               python_grad_func=None,
212               out_names=None,
213               shape_func=None,
214               capture_by_value=False,
215               whitelisted_stateful_ops=None,
216               capture_resource_var_by_value=True,
217               **kwargs):
218    """Creates _DefinedFunction.
219
220    Args:
221      func:  A python callable which constructs a tf function body.
222      argnames: A list of strings for function argument names.
223      input_types: The function's argument types. Can be a tuple, list of
224        tf data types.
225      func_name: The function name. Defaults to None, in which derives from
226        'func'.
227      grad_func: This function's gradient function, if not None. Defaults
228        to None.
229      python_grad_func: A python callable implementing the gradient of
230        the function python-side.
231      out_names: An optional list of strings for the function return value
232        names.
233      shape_func: An optional function mapping an op to a list of static
234        output shapes.
235      capture_by_value: Boolean (defaults to False). If True, captured values
236        will be copied into the function body.
237      whitelisted_stateful_ops: A set of ops that if stateful we ignore and
238        copy into the function body, when `capture_by_value` is True.
239      capture_resource_var_by_value: Boolean (defaults to True). If False,
240        captured resource variable returns the handle instead of value.
241      **kwargs: The keyword arguments. **kwargs is passed to every call
242        site of this function.
243
244    Raises:
245      ValueError: The function definition is invalid.
246
247    """
248    self._func = func
249    self._input_types = input_types
250    self._func_name = func_name
251    self._grad_func = grad_func
252    self._python_grad_func = python_grad_func
253    self._out_names = out_names
254    self._shape_func = shape_func
255    self._capture_by_value = capture_by_value
256    self._whitelisted_stateful_ops = whitelisted_stateful_ops
257    if self._whitelisted_stateful_ops is None:
258      self._whitelisted_stateful_ops = set()
259    self._capture_resource_var_by_value = capture_resource_var_by_value
260    self._extra_kwargs = kwargs
261    # Constructed only when C API is disabled, lazily
262    self._definition = None
263    # Constructed only when C API is enabled, lazily
264    self._c_func = None
265    self._sub_functions = dict()  # Constructed with _definition or _c_func
266    # pylint: disable=protected-access
267    device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
268    # pylint: enable=protected-access
269
270    # Get the innermost device if possbile.
271    self._caller_device = device_funcs[-1] if device_funcs else None
272
273    # Cached OpDef for this function. When C API is enabled, this is
274    # the only part of FunctionDef that we cache in Python. When C API
275    # is disabled the whole _definition is available and this is simply
276    # another reference to _definition.signature
277    self._op_def = None
278
279    assert isinstance(input_types, (list, tuple))
280    self._arg_types = input_types
281    self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
282                       for i in range(len(input_types))]
283
284  @property
285  def name(self):
286    """Function name."""
287    self._create_definition_if_needed()
288    return self._func_name
289
290  @property
291  def definition(self):
292    """Function definition proto."""
293    self._create_definition_if_needed()
294    if self._c_func:
295      with c_api_util.tf_buffer() as buf:
296        c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
297        fdef = function_pb2.FunctionDef()
298        proto_data = c_api.TF_GetBuffer(buf)
299        fdef.ParseFromString(compat.as_bytes(proto_data))
300      return fdef
301    return self._definition
302
303  @property
304  def _signature(self):
305    self._create_definition_if_needed()
306    return self._op_def
307
308  def set_grad_func(self, grad_func):
309    """Specifies the gradient function of this function."""
310    assert not self._grad_func
311    assert isinstance(grad_func, _DefinedFunction)
312    self._grad_func = grad_func
313
314  @property
315  def grad_func_name(self):
316    """Returns the name of the gradient function."""
317    return self._grad_func.name if self._grad_func else None
318
319  @property
320  def python_grad_func(self):
321    """Python gradient function callable."""
322    return self._python_grad_func
323
324  @property
325  def declared_input_types(self):
326    """Returns the list of data types of explicit declared inputs."""
327    return self._input_types
328
329  @property
330  def captured_inputs(self):
331    """Returns the list of implicitly captured inputs."""
332    self._create_definition_if_needed()
333    return self._extra_inputs
334
335  @property
336  def stateful_ops(self):
337    """Returns the list of stateful ops in function definition.
338
339    Returns:
340      A list of (op.name, op.type) pairs.
341    """
342    self._create_definition_if_needed()
343    return self._stateful_ops
344
345  def _create_definition_if_needed(self):
346    """Creates the function definition if it's not created yet."""
347    with context.graph_mode():
348      self._create_definition_if_needed_impl()
349
350  def _create_definition_if_needed_impl(self):
351    """This is not what you want, see _create_definition_if_needed."""
352    if self._definition is not None or self._c_func is not None:
353      return
354
355    temp_graph = func_graph_from_py_func(
356        self._func,
357        self._arg_names,
358        self._arg_types,
359        self._func_name,
360        self._capture_by_value,
361        self._caller_device,
362        whitelisted_stateful_ops=self._whitelisted_stateful_ops,
363        capture_resource_var_by_value=self._capture_resource_var_by_value)
364
365    self._extra_inputs = temp_graph.extra_inputs
366    # pylint: disable=protected-access
367    self._sub_functions = temp_graph._functions
368    # pylint: enable=protected-access
369
370    # Extra kwargs are treated as attrs on the function def.
371    if self._func_name:
372      base_func_name = self._func_name
373    else:
374      base_func_name = function_utils.get_func_name(self._func)
375      if self._grad_func:
376        base_func_name += ("_%s" % self._grad_func.name)
377    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
378
379    if not temp_graph._c_graph:  # pylint: disable=protected-access
380      # Build the FunctionDef
381      self._definition = graph_to_function_def.graph_to_function_def(
382          temp_graph,
383          temp_graph.get_operations(),
384          temp_graph.inputs,
385          temp_graph.outputs,
386          out_names=self._out_names)
387
388      for k in kwargs_attr:
389        self._definition.attr[k].CopyFrom(kwargs_attr[k])
390
391      # Hash the definition and its dependencies.
392      self._hash_str = self._create_hash_str(
393          self._definition.signature.input_arg,
394          self._definition.signature.output_arg, self._definition.node_def)
395
396      # Finally, we decide the function name to use.  If not specified,
397      # make up something which is almost certainly unique (but deterministic).
398      if not self._func_name:
399        self._func_name = "_".join([base_func_name, self._hash_str])
400      self._definition.signature.name = self._func_name
401      if self._func.__doc__:
402        self._definition.signature.description = self._func.__doc__
403
404      self._op_def = self._definition.signature
405    else:  # C API is enabled
406      output_names = ([compat.as_bytes(x) for x in self._out_names]
407                      if self._out_names else [])
408      description = self._func.__doc__ or None
409      # pylint: disable=protected-access
410      c_func = c_api.TF_GraphToFunction_wrapper(
411          temp_graph._c_graph,
412          base_func_name,
413          self._func_name is None,  # append_hash_to_fn_name
414          None,  # opers
415          [t._as_tf_output() for t in temp_graph.inputs],
416          [t._as_tf_output() for t in temp_graph.outputs],
417          output_names,
418          [], # control_outputs
419          [], # control_output_names
420          None,  # opts
421          description)
422      self._c_func = c_api_util.ScopedTFFunction(c_func)
423      # pylint: enable=protected-access
424      self._set_c_attrs(kwargs_attr)
425
426      # Set cached fields: _op_def and _func_name (if not already set)
427      self._op_def = self.definition.signature
428      if self._func_name:
429        assert self._func_name == self._op_def.name
430      else:
431        self._func_name = compat.as_str(self._op_def.name)
432
433    self._stateful_ops = [(op.name, op.type)
434                          for op in temp_graph.get_operations()
435                          if op.op_def.is_stateful]
436
437  def _set_c_attrs(self, attrs):
438    """Sets `attrs` as attributes of self._c_func.
439
440    Requires that self._c_func is not None.
441
442    Args:
443      attrs: a dictionary from attribute name to attribute proto value
444    """
445    for name, attr_value in attrs.items():
446      serialized = attr_value.SerializeToString()
447      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
448      # It might be worth creating a convenient way to re-use the same status.
449      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
450                                         serialized)
451
452  def _create_hash_str(self, input_arg, output_arg, node_def):
453    """Creates an 8-character string unique to this input.
454
455    Args:
456      input_arg: the input_arg field of an OpDef
457                 (e.g. self._definition.signature.input_arg)
458      output_arg: the output_arg field of an OpDef
459                 (e.g. self._definition.signature.output_arg)
460      node_def: the node_def field of a FunctionDef
461                (e.g. self._definition.node_def)
462
463    Returns:
464      The unique string for this input
465    """
466    hasher = hashlib.sha1()
467
468    def update_num(n):
469      hasher.update(compat.as_bytes("%x" % n))
470
471    def update_str(s):
472      update_num(len(s))
473      hasher.update(compat.as_bytes(s))
474
475    def update_strs(slist):
476      update_num(len(slist))
477      for s in slist:
478        update_str(s)
479
480    for adef in input_arg:
481      update_str(adef.SerializeToString())
482
483    for adef in output_arg:
484      update_str(adef.SerializeToString())
485
486    for n in sorted(node_def, key=lambda n: n.name):
487      update_str(n.name)
488      update_str(n.op)
489      update_strs(n.input)
490      update_num(len(n.attr))
491      # NOTE: protobuf map serialization does not guarantee ordering.
492      for k in sorted(n.attr):
493        update_str(k)
494        update_str(n.attr[k].SerializeToString())
495
496    return hasher.hexdigest()[:8]
497
498  def add_to_graph(self, g):
499    """Adds this function into the graph g."""
500    self._create_definition_if_needed()
501
502    # Adds this function into 'g'.
503    # pylint: disable=protected-access
504    if context.executing_eagerly():
505      context.context().add_function_def(self.definition)
506    else:
507      g._add_function(self)
508    # pylint: enable=protected-access
509
510    # Ensures related sub-routines are defined in 'g', too.
511    for f in self._sub_functions.values():
512      f.add_to_graph(g)
513
514    # Adds its gradient function, too.
515    if self._grad_func:
516      self._grad_func.add_to_graph(g)
517
518  def __call__(self, *args, **kwargs):
519    self.add_to_graph(ops.get_default_graph())
520    args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
521    ret, op = _call(self._signature, *args, **kwargs)
522
523    # Set a hidden attr in 'op' so that gradients_impl can refer back
524    # to this _DefinedFunction instance to access python_grad_func.
525    assert isinstance(op, ops.Operation)
526    setattr(op, "__defun", self)
527
528    if self._shape_func is not None:
529      shapes = self._shape_func(op)
530      if len(shapes) != len(op.outputs):
531        raise ValueError("shape_func produced %d shapes for %d outputs" %
532                         (len(shapes), len(op.outputs)))
533      for (t, shape) in zip(op.outputs, shapes):
534        t.set_shape(shape)
535    return ret
536
537
538class _OverloadedFunction(object):
539  """_OverloadedFunction encapsulates an overloaded function.
540
541  _OverloadedFunction maintains a mapping from input types to
542  instantiated _DefinedFunction in self._overload.
543
544  """
545
546  def __init__(self,
547               func,
548               argnames,
549               func_name=None,
550               grad_func=None,
551               python_grad_func=None,
552               out_names=None,
553               **kwargs):
554    """Creates _DefinedFunction.
555
556    Args:
557      func:  A python callable which constructs a tf function body.
558      argnames: A list of strings for function argument names.
559      func_name: The function name. Defaults to None, in which derives from
560        'func'.
561      grad_func: This function's gradient function, if not None. Defaults
562        to None.
563      python_grad_func: A python callable implementing the gradient of
564        the function python-side.
565      out_names: A list of strings for the function return value names.
566      **kwargs: The keyword arguments. **kwargs is passed to every call
567        site of this function.
568
569    Raises:
570      ValueError: The function definition is invalid.
571
572    """
573    self._func = func
574    self._argnames = argnames
575    self._func_name = func_name
576    assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
577    self._grad_func = grad_func
578    self._python_grad_func = python_grad_func
579    self._out_names = out_names
580    self._extra_kwargs = kwargs
581    self._overload = {}
582
583  def instantiate(self, input_types):
584    """Instantiate this function given input argument types.
585
586    Args:
587      input_types: A list of data types for the inputs.
588
589    Returns:
590      _DefinedFunction for the given input types.
591
592    """
593    # Stringify the type list.
594    key = _type_list_to_str(input_types)
595    defined = self._overload.get(key)
596    if not defined:
597      # If not defined yet, define the function given the input types.
598      name = self._func_name
599      if name is not None:
600        name = "_".join([name, key])
601      defined = _DefinedFunction(
602          self._func,
603          self._argnames,
604          input_types,
605          name,
606          None,
607          self._python_grad_func,
608          out_names=self._out_names,
609          **self._extra_kwargs)
610      _ = defined.name  # Fully instantiate the function definition.
611      if self._grad_func:
612        # If _grad_func is given, it is another
613        # _OverloadedFunction. We need to instantiate it with the
614        # right input types.
615        output_types = [
616            dtypes.DType(_.type) for _ in defined._signature.output_arg  # pylint: disable=protected-access
617        ]
618        # pylint: disable=protected-access
619        defined._grad_func = self._grad_func.instantiate(input_types +
620                                                         output_types)
621        # pylint: enable=protected-access
622      self._overload[key] = defined
623    return defined
624
625  def __call__(self, *args, **kwargs):
626    input_types = []
627    args = list(args)
628    for (i, x) in enumerate(args):
629      x = ops.convert_to_tensor(x)
630      if not isinstance(x, ops.Tensor):
631        raise ValueError("Expect a Tensor but get ", x)
632      input_types.append(x.dtype)
633      args[i] = x
634    return self.instantiate(input_types)(*args, **kwargs)
635
636
637class _FuncGraph(ops.Graph):
638  """A helper for constructing a function.
639
640  _FuncGraph overrides ops.Graph's create_op() so that we can keep
641  track of all inputs into every op created inside the function.  If
642  any input is from other graphs, we keep track of it in self.capture
643  and substitute the input with a place holder.
644
645  Each captured input's corresponding place holder is converted into a
646  function argument and the caller passes in the captured tensor.
647  """
648
649  def __init__(self, name, capture_by_value, whitelisted_stateful_ops,
650               capture_resource_var_by_value, *args, **kwargs):
651    super(_FuncGraph, self).__init__(*args, **kwargs)
652    self._capture_by_value = capture_by_value
653    self._whitelisted_stateful_ops = whitelisted_stateful_ops
654    self._capture_resource_var_by_value = capture_resource_var_by_value
655    self._building_function = True
656    self._outer_graph = ops.get_default_graph()
657    self._vscope = vs.get_variable_scope()
658    self._old_custom_getter = self._vscope.custom_getter
659
660    # The name of the function.
661    self.name = name
662    # Placeholder tensors representing the inputs to this function. The tensors
663    # are in this _FuncGraph.
664    self.inputs = []
665    # Tensors that will be returned this function. The tensors are in this
666    # _FuncGraph.
667    self.outputs = []
668    # Maps external tensor -> internal tensor (e.g. input placeholder).
669    self._captured = {}
670    # The external tensors that have been captured as inputs and must be passed
671    # to this function (empty if capturing by value, otherwise these are the
672    # keys of _captured).
673    self.extra_inputs = []
674    # Input placeholders that been added for captured values (empty if capturing
675    # by value).
676    self.extra_args = []
677    # Captured variables.
678    # TODO(skyewm): is this needed?
679    self.extra_vars = []
680
681  # pylint: disable=g-doc-return-or-yield
682
683  @tf_contextlib.contextmanager
684  def container(self, container_name):
685    """Returns a context manager that specifies the resource container to use.
686
687    Overridden from `tf.Graph` to update both the init_scope container
688    and the present inner container. This is necessary to make sure setting
689    containers applies correctly both to created variables and to stateful
690    ops.
691
692    Args:
693      container_name: container name string.
694
695    Returns:
696      A context manager for defining resource containers for stateful ops,
697        yields the container name.
698    """
699    original_container = self._container
700    # pylint: disable=protected-access
701    with ops.init_scope():
702      original_init_container = ops.get_default_graph()._container
703    try:
704      self._container = container_name
705      with ops.init_scope():
706        ops.get_default_graph()._container = container_name
707      yield self._container
708    finally:
709      self._container = original_container
710      with ops.init_scope():
711        ops.get_default_graph()._container = original_init_container
712    # pylint: enable=protected-access
713
714  # pylint: enable=g-doc-return-or-yield
715
716  def getvar(
717      self,
718      getter,
719      name,
720      shape=None,
721      dtype=None,
722      initializer=None,
723      reuse=None,
724      trainable=True,
725      collections=None,  # pylint: disable=redefined-outer-name
726      use_resource=None,
727      **kwargs):
728    """A custom variable getter."""
729    # Here, we switch the default graph to the outer graph and ask the
730    # variable scope in which the function is defined to give us the
731    # variable. The variable is stashed in extra_vars and returned to
732    # the caller.
733    #
734    # We capture these variables so that the variable definition is
735    # hoisted upward to the outer most graph.
736    with self._outer_graph.as_default():
737      # pylint: disable=protected-access
738      var = self._vscope.get_variable(
739          vs._get_default_variable_store(),
740          name,
741          shape=shape,
742          dtype=dtype,
743          initializer=initializer,
744          reuse=reuse,
745          trainable=trainable,
746          collections=collections,
747          use_resource=use_resource)
748      self.extra_vars.append(var)
749      if (isinstance(var, resource_variable_ops.ResourceVariable) and
750          self._capture_resource_var_by_value):
751        # For resource-based variables read the variable outside the function
752        # and pass in the value. This ensures that the function is pure and
753        # differentiable. TODO(apassos) this may have performance problems if
754        # the function will only do embedding lookups on the variable.
755        return var.value()
756      return var
757
758  def create_op(self, op_type, inputs, dtypes=None, **kwargs):  # pylint: disable=redefined-outer-name
759    for i, x in enumerate(inputs):
760      if isinstance(x, ops.EagerTensor) or x.graph is not self:
761        inputs[i] = self.capture(x)
762    return super(_FuncGraph, self).create_op(op_type, inputs,
763                                             dtypes=dtypes, **kwargs)
764
765  def capture(self, tensor, name=None):
766    """Adds the given tensor to this graph and returns the captured tensor."""
767    if tensor in self._captured:
768      # Captured already.
769      return self._captured[tensor]
770    elif self._capture_by_value:
771      return self._add_tensor_and_parents(tensor)
772    else:
773      return self._capture_tensor_as_extra_input(tensor, name)
774
775  def _capture_tensor_as_extra_input(self, tensor, name=None):
776    # Substitute with a placeholder.
777    self.extra_inputs.append(tensor)
778    # Hoist the new input placeholder out of any control flow context
779    # we're currently in.
780    with ops.control_dependencies(None):
781      ph = array_ops.placeholder(
782          tensor.dtype, shape=tensor.get_shape(), name=name)
783    # pylint: disable=protected-access
784    if isinstance(tensor, ops.EagerTensor):
785      handle_data = tensor._handle_data
786      if handle_data:
787        handle_data = handle_data.SerializeToString()
788    else:
789      handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
790                                                tensor._as_tf_output())
791
792    if handle_data:
793      c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
794                                  compat.as_bytes(handle_data))
795    # pylint: enable=protected-access
796    self.inputs.append(ph)
797    self._captured[tensor] = ph
798    self.extra_args.append(ph)
799    if _is_guaranteed_const(tensor):
800      with ops.control_dependencies(None):
801        return array_ops.guarantee_const(ph)
802    else:
803      return ph
804
805  def _add_tensor_and_parents(self, tensor):
806    op = self._add_op_and_parents(tensor.op)
807    return op.outputs[tensor.value_index]
808
809  def _add_op_and_parents(self, op):
810    # pylint: disable=protected-access
811    op_def = graph_to_function_def._get_op_def(op)
812    # pylint: enable=protected-access
813    if op_def.is_stateful and op not in self._whitelisted_stateful_ops:
814      raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
815                       "by value." % (op.name, op.type))
816    elif op.type in ("Placeholder", "PlaceholderV2"):
817      raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
818                       "by value." % (op.name, op.type))
819
820    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
821
822    captured_op = self.create_op(
823        op.type,
824        captured_inputs, [o.dtype for o in op.outputs],
825        name=op.name,
826        attrs=op.node_def.attr,
827        op_def=op_def)
828
829    for t, captured_t in zip(op.outputs, captured_op.outputs):
830      self._captured[t] = captured_t
831
832    return captured_op
833
834
835def func_graph_from_py_func(func,
836                            arg_names,
837                            arg_types,
838                            name=None,
839                            capture_by_value=False,
840                            device=None,
841                            colocation_stack=None,
842                            container=None,
843                            collections_ref=None,
844                            arg_shapes=None,
845                            whitelisted_stateful_ops=None,
846                            capture_resource_var_by_value=True):
847  """Returns a _FuncGraph generated from `func`.
848
849  Args:
850    func: A Python callable which constructs a TF function body. The arguments
851      must correspond to `arg_types`. Returns a value or list/tuple of values.
852      No returned value can be None.
853    arg_names: A sequence of strings for the function argument names.
854    arg_types: A sequence of the function's argument types.
855    name: The function name. If None, the name is derived from `func`.
856    capture_by_value: boolean. If True, captured values will be copied into the
857      function body.
858    device: device name or function.
859    colocation_stack: A colocation stack (list) the _FuncGraph should use.
860    container: A container name the _FuncGraph should start with.
861    collections_ref: A reference to a collections dict the _FuncGraph should
862      use internally.
863    arg_shapes: A sequence of the function's argument shapes.
864    whitelisted_stateful_ops: A set of ops that if stateful we ignore and
865      re-create.
866    capture_resource_var_by_value: Boolean (defaults to True). If False,
867      captured resource variable returns the handle instead of value.
868
869  Returns:
870    A _FuncGraph.
871
872  Raises:
873    ValueError: if func returns None.
874  """
875  if not name:
876    name = function_utils.get_func_name(func)
877  func_graph = _FuncGraph(name, capture_by_value, whitelisted_stateful_ops,
878                          capture_resource_var_by_value)
879
880  with func_graph.as_default(), ops.device(device):
881    # pylint: disable=protected-access
882    if collections_ref is not None:
883      func_graph._collections = collections_ref
884    if container is not None:
885      func_graph._container = container
886    if colocation_stack is not None:
887      func_graph._colocation_stack = colocation_stack
888    # pylint: enable=protected-access
889
890    if arg_shapes is None:
891      arg_shapes = [None] * len(arg_types)
892
893    # Create placeholders for the function arguments.
894    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
895      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
896      func_graph.inputs.append(argholder)
897    # Call func and gather the output tensors.
898    with vs.variable_scope("", custom_getter=func_graph.getvar):
899      outputs = func(*func_graph.inputs)
900
901    # There is no way of distinguishing between a function not returning
902    # anything and a function returning None in Python.
903    # We need to allow the former and ideally want to forbid the latter as
904    # it is most likely user error.
905    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
906    # allow users to explicitly mark the function as not returning anything.
907    # For now, we allow a single None return and interpret it as a function
908    # with no output.
909    if outputs is None:
910      outputs = []
911    else:
912      # If func only returned one value, make it a tuple.
913      if not isinstance(outputs, (list, tuple)):
914        outputs = (outputs,)
915      if any(_ is None for _ in outputs):
916        raise ValueError("Function %s can not return None." % name)
917    # Ensures each output is a Tensor in the function graph.
918    outputs = [ops.convert_to_tensor(t) for t in outputs]
919    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
920               for t in outputs]
921    func_graph.outputs = outputs
922  return func_graph
923
924
925def _is_guaranteed_const(tensor):
926  """Determines whether `tensor` is guaranteed to be a constant.
927
928  A tensor is guaranteed to be a constant if either it was produced by
929  a `GuaranteeConst` op or if all of its children are guaranteed to be
930  constants.
931
932  Args:
933    tensor: The tensor for which to determine const-ness.
934
935  Returns:
936    True if `tensor` is guaranteed to be a constant, False otherwise.
937  """
938
939  if isinstance(tensor, ops.EagerTensor):
940    return False
941
942  class Work(object):
943
944    def __init__(self, op, leaving):
945      self.op = op
946      self.leaving = leaving
947
948  is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
949  constants = set([])
950  def all_inputs_const(op):
951    # If all inputs of an op are guaranteed constants, then we can infer that
952    # the op produces a constant as well.
953    return op.inputs and all(inp.op in constants for inp in op.inputs)
954
955  visited = set([])
956  stack = [Work(tensor.op, leaving=False)]
957  while stack:
958    work = stack.pop()
959    if work.leaving:
960      if all_inputs_const(work.op):
961        constants.add(work.op)
962      continue
963    visited.add(work.op)
964    if is_guaranteed_const(work.op):
965      constants.add(work.op)
966      continue
967
968    # This op will be revisited after all its inputs are checked for const-ness.
969    stack.append(Work(work.op, leaving=True))
970    for inp in work.op.inputs:
971      if inp.op not in visited:
972        stack.append(Work(inp.op, leaving=False))
973  return tensor.op in constants
974
975
976def _call(sig, *inputs, **kwargs):
977  """Adds a node calling a function.
978
979  This adds a `call` op to the default graph that calls the function
980  of signature `sig`, passing the tensors in `inputs` as arguments.
981  It returns the outputs of the call, which are one or more tensors.
982
983  `sig` is OpDefArg.a `_DefinedFunction` object.
984
985  You can pass an optional keyword parameter `name=string` to name the
986  added operation.
987
988  You can pass an optional keyword parameter `noinline=True|False` to
989  instruct the runtime not to inline the function body into the call
990  site.
991
992  Args:
993    sig: OpDefArg. The signature of the function.
994    *inputs: arguments to the function.
995    **kwargs: Optional keyword arguments.  Can only contain 'name' or
996        'noinline'.
997
998  Returns:
999     A 2-element tuple. First element: a Tensor if the function returns a single
1000     value; a list of Tensors if the function returns multiple value; the
1001     Operation if the function returns no values. Second element: the Operation.
1002
1003  Raises:
1004    ValueError: if the arguments are invalid.
1005  """
1006  if len(inputs) != len(sig.input_arg):
1007    raise ValueError("Expected number of arguments: %d, received: %d" % (len(
1008        sig.input_arg), len(inputs)))
1009  name = kwargs.pop("name", None)
1010  g = ops.get_default_graph()
1011  func_name = sig.name
1012  if name is None:
1013    name = func_name
1014  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
1015  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
1016  op = g.create_op(
1017      func_name,
1018      list(inputs),
1019      output_types,
1020      name=name,
1021      attrs=attrs,
1022      op_def=sig,
1023      compute_shapes=False)
1024  if op.outputs:
1025    if len(op.outputs) == 1:
1026      ret = op.outputs[0]
1027    else:
1028      ret = tuple(op.outputs)
1029  else:
1030    ret = op
1031  return ret, op
1032
1033
1034def _from_definition(fdef, grad_func=None):
1035  """Creates a _DefinedFunction initialized from a FunctionDef proto.
1036
1037  Args:
1038    fdef: a FunctionDef
1039    grad_func: a _DefinedFunction or None
1040
1041  Returns:
1042    A _DefinedFunction representing fdef
1043  """
1044  # TODO(iga): This method does major surgery on _DefinedFunction.
1045  # Make it a named constructor using @classmethod of _DefinedFunction.
1046
1047  # The Python callable is only needed to create a FunctionDef. Since we have
1048  # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
1049  # have access to such a callable here).
1050  func = None
1051  argnames = [arg.name for arg in fdef.signature.input_arg]
1052  input_types = tuple(
1053      dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
1054  func_name = fdef.signature.name
1055  # Note: FunctionDefs do not include python gradient functions, so if the
1056  # original _DefinedFunction included one it will not be reflected here.
1057  python_grad_func = None
1058  out_names = [arg.name for arg in fdef.signature.output_arg]
1059  result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
1060                            python_grad_func, out_names)
1061  # pylint: disable=protected-access
1062  serialized = fdef.SerializeToString()
1063  c_func = c_api.TF_FunctionImportFunctionDef(serialized)
1064  result._c_func = c_api_util.ScopedTFFunction(c_func)
1065  result._extra_inputs = []
1066  result._op_def = fdef.signature
1067  # pylint: enable=protected-access
1068
1069  return result
1070
1071
1072def from_library(lib):
1073  """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
1074
1075  This method handles assigning the correct gradient functions to each
1076  function.
1077
1078  Args:
1079    lib: a FunctionDefLibrary
1080
1081  Returns:
1082    A list of _DefinedFunctions
1083
1084  Raises:
1085    ValueError: `lib` is invalid
1086  """
1087  if not lib.function and not lib.gradient:
1088    return []
1089
1090  # function name -> FunctionDef proto
1091  funcs = {fdef.signature.name: fdef for fdef in lib.function}
1092
1093  # Validate that all references function names have function defs
1094  for g in lib.gradient:
1095    if g.function_name not in funcs:
1096      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
1097                       (g.function_name, str(lib)))
1098    if g.gradient_func not in funcs:
1099      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
1100                       (g.gradient_func, str(lib)))
1101
1102  # function name -> gradient function name
1103  func_to_grad = collections.defaultdict(lambda: None)
1104  # gradient function name -> names of functions having that grad function
1105  grad_to_funcs = collections.defaultdict(list)
1106
1107  for gdef in lib.gradient:
1108    func_to_grad[gdef.function_name] = gdef.gradient_func
1109    grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
1110
1111  # Start with functions without gradients
1112  ready = [
1113      fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
1114  ]
1115  if not ready:
1116    raise ValueError(
1117        "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib))
1118  # function name -> _DefinedFunction
1119  initialized = {}
1120
1121  while ready:
1122    fdef = ready.pop()
1123    name = fdef.signature.name
1124
1125    grad = initialized.get(func_to_grad[name])
1126    if func_to_grad[name]:
1127      assert grad
1128    defined_func = _from_definition(fdef, grad_func=grad)
1129    initialized[name] = defined_func
1130
1131    ready.extend(funcs[f] for f in grad_to_funcs[name])
1132
1133  return initialized.values()
1134
1135
1136def _get_experimental_kwarg_as_attr(attr_name, value):
1137  """Creates an AttrValue for a python object."""
1138  if isinstance(value, bool):
1139    return attr_value_pb2.AttrValue(b=value)
1140  elif isinstance(value, int):
1141    return attr_value_pb2.AttrValue(i=value)
1142  elif isinstance(value, float):
1143    return attr_value_pb2.AttrValue(f=value)
1144  elif isinstance(value, str):
1145    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1146  else:
1147    raise ValueError("Unsupported attribute type for %s with type %s" %
1148                     (attr_name, type(value)))
1149
1150
1151def _parse_kwargs_as_attrs(func_name, **kwargs):
1152  """Parses **kwargs into a node's attributes."""
1153  attrs = {}
1154
1155  noinline = kwargs.pop("noinline", None)
1156  if noinline is not None:
1157    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
1158
1159  compiled = kwargs.pop("compiled", None)
1160  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
1161  if compiled is not None:
1162    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
1163    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
1164        b=bool(separate_compiled_gradients))
1165    # Forward _XlaScope from enclosing context (if set), otherwise create new.
1166    # pylint: disable=protected-access
1167    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
1168      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
1169    else:
1170      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
1171          s=("function_%s" % func_name).encode())
1172    # pylint: enable=protected-access
1173
1174  kwargs_keys = list(kwargs.keys())
1175  for key in kwargs_keys:
1176    if key.startswith("experimental_"):
1177      attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
1178      del kwargs[key]
1179
1180  if kwargs:
1181    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
1182  return attrs
1183
1184
1185def get_extra_vars():
1186  """Returns the captured variables by the function.
1187
1188  Returns:
1189    If the default graph is being used to define a function, the
1190    returned list of variables are those created inside the function
1191    body so far. Otherwise, returns an empty list.
1192  """
1193  g = ops.get_default_graph()
1194  if isinstance(g, _FuncGraph):
1195    return g.extra_vars
1196  else:
1197    return []
1198
1199
1200def get_extra_inputs():
1201  """Returns the captured input tensors by the function.
1202
1203  Returns:
1204    If the default graph is being used to define a function, the
1205    returned list of tensors are those accessed inside the function body
1206    but defined outside the function body so far. Otherwise, returns an
1207    empty list.
1208  """
1209  g = ops.get_default_graph()
1210  if isinstance(g, _FuncGraph):
1211    return g.extra_inputs
1212  else:
1213    return []
1214
1215
1216def get_extra_args():
1217  """Returns the corresponding function arguments for the captured inputs.
1218
1219  Returns:
1220    If the default graph is being used to define a function, the
1221    returned list of place holders are those used inside the function
1222    body corresponding those returned by get_extra_inputs(). Otherwise,
1223    returns an empty list.
1224  """
1225  g = ops.get_default_graph()
1226  if isinstance(g, _FuncGraph):
1227    return g.extra_args
1228  else:
1229    return []
1230
1231
1232def _type_list_to_str(types):
1233  if any(_ not in _DTYPE_TO_STR for _ in types):
1234    raise ValueError("Unsupported dtypes: %s" % types)
1235  return "".join([_DTYPE_TO_STR[_] for _ in types])
1236
1237
1238# NOTE: The list needs to be extended when more data types are added.
1239_DTYPE_TO_STR = {
1240    dtypes.float16: "f16",
1241    dtypes.float32: "f32",
1242    dtypes.float64: "f64",
1243    dtypes.int32: "i32",
1244    dtypes.uint8: "i8",
1245    dtypes.uint16: "u16",
1246    dtypes.uint32: "u32",
1247    dtypes.uint64: "u64",
1248    dtypes.int16: "i16",
1249    dtypes.int8: "i8",
1250    dtypes.string: "s",
1251    dtypes.complex64: "c64",
1252    dtypes.complex128: "c128",
1253    dtypes.int64: "i64",
1254    dtypes.bool: "b",
1255    dtypes.qint8: "qi8",
1256    dtypes.quint8: "qu8",
1257    dtypes.qint16: "qi16",
1258    dtypes.quint16: "qu16",
1259    dtypes.qint32: "qi32",
1260    dtypes.bfloat16: "b16"
1261}
1262
1263
1264def function_def_from_tf_function(c_func):
1265  """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
1266  with c_api_util.tf_buffer() as buf:
1267    c_api.TF_FunctionToFunctionDef(c_func, buf)
1268    data = c_api.TF_GetBuffer(buf)
1269  fdef = function_pb2.FunctionDef()
1270  fdef.ParseFromString(compat.as_bytes(data))
1271  return fdef
1272