• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""API for defining graph functions with some additional eager semantics.
17
18def_function.function wraps the function concept in function.py ("defun") to
19allow initializing `tf.Variable`s with subgraphs of the function. For example:
20
21```python
22class M(tf.Module):
23  def __init__(self):
24    self.v_opinit = None
25    self.v_arginit = None
26
27  @tf.function
28  def __call__(self, x):
29    # Variables are only created on the first call to the function. This is a
30    # common pattern in layer libraries.
31    if self.v_opinit is None:
32      # self.v_opinit will outlive the function call, but `tf.ones` is traced as
33      # part of the function body before the `tf.Variable` object is
34      # created. This subgraph is easy to lift out of the function.
35      self.v_opinit = tf.Variable(tf.ones([]))
36
37      # If arguments feed into variable initialization, it can be very tricky to
38      # disentangle from the rest of the function. We don't attempt it.
39      self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.))
40    return self.v_opinit + self.v_arginit + x
41```
42
43These patterns with "defun" throw an error asking the user to put the variable's
44initializer in a lambda. With tf.function they work with eager semantics either
45by lifting the subgraph out of the function and using it to initialize the
46variable, or by initializing variables on the first call to the function (if
47they weren't already initialized by something else, e.g. a checkpoint API). The
48latter requires tf.conds, and is not well supported by TF-XLA, so we only do it
49when necessary.
50
51Since these patterns are relatively common in layer libraries, we expose the
52wrapper in this file as `tf.function`. The function concept in function.py is an
53internal implementation detail.
54
55In order to support these variable initialization patterns, tf.function defines
56a variable subtype (UnliftedInitializerVariable) which collects the input
57subgraph. This type of variable replaces the regular variable type on the first
58tf.function trace. To exclude initializers from the function body (the `tf.ones`
59ops above and associated assignment operations), tf.function traces a second
60time if it sees variables on the first call.
61"""
62
63from __future__ import absolute_import
64from __future__ import division
65from __future__ import print_function
66
67import functools
68import threading
69import weakref
70import six
71
72from google.protobuf import text_format as _text_format
73from google.protobuf.message import DecodeError
74from tensorflow.core.framework import attr_value_pb2
75from tensorflow.python.distribute.parallel_device import parallel_device
76from tensorflow.python.eager import context
77from tensorflow.python.eager import function as function_lib
78from tensorflow.python.eager import lift_to_graph
79from tensorflow.python.eager import monitoring
80from tensorflow.python.framework import errors
81from tensorflow.python.framework import func_graph as func_graph_module
82from tensorflow.python.framework import ops
83from tensorflow.python.ops import array_ops
84from tensorflow.python.ops import control_flow_ops
85from tensorflow.python.ops import control_flow_util
86from tensorflow.python.ops import math_ops
87from tensorflow.python.ops import random_ops
88from tensorflow.python.ops import resource_variable_ops
89from tensorflow.python.platform import tf_logging as logging
90from tensorflow.python.profiler import trace
91from tensorflow.python.training.tracking import base as trackable
92from tensorflow.python.types import core
93from tensorflow.python.util import deprecation
94from tensorflow.python.util import nest
95from tensorflow.python.util import object_identity
96from tensorflow.python.util import tf_decorator
97from tensorflow.python.util import traceback_utils
98from tensorflow.python.util.tf_export import tf_export
99
100FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
101FREQUENT_TRACING_WARNING_THRESHOLD = 5
102FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
103ALLOW_DYNAMIC_VARIABLE_CREATION = False
104
105_tf_function_counter = monitoring.Counter(
106    "/tensorflow/core/tf_function_counter",
107    "Counter for the number of tf.functions created when Eager execution is "
108    "enabled.",
109    # jit_compile is "0" or "1".
110    "jit_compile")
111
112
113class _FrequentTracingDetector(object):
114  """Class keeping track of how many recent calls triggered tracing."""
115
116  __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
117
118  def __init__(self):
119    self._calls_per_tracings = []
120    self._total_warning_count = 0
121    self._call_count = 0
122
123  def called_with_tracing(self, function_name, omit_warning):
124    """Updates the list of most recent calls' tracing information.
125
126    Warns the user when recent calls caused retracing too often.
127
128    Args:
129      function_name: the python function being traced.
130      omit_warning: If 'True', this call will not warn the user even if
131        retracing happens too often.
132    """
133    self._call_count += 1
134    self._calls_per_tracings.append(1)
135
136    while self._calls_per_tracings:
137      if (self._call_count - self._calls_per_tracings[0] >
138          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
139        self._call_count -= self._calls_per_tracings.pop(0)
140      else:
141        break
142
143    if (omit_warning or self._total_warning_count >=
144        FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
145      return
146    if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
147      self._total_warning_count += 1
148      logging.warning(
149          "{} out of the last {} calls to {} triggered tf.function "
150          "retracing. Tracing is expensive and the excessive number of "
151          "tracings could be due to (1) creating @tf.function repeatedly in "
152          "a loop, (2) passing tensors with different shapes, (3) passing "
153          "Python objects instead of tensors. For (1), please define your "
154          "@tf.function outside of the loop. For (2), @tf.function has "
155          "experimental_relax_shapes=True option that relaxes argument "
156          "shapes that can avoid unnecessary retracing. For (3), please "
157          "refer to "
158          "https://www.tensorflow.org/guide/function#controlling_retracing"
159          " and https://www.tensorflow.org/api_docs/python/tf/function for "
160          " more details.".format(
161              len(self._calls_per_tracings), self._call_count, function_name))
162
163  def called_without_tracing(self):
164    # We don't count tracing when users load a concrete function directly or
165    # call get_concrete_function, so the first call can be not a tracing call.
166    if not self._calls_per_tracings:
167      self._calls_per_tracings = [0]
168    self._calls_per_tracings[-1] += 1
169    self._call_count += 1
170
171
172class _FrequentTracingDetectorManager(object):
173  """Class for the management of all _FrequentTracingDetector objects."""
174
175  __slots__ = ["_detectors", "_lock"]
176
177  def __init__(self):
178    self._detectors = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
179    self._lock = threading.Lock()
180
181  def _get_detector(self, key):
182    if key not in self._detectors:
183      self._detectors[key] = _FrequentTracingDetector()
184    return self._detectors[key]
185
186  def called_without_tracing(self, key):
187    with self._lock:
188      detector = self._get_detector(key)
189      detector.called_without_tracing()
190
191  def called_with_tracing(self, key, function_name, omit_warning):
192    with self._lock:
193      detector = self._get_detector(key)
194      detector.called_with_tracing(function_name, omit_warning)
195
196
197_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
198
199
200class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
201  """Variable which does not lift its initializer out of function context.
202
203  Instances of this variable, when created, build a graph which runs their
204  initializer inside a tf.cond(is_initialized) block.
205
206  This can only be created inside a defun called from (eventually) eager
207  mode. That is, non-function-building graphs are not supported.
208  """
209
210  def __init__(self,
211               initial_value=None,
212               trainable=None,
213               caching_device=None,
214               name=None,
215               dtype=None,
216               constraint=None,
217               add_initializers_to=None,
218               lifted_initializer_graph=None,
219               synchronization=None,
220               aggregation=None,
221               shape=None,
222               **unused_kwargs):
223    """Creates a variable.
224
225    Args:
226      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
227        which is the initial value for the Variable. The initial value must have
228        a shape specified unless `validate_shape` is set to False. Can also be a
229        callable with no argument that returns the initial value when called.
230        (Note that initializer functions from init_ops.py must first be bound
231         to a shape before being used here.)
232      trainable: If `True`, GradientTapes automatically watch uses of this
233        Variable.
234      caching_device: Optional device string or function describing where the
235        Variable should be cached for reading.  Defaults to the Variable's
236        device.  If not `None`, caches on another device.  Typical use is to
237        cache on the device where the Ops using the Variable reside, to
238        deduplicate copying through `Switch` and other conditional statements.
239      name: Optional name for the variable. Defaults to `'Variable'` and gets
240        uniquified automatically.
241      dtype: If set, initial_value will be converted to the given type.
242        If None, either the datatype will be kept (if initial_value is
243       a Tensor) or float32 will be used (if it is a Python object convertible
244       to a Tensor).
245      constraint: An optional projection function to be applied to the variable
246        after being updated by an `Optimizer` (e.g. used to implement norm
247        constraints or value constraints for layer weights). The function must
248        take as input the unprojected Tensor representing the value of the
249        variable and return the Tensor for the projected value
250        (which must have the same shape). Constraints are not safe to
251        use when doing asynchronous distributed training.
252      add_initializers_to: if not None and not in legacy graph mode, the
253        initializer tensor will be added to this map in addition to adding the
254        assignment to the function.
255      lifted_initializer_graph: FuncGraph to try to lift initializers to.
256      synchronization: Indicates when a distributed a variable will be
257        aggregated. Accepted values are constants defined in the class
258        `tf.VariableSynchronization`. By default the synchronization is set to
259        `AUTO` and the current `DistributionStrategy` chooses
260        when to synchronize.
261      aggregation: Indicates how a distributed variable will be aggregated.
262        Accepted values are constants defined in the class
263        `tf.VariableAggregation`.
264      shape: (optional) The shape of this variable. If None, the shape of
265        `initial_value` will be used. When setting this argument to
266        `tf.TensorShape(None)` (representing an unspecified shape), the variable
267        can be assigned with values of different shapes.
268
269    Raises:
270      ValueError: If the initial value is not specified, or does not have a
271        shape and `validate_shape` is `True`.
272      RuntimeError: If called outside of a function definition.
273    """
274    with ops.init_scope():
275      self._in_graph_mode = not context.executing_eagerly()
276    if not ops.inside_function():
277      # If we've been init_scope()d out of the function definition nothing to do
278      # here; we can't really do the capturing or conditional logic.
279      resource_variable_ops.ResourceVariable.__init__(
280          self, initial_value=initial_value, trainable=trainable,
281          caching_device=caching_device, name=name, dtype=dtype,
282          constraint=constraint)
283      return
284    if initial_value is None:
285      raise ValueError("`initial_value` must be a Tensor or a Python "
286                       "object convertible to a Tensor. Got None.")
287    init_from_fn = callable(initial_value)
288
289    if constraint is not None and not callable(constraint):
290      raise ValueError(f"`constraint` with type {type(constraint)} must be a "
291                       "callable.")
292
293    with ops.name_scope(name, "Variable", []
294                        if init_from_fn else [initial_value]) as scope_name:
295      with ops.name_scope("Initializer"):
296        if init_from_fn:
297          initial_value = initial_value()
298        if isinstance(initial_value, trackable.CheckpointInitialValue):
299          self._maybe_initialize_trackable()
300          self._update_uid = initial_value.checkpoint_position.restore_uid
301          initial_value = initial_value.wrapped_value
302
303        initial_value = ops.convert_to_tensor(initial_value,
304                                              name="initial_value", dtype=dtype)
305      assert initial_value is not None
306
307      # Don't use `shape or initial_value.shape` since TensorShape has
308      # overridden `__bool__`.
309      if shape is None:
310        shape = initial_value.shape
311
312    # Use the constructor for UninitializedVariable to start. Outside the name
313    # scope so we don't double up the prefix.
314    super(UnliftedInitializerVariable, self).__init__(
315        trainable=trainable,
316        caching_device=caching_device,
317        name=name,
318        shape=shape,
319        dtype=initial_value.dtype,
320        constraint=constraint,
321        synchronization=synchronization,
322        aggregation=aggregation,
323        extra_handle_data=initial_value,
324        **unused_kwargs)
325
326    with ops.name_scope(scope_name):
327      if self._in_graph_mode:
328        with ops.init_scope():
329          outer_graph = ops.get_default_graph()
330        func_graph = ops.get_default_graph()
331        function_placeholders = (
332            func_graph.inputs + func_graph.internal_captures)
333        placeholder_ops = set(
334            [tensor.op for tensor in function_placeholders])
335        lifted_initializer = lift_to_graph.lift_to_graph(
336            [initial_value], outer_graph,
337            disallowed_placeholders=placeholder_ops)[initial_value]
338        with ops.init_scope():
339          self._initial_value = lifted_initializer
340          with ops.name_scope("IsInitialized"):
341            self._is_initialized_op = (
342                resource_variable_ops.var_is_initialized_op(self._handle))
343          if initial_value is not None:
344            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
345              self._initializer_op = resource_variable_ops.assign_variable_op(
346                  self._handle, lifted_initializer, name=n)
347      elif context.executing_eagerly():
348        # In this case, both current scope and init scope are eager.
349        # Assign_variable_op will be executed immediately. So we don't need to
350        # add it to "add_initializers_to" to lift it out.
351        with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
352          resource_variable_ops.assign_variable_op(
353              self._handle, initial_value, name=n)
354      else:
355        # Init scope is eager but current scope is graph. We will lift out this
356        # variable by addint it into "add_initializers_to".
357        if add_initializers_to is not None:
358          add_initializers_to.append((self, initial_value))
359
360        def assign_fn():
361          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
362            resource_variable_ops.assign_variable_op(
363                self._handle,
364                initial_value,
365                name=n)
366            # Returning values to keep tf.cond happy.
367          return ops.convert_to_tensor(1)
368        def not_assign_fn():
369          return ops.convert_to_tensor(0)
370        # Note: this cond is always guaranteed to run because we're inside a
371        # defun which will insert automatic control dependencies. It will only
372        # execute assign_fn if lifting failed.
373        graph = ops.get_default_graph()
374
375        # Capture the handle ahead of time in order to avoid querying the shape
376        # of the handle which helps async execution performance
377        graph.capture(self._handle, shape=())
378        control_flow_ops.cond(
379            resource_variable_ops.var_is_initialized_op(self._handle),
380            not_assign_fn, assign_fn)
381
382
383RUN_FUNCTIONS_EAGERLY = False
384
385
386@deprecation.deprecated(
387    None,
388    "Use `tf.config.run_functions_eagerly` instead of the experimental "
389    "version.")
390@tf_export("config.experimental_run_functions_eagerly")
391def experimental_run_functions_eagerly(run_eagerly):
392  """Enables / disables eager execution of `tf.function`s.
393
394  Calling `tf.config.experimental_run_functions_eagerly(True)` will make all
395  invocations of `tf.function` run eagerly instead of running as a traced graph
396  function.
397
398  See `tf.config.run_functions_eagerly` for an example.
399
400  Note: This flag has no effect on functions passed into tf.data transformations
401  as arguments. tf.data functions are never executed eagerly and are always
402  executed as a compiled Tensorflow Graph.
403
404  Args:
405    run_eagerly: Boolean. Whether to run functions eagerly.
406  """
407  return run_functions_eagerly(run_eagerly)
408
409
410@tf_export("config.run_functions_eagerly")
411def run_functions_eagerly(run_eagerly):
412  """Enables / disables eager execution of `tf.function`s.
413
414  Calling `tf.config.run_functions_eagerly(True)` will make all
415  invocations of `tf.function` run eagerly instead of running as a traced graph
416  function.
417
418  This can be useful for debugging.
419
420  >>> def my_func(a):
421  ...  print("Python side effect")
422  ...  return a + a
423  >>> a_fn = tf.function(my_func)
424
425  >>> # A side effect the first time the function is traced
426  >>> a_fn(tf.constant(1))
427  Python side effect
428  <tf.Tensor: shape=(), dtype=int32, numpy=2>
429
430  >>> # No further side effect, as the traced function is called
431  >>> a_fn(tf.constant(2))
432  <tf.Tensor: shape=(), dtype=int32, numpy=4>
433
434  >>> # Now, switch to eager running
435  >>> tf.config.run_functions_eagerly(True)
436  >>> # Side effect, as the function is called directly
437  >>> a_fn(tf.constant(2))
438  Python side effect
439  <tf.Tensor: shape=(), dtype=int32, numpy=4>
440
441  >>> # Turn this back off
442  >>> tf.config.run_functions_eagerly(False)
443
444  Note: This flag has no effect on functions passed into tf.data transformations
445  as arguments. tf.data functions are never executed eagerly and are always
446  executed as a compiled Tensorflow Graph.
447
448  Args:
449    run_eagerly: Boolean. Whether to run functions eagerly.
450  """
451  global RUN_FUNCTIONS_EAGERLY
452  RUN_FUNCTIONS_EAGERLY = bool(run_eagerly)
453
454
455@deprecation.deprecated(
456    None,
457    "Use tf.config.functions_run_eagerly instead of the experimental version.")
458@tf_export("config.experimental_functions_run_eagerly")
459def experimental_functions_run_eagerly():
460  """Returns the value of the `experimental_run_functions_eagerly` setting."""
461  return functions_run_eagerly()
462
463
464@tf_export("config.functions_run_eagerly")
465def functions_run_eagerly():
466  """Returns the value of the `run_functions_eagerly` setting."""
467  return RUN_FUNCTIONS_EAGERLY
468
469
470def _evaluate_var_is_initialized(variables):
471  """Compute booleans indicating whether each variable is initialized."""
472  with ops.init_scope():
473    var_is_initialized = []
474    for v in variables:
475      var_is_initialized.append(
476          resource_variable_ops.var_is_initialized_op(v.handle))
477    try:
478      # Stack all the var_is_initialized values into one tensor and interpret
479      # the numpy value. This will reduce the number of RPCs between client and
480      # worker in the remote case.
481      return array_ops.stack(var_is_initialized).numpy()
482    except errors.UnimplementedError:
483      # Some devices do not support implicit copy-off to host. Fall back to
484      # variable-by-variable processing.
485      for index, v in enumerate(variables):
486        try:
487          numpy_value = var_is_initialized[index].numpy()
488        except errors.UnimplementedError:
489          # This is a variable on a parallel device; we'll extract its value on
490          # each replica and assert that they're identical.
491          components = parallel_device.unpack(var_is_initialized[index])
492          with ops.device(None):
493            components = array_ops.stack(components)
494            all_initialized = math_ops.reduce_all(components).numpy()
495            any_initialized = math_ops.reduce_any(components).numpy()
496          if all_initialized != any_initialized:
497            raise NotImplementedError(
498                f"Some but not all components of a parallel variable {v!r} "
499                "were initialized between their creation in a tf.function and "
500                "the function's trace having completed. This is not "
501                "supported; consider initializing either all or none of the "
502                "components, or moving initialization out of the function.")
503          numpy_value = all_initialized
504        var_is_initialized[index] = numpy_value
505  return var_is_initialized
506
507
508class FunctionDeleter(object):
509
510  __slots__ = ["func_graph"]
511
512  def __init__(self, func_graph):
513    self.func_graph = func_graph
514
515  def __del__(self):
516    try:
517      func_graph_module.dismantle_func_graph(self.func_graph)
518    except:  # pylint: disable=bare-except
519      # Note: bare except here because this can be noisy at shutdown time.
520      pass
521
522
523class OptionalXlaContext(object):
524  """Wrapper for XLA context optionally applied under a context manager."""
525
526  def __init__(self, is_compiled):
527    wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \
528              ops.get_default_graph())
529    self.xla_context = control_flow_ops.XLAControlFlowContext() \
530        if wrap else None
531
532  def __enter__(self):
533    if self.xla_context:
534      self.xla_context.Enter()
535
536  def __exit__(self, t, value, traceback):
537    if self.xla_context:
538      self.xla_context.Exit()
539
540
541# TODO(mdan): Consider expose this type for instance type checking.
542@tf_export("__internal__.function.Function", v1=[])
543class Function(core.GenericFunction):
544  """A `tf.types.experimental.GenericFunction` created by `tf.function`.
545
546  Currently, individual methods/attributes under this class are not guaranteed
547  by the TF API contract, and are subject to future changes.
548  """
549
550  def __init__(self,
551               python_function,
552               name,
553               input_signature=None,
554               autograph=True,
555               jit_compile=None,
556               experimental_implements=None,
557               experimental_autograph_options=None,
558               experimental_relax_shapes=False,
559               experimental_follow_type_hints=None):
560    """Initializes a `Function`.
561
562    Args:
563      python_function: the function to be wrapped.
564      name: the name given to it.
565      input_signature: See the documentation for `tf.function`.
566      autograph: See the documentation for `tf.function`.
567      jit_compile: See the documentation for `tf.function`.
568      experimental_implements: See the documentation for `tf.function`.
569      experimental_autograph_options: See the documentation for `tf.function`.
570      experimental_relax_shapes: See the documentation for `tf.function`.
571      experimental_follow_type_hints: See the documentation for `tf.function`.
572
573    Raises:
574      ValueError: if `input_signature` is not None and the `python_function`'s
575        argspec has keyword arguments.
576    """
577    self._lock = threading.Lock()
578    self._python_function = python_function
579    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
580        python_function,
581        input_signature,
582        jit_compile=jit_compile,
583        experimental_follow_type_hints=experimental_follow_type_hints,
584    )
585    self._implements = experimental_implements
586    # If `True`, the function uses the rendezvous of the parent. This is only
587    # needed to support code where raw send/recv operations are inserted and
588    # when functions are run in graph mode where they may not be inlined.
589    self._shared_rendezvous = None
590    self._autograph = autograph
591    self._experimental_autograph_options = experimental_autograph_options
592    self._experimental_relax_shapes = experimental_relax_shapes
593    self._jit_compile = jit_compile
594    if experimental_follow_type_hints is None:
595      experimental_follow_type_hints = False
596    self._experimental_follow_type_hints = experimental_follow_type_hints
597    self._created_variables = None  # GUARDED_BY(self._lock)
598    self._stateful_fn = None  # GUARDED_BY(self._lock)
599    self._stateless_fn = None  # GUARDED_BY(self._lock)
600    self._descriptor_cache = weakref.WeakKeyDictionary()
601    self._name = name
602    self._input_signature = input_signature
603    self._key_for_call_stats = self._get_key_for_call_stats()
604    self._omit_frequent_tracing_warning = False
605    ops._tf_function_api_guage.get_cell().set(True)  # pylint: disable=protected-access
606
607  def __getstate__(self):
608    """Custom pickling, to omit unpickleable objects."""
609    result = self.__dict__.copy()
610    del result["_lock"]
611    del result["_descriptor_cache"]
612    del result["_key_for_call_stats"]
613    return result
614
615  def __setstate__(self, state):
616    """Restore from pickled state."""
617    self.__dict__ = state
618    self._lock = threading.Lock()
619    self._descriptor_cache = weakref.WeakKeyDictionary()
620    self._key_for_call_stats = self._get_key_for_call_stats()
621
622  def _get_key_for_call_stats(self):
623    """Returns key instance to track call stats and retracings.
624
625    The key instance a best-effort to preserve global consistency.
626    """
627    target_function = self._python_function
628    # `__wrapped__` is a conventional Python attribute that a higher-order
629    # function keeps its original function's instance.  We also directly use
630    # this attribute for dealing with a class method.  See
631    # `bound_method_wrapper` in `function.py`.  If we don't use `__wrapped__`,
632    # all class methods will return the same `bound_method_wrapper` instance
633    # from this function.
634    while hasattr(target_function, "__wrapped__"):
635      target_function = target_function.__wrapped__
636
637    if hasattr(target_function, "__func__"):
638      target_function = target_function.__func__
639
640    if hasattr(target_function, "__code__"):
641      return target_function.__code__
642
643    return self._python_function
644
645  def _defun_with_scope(self, scope):
646    """Creates a defun wrapped inside a variable creator scope."""
647
648    weak_wrapped_fn = None
649    compile_with_xla = self._jit_compile
650
651    def wrapped_fn(*args, **kwds):
652      """Wraps `self._python_function` in a variable creator scope."""
653      # We register a variable creator with reduced priority. If an outer
654      # variable creator is just modifying keyword arguments to the variable
655      # constructor, this will work harmoniously. Since the `scope` registered
656      # here actually creates the variable, it taking priority would otherwise
657      # ignore the outer creator.
658      #
659      # If an outer variable creator calls the variable constructor manually,
660      # for example creating a MirroredVariable, then they won't call our
661      # creator. This means we won't be able to trace the initialization graph,
662      # and so variable initializers can't depend on function arguments. This is
663      # better than the alternative, tracing the initialization graph but giving
664      # the user a variable type they didn't want.
665      default_graph = ops.get_default_graph()
666      with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
667        # __wrapped__ allows AutoGraph to swap in a converted function. We give
668        # the function a weak reference to itself to avoid a reference cycle.
669        with OptionalXlaContext(compile_with_xla):
670          out = weak_wrapped_fn().__wrapped__(*args, **kwds)
671        return out
672
673    weak_wrapped_fn = weakref.ref(wrapped_fn)
674
675    return self._defun(tf_decorator.make_decorator(
676        self._python_function,
677        wrapped_fn))
678
679  def _create_implements_attribute(self):
680    """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME."""
681    attributes = {}
682    if isinstance(self._implements, str):
683      # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a
684      # NameAttrList. This is used when apart from the function name being
685      # implemented, a list of attributes is also being specified.
686      # The attributes are specified as key-value pairs in the NameAttrList
687      # of the corresponding AttrValue. The function name will be in the
688      # 'name' field of the NameAttrList. Else, it is just a string
689      # corresponding to the function name.
690      try:
691        implements_attr = six.ensure_text(self._implements, "utf-8")
692        attr_value = attr_value_pb2.AttrValue()
693        nameattrlist = attr_value_pb2.NameAttrList()
694        _text_format.Merge(implements_attr, nameattrlist)
695        attr_value.func.CopyFrom(nameattrlist)
696        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value
697      except (_text_format.ParseError, DecodeError):
698        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
699    return attributes
700
701  def _defun(self, fn):
702    """Returns a defun generated from the input function."""
703    attributes = {}
704
705    if self._implements is not None:
706      attributes = self._create_implements_attribute()
707
708    share = self._shared_rendezvous
709    if share is not None:
710      attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
711
712    if self._jit_compile is not None:
713      attributes.update(_XlaMustCompile=bool(self._jit_compile))
714      if self._jit_compile:
715        attributes.update(_noinline=True)
716    if not attributes:
717      attributes = None
718    return function_lib.defun_with_attributes(
719        fn,
720        input_signature=self.input_signature,
721        attributes=attributes,
722        autograph=self._autograph,
723        jit_compile=self._jit_compile,
724        experimental_autograph_options=self._experimental_autograph_options,
725        experimental_follow_type_hints=self._experimental_follow_type_hints,
726        experimental_relax_shapes=self._experimental_relax_shapes)
727
728  def _initialize(self, args, kwds, add_initializers_to=None):
729    """Initializes, on the first call.
730
731    Creates two `Function`s, one that will allow creation of variables
732    and one that won't.
733
734    Additionally runs a trace for the `Function` that allows creation
735    of variables.
736
737    Args:
738      args: Arguments to the underlying python callable.
739      kwds: Keyword arguments to the python callable.
740      add_initializers_to: Where to collect variable initializers, if not None.
741    """
742
743    if self._input_signature is not None:
744      arglen = len(self._input_signature)
745      arg_names_len = len(self.function_spec.arg_names)
746      default_arg_len = len(self.function_spec.fullargspec.defaults or ())
747      required_arg_len = arg_names_len - default_arg_len
748      # The input signature must cover all required function arguments.
749      if arglen < required_arg_len:
750        missing_tensor_specs = self.function_spec.arg_names[
751            arglen:required_arg_len]
752        raise TypeError(
753            f"The decorated function {self._name} has {required_arg_len} "
754            f"required argument(s), but tf.function was only passed an "
755            f"input_signature of length {arglen}. This covers {arglen} "
756            f"required argument(s): {self.function_spec.arg_names[:arglen]}, "
757            f"but TensorSpecs are still required for the remaining "
758            f"{len(missing_tensor_specs)} argument(s): {missing_tensor_specs}.")
759
760    created_variables = []
761    lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
762
763    def variable_capturing_scope(unused_next_creator, **kwds):
764      """Creates UnliftedInitializerVariables and saves references to them."""
765      v = UnliftedInitializerVariable(
766          add_initializers_to=add_initializers_to,
767          lifted_initializer_graph=lifted_initializer_graph, **kwds)
768      created_variables.append(weakref.ref(v))
769      return v
770
771    self._created_variables = created_variables
772    self._stateful_fn = self._defun_with_scope(variable_capturing_scope)
773    self._stateful_fn._name = self._name  # pylint: disable=protected-access
774    # Force the definition of the function for these arguments
775    self._lifted_initializer_graph = lifted_initializer_graph
776    self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
777    self._concrete_stateful_fn = (
778        self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
779            *args, **kwds))
780
781    def invalid_creator_scope(*unused_args, **unused_kwds):
782      """Disables variable creation."""
783      raise ValueError(
784          "tf.function-decorated function tried to create "
785          "variables on non-first call.")
786
787    self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
788    self._stateless_fn._name = self._name  # pylint: disable=protected-access
789
790  def _clone(self, python_function):
791    """Clone the function with different python function."""
792    f = Function(
793        python_function=(self._python_function
794                         if python_function is None else python_function),
795        name=self._name,
796        input_signature=self._input_signature,
797        autograph=self._autograph,
798        jit_compile=self._jit_compile,
799        experimental_implements=self._implements,
800        experimental_autograph_options=self._experimental_autograph_options,
801        experimental_relax_shapes=self._experimental_relax_shapes,
802        experimental_follow_type_hints=self._experimental_follow_type_hints)
803
804    if self._shared_rendezvous:
805      f._shared_rendezvous = self._shared_rendezvous  # pylint: disable=protected-access
806
807    return f
808
809  def _decorate(self, decorator):
810    """Allows the captured Python function to be decorated in place.
811
812    This method is only safe to call when the Function has not been called by a
813    user. It makes sense to use this method to push a decorator into the
814    function rather than wrapping the function in the decorator.
815
816    We use this in tf.Module to allow user annotated `tf.functions` to remain as
817    `Function` objects but still automatically enter the Module name_scope
818    when they are evaluated like all other methods.
819
820    Args:
821      decorator: A callable accepting a single argument which is the function
822        to decorate and returning a callable result.
823
824    Raises:
825      ValueError: If the function has been called a ValueError is raised.
826    """
827    if self._stateful_fn is not None or self._stateless_fn is not None:
828      raise ValueError(
829          "Functions cannot be decorated after they have been traced.")
830
831    self._python_function = decorator(self._python_function)
832    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
833        self._python_function, self.input_signature)
834
835  # TODO: Remove this private method after updating all its uses
836  # A good moment to do this could be when the experimental label is removed
837  def _get_tracing_count(self):
838    return self.experimental_get_tracing_count()
839
840  def experimental_get_tracing_count(self):
841    """Returns the number of times the function has been traced.
842
843    For more information on when a function is traced and when it is
844    traced multiple times see https://www.tensorflow.org/guide/function.
845    Example:
846
847    >>> @tf.function
848    ... def double(a):
849    ...   return a + a
850    >>> double(tf.constant(1))
851    >>> double(tf.constant(2))
852    >>> double.experimental_get_tracing_count()
853    1
854    >>> double(tf.constant("a"))
855    >>> double.experimental_get_tracing_count()
856    2
857
858
859    The first time experimental_get_tracing_count is called
860    it returns 1, as the function is traced the first
861    time it is called, and the second time the same graph is used
862    since we're calling it with a parameter of the same type.
863
864    The second time experimental_get_tracing_count is called
865    it returns 2, as we called double with a
866    different argument type, and so it was traced again.
867
868    """
869    result = self._stateless_fn.tracing_count if self._stateless_fn else 0
870    result += self._stateful_fn.tracing_count if self._stateful_fn else 0
871    return result
872
873  @property
874  def _run_functions_eagerly(self):
875    return RUN_FUNCTIONS_EAGERLY
876
877  @traceback_utils.filter_traceback
878  def __call__(self, *args, **kwds):
879    # Implements GenericFunction.__call__.
880    if self._run_functions_eagerly:
881      with trace.Trace(self._name, tf_function_call="eager"):
882        return self._python_function(*args, **kwds)
883
884    # Only count the statistics the first time, before initialization took
885    # place.
886    if self._created_variables is None:
887      compiled = bool(self._jit_compile and
888                      not control_flow_util.GraphOrParentsInXlaContext(
889                          ops.get_default_graph()))
890      # For nested functions, increment the counter only when a function with
891      # jit_compile=True is called within a function with jit_compile=False. We
892      # count this special case to correctly record that both jit_compile=True
893      # and jit_compile=False is being used for parts of the outer function.
894      if ops.executing_eagerly_outside_functions() and (
895          context.executing_eagerly() or compiled):
896        # Labels must be strings in Python, so we convert 'compiled' to a string
897        _tf_function_counter.get_cell(str(int(compiled))).increase_by(1)
898
899    tracing_count = self.experimental_get_tracing_count()
900    with trace.Trace(self._name) as tm:
901      # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation.
902      compiler = "xla" if self._jit_compile else "nonXla"
903
904      with OptionalXlaContext(self._jit_compile):
905        result = self._call(*args, **kwds)
906
907      new_tracing_count = self.experimental_get_tracing_count()
908      without_tracing = (tracing_count == new_tracing_count)
909      execution_mode = "notTraced" if without_tracing else "traced"
910      tm.set_metadata(tf_function_call=execution_mode + "-" + compiler,
911                      tracing_count=new_tracing_count)
912
913    if context.executing_eagerly():
914      if without_tracing:
915        _frequent_tracing_detector_manager.called_without_tracing(
916            self._key_for_call_stats)
917      else:
918        _frequent_tracing_detector_manager.called_with_tracing(
919            self._key_for_call_stats, self._python_function,
920            self._omit_frequent_tracing_warning)
921
922    return result
923
924  def _call(self, *args, **kwds):
925    """Calls the graph function."""
926    self._lock.acquire()
927    if ALLOW_DYNAMIC_VARIABLE_CREATION:
928      condition = self._created_variables and self._stateful_fn is None
929    else:
930      condition = self._created_variables
931    if condition:
932      # Release the lock early so that multiple threads can perform the call
933      # in parallel.
934      self._lock.release()
935      # In this case we have created variables on the first call, so we run the
936      # defunned version which is guaranteed to never create variables.
937      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
938    elif self._stateful_fn is not None:
939      # Release the lock early so that multiple threads can perform the call
940      # in parallel.
941      self._lock.release()
942      # In this case we have not created variables on the first call. So we can
943      # run the first trace but we should fail if variables are created.
944      results = self._stateful_fn(*args, **kwds)
945      if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
946        raise ValueError("Creating variables on a non-first call to a function"
947                         " decorated with tf.function.")
948      return results
949
950    try:
951      # This is the first call of __call__, so we have to initialize.
952      initializers = []
953      self._initialize(args, kwds, add_initializers_to=initializers)
954    finally:
955      # At this point we know that the initialization is complete (or less
956      # interestingly an exception was raised) so we no longer need a lock.
957      self._lock.release()
958
959    if self._created_variables:
960      try:
961        # Attempt to initialize variables eagerly and without conds by lifting
962        # out initialization graphs. This is the only initialization strategy
963        # compatible with XLA at the moment.
964        self._initialize_uninitialized_variables(initializers)
965      except lift_to_graph.UnliftableError:
966        pass  # Fall through to cond-based initialization.
967      else:
968        # Lifting succeeded, so variables are initialized and we can run the
969        # stateless function.
970        return self._stateless_fn(*args, **kwds)
971    else:
972      _, _, _, filtered_flat_args = \
973          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
974              *args, **kwds)
975      # If we did not create any variables the trace we have is good enough.
976      return self._concrete_stateful_fn._call_flat(
977          filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
978
979    def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
980      """Conditionally runs initialization if it's needed."""
981      condition = True
982      for wr in self._created_variables:
983        variable = wr()
984        if variable is None:
985          raise ValueError(
986              "A tf.Variable created inside your tf.function has been"
987              " garbage-collected. Your code needs to keep Python references"
988              " to variables created inside `tf.function`s.\n"
989              "\n"
990              "A common way to raise this error is to create and return a"
991              " variable only referenced inside your function:\n"
992              "\n"
993              "@tf.function\n"
994              "def f():\n"
995              "  v = tf.Variable(1.0)\n"
996              "  return v\n"
997              "\n"
998              "v = f()  # Crashes with this error message!\n"
999              "\n"
1000              "The reason this crashes is that @tf.function annotated"
1001              " function returns a **`tf.Tensor`** with the **value** of the"
1002              " variable when the function is called rather than the"
1003              " variable instance itself. As such there is no code holding a"
1004              " reference to the `v` created inside the function and Python"
1005              " garbage collects it.\n"
1006              "\n"
1007              "The simplest way to fix this issue is to create variables"
1008              " outside the function and capture them:\n"
1009              "\n"
1010              "v = tf.Variable(1.0)\n"
1011              "\n"
1012              "@tf.function\n"
1013              "def f():\n"
1014              "  return v\n"
1015              "\n"
1016              "f()  # <tf.Tensor: numpy=1.>\n"
1017              "v.assign_add(1.)\n"
1018              "f()  # <tf.Tensor: numpy=2.>")
1019        condition = math_ops.logical_and(
1020            condition, resource_variable_ops.var_is_initialized_op(
1021                variable.handle))
1022      # We want to call stateless_fn if possible because it avoids recomputing
1023      # potentially expensive initializers.
1024      return control_flow_ops.cond(
1025          condition,
1026          lambda: self._stateless_fn(*inner_args, **inner_kwds),
1027          functools.partial(
1028              self._concrete_stateful_fn._call_flat,  # pylint: disable=protected-access
1029              inner_filtered_flat_args,
1030              captured_inputs=self._concrete_stateful_fn.captured_inputs))
1031
1032    # We've created variables and are unable to lift the initialization graphs,
1033    # so we fall back to initializing with conds while running the function.
1034    canon_args, canon_kwds, _, filtered_flat_args = \
1035        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
1036            *args, **kwds)
1037    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
1038                                            filtered_flat_args)
1039
1040  def experimental_get_compiler_ir(self, *args, **kwargs):
1041    # Implements GenericFunction.experimental_get_compiler_ir
1042    context.ensure_initialized()
1043    if not self._jit_compile:
1044      raise ValueError("Compiler IR can only be returned for functions marked "
1045                       "with 'jit_compile=True'")
1046
1047    concrete_fn = self.get_concrete_function(*args, **kwargs)
1048    fn_name = concrete_fn.name
1049
1050    # pylint: disable=protected-access
1051    _, _, _, filtered_flat_args = \
1052        concrete_fn._function_spec.canonicalize_function_inputs(
1053            *args, **kwargs)
1054
1055    def compiler_ir_generator(stage="hlo", device_name=None):
1056      # TODO(cheshire): This is a hack to get the current "preferred" device,
1057      # there is no current API to get it otherwise.
1058      if device_name is None:
1059        device_name = random_ops.random_normal([]).device
1060      res_bytes = context.context().get_compiler_ir(
1061          device_name=device_name,
1062          stage=stage,
1063          function_name=fn_name,
1064          args=list(filtered_flat_args) + concrete_fn.captured_inputs)
1065      if stage in ("hlo_serialized", "optimized_hlo_serialized",
1066                   "optimized_hlo_proto_serialized"):
1067        return res_bytes
1068      else:
1069        return res_bytes.decode("utf-8")
1070
1071    return compiler_ir_generator
1072
1073  @property
1074  def python_function(self):
1075    """The python function wrapped in this tf.function."""
1076    return self._python_function
1077
1078  @property
1079  def input_signature(self):
1080    return self._function_spec.input_signature
1081
1082  @property
1083  def function_spec(self):
1084    return self._function_spec
1085
1086  def pretty_printed_concrete_signatures(self, verbose=True):
1087    joiner = "\n\n" if verbose else "\n"
1088    return joiner.join([
1089        c.pretty_printed_signature(verbose=verbose)
1090        for c in self._list_all_concrete_functions()
1091    ])
1092
1093  def _initialize_uninitialized_variables(self, initializers):
1094    """Make and call a `ConcreteFunction` which initializes variables."""
1095
1096    if not initializers:
1097      return
1098
1099    var_is_initialized = _evaluate_var_is_initialized(
1100        [v for v, _ in initializers])
1101
1102    # Note: using defun here avoids an infinite recursion.
1103    # Most of the code in this function runs eagerly with init_scope, where
1104    # autograph is not necessary.
1105    @function_lib.defun(autograph=False)
1106    def initialize_variables():
1107      op_map = object_identity.ObjectIdentityDictionary()
1108
1109      inits = []
1110      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1111        with ops.init_scope():
1112          if is_initialized:
1113            continue
1114        inits.append(init)
1115
1116      if inits:
1117        op_map = lift_to_graph.lift_to_graph(
1118            inits, ops.get_default_graph(), op_map=op_map)
1119      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1120        with ops.init_scope():
1121          if is_initialized:
1122            continue
1123        v.assign(op_map[init], read_value=False)
1124
1125    with ops.init_scope():
1126      return initialize_variables.get_concrete_function()()
1127
1128  def get_initialization_function(self, *args, **kwargs):
1129    """Returns a `ConcreteFunction` which initializes this function's variables.
1130
1131    Requires that this function hasn't been accessed yet through either calling
1132    it or calling get_concrete_function. Fails if we cannot build an initializer
1133    function which does not depend on the concrete values of the inputs to this
1134    function.
1135
1136    Note that running this function will overwrite any values currently assigned
1137    to variables, for example restores from a checkpoint.
1138
1139    Args:
1140      *args: arguments to the underlying python callable.
1141      **kwargs: keyword arguments to the python callable.
1142
1143    Returns:
1144      A `ConcreteFunction` object which initializes the variables of this
1145      function.
1146
1147    Raises:
1148      RuntimeError: if called after the variables have been initialized.
1149    """
1150    with self._lock:
1151      if self._stateful_fn is not None:
1152        raise RuntimeError(
1153            "get_initialization_function cannot be called after the function "
1154            "has been used")
1155      # Here we trace the function, collect the initializers, and attempt to
1156      # extract them and run them eagerly. Fail only if we cannot do so.
1157      initializers = []
1158      self._initialize(args, kwargs, add_initializers_to=initializers)
1159
1160    # Note: using defun here avoids an infinite recursion.
1161    @function_lib.defun
1162    def initialize_variables():
1163      for v, init in initializers:
1164        v.assign(
1165            lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init],
1166            read_value=False)
1167
1168    return initialize_variables.get_concrete_function()
1169
1170  def _list_all_concrete_functions(self):
1171    """Returns all concrete functions."""
1172    if self.input_signature is not None:
1173      self.get_concrete_function()
1174    concrete_functions = []
1175    # pylint: disable=protected-access
1176    if self._stateful_fn:
1177      concrete_functions.extend(
1178          self._stateful_fn._function_cache.all_values())
1179    if self._stateless_fn:
1180      concrete_functions.extend(
1181          self._stateless_fn._function_cache.all_values())
1182    # pylint: enable=protected-access
1183    return concrete_functions
1184
1185  def _list_all_concrete_functions_for_serialization(self):
1186    """Returns all concrete functions for serialization.
1187
1188    Returns:
1189      A list of instances of `ConcreteFunction`.
1190    """
1191    concrete_functions = self._list_all_concrete_functions()
1192    seen_signatures = []
1193    for concrete_function in concrete_functions:
1194      signature = concrete_function.structured_input_signature
1195      flattened = nest.flatten(signature)
1196      if any(
1197          isinstance(arg, func_graph_module.UnknownArgument)
1198          for arg in flattened):
1199        logging.info("Unsupported signature for serialization: %s.", signature)
1200        continue
1201      equal_to_signature = functools.partial(
1202          function_lib.is_same_structure, signature, check_values=True)
1203      if not any(equal_to_signature(s) for s in seen_signatures):
1204        seen_signatures.append(signature)
1205
1206    # Re-create concrete functions for these signatures. Re-creating ensures
1207    # that if the cache key has changed, the function will be traced again.
1208    concrete_functions = []
1209    for args, kwargs in seen_signatures:
1210      concrete_functions.append(self.get_concrete_function(*args, **kwargs))
1211    return concrete_functions
1212
1213  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
1214    """Returns a `ConcreteFunction` specialized to inputs and execution context.
1215
1216    Unlike `get_concrete_function(...)`, the graph will be deleted when the
1217    returned function is deleted.  It's useful to avoid creating a reference
1218    cycle when you know for sure that the graph will be no longer used without
1219    the returned function.
1220
1221    Args:
1222      *args: inputs to specialize on.
1223      **kwargs: inputs to specialize on.
1224
1225    Returns:
1226      A TensorFlow function which takes exactly one `tf.Tensor` per argument.
1227
1228    Raises:
1229      ValueError: if this object has not yet been called on concrete values.
1230    """
1231    with self._lock:
1232      if self._stateful_fn is None:
1233        initializers = []
1234        self._initialize(args, kwargs, add_initializers_to=initializers)
1235        self._initialize_uninitialized_variables(initializers)
1236
1237    if self._created_variables:
1238      # In this case we have created variables on the first call, so we run the
1239      # defunned version which is guaranteed to never create variables.
1240      return self._stateless_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1241          *args, **kwargs)
1242    elif self._stateful_fn is not None:
1243      # In this case we have not created variables on the first call. So we can
1244      # run the first trace but we should fail if variables are created.
1245      concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1246          *args, **kwargs)
1247      if self._created_variables:
1248        raise ValueError("Creating variables on a non-first call to a function"
1249                         " decorated with tf.function.")
1250      return concrete
1251
1252  def get_concrete_function(self, *args, **kwargs):
1253    # Implements GenericFunction.get_concrete_function.
1254    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1255    concrete._garbage_collector.release()  # pylint: disable=protected-access
1256    return concrete
1257
1258  def __get__(self, instance, owner):
1259    """Makes it possible to defun instance methods."""
1260    del owner
1261    # `instance` here is the instance that this `Function` was accessed through
1262    # e.g., for
1263    #
1264    #   class Foo(object):
1265    #
1266    #     @function.defun
1267    #     def bar(self):
1268    #       ...
1269    #
1270    #   foo = Foo()
1271    #   foo.bar()  # `foo.bar` is a `Function` instance
1272    #
1273    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
1274    # new instance of `Function` here to allow different instances each
1275    # to create variables once, thereby allowing methods to be decorated with
1276    # tf.function. Keeps a cache to avoid retracing the function every time the
1277    # descriptor is accessed.
1278    if instance not in self._descriptor_cache:
1279      if instance is None:
1280        return self
1281      self._descriptor_cache[instance] = (
1282          function_lib.class_method_to_instance_method(self, instance))
1283    return self._descriptor_cache[instance]
1284
1285
1286@tf_export("function")
1287@deprecation.deprecated_args(None,
1288                             "experimental_compile is deprecated, use "
1289                             "jit_compile instead", "experimental_compile")
1290def function(func=None,
1291             input_signature=None,
1292             autograph=True,
1293             jit_compile=None,
1294             experimental_implements=None,
1295             experimental_autograph_options=None,
1296             experimental_relax_shapes=False,
1297             experimental_compile=None,
1298             experimental_follow_type_hints=None) -> core.GenericFunction:
1299  """Compiles a function into a callable TensorFlow graph.
1300
1301  `tf.function` constructs a `tf.types.experimental.GenericFunction` that
1302  executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the
1303  TensorFlow operations in `func`. More information on the topic can be found
1304  in [Introduction to Graphs and tf.function]
1305  (https://www.tensorflow.org/guide/intro_to_graphs).
1306
1307  See [Better Performance with tf.function]
1308  (https://www.tensorflow.org/guide/function) for tips on performance and
1309  known limitations.
1310
1311  Example usage:
1312
1313  >>> @tf.function
1314  ... def f(x, y):
1315  ...   return x ** 2 + y
1316  >>> x = tf.constant([2, 3])
1317  >>> y = tf.constant([3, -2])
1318  >>> f(x, y)
1319  <tf.Tensor: ... numpy=array([7, 7], ...)>
1320
1321  The trace-compilation allows non-TensorFlow operations to execute, but under
1322  special conditions. In general, only TensorFlow operations are guaranteed to
1323  run and create fresh results whenever the `GenericFunction` is called.
1324
1325  ## Features
1326
1327  `func` may use data-dependent control flow, including `if`, `for`, `while`
1328  `break`, `continue` and `return` statements:
1329
1330  >>> @tf.function
1331  ... def f(x):
1332  ...   if tf.reduce_sum(x) > 0:
1333  ...     return x * x
1334  ...   else:
1335  ...     return -x // 2
1336  >>> f(tf.constant(-2))
1337  <tf.Tensor: ... numpy=1>
1338
1339  `func`'s closure may include `tf.Tensor` and `tf.Variable` objects:
1340
1341  >>> @tf.function
1342  ... def f():
1343  ...   return x ** 2 + y
1344  >>> x = tf.constant([-2, -3])
1345  >>> y = tf.Variable([3, -2])
1346  >>> f()
1347  <tf.Tensor: ... numpy=array([7, 7], ...)>
1348
1349  `func` may also use ops with side effects, such as `tf.print`, `tf.Variable`
1350  and others:
1351
1352  >>> v = tf.Variable(1)
1353  >>> @tf.function
1354  ... def f(x):
1355  ...   for i in tf.range(x):
1356  ...     v.assign_add(i)
1357  >>> f(3)
1358  >>> v
1359  <tf.Variable ... numpy=4>
1360
1361  Important: Any Python side-effects (appending to a list, printing with
1362  `print`, etc) will only happen once, when `func` is traced. To have
1363  side-effects executed into your `tf.function` they need to be written
1364  as TF ops:
1365
1366  >>> l = []
1367  >>> @tf.function
1368  ... def f(x):
1369  ...   for i in x:
1370  ...     l.append(i + 1)    # Caution! Will only happen once when tracing
1371  >>> f(tf.constant([1, 2, 3]))
1372  >>> l
1373  [<tf.Tensor ...>]
1374
1375  Instead, use TensorFlow collections like `tf.TensorArray`:
1376
1377  >>> @tf.function
1378  ... def f(x):
1379  ...   ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
1380  ...   for i in range(len(x)):
1381  ...     ta = ta.write(i, x[i] + 1)
1382  ...   return ta.stack()
1383  >>> f(tf.constant([1, 2, 3]))
1384  <tf.Tensor: ..., numpy=array([2, 3, 4], ...)>
1385
1386  ## `tf.function` creates polymorphic callables
1387
1388  Internally, `tf.types.experimental.GenericFunction` may contain multiple
1389  `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with
1390  different data types or shapes, since TensorFlow can perform more
1391  optimizations on graphs of specific shapes, dtypes and values of constant
1392  arguments. `tf.function` treats any pure Python values as opaque objects (best
1393  thought of as compile-time constants), and builds a separate `tf.Graph` for
1394  each set of Python arguments that it encounters.
1395  For more information, see the
1396  [tf.function guide](https://www.tensorflow.org/guide/function?hl=en#rules_of_tracing)
1397
1398  Executing a `GenericFunction` will select and execute the appropriate
1399  `ConcreteFunction` based on the argument types and values.
1400
1401  To obtain an individual `ConcreteFunction`, use the
1402  `GenericFunction.get_concrete_function` method. It can be called with the
1403  same arguments as `func` and returns a
1404  `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a
1405  single `tf.Graph`:
1406
1407  >>> @tf.function
1408  ... def f(x):
1409  ...   return x + 1
1410  >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1411  True
1412
1413  `ConcreteFunction`s can be executed just like `GenericFunction`s, but their
1414  input is resticted to the types to which they're specialized.
1415
1416  ## Retracing
1417
1418  `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is
1419  called with new TensorFlow types or shapes, or with new Python values as
1420  arguments. When `GenericFunction` builds a new trace, it is said that `func`
1421  is retraced. Retracing is a frequent performance concern for `tf.function` as
1422  it can be considerably slower than executing a graph that's already been
1423  traced. It is ideal to minimize the amount of retracing in your code.
1424
1425  Caution: Passing python scalars or lists as arguments to `tf.function` will
1426  usually retrace. To avoid this, pass numeric arguments as Tensors whenever
1427  possible:
1428
1429  >>> @tf.function
1430  ... def f(x):
1431  ...   return tf.abs(x)
1432  >>> f1 = f.get_concrete_function(1)
1433  >>> f2 = f.get_concrete_function(2)  # Slow - compiles new graph
1434  >>> f1 is f2
1435  False
1436  >>> f1 = f.get_concrete_function(tf.constant(1))
1437  >>> f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
1438  >>> f1 is f2
1439  True
1440
1441  Python numerical arguments should only be used when they take few distinct
1442  values, such as hyperparameters like the number of layers in a neural network.
1443
1444  ## Input signatures
1445
1446  For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for
1447  every unique set of input shapes and datatypes. The example below creates two
1448  separate `ConcreteFunction`s, each specialized to a different shape:
1449
1450  >>> @tf.function
1451  ... def f(x):
1452  ...   return x + 1
1453  >>> vector = tf.constant([1.0, 1.0])
1454  >>> matrix = tf.constant([[3.0]])
1455  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1456  False
1457
1458  An "input signature" can be optionally provided to `tf.function` to control
1459  this process. The input signature specifies the shape and type of each
1460  Tensor argument to the function using a `tf.TensorSpec` object. More general
1461  shapes can be used. This ensures only one `ConcreteFunction` is created, and
1462  restricts the `GenericFunction` to the specified shapes and types. It is
1463  an effective way to limit retracing when Tensors have dynamic shapes.
1464
1465  >>> @tf.function(
1466  ...     input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
1467  ... def f(x):
1468  ...   return x + 1
1469  >>> vector = tf.constant([1.0, 1.0])
1470  >>> matrix = tf.constant([[3.0]])
1471  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1472  True
1473
1474  ## Variables may only be created once
1475
1476  `tf.function` only allows creating new `tf.Variable` objects when it is called
1477  for the first time:
1478
1479  >>> class MyModule(tf.Module):
1480  ...   def __init__(self):
1481  ...     self.v = None
1482  ...
1483  ...   @tf.function
1484  ...   def __call__(self, x):
1485  ...     if self.v is None:
1486  ...       self.v = tf.Variable(tf.ones_like(x))
1487  ...     return self.v * x
1488
1489  In general, it is recommended to create `tf.Variable`s outside of
1490  `tf.function`.
1491  In simple cases, persisting state across `tf.function` boundaries may be
1492  implemented using a pure functional style in which state is represented by
1493  `tf.Tensor`s passed as arguments and returned as return values.
1494
1495  Contrast the two styles below:
1496
1497  >>> state = tf.Variable(1)
1498  >>> @tf.function
1499  ... def f(x):
1500  ...   state.assign_add(x)
1501  >>> f(tf.constant(2))  # Non-pure functional style
1502  >>> state
1503  <tf.Variable ... numpy=3>
1504
1505  >>> state = tf.constant(1)
1506  >>> @tf.function
1507  ... def f(state, x):
1508  ...   state += x
1509  ...   return state
1510  >>> state = f(state, tf.constant(2))  # Pure functional style
1511  >>> state
1512  <tf.Tensor: ... numpy=3>
1513
1514  ## Python operations execute only once per trace
1515
1516  `func` may contain TensorFlow operations mixed with pure Python operations.
1517  However, when the function is executed, only the TensorFlow operations will
1518  run. The Python operations run only once, at trace time. If TensorFlow
1519  operations depend on results from Pyhton operations, those results will be
1520  frozen into the graph.
1521
1522  >>> @tf.function
1523  ... def f(a, b):
1524  ...   print('this runs at trace time; a is', a, 'and b is', b)
1525  ...   return b
1526  >>> f(1, tf.constant(1))
1527  this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
1528  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1529
1530  >>> f(1, tf.constant(2))
1531  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1532
1533  >>> f(2, tf.constant(1))
1534  this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
1535  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1536
1537  >>> f(2, tf.constant(2))
1538  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1539
1540  ## Using type annotations to improve performance
1541
1542  'experimental_follow_type_hints` can be used along with type annotations to
1543  reduce retracing by automatically casting any Python values to `tf.Tensor`
1544  (something that is not done by default, unless you use input signatures).
1545
1546  >>> @tf.function(experimental_follow_type_hints=True)
1547  ... def f_with_hints(x: tf.Tensor):
1548  ...   print('Tracing')
1549  ...   return x
1550  >>> @tf.function(experimental_follow_type_hints=False)
1551  ... def f_no_hints(x: tf.Tensor):
1552  ...   print('Tracing')
1553  ...   return x
1554  >>> f_no_hints(1)
1555  Tracing
1556  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1557  >>> f_no_hints(2)
1558  Tracing
1559  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1560  >>> f_with_hints(1)
1561  Tracing
1562  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1563  >>> f_with_hints(2)
1564  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1565
1566  Args:
1567    func: the function to be compiled. If `func` is None, `tf.function` returns
1568      a decorator that can be invoked with a single argument - `func`. In other
1569      words, `tf.function(input_signature=...)(func)` is equivalent to
1570      `tf.function(func, input_signature=...)`. The former can be used as
1571      decorator.
1572    input_signature: A possibly nested sequence of `tf.TensorSpec` objects
1573      specifying the shapes and dtypes of the Tensors that will be supplied to
1574      this function. If `None`, a separate function is instantiated for each
1575      inferred input signature.  If input_signature is specified, every input to
1576      `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1577    autograph: Whether autograph should be applied on `func` before tracing a
1578      graph. Data-dependent control flow requires `autograph=True`. For more
1579      information, see the [tf.function and AutoGraph guide](
1580      https://www.tensorflow.org/guide/function#autograph_transformations).
1581    jit_compile: If `True`, compiles the function using
1582      [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1583      such as fusion, and attempts to emit more efficient code. This may
1584      drastically improve the performance. If set to `True`,
1585      the whole function needs to be compilable by XLA, or an
1586      `errors.InvalidArgumentError` is thrown.
1587      If `None` (default), compiles the function with XLA when running on TPU
1588      and goes through the regular function execution path when running on
1589      other devices.
1590      If `False`, executes the function without XLA compilation.  Set this value
1591      to `False` when directly running a multi-device function on TPUs (e.g. two
1592      TPU cores, one TPU core and its host CPU).
1593      Not all functions are compilable, see a list of
1594      [sharp corners](https://tensorflow.org/xla/known_issues).
1595    experimental_implements: If provided, contains a name of a "known" function
1596      this implements. For example "mycompany.my_recurrent_cell".
1597      This is stored as an attribute in inference function,
1598      which can then be detected when processing serialized function.
1599      See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)  # pylint: disable=line-too-long
1600      for details.  For an example of utilizing this attribute see this
1601      [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1602      The code above automatically detects and substitutes function that
1603      implements "embedded_matmul" and allows TFLite to substitute its own
1604      implementations. For instance, a tensorflow user can use this
1605       attribute to mark that their function also implements
1606      `embedded_matmul` (perhaps more efficiently!)
1607      by specifying it using this parameter:
1608      `@tf.function(experimental_implements="embedded_matmul")`
1609      This can either be specified as just the string name of the function or
1610      a NameAttrList corresponding to a list of key-value attributes associated
1611      with the function name. The name of the function will be in the 'name'
1612      field of the NameAttrList. To define a formal TF op for this function
1613      implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1614      project.
1615    experimental_autograph_options: Optional tuple of
1616      `tf.autograph.experimental.Feature` values.
1617    experimental_relax_shapes: When True, `tf.function` may generate fewer,
1618      graphs that are less specialized on input shapes.
1619    experimental_compile: Deprecated alias to 'jit_compile'.
1620    experimental_follow_type_hints: When True, the function may use type
1621      annotations from `func` to optimize the tracing performance. For example,
1622      arguments annotated with `tf.Tensor` will automatically be converted
1623      to a Tensor.
1624
1625  Returns:
1626     If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
1627     If `func` is None, returns a decorator that, when invoked with a single
1628     `func` argument, returns a `tf.types.experimental.GenericFunction`.
1629
1630  Raises:
1631     `ValueError` when attempting to use `jit_compile=True`, but XLA support is
1632     not available.
1633  """
1634  if func is not None:
1635    function_lib.validate_python_function(func)
1636  if input_signature is not None:
1637    function_lib.validate_signature(input_signature)
1638  if experimental_follow_type_hints is None:
1639    experimental_follow_type_hints = False
1640
1641  def decorated(inner_function):
1642    try:
1643      name = inner_function.__name__
1644    except AttributeError:
1645      name = "function"
1646    return tf_decorator.make_decorator(
1647        inner_function,
1648        decorator_name="tf.function",
1649        decorator_func=Function(
1650            inner_function,
1651            name,
1652            input_signature=input_signature,
1653            autograph=autograph,
1654            experimental_autograph_options=experimental_autograph_options,
1655            experimental_relax_shapes=experimental_relax_shapes,
1656
1657            # TODO(b/171825496): Update once `experimental_compile` is removed
1658            # entirely in favor of 'jit_compile'.
1659            jit_compile=deprecation.deprecated_argument_lookup(
1660                "jit_compile",
1661                jit_compile,
1662                "experimental_compile",
1663                experimental_compile),
1664            experimental_implements=experimental_implements,
1665            experimental_follow_type_hints=experimental_follow_type_hints))
1666
1667  # This code path is for the `foo = tf.function(foo, ...)` use case
1668  if func is not None:
1669    return decorated(func)
1670
1671  # This code path is for the
1672  #
1673  # @tf.function(...)
1674  # def foo(...):
1675  #    ...
1676  #
1677  # use case, which is equivalent to `foo = tf.function(...)(foo)`
1678  return decorated
1679