1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""API for defining graph functions with some additional eager semantics. 17 18def_function.function wraps the function concept in function.py ("defun") to 19allow initializing `tf.Variable`s with subgraphs of the function. For example: 20 21```python 22class M(tf.Module): 23 def __init__(self): 24 self.v_opinit = None 25 self.v_arginit = None 26 27 @tf.function 28 def __call__(self, x): 29 # Variables are only created on the first call to the function. This is a 30 # common pattern in layer libraries. 31 if self.v_opinit is None: 32 # self.v_opinit will outlive the function call, but `tf.ones` is traced as 33 # part of the function body before the `tf.Variable` object is 34 # created. This subgraph is easy to lift out of the function. 35 self.v_opinit = tf.Variable(tf.ones([])) 36 37 # If arguments feed into variable initialization, it can be very tricky to 38 # disentangle from the rest of the function. We don't attempt it. 39 self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.)) 40 return self.v_opinit + self.v_arginit + x 41``` 42 43These patterns with "defun" throw an error asking the user to put the variable's 44initializer in a lambda. With tf.function they work with eager semantics either 45by lifting the subgraph out of the function and using it to initialize the 46variable, or by initializing variables on the first call to the function (if 47they weren't already initialized by something else, e.g. a checkpoint API). The 48latter requires tf.conds, and is not well supported by TF-XLA, so we only do it 49when necessary. 50 51Since these patterns are relatively common in layer libraries, we expose the 52wrapper in this file as `tf.function`. The function concept in function.py is an 53internal implementation detail. 54 55In order to support these variable initialization patterns, tf.function defines 56a variable subtype (UnliftedInitializerVariable) which collects the input 57subgraph. This type of variable replaces the regular variable type on the first 58tf.function trace. To exclude initializers from the function body (the `tf.ones` 59ops above and associated assignment operations), tf.function traces a second 60time if it sees variables on the first call. 61""" 62 63from __future__ import absolute_import 64from __future__ import division 65from __future__ import print_function 66 67import functools 68import threading 69import weakref 70import six 71 72from google.protobuf import text_format as _text_format 73from google.protobuf.message import DecodeError 74from tensorflow.core.framework import attr_value_pb2 75from tensorflow.python.distribute.parallel_device import parallel_device 76from tensorflow.python.eager import context 77from tensorflow.python.eager import function as function_lib 78from tensorflow.python.eager import lift_to_graph 79from tensorflow.python.eager import monitoring 80from tensorflow.python.framework import errors 81from tensorflow.python.framework import func_graph as func_graph_module 82from tensorflow.python.framework import ops 83from tensorflow.python.ops import array_ops 84from tensorflow.python.ops import control_flow_ops 85from tensorflow.python.ops import control_flow_util 86from tensorflow.python.ops import math_ops 87from tensorflow.python.ops import random_ops 88from tensorflow.python.ops import resource_variable_ops 89from tensorflow.python.platform import tf_logging as logging 90from tensorflow.python.profiler import trace 91from tensorflow.python.training.tracking import base as trackable 92from tensorflow.python.types import core 93from tensorflow.python.util import deprecation 94from tensorflow.python.util import nest 95from tensorflow.python.util import object_identity 96from tensorflow.python.util import tf_decorator 97from tensorflow.python.util import traceback_utils 98from tensorflow.python.util.tf_export import tf_export 99 100FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10 101FREQUENT_TRACING_WARNING_THRESHOLD = 5 102FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2 103ALLOW_DYNAMIC_VARIABLE_CREATION = False 104 105_tf_function_counter = monitoring.Counter( 106 "/tensorflow/core/tf_function_counter", 107 "Counter for the number of tf.functions created when Eager execution is " 108 "enabled.", 109 # jit_compile is "0" or "1". 110 "jit_compile") 111 112 113class _FrequentTracingDetector(object): 114 """Class keeping track of how many recent calls triggered tracing.""" 115 116 __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"] 117 118 def __init__(self): 119 self._calls_per_tracings = [] 120 self._total_warning_count = 0 121 self._call_count = 0 122 123 def called_with_tracing(self, function_name, omit_warning): 124 """Updates the list of most recent calls' tracing information. 125 126 Warns the user when recent calls caused retracing too often. 127 128 Args: 129 function_name: the python function being traced. 130 omit_warning: If 'True', this call will not warn the user even if 131 retracing happens too often. 132 """ 133 self._call_count += 1 134 self._calls_per_tracings.append(1) 135 136 while self._calls_per_tracings: 137 if (self._call_count - self._calls_per_tracings[0] > 138 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY): 139 self._call_count -= self._calls_per_tracings.pop(0) 140 else: 141 break 142 143 if (omit_warning or self._total_warning_count >= 144 FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR): 145 return 146 if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD: 147 self._total_warning_count += 1 148 logging.warning( 149 "{} out of the last {} calls to {} triggered tf.function " 150 "retracing. Tracing is expensive and the excessive number of " 151 "tracings could be due to (1) creating @tf.function repeatedly in " 152 "a loop, (2) passing tensors with different shapes, (3) passing " 153 "Python objects instead of tensors. For (1), please define your " 154 "@tf.function outside of the loop. For (2), @tf.function has " 155 "experimental_relax_shapes=True option that relaxes argument " 156 "shapes that can avoid unnecessary retracing. For (3), please " 157 "refer to " 158 "https://www.tensorflow.org/guide/function#controlling_retracing" 159 " and https://www.tensorflow.org/api_docs/python/tf/function for " 160 " more details.".format( 161 len(self._calls_per_tracings), self._call_count, function_name)) 162 163 def called_without_tracing(self): 164 # We don't count tracing when users load a concrete function directly or 165 # call get_concrete_function, so the first call can be not a tracing call. 166 if not self._calls_per_tracings: 167 self._calls_per_tracings = [0] 168 self._calls_per_tracings[-1] += 1 169 self._call_count += 1 170 171 172class _FrequentTracingDetectorManager(object): 173 """Class for the management of all _FrequentTracingDetector objects.""" 174 175 __slots__ = ["_detectors", "_lock"] 176 177 def __init__(self): 178 self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock) 179 self._lock = threading.Lock() 180 181 def _get_detector(self, key): 182 if key not in self._detectors: 183 self._detectors[key] = _FrequentTracingDetector() 184 return self._detectors[key] 185 186 def called_without_tracing(self, key): 187 with self._lock: 188 detector = self._get_detector(key) 189 detector.called_without_tracing() 190 191 def called_with_tracing(self, key, function_name, omit_warning): 192 with self._lock: 193 detector = self._get_detector(key) 194 detector.called_with_tracing(function_name, omit_warning) 195 196 197_frequent_tracing_detector_manager = _FrequentTracingDetectorManager() 198 199 200class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): 201 """Variable which does not lift its initializer out of function context. 202 203 Instances of this variable, when created, build a graph which runs their 204 initializer inside a tf.cond(is_initialized) block. 205 206 This can only be created inside a defun called from (eventually) eager 207 mode. That is, non-function-building graphs are not supported. 208 """ 209 210 def __init__(self, 211 initial_value=None, 212 trainable=None, 213 caching_device=None, 214 name=None, 215 dtype=None, 216 constraint=None, 217 add_initializers_to=None, 218 lifted_initializer_graph=None, 219 synchronization=None, 220 aggregation=None, 221 shape=None, 222 **unused_kwargs): 223 """Creates a variable. 224 225 Args: 226 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 227 which is the initial value for the Variable. The initial value must have 228 a shape specified unless `validate_shape` is set to False. Can also be a 229 callable with no argument that returns the initial value when called. 230 (Note that initializer functions from init_ops.py must first be bound 231 to a shape before being used here.) 232 trainable: If `True`, GradientTapes automatically watch uses of this 233 Variable. 234 caching_device: Optional device string or function describing where the 235 Variable should be cached for reading. Defaults to the Variable's 236 device. If not `None`, caches on another device. Typical use is to 237 cache on the device where the Ops using the Variable reside, to 238 deduplicate copying through `Switch` and other conditional statements. 239 name: Optional name for the variable. Defaults to `'Variable'` and gets 240 uniquified automatically. 241 dtype: If set, initial_value will be converted to the given type. 242 If None, either the datatype will be kept (if initial_value is 243 a Tensor) or float32 will be used (if it is a Python object convertible 244 to a Tensor). 245 constraint: An optional projection function to be applied to the variable 246 after being updated by an `Optimizer` (e.g. used to implement norm 247 constraints or value constraints for layer weights). The function must 248 take as input the unprojected Tensor representing the value of the 249 variable and return the Tensor for the projected value 250 (which must have the same shape). Constraints are not safe to 251 use when doing asynchronous distributed training. 252 add_initializers_to: if not None and not in legacy graph mode, the 253 initializer tensor will be added to this map in addition to adding the 254 assignment to the function. 255 lifted_initializer_graph: FuncGraph to try to lift initializers to. 256 synchronization: Indicates when a distributed a variable will be 257 aggregated. Accepted values are constants defined in the class 258 `tf.VariableSynchronization`. By default the synchronization is set to 259 `AUTO` and the current `DistributionStrategy` chooses 260 when to synchronize. 261 aggregation: Indicates how a distributed variable will be aggregated. 262 Accepted values are constants defined in the class 263 `tf.VariableAggregation`. 264 shape: (optional) The shape of this variable. If None, the shape of 265 `initial_value` will be used. When setting this argument to 266 `tf.TensorShape(None)` (representing an unspecified shape), the variable 267 can be assigned with values of different shapes. 268 269 Raises: 270 ValueError: If the initial value is not specified, or does not have a 271 shape and `validate_shape` is `True`. 272 RuntimeError: If called outside of a function definition. 273 """ 274 with ops.init_scope(): 275 self._in_graph_mode = not context.executing_eagerly() 276 if not ops.inside_function(): 277 # If we've been init_scope()d out of the function definition nothing to do 278 # here; we can't really do the capturing or conditional logic. 279 resource_variable_ops.ResourceVariable.__init__( 280 self, initial_value=initial_value, trainable=trainable, 281 caching_device=caching_device, name=name, dtype=dtype, 282 constraint=constraint) 283 return 284 if initial_value is None: 285 raise ValueError("`initial_value` must be a Tensor or a Python " 286 "object convertible to a Tensor. Got None.") 287 init_from_fn = callable(initial_value) 288 289 if constraint is not None and not callable(constraint): 290 raise ValueError(f"`constraint` with type {type(constraint)} must be a " 291 "callable.") 292 293 with ops.name_scope(name, "Variable", [] 294 if init_from_fn else [initial_value]) as scope_name: 295 with ops.name_scope("Initializer"): 296 if init_from_fn: 297 initial_value = initial_value() 298 if isinstance(initial_value, trackable.CheckpointInitialValue): 299 self._maybe_initialize_trackable() 300 self._update_uid = initial_value.checkpoint_position.restore_uid 301 initial_value = initial_value.wrapped_value 302 303 initial_value = ops.convert_to_tensor(initial_value, 304 name="initial_value", dtype=dtype) 305 assert initial_value is not None 306 307 # Don't use `shape or initial_value.shape` since TensorShape has 308 # overridden `__bool__`. 309 if shape is None: 310 shape = initial_value.shape 311 312 # Use the constructor for UninitializedVariable to start. Outside the name 313 # scope so we don't double up the prefix. 314 super(UnliftedInitializerVariable, self).__init__( 315 trainable=trainable, 316 caching_device=caching_device, 317 name=name, 318 shape=shape, 319 dtype=initial_value.dtype, 320 constraint=constraint, 321 synchronization=synchronization, 322 aggregation=aggregation, 323 extra_handle_data=initial_value, 324 **unused_kwargs) 325 326 with ops.name_scope(scope_name): 327 if self._in_graph_mode: 328 with ops.init_scope(): 329 outer_graph = ops.get_default_graph() 330 func_graph = ops.get_default_graph() 331 function_placeholders = ( 332 func_graph.inputs + func_graph.internal_captures) 333 placeholder_ops = set( 334 [tensor.op for tensor in function_placeholders]) 335 lifted_initializer = lift_to_graph.lift_to_graph( 336 [initial_value], outer_graph, 337 disallowed_placeholders=placeholder_ops)[initial_value] 338 with ops.init_scope(): 339 self._initial_value = lifted_initializer 340 with ops.name_scope("IsInitialized"): 341 self._is_initialized_op = ( 342 resource_variable_ops.var_is_initialized_op(self._handle)) 343 if initial_value is not None: 344 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 345 self._initializer_op = resource_variable_ops.assign_variable_op( 346 self._handle, lifted_initializer, name=n) 347 elif context.executing_eagerly(): 348 # In this case, both current scope and init scope are eager. 349 # Assign_variable_op will be executed immediately. So we don't need to 350 # add it to "add_initializers_to" to lift it out. 351 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 352 resource_variable_ops.assign_variable_op( 353 self._handle, initial_value, name=n) 354 else: 355 # Init scope is eager but current scope is graph. We will lift out this 356 # variable by addint it into "add_initializers_to". 357 if add_initializers_to is not None: 358 add_initializers_to.append((self, initial_value)) 359 360 def assign_fn(): 361 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 362 resource_variable_ops.assign_variable_op( 363 self._handle, 364 initial_value, 365 name=n) 366 # Returning values to keep tf.cond happy. 367 return ops.convert_to_tensor(1) 368 def not_assign_fn(): 369 return ops.convert_to_tensor(0) 370 # Note: this cond is always guaranteed to run because we're inside a 371 # defun which will insert automatic control dependencies. It will only 372 # execute assign_fn if lifting failed. 373 graph = ops.get_default_graph() 374 375 # Capture the handle ahead of time in order to avoid querying the shape 376 # of the handle which helps async execution performance 377 graph.capture(self._handle, shape=()) 378 control_flow_ops.cond( 379 resource_variable_ops.var_is_initialized_op(self._handle), 380 not_assign_fn, assign_fn) 381 382 383RUN_FUNCTIONS_EAGERLY = False 384 385 386@deprecation.deprecated( 387 None, 388 "Use `tf.config.run_functions_eagerly` instead of the experimental " 389 "version.") 390@tf_export("config.experimental_run_functions_eagerly") 391def experimental_run_functions_eagerly(run_eagerly): 392 """Enables / disables eager execution of `tf.function`s. 393 394 Calling `tf.config.experimental_run_functions_eagerly(True)` will make all 395 invocations of `tf.function` run eagerly instead of running as a traced graph 396 function. 397 398 See `tf.config.run_functions_eagerly` for an example. 399 400 Note: This flag has no effect on functions passed into tf.data transformations 401 as arguments. tf.data functions are never executed eagerly and are always 402 executed as a compiled Tensorflow Graph. 403 404 Args: 405 run_eagerly: Boolean. Whether to run functions eagerly. 406 """ 407 return run_functions_eagerly(run_eagerly) 408 409 410@tf_export("config.run_functions_eagerly") 411def run_functions_eagerly(run_eagerly): 412 """Enables / disables eager execution of `tf.function`s. 413 414 Calling `tf.config.run_functions_eagerly(True)` will make all 415 invocations of `tf.function` run eagerly instead of running as a traced graph 416 function. 417 418 This can be useful for debugging. 419 420 >>> def my_func(a): 421 ... print("Python side effect") 422 ... return a + a 423 >>> a_fn = tf.function(my_func) 424 425 >>> # A side effect the first time the function is traced 426 >>> a_fn(tf.constant(1)) 427 Python side effect 428 <tf.Tensor: shape=(), dtype=int32, numpy=2> 429 430 >>> # No further side effect, as the traced function is called 431 >>> a_fn(tf.constant(2)) 432 <tf.Tensor: shape=(), dtype=int32, numpy=4> 433 434 >>> # Now, switch to eager running 435 >>> tf.config.run_functions_eagerly(True) 436 >>> # Side effect, as the function is called directly 437 >>> a_fn(tf.constant(2)) 438 Python side effect 439 <tf.Tensor: shape=(), dtype=int32, numpy=4> 440 441 >>> # Turn this back off 442 >>> tf.config.run_functions_eagerly(False) 443 444 Note: This flag has no effect on functions passed into tf.data transformations 445 as arguments. tf.data functions are never executed eagerly and are always 446 executed as a compiled Tensorflow Graph. 447 448 Args: 449 run_eagerly: Boolean. Whether to run functions eagerly. 450 """ 451 global RUN_FUNCTIONS_EAGERLY 452 RUN_FUNCTIONS_EAGERLY = bool(run_eagerly) 453 454 455@deprecation.deprecated( 456 None, 457 "Use tf.config.functions_run_eagerly instead of the experimental version.") 458@tf_export("config.experimental_functions_run_eagerly") 459def experimental_functions_run_eagerly(): 460 """Returns the value of the `experimental_run_functions_eagerly` setting.""" 461 return functions_run_eagerly() 462 463 464@tf_export("config.functions_run_eagerly") 465def functions_run_eagerly(): 466 """Returns the value of the `run_functions_eagerly` setting.""" 467 return RUN_FUNCTIONS_EAGERLY 468 469 470def _evaluate_var_is_initialized(variables): 471 """Compute booleans indicating whether each variable is initialized.""" 472 with ops.init_scope(): 473 var_is_initialized = [] 474 for v in variables: 475 var_is_initialized.append( 476 resource_variable_ops.var_is_initialized_op(v.handle)) 477 try: 478 # Stack all the var_is_initialized values into one tensor and interpret 479 # the numpy value. This will reduce the number of RPCs between client and 480 # worker in the remote case. 481 return array_ops.stack(var_is_initialized).numpy() 482 except errors.UnimplementedError: 483 # Some devices do not support implicit copy-off to host. Fall back to 484 # variable-by-variable processing. 485 for index, v in enumerate(variables): 486 try: 487 numpy_value = var_is_initialized[index].numpy() 488 except errors.UnimplementedError: 489 # This is a variable on a parallel device; we'll extract its value on 490 # each replica and assert that they're identical. 491 components = parallel_device.unpack(var_is_initialized[index]) 492 with ops.device(None): 493 components = array_ops.stack(components) 494 all_initialized = math_ops.reduce_all(components).numpy() 495 any_initialized = math_ops.reduce_any(components).numpy() 496 if all_initialized != any_initialized: 497 raise NotImplementedError( 498 f"Some but not all components of a parallel variable {v!r} " 499 "were initialized between their creation in a tf.function and " 500 "the function's trace having completed. This is not " 501 "supported; consider initializing either all or none of the " 502 "components, or moving initialization out of the function.") 503 numpy_value = all_initialized 504 var_is_initialized[index] = numpy_value 505 return var_is_initialized 506 507 508class FunctionDeleter(object): 509 510 __slots__ = ["func_graph"] 511 512 def __init__(self, func_graph): 513 self.func_graph = func_graph 514 515 def __del__(self): 516 try: 517 func_graph_module.dismantle_func_graph(self.func_graph) 518 except: # pylint: disable=bare-except 519 # Note: bare except here because this can be noisy at shutdown time. 520 pass 521 522 523class OptionalXlaContext(object): 524 """Wrapper for XLA context optionally applied under a context manager.""" 525 526 def __init__(self, is_compiled): 527 wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \ 528 ops.get_default_graph()) 529 self.xla_context = control_flow_ops.XLAControlFlowContext() \ 530 if wrap else None 531 532 def __enter__(self): 533 if self.xla_context: 534 self.xla_context.Enter() 535 536 def __exit__(self, t, value, traceback): 537 if self.xla_context: 538 self.xla_context.Exit() 539 540 541# TODO(mdan): Consider expose this type for instance type checking. 542@tf_export("__internal__.function.Function", v1=[]) 543class Function(core.GenericFunction): 544 """A `tf.types.experimental.GenericFunction` created by `tf.function`. 545 546 Currently, individual methods/attributes under this class are not guaranteed 547 by the TF API contract, and are subject to future changes. 548 """ 549 550 def __init__(self, 551 python_function, 552 name, 553 input_signature=None, 554 autograph=True, 555 jit_compile=None, 556 experimental_implements=None, 557 experimental_autograph_options=None, 558 experimental_relax_shapes=False, 559 experimental_follow_type_hints=None): 560 """Initializes a `Function`. 561 562 Args: 563 python_function: the function to be wrapped. 564 name: the name given to it. 565 input_signature: See the documentation for `tf.function`. 566 autograph: See the documentation for `tf.function`. 567 jit_compile: See the documentation for `tf.function`. 568 experimental_implements: See the documentation for `tf.function`. 569 experimental_autograph_options: See the documentation for `tf.function`. 570 experimental_relax_shapes: See the documentation for `tf.function`. 571 experimental_follow_type_hints: See the documentation for `tf.function`. 572 573 Raises: 574 ValueError: if `input_signature` is not None and the `python_function`'s 575 argspec has keyword arguments. 576 """ 577 self._lock = threading.Lock() 578 self._python_function = python_function 579 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 580 python_function, 581 input_signature, 582 jit_compile=jit_compile, 583 experimental_follow_type_hints=experimental_follow_type_hints, 584 ) 585 self._implements = experimental_implements 586 # If `True`, the function uses the rendezvous of the parent. This is only 587 # needed to support code where raw send/recv operations are inserted and 588 # when functions are run in graph mode where they may not be inlined. 589 self._shared_rendezvous = None 590 self._autograph = autograph 591 self._experimental_autograph_options = experimental_autograph_options 592 self._experimental_relax_shapes = experimental_relax_shapes 593 self._jit_compile = jit_compile 594 if experimental_follow_type_hints is None: 595 experimental_follow_type_hints = False 596 self._experimental_follow_type_hints = experimental_follow_type_hints 597 self._created_variables = None # GUARDED_BY(self._lock) 598 self._stateful_fn = None # GUARDED_BY(self._lock) 599 self._stateless_fn = None # GUARDED_BY(self._lock) 600 self._descriptor_cache = weakref.WeakKeyDictionary() 601 self._name = name 602 self._input_signature = input_signature 603 self._key_for_call_stats = self._get_key_for_call_stats() 604 self._omit_frequent_tracing_warning = False 605 ops._tf_function_api_guage.get_cell().set(True) # pylint: disable=protected-access 606 607 def __getstate__(self): 608 """Custom pickling, to omit unpickleable objects.""" 609 result = self.__dict__.copy() 610 del result["_lock"] 611 del result["_descriptor_cache"] 612 del result["_key_for_call_stats"] 613 return result 614 615 def __setstate__(self, state): 616 """Restore from pickled state.""" 617 self.__dict__ = state 618 self._lock = threading.Lock() 619 self._descriptor_cache = weakref.WeakKeyDictionary() 620 self._key_for_call_stats = self._get_key_for_call_stats() 621 622 def _get_key_for_call_stats(self): 623 """Returns key instance to track call stats and retracings. 624 625 The key instance a best-effort to preserve global consistency. 626 """ 627 target_function = self._python_function 628 # `__wrapped__` is a conventional Python attribute that a higher-order 629 # function keeps its original function's instance. We also directly use 630 # this attribute for dealing with a class method. See 631 # `bound_method_wrapper` in `function.py`. If we don't use `__wrapped__`, 632 # all class methods will return the same `bound_method_wrapper` instance 633 # from this function. 634 while hasattr(target_function, "__wrapped__"): 635 target_function = target_function.__wrapped__ 636 637 if hasattr(target_function, "__func__"): 638 target_function = target_function.__func__ 639 640 if hasattr(target_function, "__code__"): 641 return target_function.__code__ 642 643 return self._python_function 644 645 def _defun_with_scope(self, scope): 646 """Creates a defun wrapped inside a variable creator scope.""" 647 648 weak_wrapped_fn = None 649 compile_with_xla = self._jit_compile 650 651 def wrapped_fn(*args, **kwds): 652 """Wraps `self._python_function` in a variable creator scope.""" 653 # We register a variable creator with reduced priority. If an outer 654 # variable creator is just modifying keyword arguments to the variable 655 # constructor, this will work harmoniously. Since the `scope` registered 656 # here actually creates the variable, it taking priority would otherwise 657 # ignore the outer creator. 658 # 659 # If an outer variable creator calls the variable constructor manually, 660 # for example creating a MirroredVariable, then they won't call our 661 # creator. This means we won't be able to trace the initialization graph, 662 # and so variable initializers can't depend on function arguments. This is 663 # better than the alternative, tracing the initialization graph but giving 664 # the user a variable type they didn't want. 665 default_graph = ops.get_default_graph() 666 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access 667 # __wrapped__ allows AutoGraph to swap in a converted function. We give 668 # the function a weak reference to itself to avoid a reference cycle. 669 with OptionalXlaContext(compile_with_xla): 670 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 671 return out 672 673 weak_wrapped_fn = weakref.ref(wrapped_fn) 674 675 return self._defun(tf_decorator.make_decorator( 676 self._python_function, 677 wrapped_fn)) 678 679 def _create_implements_attribute(self): 680 """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME.""" 681 attributes = {} 682 if isinstance(self._implements, str): 683 # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a 684 # NameAttrList. This is used when apart from the function name being 685 # implemented, a list of attributes is also being specified. 686 # The attributes are specified as key-value pairs in the NameAttrList 687 # of the corresponding AttrValue. The function name will be in the 688 # 'name' field of the NameAttrList. Else, it is just a string 689 # corresponding to the function name. 690 try: 691 implements_attr = six.ensure_text(self._implements, "utf-8") 692 attr_value = attr_value_pb2.AttrValue() 693 nameattrlist = attr_value_pb2.NameAttrList() 694 _text_format.Merge(implements_attr, nameattrlist) 695 attr_value.func.CopyFrom(nameattrlist) 696 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value 697 except (_text_format.ParseError, DecodeError): 698 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements 699 return attributes 700 701 def _defun(self, fn): 702 """Returns a defun generated from the input function.""" 703 attributes = {} 704 705 if self._implements is not None: 706 attributes = self._create_implements_attribute() 707 708 share = self._shared_rendezvous 709 if share is not None: 710 attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share 711 712 if self._jit_compile is not None: 713 attributes.update(_XlaMustCompile=bool(self._jit_compile)) 714 if self._jit_compile: 715 attributes.update(_noinline=True) 716 if not attributes: 717 attributes = None 718 return function_lib.defun_with_attributes( 719 fn, 720 input_signature=self.input_signature, 721 attributes=attributes, 722 autograph=self._autograph, 723 jit_compile=self._jit_compile, 724 experimental_autograph_options=self._experimental_autograph_options, 725 experimental_follow_type_hints=self._experimental_follow_type_hints, 726 experimental_relax_shapes=self._experimental_relax_shapes) 727 728 def _initialize(self, args, kwds, add_initializers_to=None): 729 """Initializes, on the first call. 730 731 Creates two `Function`s, one that will allow creation of variables 732 and one that won't. 733 734 Additionally runs a trace for the `Function` that allows creation 735 of variables. 736 737 Args: 738 args: Arguments to the underlying python callable. 739 kwds: Keyword arguments to the python callable. 740 add_initializers_to: Where to collect variable initializers, if not None. 741 """ 742 743 if self._input_signature is not None: 744 arglen = len(self._input_signature) 745 arg_names_len = len(self.function_spec.arg_names) 746 default_arg_len = len(self.function_spec.fullargspec.defaults or ()) 747 required_arg_len = arg_names_len - default_arg_len 748 # The input signature must cover all required function arguments. 749 if arglen < required_arg_len: 750 missing_tensor_specs = self.function_spec.arg_names[ 751 arglen:required_arg_len] 752 raise TypeError( 753 f"The decorated function {self._name} has {required_arg_len} " 754 f"required argument(s), but tf.function was only passed an " 755 f"input_signature of length {arglen}. This covers {arglen} " 756 f"required argument(s): {self.function_spec.arg_names[:arglen]}, " 757 f"but TensorSpecs are still required for the remaining " 758 f"{len(missing_tensor_specs)} argument(s): {missing_tensor_specs}.") 759 760 created_variables = [] 761 lifted_initializer_graph = func_graph_module.FuncGraph("initializer") 762 763 def variable_capturing_scope(unused_next_creator, **kwds): 764 """Creates UnliftedInitializerVariables and saves references to them.""" 765 v = UnliftedInitializerVariable( 766 add_initializers_to=add_initializers_to, 767 lifted_initializer_graph=lifted_initializer_graph, **kwds) 768 created_variables.append(weakref.ref(v)) 769 return v 770 771 self._created_variables = created_variables 772 self._stateful_fn = self._defun_with_scope(variable_capturing_scope) 773 self._stateful_fn._name = self._name # pylint: disable=protected-access 774 # Force the definition of the function for these arguments 775 self._lifted_initializer_graph = lifted_initializer_graph 776 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) 777 self._concrete_stateful_fn = ( 778 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access 779 *args, **kwds)) 780 781 def invalid_creator_scope(*unused_args, **unused_kwds): 782 """Disables variable creation.""" 783 raise ValueError( 784 "tf.function-decorated function tried to create " 785 "variables on non-first call.") 786 787 self._stateless_fn = self._defun_with_scope(invalid_creator_scope) 788 self._stateless_fn._name = self._name # pylint: disable=protected-access 789 790 def _clone(self, python_function): 791 """Clone the function with different python function.""" 792 f = Function( 793 python_function=(self._python_function 794 if python_function is None else python_function), 795 name=self._name, 796 input_signature=self._input_signature, 797 autograph=self._autograph, 798 jit_compile=self._jit_compile, 799 experimental_implements=self._implements, 800 experimental_autograph_options=self._experimental_autograph_options, 801 experimental_relax_shapes=self._experimental_relax_shapes, 802 experimental_follow_type_hints=self._experimental_follow_type_hints) 803 804 if self._shared_rendezvous: 805 f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access 806 807 return f 808 809 def _decorate(self, decorator): 810 """Allows the captured Python function to be decorated in place. 811 812 This method is only safe to call when the Function has not been called by a 813 user. It makes sense to use this method to push a decorator into the 814 function rather than wrapping the function in the decorator. 815 816 We use this in tf.Module to allow user annotated `tf.functions` to remain as 817 `Function` objects but still automatically enter the Module name_scope 818 when they are evaluated like all other methods. 819 820 Args: 821 decorator: A callable accepting a single argument which is the function 822 to decorate and returning a callable result. 823 824 Raises: 825 ValueError: If the function has been called a ValueError is raised. 826 """ 827 if self._stateful_fn is not None or self._stateless_fn is not None: 828 raise ValueError( 829 "Functions cannot be decorated after they have been traced.") 830 831 self._python_function = decorator(self._python_function) 832 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 833 self._python_function, self.input_signature) 834 835 # TODO: Remove this private method after updating all its uses 836 # A good moment to do this could be when the experimental label is removed 837 def _get_tracing_count(self): 838 return self.experimental_get_tracing_count() 839 840 def experimental_get_tracing_count(self): 841 """Returns the number of times the function has been traced. 842 843 For more information on when a function is traced and when it is 844 traced multiple times see https://www.tensorflow.org/guide/function. 845 Example: 846 847 >>> @tf.function 848 ... def double(a): 849 ... return a + a 850 >>> double(tf.constant(1)) 851 >>> double(tf.constant(2)) 852 >>> double.experimental_get_tracing_count() 853 1 854 >>> double(tf.constant("a")) 855 >>> double.experimental_get_tracing_count() 856 2 857 858 859 The first time experimental_get_tracing_count is called 860 it returns 1, as the function is traced the first 861 time it is called, and the second time the same graph is used 862 since we're calling it with a parameter of the same type. 863 864 The second time experimental_get_tracing_count is called 865 it returns 2, as we called double with a 866 different argument type, and so it was traced again. 867 868 """ 869 result = self._stateless_fn.tracing_count if self._stateless_fn else 0 870 result += self._stateful_fn.tracing_count if self._stateful_fn else 0 871 return result 872 873 @property 874 def _run_functions_eagerly(self): 875 return RUN_FUNCTIONS_EAGERLY 876 877 @traceback_utils.filter_traceback 878 def __call__(self, *args, **kwds): 879 # Implements GenericFunction.__call__. 880 if self._run_functions_eagerly: 881 with trace.Trace(self._name, tf_function_call="eager"): 882 return self._python_function(*args, **kwds) 883 884 # Only count the statistics the first time, before initialization took 885 # place. 886 if self._created_variables is None: 887 compiled = bool(self._jit_compile and 888 not control_flow_util.GraphOrParentsInXlaContext( 889 ops.get_default_graph())) 890 # For nested functions, increment the counter only when a function with 891 # jit_compile=True is called within a function with jit_compile=False. We 892 # count this special case to correctly record that both jit_compile=True 893 # and jit_compile=False is being used for parts of the outer function. 894 if ops.executing_eagerly_outside_functions() and ( 895 context.executing_eagerly() or compiled): 896 # Labels must be strings in Python, so we convert 'compiled' to a string 897 _tf_function_counter.get_cell(str(int(compiled))).increase_by(1) 898 899 tracing_count = self.experimental_get_tracing_count() 900 with trace.Trace(self._name) as tm: 901 # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation. 902 compiler = "xla" if self._jit_compile else "nonXla" 903 904 with OptionalXlaContext(self._jit_compile): 905 result = self._call(*args, **kwds) 906 907 new_tracing_count = self.experimental_get_tracing_count() 908 without_tracing = (tracing_count == new_tracing_count) 909 execution_mode = "notTraced" if without_tracing else "traced" 910 tm.set_metadata(tf_function_call=execution_mode + "-" + compiler, 911 tracing_count=new_tracing_count) 912 913 if context.executing_eagerly(): 914 if without_tracing: 915 _frequent_tracing_detector_manager.called_without_tracing( 916 self._key_for_call_stats) 917 else: 918 _frequent_tracing_detector_manager.called_with_tracing( 919 self._key_for_call_stats, self._python_function, 920 self._omit_frequent_tracing_warning) 921 922 return result 923 924 def _call(self, *args, **kwds): 925 """Calls the graph function.""" 926 self._lock.acquire() 927 if ALLOW_DYNAMIC_VARIABLE_CREATION: 928 condition = self._created_variables and self._stateful_fn is None 929 else: 930 condition = self._created_variables 931 if condition: 932 # Release the lock early so that multiple threads can perform the call 933 # in parallel. 934 self._lock.release() 935 # In this case we have created variables on the first call, so we run the 936 # defunned version which is guaranteed to never create variables. 937 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 938 elif self._stateful_fn is not None: 939 # Release the lock early so that multiple threads can perform the call 940 # in parallel. 941 self._lock.release() 942 # In this case we have not created variables on the first call. So we can 943 # run the first trace but we should fail if variables are created. 944 results = self._stateful_fn(*args, **kwds) 945 if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION: 946 raise ValueError("Creating variables on a non-first call to a function" 947 " decorated with tf.function.") 948 return results 949 950 try: 951 # This is the first call of __call__, so we have to initialize. 952 initializers = [] 953 self._initialize(args, kwds, add_initializers_to=initializers) 954 finally: 955 # At this point we know that the initialization is complete (or less 956 # interestingly an exception was raised) so we no longer need a lock. 957 self._lock.release() 958 959 if self._created_variables: 960 try: 961 # Attempt to initialize variables eagerly and without conds by lifting 962 # out initialization graphs. This is the only initialization strategy 963 # compatible with XLA at the moment. 964 self._initialize_uninitialized_variables(initializers) 965 except lift_to_graph.UnliftableError: 966 pass # Fall through to cond-based initialization. 967 else: 968 # Lifting succeeded, so variables are initialized and we can run the 969 # stateless function. 970 return self._stateless_fn(*args, **kwds) 971 else: 972 _, _, _, filtered_flat_args = \ 973 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 974 *args, **kwds) 975 # If we did not create any variables the trace we have is good enough. 976 return self._concrete_stateful_fn._call_flat( 977 filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access 978 979 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args): 980 """Conditionally runs initialization if it's needed.""" 981 condition = True 982 for wr in self._created_variables: 983 variable = wr() 984 if variable is None: 985 raise ValueError( 986 "A tf.Variable created inside your tf.function has been" 987 " garbage-collected. Your code needs to keep Python references" 988 " to variables created inside `tf.function`s.\n" 989 "\n" 990 "A common way to raise this error is to create and return a" 991 " variable only referenced inside your function:\n" 992 "\n" 993 "@tf.function\n" 994 "def f():\n" 995 " v = tf.Variable(1.0)\n" 996 " return v\n" 997 "\n" 998 "v = f() # Crashes with this error message!\n" 999 "\n" 1000 "The reason this crashes is that @tf.function annotated" 1001 " function returns a **`tf.Tensor`** with the **value** of the" 1002 " variable when the function is called rather than the" 1003 " variable instance itself. As such there is no code holding a" 1004 " reference to the `v` created inside the function and Python" 1005 " garbage collects it.\n" 1006 "\n" 1007 "The simplest way to fix this issue is to create variables" 1008 " outside the function and capture them:\n" 1009 "\n" 1010 "v = tf.Variable(1.0)\n" 1011 "\n" 1012 "@tf.function\n" 1013 "def f():\n" 1014 " return v\n" 1015 "\n" 1016 "f() # <tf.Tensor: numpy=1.>\n" 1017 "v.assign_add(1.)\n" 1018 "f() # <tf.Tensor: numpy=2.>") 1019 condition = math_ops.logical_and( 1020 condition, resource_variable_ops.var_is_initialized_op( 1021 variable.handle)) 1022 # We want to call stateless_fn if possible because it avoids recomputing 1023 # potentially expensive initializers. 1024 return control_flow_ops.cond( 1025 condition, 1026 lambda: self._stateless_fn(*inner_args, **inner_kwds), 1027 functools.partial( 1028 self._concrete_stateful_fn._call_flat, # pylint: disable=protected-access 1029 inner_filtered_flat_args, 1030 captured_inputs=self._concrete_stateful_fn.captured_inputs)) 1031 1032 # We've created variables and are unable to lift the initialization graphs, 1033 # so we fall back to initializing with conds while running the function. 1034 canon_args, canon_kwds, _, filtered_flat_args = \ 1035 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 1036 *args, **kwds) 1037 return function_lib.defun(fn_with_cond)(canon_args, canon_kwds, 1038 filtered_flat_args) 1039 1040 def experimental_get_compiler_ir(self, *args, **kwargs): 1041 # Implements GenericFunction.experimental_get_compiler_ir 1042 context.ensure_initialized() 1043 if not self._jit_compile: 1044 raise ValueError("Compiler IR can only be returned for functions marked " 1045 "with 'jit_compile=True'") 1046 1047 concrete_fn = self.get_concrete_function(*args, **kwargs) 1048 fn_name = concrete_fn.name 1049 1050 # pylint: disable=protected-access 1051 _, _, _, filtered_flat_args = \ 1052 concrete_fn._function_spec.canonicalize_function_inputs( 1053 *args, **kwargs) 1054 1055 def compiler_ir_generator(stage="hlo", device_name=None): 1056 # TODO(cheshire): This is a hack to get the current "preferred" device, 1057 # there is no current API to get it otherwise. 1058 if device_name is None: 1059 device_name = random_ops.random_normal([]).device 1060 res_bytes = context.context().get_compiler_ir( 1061 device_name=device_name, 1062 stage=stage, 1063 function_name=fn_name, 1064 args=list(filtered_flat_args) + concrete_fn.captured_inputs) 1065 if stage in ("hlo_serialized", "optimized_hlo_serialized", 1066 "optimized_hlo_proto_serialized"): 1067 return res_bytes 1068 else: 1069 return res_bytes.decode("utf-8") 1070 1071 return compiler_ir_generator 1072 1073 @property 1074 def python_function(self): 1075 """The python function wrapped in this tf.function.""" 1076 return self._python_function 1077 1078 @property 1079 def input_signature(self): 1080 return self._function_spec.input_signature 1081 1082 @property 1083 def function_spec(self): 1084 return self._function_spec 1085 1086 def pretty_printed_concrete_signatures(self, verbose=True): 1087 joiner = "\n\n" if verbose else "\n" 1088 return joiner.join([ 1089 c.pretty_printed_signature(verbose=verbose) 1090 for c in self._list_all_concrete_functions() 1091 ]) 1092 1093 def _initialize_uninitialized_variables(self, initializers): 1094 """Make and call a `ConcreteFunction` which initializes variables.""" 1095 1096 if not initializers: 1097 return 1098 1099 var_is_initialized = _evaluate_var_is_initialized( 1100 [v for v, _ in initializers]) 1101 1102 # Note: using defun here avoids an infinite recursion. 1103 # Most of the code in this function runs eagerly with init_scope, where 1104 # autograph is not necessary. 1105 @function_lib.defun(autograph=False) 1106 def initialize_variables(): 1107 op_map = object_identity.ObjectIdentityDictionary() 1108 1109 inits = [] 1110 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1111 with ops.init_scope(): 1112 if is_initialized: 1113 continue 1114 inits.append(init) 1115 1116 if inits: 1117 op_map = lift_to_graph.lift_to_graph( 1118 inits, ops.get_default_graph(), op_map=op_map) 1119 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1120 with ops.init_scope(): 1121 if is_initialized: 1122 continue 1123 v.assign(op_map[init], read_value=False) 1124 1125 with ops.init_scope(): 1126 return initialize_variables.get_concrete_function()() 1127 1128 def get_initialization_function(self, *args, **kwargs): 1129 """Returns a `ConcreteFunction` which initializes this function's variables. 1130 1131 Requires that this function hasn't been accessed yet through either calling 1132 it or calling get_concrete_function. Fails if we cannot build an initializer 1133 function which does not depend on the concrete values of the inputs to this 1134 function. 1135 1136 Note that running this function will overwrite any values currently assigned 1137 to variables, for example restores from a checkpoint. 1138 1139 Args: 1140 *args: arguments to the underlying python callable. 1141 **kwargs: keyword arguments to the python callable. 1142 1143 Returns: 1144 A `ConcreteFunction` object which initializes the variables of this 1145 function. 1146 1147 Raises: 1148 RuntimeError: if called after the variables have been initialized. 1149 """ 1150 with self._lock: 1151 if self._stateful_fn is not None: 1152 raise RuntimeError( 1153 "get_initialization_function cannot be called after the function " 1154 "has been used") 1155 # Here we trace the function, collect the initializers, and attempt to 1156 # extract them and run them eagerly. Fail only if we cannot do so. 1157 initializers = [] 1158 self._initialize(args, kwargs, add_initializers_to=initializers) 1159 1160 # Note: using defun here avoids an infinite recursion. 1161 @function_lib.defun 1162 def initialize_variables(): 1163 for v, init in initializers: 1164 v.assign( 1165 lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init], 1166 read_value=False) 1167 1168 return initialize_variables.get_concrete_function() 1169 1170 def _list_all_concrete_functions(self): 1171 """Returns all concrete functions.""" 1172 if self.input_signature is not None: 1173 self.get_concrete_function() 1174 concrete_functions = [] 1175 # pylint: disable=protected-access 1176 if self._stateful_fn: 1177 concrete_functions.extend( 1178 self._stateful_fn._function_cache.all_values()) 1179 if self._stateless_fn: 1180 concrete_functions.extend( 1181 self._stateless_fn._function_cache.all_values()) 1182 # pylint: enable=protected-access 1183 return concrete_functions 1184 1185 def _list_all_concrete_functions_for_serialization(self): 1186 """Returns all concrete functions for serialization. 1187 1188 Returns: 1189 A list of instances of `ConcreteFunction`. 1190 """ 1191 concrete_functions = self._list_all_concrete_functions() 1192 seen_signatures = [] 1193 for concrete_function in concrete_functions: 1194 signature = concrete_function.structured_input_signature 1195 flattened = nest.flatten(signature) 1196 if any( 1197 isinstance(arg, func_graph_module.UnknownArgument) 1198 for arg in flattened): 1199 logging.info("Unsupported signature for serialization: %s.", signature) 1200 continue 1201 equal_to_signature = functools.partial( 1202 function_lib.is_same_structure, signature, check_values=True) 1203 if not any(equal_to_signature(s) for s in seen_signatures): 1204 seen_signatures.append(signature) 1205 1206 # Re-create concrete functions for these signatures. Re-creating ensures 1207 # that if the cache key has changed, the function will be traced again. 1208 concrete_functions = [] 1209 for args, kwargs in seen_signatures: 1210 concrete_functions.append(self.get_concrete_function(*args, **kwargs)) 1211 return concrete_functions 1212 1213 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 1214 """Returns a `ConcreteFunction` specialized to inputs and execution context. 1215 1216 Unlike `get_concrete_function(...)`, the graph will be deleted when the 1217 returned function is deleted. It's useful to avoid creating a reference 1218 cycle when you know for sure that the graph will be no longer used without 1219 the returned function. 1220 1221 Args: 1222 *args: inputs to specialize on. 1223 **kwargs: inputs to specialize on. 1224 1225 Returns: 1226 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 1227 1228 Raises: 1229 ValueError: if this object has not yet been called on concrete values. 1230 """ 1231 with self._lock: 1232 if self._stateful_fn is None: 1233 initializers = [] 1234 self._initialize(args, kwargs, add_initializers_to=initializers) 1235 self._initialize_uninitialized_variables(initializers) 1236 1237 if self._created_variables: 1238 # In this case we have created variables on the first call, so we run the 1239 # defunned version which is guaranteed to never create variables. 1240 return self._stateless_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1241 *args, **kwargs) 1242 elif self._stateful_fn is not None: 1243 # In this case we have not created variables on the first call. So we can 1244 # run the first trace but we should fail if variables are created. 1245 concrete = self._stateful_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1246 *args, **kwargs) 1247 if self._created_variables: 1248 raise ValueError("Creating variables on a non-first call to a function" 1249 " decorated with tf.function.") 1250 return concrete 1251 1252 def get_concrete_function(self, *args, **kwargs): 1253 # Implements GenericFunction.get_concrete_function. 1254 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1255 concrete._garbage_collector.release() # pylint: disable=protected-access 1256 return concrete 1257 1258 def __get__(self, instance, owner): 1259 """Makes it possible to defun instance methods.""" 1260 del owner 1261 # `instance` here is the instance that this `Function` was accessed through 1262 # e.g., for 1263 # 1264 # class Foo(object): 1265 # 1266 # @function.defun 1267 # def bar(self): 1268 # ... 1269 # 1270 # foo = Foo() 1271 # foo.bar() # `foo.bar` is a `Function` instance 1272 # 1273 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 1274 # new instance of `Function` here to allow different instances each 1275 # to create variables once, thereby allowing methods to be decorated with 1276 # tf.function. Keeps a cache to avoid retracing the function every time the 1277 # descriptor is accessed. 1278 if instance not in self._descriptor_cache: 1279 if instance is None: 1280 return self 1281 self._descriptor_cache[instance] = ( 1282 function_lib.class_method_to_instance_method(self, instance)) 1283 return self._descriptor_cache[instance] 1284 1285 1286@tf_export("function") 1287@deprecation.deprecated_args(None, 1288 "experimental_compile is deprecated, use " 1289 "jit_compile instead", "experimental_compile") 1290def function(func=None, 1291 input_signature=None, 1292 autograph=True, 1293 jit_compile=None, 1294 experimental_implements=None, 1295 experimental_autograph_options=None, 1296 experimental_relax_shapes=False, 1297 experimental_compile=None, 1298 experimental_follow_type_hints=None) -> core.GenericFunction: 1299 """Compiles a function into a callable TensorFlow graph. 1300 1301 `tf.function` constructs a `tf.types.experimental.GenericFunction` that 1302 executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the 1303 TensorFlow operations in `func`. More information on the topic can be found 1304 in [Introduction to Graphs and tf.function] 1305 (https://www.tensorflow.org/guide/intro_to_graphs). 1306 1307 See [Better Performance with tf.function] 1308 (https://www.tensorflow.org/guide/function) for tips on performance and 1309 known limitations. 1310 1311 Example usage: 1312 1313 >>> @tf.function 1314 ... def f(x, y): 1315 ... return x ** 2 + y 1316 >>> x = tf.constant([2, 3]) 1317 >>> y = tf.constant([3, -2]) 1318 >>> f(x, y) 1319 <tf.Tensor: ... numpy=array([7, 7], ...)> 1320 1321 The trace-compilation allows non-TensorFlow operations to execute, but under 1322 special conditions. In general, only TensorFlow operations are guaranteed to 1323 run and create fresh results whenever the `GenericFunction` is called. 1324 1325 ## Features 1326 1327 `func` may use data-dependent control flow, including `if`, `for`, `while` 1328 `break`, `continue` and `return` statements: 1329 1330 >>> @tf.function 1331 ... def f(x): 1332 ... if tf.reduce_sum(x) > 0: 1333 ... return x * x 1334 ... else: 1335 ... return -x // 2 1336 >>> f(tf.constant(-2)) 1337 <tf.Tensor: ... numpy=1> 1338 1339 `func`'s closure may include `tf.Tensor` and `tf.Variable` objects: 1340 1341 >>> @tf.function 1342 ... def f(): 1343 ... return x ** 2 + y 1344 >>> x = tf.constant([-2, -3]) 1345 >>> y = tf.Variable([3, -2]) 1346 >>> f() 1347 <tf.Tensor: ... numpy=array([7, 7], ...)> 1348 1349 `func` may also use ops with side effects, such as `tf.print`, `tf.Variable` 1350 and others: 1351 1352 >>> v = tf.Variable(1) 1353 >>> @tf.function 1354 ... def f(x): 1355 ... for i in tf.range(x): 1356 ... v.assign_add(i) 1357 >>> f(3) 1358 >>> v 1359 <tf.Variable ... numpy=4> 1360 1361 Important: Any Python side-effects (appending to a list, printing with 1362 `print`, etc) will only happen once, when `func` is traced. To have 1363 side-effects executed into your `tf.function` they need to be written 1364 as TF ops: 1365 1366 >>> l = [] 1367 >>> @tf.function 1368 ... def f(x): 1369 ... for i in x: 1370 ... l.append(i + 1) # Caution! Will only happen once when tracing 1371 >>> f(tf.constant([1, 2, 3])) 1372 >>> l 1373 [<tf.Tensor ...>] 1374 1375 Instead, use TensorFlow collections like `tf.TensorArray`: 1376 1377 >>> @tf.function 1378 ... def f(x): 1379 ... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) 1380 ... for i in range(len(x)): 1381 ... ta = ta.write(i, x[i] + 1) 1382 ... return ta.stack() 1383 >>> f(tf.constant([1, 2, 3])) 1384 <tf.Tensor: ..., numpy=array([2, 3, 4], ...)> 1385 1386 ## `tf.function` creates polymorphic callables 1387 1388 Internally, `tf.types.experimental.GenericFunction` may contain multiple 1389 `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with 1390 different data types or shapes, since TensorFlow can perform more 1391 optimizations on graphs of specific shapes, dtypes and values of constant 1392 arguments. `tf.function` treats any pure Python values as opaque objects (best 1393 thought of as compile-time constants), and builds a separate `tf.Graph` for 1394 each set of Python arguments that it encounters. 1395 For more information, see the 1396 [tf.function guide](https://www.tensorflow.org/guide/function?hl=en#rules_of_tracing) 1397 1398 Executing a `GenericFunction` will select and execute the appropriate 1399 `ConcreteFunction` based on the argument types and values. 1400 1401 To obtain an individual `ConcreteFunction`, use the 1402 `GenericFunction.get_concrete_function` method. It can be called with the 1403 same arguments as `func` and returns a 1404 `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a 1405 single `tf.Graph`: 1406 1407 >>> @tf.function 1408 ... def f(x): 1409 ... return x + 1 1410 >>> isinstance(f.get_concrete_function(1).graph, tf.Graph) 1411 True 1412 1413 `ConcreteFunction`s can be executed just like `GenericFunction`s, but their 1414 input is resticted to the types to which they're specialized. 1415 1416 ## Retracing 1417 1418 `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is 1419 called with new TensorFlow types or shapes, or with new Python values as 1420 arguments. When `GenericFunction` builds a new trace, it is said that `func` 1421 is retraced. Retracing is a frequent performance concern for `tf.function` as 1422 it can be considerably slower than executing a graph that's already been 1423 traced. It is ideal to minimize the amount of retracing in your code. 1424 1425 Caution: Passing python scalars or lists as arguments to `tf.function` will 1426 usually retrace. To avoid this, pass numeric arguments as Tensors whenever 1427 possible: 1428 1429 >>> @tf.function 1430 ... def f(x): 1431 ... return tf.abs(x) 1432 >>> f1 = f.get_concrete_function(1) 1433 >>> f2 = f.get_concrete_function(2) # Slow - compiles new graph 1434 >>> f1 is f2 1435 False 1436 >>> f1 = f.get_concrete_function(tf.constant(1)) 1437 >>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1 1438 >>> f1 is f2 1439 True 1440 1441 Python numerical arguments should only be used when they take few distinct 1442 values, such as hyperparameters like the number of layers in a neural network. 1443 1444 ## Input signatures 1445 1446 For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for 1447 every unique set of input shapes and datatypes. The example below creates two 1448 separate `ConcreteFunction`s, each specialized to a different shape: 1449 1450 >>> @tf.function 1451 ... def f(x): 1452 ... return x + 1 1453 >>> vector = tf.constant([1.0, 1.0]) 1454 >>> matrix = tf.constant([[3.0]]) 1455 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1456 False 1457 1458 An "input signature" can be optionally provided to `tf.function` to control 1459 this process. The input signature specifies the shape and type of each 1460 Tensor argument to the function using a `tf.TensorSpec` object. More general 1461 shapes can be used. This ensures only one `ConcreteFunction` is created, and 1462 restricts the `GenericFunction` to the specified shapes and types. It is 1463 an effective way to limit retracing when Tensors have dynamic shapes. 1464 1465 >>> @tf.function( 1466 ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 1467 ... def f(x): 1468 ... return x + 1 1469 >>> vector = tf.constant([1.0, 1.0]) 1470 >>> matrix = tf.constant([[3.0]]) 1471 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1472 True 1473 1474 ## Variables may only be created once 1475 1476 `tf.function` only allows creating new `tf.Variable` objects when it is called 1477 for the first time: 1478 1479 >>> class MyModule(tf.Module): 1480 ... def __init__(self): 1481 ... self.v = None 1482 ... 1483 ... @tf.function 1484 ... def __call__(self, x): 1485 ... if self.v is None: 1486 ... self.v = tf.Variable(tf.ones_like(x)) 1487 ... return self.v * x 1488 1489 In general, it is recommended to create `tf.Variable`s outside of 1490 `tf.function`. 1491 In simple cases, persisting state across `tf.function` boundaries may be 1492 implemented using a pure functional style in which state is represented by 1493 `tf.Tensor`s passed as arguments and returned as return values. 1494 1495 Contrast the two styles below: 1496 1497 >>> state = tf.Variable(1) 1498 >>> @tf.function 1499 ... def f(x): 1500 ... state.assign_add(x) 1501 >>> f(tf.constant(2)) # Non-pure functional style 1502 >>> state 1503 <tf.Variable ... numpy=3> 1504 1505 >>> state = tf.constant(1) 1506 >>> @tf.function 1507 ... def f(state, x): 1508 ... state += x 1509 ... return state 1510 >>> state = f(state, tf.constant(2)) # Pure functional style 1511 >>> state 1512 <tf.Tensor: ... numpy=3> 1513 1514 ## Python operations execute only once per trace 1515 1516 `func` may contain TensorFlow operations mixed with pure Python operations. 1517 However, when the function is executed, only the TensorFlow operations will 1518 run. The Python operations run only once, at trace time. If TensorFlow 1519 operations depend on results from Pyhton operations, those results will be 1520 frozen into the graph. 1521 1522 >>> @tf.function 1523 ... def f(a, b): 1524 ... print('this runs at trace time; a is', a, 'and b is', b) 1525 ... return b 1526 >>> f(1, tf.constant(1)) 1527 this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32) 1528 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1529 1530 >>> f(1, tf.constant(2)) 1531 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1532 1533 >>> f(2, tf.constant(1)) 1534 this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32) 1535 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1536 1537 >>> f(2, tf.constant(2)) 1538 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1539 1540 ## Using type annotations to improve performance 1541 1542 'experimental_follow_type_hints` can be used along with type annotations to 1543 reduce retracing by automatically casting any Python values to `tf.Tensor` 1544 (something that is not done by default, unless you use input signatures). 1545 1546 >>> @tf.function(experimental_follow_type_hints=True) 1547 ... def f_with_hints(x: tf.Tensor): 1548 ... print('Tracing') 1549 ... return x 1550 >>> @tf.function(experimental_follow_type_hints=False) 1551 ... def f_no_hints(x: tf.Tensor): 1552 ... print('Tracing') 1553 ... return x 1554 >>> f_no_hints(1) 1555 Tracing 1556 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1557 >>> f_no_hints(2) 1558 Tracing 1559 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1560 >>> f_with_hints(1) 1561 Tracing 1562 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1563 >>> f_with_hints(2) 1564 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1565 1566 Args: 1567 func: the function to be compiled. If `func` is None, `tf.function` returns 1568 a decorator that can be invoked with a single argument - `func`. In other 1569 words, `tf.function(input_signature=...)(func)` is equivalent to 1570 `tf.function(func, input_signature=...)`. The former can be used as 1571 decorator. 1572 input_signature: A possibly nested sequence of `tf.TensorSpec` objects 1573 specifying the shapes and dtypes of the Tensors that will be supplied to 1574 this function. If `None`, a separate function is instantiated for each 1575 inferred input signature. If input_signature is specified, every input to 1576 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. 1577 autograph: Whether autograph should be applied on `func` before tracing a 1578 graph. Data-dependent control flow requires `autograph=True`. For more 1579 information, see the [tf.function and AutoGraph guide]( 1580 https://www.tensorflow.org/guide/function#autograph_transformations). 1581 jit_compile: If `True`, compiles the function using 1582 [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations, 1583 such as fusion, and attempts to emit more efficient code. This may 1584 drastically improve the performance. If set to `True`, 1585 the whole function needs to be compilable by XLA, or an 1586 `errors.InvalidArgumentError` is thrown. 1587 If `None` (default), compiles the function with XLA when running on TPU 1588 and goes through the regular function execution path when running on 1589 other devices. 1590 If `False`, executes the function without XLA compilation. Set this value 1591 to `False` when directly running a multi-device function on TPUs (e.g. two 1592 TPU cores, one TPU core and its host CPU). 1593 Not all functions are compilable, see a list of 1594 [sharp corners](https://tensorflow.org/xla/known_issues). 1595 experimental_implements: If provided, contains a name of a "known" function 1596 this implements. For example "mycompany.my_recurrent_cell". 1597 This is stored as an attribute in inference function, 1598 which can then be detected when processing serialized function. 1599 See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long 1600 for details. For an example of utilizing this attribute see this 1601 [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc) 1602 The code above automatically detects and substitutes function that 1603 implements "embedded_matmul" and allows TFLite to substitute its own 1604 implementations. For instance, a tensorflow user can use this 1605 attribute to mark that their function also implements 1606 `embedded_matmul` (perhaps more efficiently!) 1607 by specifying it using this parameter: 1608 `@tf.function(experimental_implements="embedded_matmul")` 1609 This can either be specified as just the string name of the function or 1610 a NameAttrList corresponding to a list of key-value attributes associated 1611 with the function name. The name of the function will be in the 'name' 1612 field of the NameAttrList. To define a formal TF op for this function 1613 implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr) 1614 project. 1615 experimental_autograph_options: Optional tuple of 1616 `tf.autograph.experimental.Feature` values. 1617 experimental_relax_shapes: When True, `tf.function` may generate fewer, 1618 graphs that are less specialized on input shapes. 1619 experimental_compile: Deprecated alias to 'jit_compile'. 1620 experimental_follow_type_hints: When True, the function may use type 1621 annotations from `func` to optimize the tracing performance. For example, 1622 arguments annotated with `tf.Tensor` will automatically be converted 1623 to a Tensor. 1624 1625 Returns: 1626 If `func` is not None, returns a `tf.types.experimental.GenericFunction`. 1627 If `func` is None, returns a decorator that, when invoked with a single 1628 `func` argument, returns a `tf.types.experimental.GenericFunction`. 1629 1630 Raises: 1631 `ValueError` when attempting to use `jit_compile=True`, but XLA support is 1632 not available. 1633 """ 1634 if func is not None: 1635 function_lib.validate_python_function(func) 1636 if input_signature is not None: 1637 function_lib.validate_signature(input_signature) 1638 if experimental_follow_type_hints is None: 1639 experimental_follow_type_hints = False 1640 1641 def decorated(inner_function): 1642 try: 1643 name = inner_function.__name__ 1644 except AttributeError: 1645 name = "function" 1646 return tf_decorator.make_decorator( 1647 inner_function, 1648 decorator_name="tf.function", 1649 decorator_func=Function( 1650 inner_function, 1651 name, 1652 input_signature=input_signature, 1653 autograph=autograph, 1654 experimental_autograph_options=experimental_autograph_options, 1655 experimental_relax_shapes=experimental_relax_shapes, 1656 1657 # TODO(b/171825496): Update once `experimental_compile` is removed 1658 # entirely in favor of 'jit_compile'. 1659 jit_compile=deprecation.deprecated_argument_lookup( 1660 "jit_compile", 1661 jit_compile, 1662 "experimental_compile", 1663 experimental_compile), 1664 experimental_implements=experimental_implements, 1665 experimental_follow_type_hints=experimental_follow_type_hints)) 1666 1667 # This code path is for the `foo = tf.function(foo, ...)` use case 1668 if func is not None: 1669 return decorated(func) 1670 1671 # This code path is for the 1672 # 1673 # @tf.function(...) 1674 # def foo(...): 1675 # ... 1676 # 1677 # use case, which is equivalent to `foo = tf.function(...)(foo)` 1678 return decorated 1679