1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining graph functions with eager semantics.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import threading 24import weakref 25import six 26 27from google.protobuf import text_format as _text_format 28from google.protobuf.message import DecodeError 29from tensorflow.core.framework import attr_value_pb2 30from tensorflow.python.distribute.parallel_device import parallel_device 31from tensorflow.python.eager import context 32from tensorflow.python.eager import function as function_lib 33from tensorflow.python.eager import lift_to_graph 34from tensorflow.python.eager import monitoring 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import func_graph as func_graph_module 37from tensorflow.python.framework import ops 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import control_flow_util 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import random_ops 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.profiler import trace 46from tensorflow.python.training.tracking import base as trackable 47from tensorflow.python.util import deprecation 48from tensorflow.python.util import nest 49from tensorflow.python.util import object_identity 50from tensorflow.python.util import tf_decorator 51from tensorflow.python.util.tf_export import tf_export 52 53FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10 54FREQUENT_TRACING_WARNING_THRESHOLD = 5 55FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2 56 57 58_tf_function_counter = monitoring.Counter( 59 "/tensorflow/core/tf_function_counter", 60 "Counter for the number of tf.functions created when Eager execution is " 61 "enabled.", 62 # jit_compile is "0" or "1". 63 "jit_compile") 64 65 66class _FrequentTracingDetector(object): 67 """Class keeping track of how many recent calls triggered tracing.""" 68 69 __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"] 70 71 def __init__(self): 72 self._calls_per_tracings = [] 73 self._total_warning_count = 0 74 self._call_count = 0 75 76 def called_with_tracing(self, function_name, omit_warning): 77 """Updates the list of most recent calls' tracing information. 78 79 Warns the user when recent calls caused retracing too often. 80 81 Args: 82 function_name: the python function being traced. 83 omit_warning: If 'True', this call will not warn the user even if 84 retracing happens too often. 85 """ 86 self._call_count += 1 87 self._calls_per_tracings.append(1) 88 89 while self._calls_per_tracings: 90 if (self._call_count - self._calls_per_tracings[0] > 91 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY): 92 self._call_count -= self._calls_per_tracings.pop(0) 93 else: 94 break 95 96 if (omit_warning or self._total_warning_count >= 97 FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR): 98 return 99 if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD: 100 self._total_warning_count += 1 101 logging.warning( 102 "{} out of the last {} calls to {} triggered tf.function " 103 "retracing. Tracing is expensive and the excessive number of " 104 "tracings could be due to (1) creating @tf.function repeatedly in " 105 "a loop, (2) passing tensors with different shapes, (3) passing " 106 "Python objects instead of tensors. For (1), please define your " 107 "@tf.function outside of the loop. For (2), @tf.function has " 108 "experimental_relax_shapes=True option that relaxes argument " 109 "shapes that can avoid unnecessary retracing. For (3), please " 110 "refer to " 111 "https://www.tensorflow.org/guide/function#controlling_retracing" 112 " and https://www.tensorflow.org/api_docs/python/tf/function for " 113 " more details.".format( 114 len(self._calls_per_tracings), self._call_count, function_name)) 115 116 def called_without_tracing(self): 117 # We don't count tracing when users load a concrete function directly or 118 # call get_concrete_function, so the first call can be not a tracing call. 119 if not self._calls_per_tracings: 120 self._calls_per_tracings = [0] 121 self._calls_per_tracings[-1] += 1 122 self._call_count += 1 123 124 125class _FrequentTracingDetectorManager(object): 126 """Class for the management of all _FrequentTracingDetector objects.""" 127 128 __slots__ = ["_detectors", "_lock"] 129 130 def __init__(self): 131 self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock) 132 self._lock = threading.Lock() 133 134 def _get_detector(self, key): 135 if key not in self._detectors: 136 self._detectors[key] = _FrequentTracingDetector() 137 return self._detectors[key] 138 139 def called_without_tracing(self, key): 140 with self._lock: 141 detector = self._get_detector(key) 142 detector.called_without_tracing() 143 144 def called_with_tracing(self, key, function_name, omit_warning): 145 with self._lock: 146 detector = self._get_detector(key) 147 detector.called_with_tracing(function_name, omit_warning) 148 149 150_frequent_tracing_detector_manager = _FrequentTracingDetectorManager() 151 152 153class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): 154 """Variable which does not lift its initializer out of function context. 155 156 Instances of this variable, when created, build a graph which runs their 157 initializer inside a tf.cond(is_initialized) block. 158 159 This can only be created inside a defun called from (eventually) eager 160 mode. That is, non-function-building graphs are not supported. 161 """ 162 163 def __init__(self, 164 initial_value=None, 165 trainable=None, 166 caching_device=None, 167 name=None, 168 dtype=None, 169 constraint=None, 170 add_initializers_to=None, 171 lifted_initializer_graph=None, 172 synchronization=None, 173 aggregation=None, 174 shape=None, 175 **unused_kwargs): 176 """Creates a variable. 177 178 Args: 179 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 180 which is the initial value for the Variable. The initial value must have 181 a shape specified unless `validate_shape` is set to False. Can also be a 182 callable with no argument that returns the initial value when called. 183 (Note that initializer functions from init_ops.py must first be bound 184 to a shape before being used here.) 185 trainable: If `True`, GradientTapes automatically watch uses of this 186 Variable. 187 caching_device: Optional device string or function describing where the 188 Variable should be cached for reading. Defaults to the Variable's 189 device. If not `None`, caches on another device. Typical use is to 190 cache on the device where the Ops using the Variable reside, to 191 deduplicate copying through `Switch` and other conditional statements. 192 name: Optional name for the variable. Defaults to `'Variable'` and gets 193 uniquified automatically. 194 dtype: If set, initial_value will be converted to the given type. 195 If None, either the datatype will be kept (if initial_value is 196 a Tensor) or float32 will be used (if it is a Python object convertible 197 to a Tensor). 198 constraint: An optional projection function to be applied to the variable 199 after being updated by an `Optimizer` (e.g. used to implement norm 200 constraints or value constraints for layer weights). The function must 201 take as input the unprojected Tensor representing the value of the 202 variable and return the Tensor for the projected value 203 (which must have the same shape). Constraints are not safe to 204 use when doing asynchronous distributed training. 205 add_initializers_to: if not None and not in legacy graph mode, the 206 initializer tensor will be added to this map in addition to adding the 207 assignment to the function. 208 lifted_initializer_graph: FuncGraph to try to lift initializers to. 209 synchronization: Indicates when a distributed a variable will be 210 aggregated. Accepted values are constants defined in the class 211 `tf.VariableSynchronization`. By default the synchronization is set to 212 `AUTO` and the current `DistributionStrategy` chooses 213 when to synchronize. 214 aggregation: Indicates how a distributed variable will be aggregated. 215 Accepted values are constants defined in the class 216 `tf.VariableAggregation`. 217 shape: (optional) The shape of this variable. If None, the shape of 218 `initial_value` will be used. When setting this argument to 219 `tf.TensorShape(None)` (representing an unspecified shape), the variable 220 can be assigned with values of different shapes. 221 222 Raises: 223 ValueError: If the initial value is not specified, or does not have a 224 shape and `validate_shape` is `True`. 225 RuntimeError: If called outside of a function definition. 226 """ 227 with ops.init_scope(): 228 self._in_graph_mode = not context.executing_eagerly() 229 if not ops.inside_function(): 230 # If we've been init_scope()d out of the function definition nothing to do 231 # here; we can't really do the capturing or conditional logic. 232 resource_variable_ops.ResourceVariable.__init__( 233 self, initial_value=initial_value, trainable=trainable, 234 caching_device=caching_device, name=name, dtype=dtype, 235 constraint=constraint) 236 return 237 if initial_value is None: 238 raise ValueError("initial_value must be specified.") 239 init_from_fn = callable(initial_value) 240 241 if constraint is not None and not callable(constraint): 242 raise ValueError("The `constraint` argument must be a callable.") 243 244 with ops.name_scope(name, "Variable", [] 245 if init_from_fn else [initial_value]) as scope_name: 246 with ops.name_scope("Initializer"): 247 if init_from_fn: 248 initial_value = initial_value() 249 if isinstance(initial_value, trackable.CheckpointInitialValue): 250 self._maybe_initialize_trackable() 251 self._update_uid = initial_value.checkpoint_position.restore_uid 252 initial_value = initial_value.wrapped_value 253 254 initial_value = ops.convert_to_tensor(initial_value, 255 name="initial_value", dtype=dtype) 256 assert initial_value is not None 257 258 # Don't use `shape or initial_value.shape` since TensorShape has 259 # overridden `__bool__`. 260 if shape is None: 261 shape = initial_value.shape 262 263 # Use the constructor for UninitializedVariable to start. Outside the name 264 # scope so we don't double up the prefix. 265 super(UnliftedInitializerVariable, self).__init__( 266 trainable=trainable, 267 caching_device=caching_device, 268 name=name, 269 shape=shape, 270 dtype=initial_value.dtype, 271 constraint=constraint, 272 synchronization=synchronization, 273 aggregation=aggregation, 274 extra_handle_data=initial_value, 275 **unused_kwargs) 276 277 with ops.name_scope(scope_name): 278 if self._in_graph_mode: 279 with ops.init_scope(): 280 outer_graph = ops.get_default_graph() 281 func_graph = ops.get_default_graph() 282 function_placeholders = ( 283 func_graph.inputs + func_graph.internal_captures) 284 placeholder_ops = set( 285 [tensor.op for tensor in function_placeholders]) 286 lifted_initializer = lift_to_graph.lift_to_graph( 287 [initial_value], outer_graph, 288 disallowed_placeholders=placeholder_ops)[initial_value] 289 with ops.init_scope(): 290 self._initial_value = lifted_initializer 291 with ops.name_scope("IsInitialized"): 292 self._is_initialized_op = ( 293 resource_variable_ops.var_is_initialized_op(self._handle)) 294 if initial_value is not None: 295 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 296 self._initializer_op = resource_variable_ops.assign_variable_op( 297 self._handle, lifted_initializer, name=n) 298 elif context.executing_eagerly(): 299 # In this case, both current scope and init scope are eager. 300 # Assign_variable_op will be executed immediately. So we don't need to 301 # add it to "add_initializers_to" to lift it out. 302 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 303 resource_variable_ops.assign_variable_op( 304 self._handle, initial_value, name=n) 305 else: 306 # Init scope is eager but current scope is graph. We will lift out this 307 # variable by addint it into "add_initializers_to". 308 if add_initializers_to is not None: 309 add_initializers_to.append((self, initial_value)) 310 311 def assign_fn(): 312 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 313 resource_variable_ops.assign_variable_op( 314 self._handle, 315 initial_value, 316 name=n) 317 # Returning values to keep tf.cond happy. 318 return ops.convert_to_tensor(1) 319 def not_assign_fn(): 320 return ops.convert_to_tensor(0) 321 # Note: this cond is always guaranteed to run because we're inside a 322 # defun which will insert automatic control dependencies. It will only 323 # execute assign_fn if lifting failed. 324 graph = ops.get_default_graph() 325 326 # Capture the handle ahead of time in order to avoid querying the shape 327 # of the handle which helps async execution performance 328 graph.capture(self._handle, shape=()) 329 control_flow_ops.cond( 330 resource_variable_ops.var_is_initialized_op(self._handle), 331 not_assign_fn, assign_fn) 332 333 334RUN_FUNCTIONS_EAGERLY = False 335 336 337@deprecation.deprecated( 338 None, 339 "Use `tf.config.run_functions_eagerly` instead of the experimental " 340 "version.") 341@tf_export("config.experimental_run_functions_eagerly") 342def experimental_run_functions_eagerly(run_eagerly): 343 """Enables / disables eager execution of `tf.function`s. 344 345 Calling `tf.config.experimental_run_functions_eagerly(True)` will make all 346 invocations of `tf.function` run eagerly instead of running as a traced graph 347 function. 348 349 See `tf.config.run_functions_eagerly` for an example. 350 351 Note: This flag has no effect on functions passed into tf.data transformations 352 as arguments. tf.data functions are never executed eagerly and are always 353 executed as a compiled Tensorflow Graph. 354 355 Args: 356 run_eagerly: Boolean. Whether to run functions eagerly. 357 """ 358 return run_functions_eagerly(run_eagerly) 359 360 361@tf_export("config.run_functions_eagerly") 362def run_functions_eagerly(run_eagerly): 363 """Enables / disables eager execution of `tf.function`s. 364 365 Calling `tf.config.run_functions_eagerly(True)` will make all 366 invocations of `tf.function` run eagerly instead of running as a traced graph 367 function. 368 369 This can be useful for debugging. 370 371 >>> def my_func(a): 372 ... print("Python side effect") 373 ... return a + a 374 >>> a_fn = tf.function(my_func) 375 376 >>> # A side effect the first time the function is traced 377 >>> a_fn(tf.constant(1)) 378 Python side effect 379 <tf.Tensor: shape=(), dtype=int32, numpy=2> 380 381 >>> # No further side effect, as the traced function is called 382 >>> a_fn(tf.constant(2)) 383 <tf.Tensor: shape=(), dtype=int32, numpy=4> 384 385 >>> # Now, switch to eager running 386 >>> tf.config.run_functions_eagerly(True) 387 >>> # Side effect, as the function is called directly 388 >>> a_fn(tf.constant(2)) 389 Python side effect 390 <tf.Tensor: shape=(), dtype=int32, numpy=4> 391 392 >>> # Turn this back off 393 >>> tf.config.run_functions_eagerly(False) 394 395 Note: This flag has no effect on functions passed into tf.data transformations 396 as arguments. tf.data functions are never executed eagerly and are always 397 executed as a compiled Tensorflow Graph. 398 399 Args: 400 run_eagerly: Boolean. Whether to run functions eagerly. 401 """ 402 global RUN_FUNCTIONS_EAGERLY 403 RUN_FUNCTIONS_EAGERLY = bool(run_eagerly) 404 405 406@deprecation.deprecated( 407 None, 408 "Use tf.config.functions_run_eagerly instead of the experimental version.") 409@tf_export("config.experimental_functions_run_eagerly") 410def experimental_functions_run_eagerly(): 411 """Returns the value of the `experimental_run_functions_eagerly` setting.""" 412 return functions_run_eagerly() 413 414 415@tf_export("config.functions_run_eagerly") 416def functions_run_eagerly(): 417 """Returns the value of the `run_functions_eagerly` setting.""" 418 return RUN_FUNCTIONS_EAGERLY 419 420 421def _evaluate_var_is_initialized(variables): 422 """Compute booleans indicating whether each variable is initialized.""" 423 with ops.init_scope(): 424 var_is_initialized = [] 425 for v in variables: 426 var_is_initialized.append( 427 resource_variable_ops.var_is_initialized_op(v.handle)) 428 try: 429 # Stack all the var_is_initialized values into one tensor and interpret 430 # the numpy value. This will reduce the number of RPCs between client and 431 # worker in the remote case. 432 return array_ops.stack(var_is_initialized).numpy() 433 except errors.UnimplementedError: 434 # Some devices do not support implicit copy-off to host. Fall back to 435 # variable-by-variable processing. 436 for index, v in enumerate(variables): 437 try: 438 numpy_value = var_is_initialized[index].numpy() 439 except errors.UnimplementedError: 440 # This is a variable on a parallel device; we'll extract its value on 441 # each replica and assert that they're identical. 442 components = parallel_device.unpack(var_is_initialized[index]) 443 with ops.device(None): 444 components = array_ops.stack(components) 445 all_initialized = math_ops.reduce_all(components).numpy() 446 any_initialized = math_ops.reduce_any(components).numpy() 447 if all_initialized != any_initialized: 448 raise NotImplementedError( 449 ("Some but not all components of a parallel variable {} were " 450 "initialized between their creation in a tf.function and " 451 "the function's trace having completed. This is not yet " 452 "supported; consider initializing either all or none of the " 453 "components, or moving initialization out of the function." 454 ).format(repr(v))) 455 numpy_value = all_initialized 456 var_is_initialized[index] = numpy_value 457 return var_is_initialized 458 459 460class FunctionDeleter(object): 461 462 __slots__ = ["func_graph"] 463 464 def __init__(self, func_graph): 465 self.func_graph = func_graph 466 467 def __del__(self): 468 try: 469 func_graph_module.dismantle_func_graph(self.func_graph) 470 except: # pylint: disable=bare-except 471 # Note: bare except here because this can be noisy at shutdown time. 472 pass 473 474 475class Function(object): 476 """Wrapper class for the graph functions defined for a Python function. 477 478 See the documentation for `tf.function` for more information on the semantics 479 of defined functions. 480 481 `Function` is thread-compatible. 482 """ 483 484 def __init__(self, 485 python_function, 486 name, 487 input_signature=None, 488 autograph=True, 489 jit_compile=None, 490 experimental_implements=None, 491 experimental_autograph_options=None, 492 experimental_relax_shapes=False, 493 experimental_follow_type_hints=None): 494 """Initializes a `Function`. 495 496 Args: 497 python_function: the function to be wrapped. 498 name: the name given to it. 499 input_signature: See the documentation for `tf.function`. 500 autograph: See the documentation for `tf.function`. 501 jit_compile: See the documentation for `tf.function`. 502 experimental_implements: See the documentation for `tf.function`. 503 experimental_autograph_options: See the documentation for `tf.function`. 504 experimental_relax_shapes: See the documentation for `tf.function`. 505 experimental_follow_type_hints: See the documentation for `tf.function`. 506 507 Raises: 508 ValueError: if `input_signature` is not None and the `python_function`'s 509 argspec has keyword arguments. 510 """ 511 self._lock = threading.Lock() 512 self._python_function = python_function 513 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 514 python_function, 515 input_signature, 516 jit_compile=jit_compile, 517 experimental_follow_type_hints=experimental_follow_type_hints, 518 ) 519 self._implements = experimental_implements 520 # If `True`, the function uses the rendezvous of the parent. This is only 521 # needed to support code where raw send/recv operations are inserted and 522 # when functions are run in graph mode where they may not be inlined. 523 self._shared_rendezvous = None 524 self._autograph = autograph 525 self._experimental_autograph_options = experimental_autograph_options 526 self._experimental_relax_shapes = experimental_relax_shapes 527 self._jit_compile = jit_compile 528 if experimental_follow_type_hints is None: 529 experimental_follow_type_hints = False 530 self._experimental_follow_type_hints = experimental_follow_type_hints 531 self._created_variables = None # GUARDED_BY(self._lock) 532 self._stateful_fn = None # GUARDED_BY(self._lock) 533 self._stateless_fn = None # GUARDED_BY(self._lock) 534 self._descriptor_cache = weakref.WeakKeyDictionary() 535 self._name = name 536 self._input_signature = input_signature 537 self._key_for_call_stats = self._get_key_for_call_stats() 538 self._omit_frequent_tracing_warning = False 539 ops._tf_function_api_guage.get_cell().set(True) # pylint: disable=protected-access 540 541 def __getstate__(self): 542 """Custom pickling, to omit unpickleable objects.""" 543 result = self.__dict__.copy() 544 del result["_lock"] 545 del result["_descriptor_cache"] 546 del result["_key_for_call_stats"] 547 return result 548 549 def __setstate__(self, state): 550 """Restore from pickled state.""" 551 self.__dict__ = state 552 self._lock = threading.Lock() 553 self._descriptor_cache = weakref.WeakKeyDictionary() 554 self._key_for_call_stats = self._get_key_for_call_stats() 555 556 def _get_key_for_call_stats(self): 557 """Returns key instance to track call stats and retracings. 558 559 The key instance a best-effort to preserve global consistency. 560 """ 561 target_function = self._python_function 562 # `__wrapped__` is a conventional Python attribute that a higher-order 563 # function keeps its original function's instance. We also directly use 564 # this attribute for dealing with a class method. See 565 # `bound_method_wrapper` in `function.py`. If we don't use `__wrapped__`, 566 # all class methods will return the same `bound_method_wrapper` instance 567 # from this function. 568 while hasattr(target_function, "__wrapped__"): 569 target_function = target_function.__wrapped__ 570 571 if hasattr(target_function, "__func__"): 572 target_function = target_function.__func__ 573 574 if hasattr(target_function, "__code__"): 575 return target_function.__code__ 576 577 return self._python_function 578 579 def _defun_with_scope(self, scope): 580 """Creates a defun wrapped inside a variable creator scope.""" 581 582 weak_wrapped_fn = None 583 compile_with_xla = self._jit_compile 584 585 def wrapped_fn(*args, **kwds): 586 """Wraps `self._python_function` in a variable creator scope.""" 587 # We register a variable creator with reduced priority. If an outer 588 # variable creator is just modifying keyword arguments to the variable 589 # constructor, this will work harmoniously. Since the `scope` registered 590 # here actually creates the variable, it taking priority would otherwise 591 # ignore the outer creator. 592 # 593 # If an outer variable creator calls the variable constructor manually, 594 # for example creating a MirroredVariable, then they won't call our 595 # creator. This means we won't be able to trace the initialization graph, 596 # and so variable initializers can't depend on function arguments. This is 597 # better than the alternative, tracing the initialization graph but giving 598 # the user a variable type they didn't want. 599 default_graph = ops.get_default_graph() 600 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access 601 # __wrapped__ allows AutoGraph to swap in a converted function. We give 602 # the function a weak reference to itself to avoid a reference cycle. 603 if compile_with_xla and \ 604 not control_flow_util.GraphOrParentsInXlaContext(default_graph): 605 xla_context = control_flow_ops.XLAControlFlowContext() 606 try: 607 xla_context.Enter() 608 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 609 finally: 610 xla_context.Exit() 611 else: 612 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 613 return out 614 615 weak_wrapped_fn = weakref.ref(wrapped_fn) 616 617 return self._defun(tf_decorator.make_decorator( 618 self._python_function, 619 wrapped_fn)) 620 621 def _create_implements_attribute(self): 622 """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME.""" 623 attributes = {} 624 if isinstance(self._implements, str): 625 # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a 626 # NameAttrList. This is used when apart from the function name being 627 # implemented, a list of attributes is also being specified. 628 # The attributes are specified as key-value pairs in the NameAttrList 629 # of the corresponding AttrValue. The function name will be in the 630 # 'name' field of the NameAttrList. Else, it is just a string 631 # corresponding to the function name. 632 try: 633 implements_attr = six.ensure_text(self._implements, "utf-8") 634 attr_value = attr_value_pb2.AttrValue() 635 nameattrlist = attr_value_pb2.NameAttrList() 636 _text_format.Merge(implements_attr, nameattrlist) 637 attr_value.func.CopyFrom(nameattrlist) 638 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value 639 except (_text_format.ParseError, DecodeError): 640 attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements 641 return attributes 642 643 def _defun(self, fn): 644 """Returns a defun generated from the input function.""" 645 attributes = {} 646 647 if self._implements is not None: 648 attributes = self._create_implements_attribute() 649 650 share = self._shared_rendezvous 651 if share is not None: 652 attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share 653 654 if self._jit_compile is not None: 655 attributes.update(_XlaMustCompile=bool(self._jit_compile)) 656 if self._jit_compile: 657 attributes.update(_noinline=True) 658 if not attributes: 659 attributes = None 660 return function_lib.defun_with_attributes( 661 fn, 662 input_signature=self.input_signature, 663 attributes=attributes, 664 autograph=self._autograph, 665 jit_compile=self._jit_compile, 666 experimental_autograph_options=self._experimental_autograph_options, 667 experimental_follow_type_hints=self._experimental_follow_type_hints, 668 experimental_relax_shapes=self._experimental_relax_shapes) 669 670 def _initialize(self, args, kwds, add_initializers_to=None): 671 """Initializes, on the first call. 672 673 Creates two `Function`s, one that will allow creation of variables 674 and one that won't. 675 676 Additionally runs a trace for the `Function` that allows creation 677 of variables. 678 679 Args: 680 args: Arguments to the underlying python callable. 681 kwds: Keyword arguments to the python callable. 682 add_initializers_to: Where to collect variable initializers, if not None. 683 """ 684 685 created_variables = [] 686 lifted_initializer_graph = func_graph_module.FuncGraph("initializer") 687 688 def variable_capturing_scope(unused_next_creator, **kwds): 689 """Creates UnliftedInitializerVariables and saves references to them.""" 690 v = UnliftedInitializerVariable( 691 add_initializers_to=add_initializers_to, 692 lifted_initializer_graph=lifted_initializer_graph, **kwds) 693 created_variables.append(weakref.ref(v)) 694 return v 695 696 self._created_variables = created_variables 697 self._stateful_fn = self._defun_with_scope(variable_capturing_scope) 698 self._stateful_fn._name = self._name # pylint: disable=protected-access 699 # Force the definition of the function for these arguments 700 self._lifted_initializer_graph = lifted_initializer_graph 701 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) 702 self._concrete_stateful_fn = ( 703 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access 704 *args, **kwds)) 705 706 compiled = bool(self._jit_compile and 707 not control_flow_util.GraphOrParentsInXlaContext( 708 ops.get_default_graph())) 709 # For nested functions, increment the counter only when a function with 710 # jit_compile=True is called within a function with jit_compile=False. We 711 # count this special case to correctly record that both jit_compile=True and 712 # jit_compile=False is being used for parts of the outer function. 713 if ops.executing_eagerly_outside_functions() and ( 714 context.executing_eagerly() or compiled): 715 # Labels must be strings in Python, so we convert 'compiled' to a string 716 _tf_function_counter.get_cell(str(int(compiled))).increase_by(1) 717 718 def invalid_creator_scope(*unused_args, **unused_kwds): 719 """Disables variable creation.""" 720 raise ValueError( 721 "tf.function-decorated function tried to create " 722 "variables on non-first call.") 723 724 self._stateless_fn = self._defun_with_scope(invalid_creator_scope) 725 self._stateless_fn._name = self._name # pylint: disable=protected-access 726 727 def _clone(self, python_function): 728 """Clone the function with different python function.""" 729 f = Function( 730 python_function=(self._python_function 731 if python_function is None else python_function), 732 name=self._name, 733 input_signature=self._input_signature, 734 autograph=self._autograph, 735 jit_compile=self._jit_compile, 736 experimental_implements=self._implements, 737 experimental_autograph_options=self._experimental_autograph_options, 738 experimental_relax_shapes=self._experimental_relax_shapes, 739 experimental_follow_type_hints=self._experimental_follow_type_hints) 740 741 if self._shared_rendezvous: 742 f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access 743 744 return f 745 746 def _decorate(self, decorator): 747 """Allows the captured Python function to be decorated in place. 748 749 This method is only safe to call when the Function has not been called by a 750 user. It makes sense to use this method to push a decorator into the 751 function rather than wrapping the function in the decorator. 752 753 We use this in tf.Module to allow user annotated `tf.functions` to remain as 754 `Function` objects but still automatically enter the Module name_scope 755 when they are evaluated like all other methods. 756 757 Args: 758 decorator: A callable accepting a single argument which is the function 759 to decorate and returning a callable result. 760 761 Raises: 762 ValueError: If the function has been called a ValueError is raised. 763 """ 764 if self._stateful_fn is not None or self._stateless_fn is not None: 765 raise ValueError( 766 "Functions cannot be decorated after they have been traced.") 767 768 self._python_function = decorator(self._python_function) 769 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 770 self._python_function, self.input_signature) 771 772 # TODO: Remove this private method after updating all its uses 773 # A good moment to do this could be when the experimental label is removed 774 def _get_tracing_count(self): 775 return self.experimental_get_tracing_count() 776 777 def experimental_get_tracing_count(self): 778 """Returns the number of times the function has been traced. 779 780 For more information on when a function is traced and when it is 781 traced multiple times see https://www.tensorflow.org/guide/function. 782 Example: 783 784 >>> @tf.function 785 ... def double(a): 786 ... return a + a 787 >>> double(tf.constant(1)) 788 >>> double(tf.constant(2)) 789 >>> double.experimental_get_tracing_count() 790 1 791 >>> double(tf.constant("a")) 792 >>> double.experimental_get_tracing_count() 793 2 794 795 796 The first time experimental_get_tracing_count is called 797 it returns 1, as the function is traced the first 798 time it is called, and the second time the same graph is used 799 since we're calling it with a parameter of the same type. 800 801 The second time experimental_get_tracing_count is called 802 it returns 2, as we called double with a 803 different argument type, and so it was traced again. 804 805 """ 806 result = self._stateless_fn.tracing_count if self._stateless_fn else 0 807 result += self._stateful_fn.tracing_count if self._stateful_fn else 0 808 return result 809 810 @property 811 def _run_functions_eagerly(self): 812 return RUN_FUNCTIONS_EAGERLY 813 814 def __call__(self, *args, **kwds): 815 """Calls the graph function and warn too frequent tracings.""" 816 if self._run_functions_eagerly: 817 with trace.Trace(self._name, tf_function_call="eager"): 818 return self._python_function(*args, **kwds) 819 820 tracing_count = self.experimental_get_tracing_count() 821 with trace.Trace(self._name) as tm: 822 result = self._call(*args, **kwds) 823 compiler = "xla" if self._jit_compile else "nonXla" 824 new_tracing_count = self.experimental_get_tracing_count() 825 without_tracing = (tracing_count == new_tracing_count) 826 execution_mode = "notTraced" if without_tracing else "traced" 827 tm.set_metadata(tf_function_call=execution_mode + "-" + compiler, 828 tracing_count=new_tracing_count) 829 830 if context.executing_eagerly(): 831 if without_tracing: 832 _frequent_tracing_detector_manager.called_without_tracing( 833 self._key_for_call_stats) 834 else: 835 _frequent_tracing_detector_manager.called_with_tracing( 836 self._key_for_call_stats, self._python_function, 837 self._omit_frequent_tracing_warning) 838 839 return result 840 841 def _call(self, *args, **kwds): 842 """Calls the graph function.""" 843 self._lock.acquire() 844 if self._created_variables: 845 # Release the lock early so that multiple threads can perform the call 846 # in parallel. 847 self._lock.release() 848 # In this case we have created variables on the first call, so we run the 849 # defunned version which is guaranteed to never create variables. 850 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 851 elif self._stateful_fn is not None: 852 # Release the lock early so that multiple threads can perform the call 853 # in parallel. 854 self._lock.release() 855 # In this case we have not created variables on the first call. So we can 856 # run the first trace but we should fail if variables are created. 857 results = self._stateful_fn(*args, **kwds) 858 if self._created_variables: 859 raise ValueError("Creating variables on a non-first call to a function" 860 " decorated with tf.function.") 861 return results 862 863 try: 864 # This is the first call of __call__, so we have to initialize. 865 initializers = [] 866 self._initialize(args, kwds, add_initializers_to=initializers) 867 finally: 868 # At this point we know that the initialization is complete (or less 869 # interestingly an exception was raised) so we no longer need a lock. 870 self._lock.release() 871 872 if self._created_variables: 873 try: 874 # Attempt to initialize variables eagerly and without conds by lifting 875 # out initialization graphs. This is the only initialization strategy 876 # compatible with XLA at the moment. 877 self._initialize_uninitialized_variables(initializers) 878 except lift_to_graph.UnliftableError: 879 pass # Fall through to cond-based initialization. 880 else: 881 # Lifting succeeded, so variables are initialized and we can run the 882 # stateless function. 883 return self._stateless_fn(*args, **kwds) 884 else: 885 _, _, _, filtered_flat_args = \ 886 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 887 *args, **kwds) 888 # If we did not create any variables the trace we have is good enough. 889 return self._concrete_stateful_fn._call_flat( 890 filtered_flat_args, self._concrete_stateful_fn.captured_inputs) # pylint: disable=protected-access 891 892 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args): 893 """Conditionally runs initialization if it's needed.""" 894 condition = True 895 for wr in self._created_variables: 896 variable = wr() 897 if variable is None: 898 raise ValueError( 899 "A tf.Variable created inside your tf.function has been" 900 " garbage-collected. Your code needs to keep Python references" 901 " to variables created inside `tf.function`s.\n" 902 "\n" 903 "A common way to raise this error is to create and return a" 904 " variable only referenced inside your function:\n" 905 "\n" 906 "@tf.function\n" 907 "def f():\n" 908 " v = tf.Variable(1.0)\n" 909 " return v\n" 910 "\n" 911 "v = f() # Crashes with this error message!\n" 912 "\n" 913 "The reason this crashes is that @tf.function annotated" 914 " function returns a **`tf.Tensor`** with the **value** of the" 915 " variable when the function is called rather than the" 916 " variable instance itself. As such there is no code holding a" 917 " reference to the `v` created inside the function and Python" 918 " garbage collects it.\n" 919 "\n" 920 "The simplest way to fix this issue is to create variables" 921 " outside the function and capture them:\n" 922 "\n" 923 "v = tf.Variable(1.0)\n" 924 "\n" 925 "@tf.function\n" 926 "def f():\n" 927 " return v\n" 928 "\n" 929 "f() # <tf.Tensor: numpy=1.>\n" 930 "v.assign_add(1.)\n" 931 "f() # <tf.Tensor: numpy=2.>") 932 condition = math_ops.logical_and( 933 condition, resource_variable_ops.var_is_initialized_op( 934 variable.handle)) 935 # We want to call stateless_fn if possible because it avoids recomputing 936 # potentially expensive initializers. 937 return control_flow_ops.cond( 938 condition, 939 lambda: self._stateless_fn(*inner_args, **inner_kwds), 940 functools.partial( 941 self._concrete_stateful_fn._call_flat, # pylint: disable=protected-access 942 inner_filtered_flat_args, 943 captured_inputs=self._concrete_stateful_fn.captured_inputs)) 944 945 # We've created variables and are unable to lift the initialization graphs, 946 # so we fall back to initializing with conds while running the function. 947 canon_args, canon_kwds, _, filtered_flat_args = \ 948 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 949 *args, **kwds) 950 return function_lib.defun(fn_with_cond)(canon_args, canon_kwds, 951 filtered_flat_args) 952 953 def experimental_get_compiler_ir(self, *args, **kwargs): 954 """Returns compiler IR for the compiled function. 955 956 This API is intended *only* for debugging as there are no guarantees on 957 backwards compatibility of returned IR or the allowed values of `stage`. 958 959 Args: 960 *args: Arguments used for compilation; same arguments as used for calling 961 the function. Need to be eager tensors. 962 **kwargs: Keyword arguments used for compilation. 963 964 Returns: 965 Function callable with the following kwargs: 966 - `stage` at which the compiler IR should be serialized. Allowed values 967 are: 968 - `hlo`: HLO output after conversion from TF 969 (https://www.tensorflow.org/xla/operation_semantics). 970 - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized 971 HLO module proto (a bytes object). 972 - `optimized_hlo`: HLO after compiler optimizations. 973 - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the 974 output is a serialized HLO module proto (a bytes object). 975 - `optimized_hlo_dot`: optimized HLO in DOT format suitable for 976 Graphviz. 977 - `device_name` can be either None, in which case the preferred device 978 is used for compilation, or a device name. It can be a full device 979 name, or a partial one, e.g., `/device:CPU:0`. 980 981 For example, for 982 983 ```python 984 @tf.function(jit_compile=True) 985 def f(x): 986 return x + 1 987 988 f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo') 989 ``` 990 991 the output is: 992 993 ``` 994 HloModule a_inference_f_13__.9 995 996 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] { 997 %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false} 998 %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1) 999 %constant.3 = f32[] constant(1) 1000 %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3) 1001 %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, 1002 f32[10,10]{1,0} %broadcast.4) 1003 %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5) 1004 %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6) 1005 ROOT %get-tuple-element.8 = f32[10,10]{1,0} 1006 get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0 1007 } 1008 ``` 1009 1010 Raises: 1011 ValueError: If an invalid `stage` is selected or if applied to a function 1012 which is not compiled (`jit_compile=True` is not set). 1013 TypeError: When called with input in graph mode. 1014 """ 1015 context.ensure_initialized() 1016 if not self._jit_compile: 1017 raise ValueError("Compiler IR can only be returned for functions marked " 1018 "with 'jit_compile=True'") 1019 1020 concrete_fn = self.get_concrete_function(*args, **kwargs) 1021 fn_name = concrete_fn.name 1022 1023 # pylint: disable=protected-access 1024 _, _, _, filtered_flat_args = \ 1025 concrete_fn._function_spec.canonicalize_function_inputs( 1026 *args, **kwargs) 1027 1028 def compiler_ir_generator(stage="hlo", device_name=None): 1029 # TODO(cheshire): This is a hack to get the current "preferred" device, 1030 # there is no current API to get it otherwise. 1031 if device_name is None: 1032 device_name = random_ops.random_normal([]).device 1033 res_bytes = context.context().get_compiler_ir( 1034 device_name=device_name, 1035 stage=stage, 1036 function_name=fn_name, 1037 args=list(filtered_flat_args) + concrete_fn.captured_inputs) 1038 if stage in ("hlo_serialized", "optimized_hlo_serialized"): 1039 return res_bytes 1040 else: 1041 return res_bytes.decode("utf-8") 1042 1043 return compiler_ir_generator 1044 1045 @property 1046 def python_function(self): 1047 """The python function wrapped in this tf.function.""" 1048 return self._python_function 1049 1050 @property 1051 def input_signature(self): 1052 return self._function_spec.input_signature 1053 1054 @property 1055 def function_spec(self): 1056 return self._function_spec 1057 1058 def pretty_printed_concrete_signatures(self, verbose=True): 1059 joiner = "\n\n" if verbose else "\n" 1060 return joiner.join([ 1061 c.pretty_printed_signature(verbose=verbose) 1062 for c in self._list_all_concrete_functions() 1063 ]) 1064 1065 def _initialize_uninitialized_variables(self, initializers): 1066 """Make and call a `ConcreteFunction` which initializes variables.""" 1067 1068 if not initializers: 1069 return 1070 1071 var_is_initialized = _evaluate_var_is_initialized( 1072 [v for v, _ in initializers]) 1073 1074 # Note: using defun here avoids an infinite recursion. 1075 # Most of the code in this function runs eagerly with init_scope, where 1076 # autograph is not necessary. 1077 @function_lib.defun(autograph=False) 1078 def initialize_variables(): 1079 op_map = object_identity.ObjectIdentityDictionary() 1080 1081 inits = [] 1082 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1083 with ops.init_scope(): 1084 if is_initialized: 1085 continue 1086 inits.append(init) 1087 1088 if inits: 1089 op_map = lift_to_graph.lift_to_graph( 1090 inits, ops.get_default_graph(), op_map=op_map) 1091 for (v, init), is_initialized in zip(initializers, var_is_initialized): 1092 with ops.init_scope(): 1093 if is_initialized: 1094 continue 1095 v.assign(op_map[init], read_value=False) 1096 1097 with ops.init_scope(): 1098 return initialize_variables.get_concrete_function()() 1099 1100 def get_initialization_function(self, *args, **kwargs): 1101 """Returns a `ConcreteFunction` which initializes this function's variables. 1102 1103 Requires that this function hasn't been accessed yet through either calling 1104 it or calling get_concrete_function. Fails if we cannot build an initializer 1105 function which does not depend on the concrete values of the inputs to this 1106 function. 1107 1108 Note that running this function will overwrite any values currently assigned 1109 to variables, for example restores from a checkpoint. 1110 1111 Args: 1112 *args: arguments to the underlying python callable. 1113 **kwargs: keyword arguments to the python callable. 1114 1115 Returns: 1116 A `ConcreteFunction` object which initializes the variables of this 1117 function. 1118 1119 Raises: 1120 RuntimeError: if called after the variables have been initialized. 1121 """ 1122 with self._lock: 1123 if self._stateful_fn is not None: 1124 raise RuntimeError( 1125 "get_initialization_function cannot be called after the function " 1126 "has been used") 1127 # Here we trace the function, collect the initializers, and attempt to 1128 # extract them and run them eagerly. Fail only if we cannot do so. 1129 initializers = [] 1130 self._initialize(args, kwargs, add_initializers_to=initializers) 1131 1132 # Note: using defun here avoids an infinite recursion. 1133 @function_lib.defun 1134 def initialize_variables(): 1135 for v, init in initializers: 1136 v.assign( 1137 lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init], 1138 read_value=False) 1139 1140 return initialize_variables.get_concrete_function() 1141 1142 def _list_all_concrete_functions(self): 1143 """Returns all concrete functions.""" 1144 if self.input_signature is not None: 1145 self.get_concrete_function() 1146 concrete_functions = [] 1147 # pylint: disable=protected-access 1148 if self._stateful_fn: 1149 concrete_functions.extend( 1150 self._stateful_fn._function_cache.all_values()) 1151 if self._stateless_fn: 1152 concrete_functions.extend( 1153 self._stateless_fn._function_cache.all_values()) 1154 # pylint: enable=protected-access 1155 return concrete_functions 1156 1157 def _list_all_concrete_functions_for_serialization(self): 1158 """Returns all concrete functions for serialization. 1159 1160 Returns: 1161 A list of instances of `ConcreteFunction`. 1162 """ 1163 concrete_functions = self._list_all_concrete_functions() 1164 seen_signatures = [] 1165 for concrete_function in concrete_functions: 1166 signature = concrete_function.structured_input_signature 1167 flattened = nest.flatten(signature) 1168 if any( 1169 isinstance(arg, func_graph_module.UnknownArgument) 1170 for arg in flattened): 1171 logging.info("Unsupported signature for serialization: %s.", signature) 1172 continue 1173 equal_to_signature = functools.partial( 1174 function_lib.is_same_structure, signature, check_values=True) 1175 if not any(equal_to_signature(s) for s in seen_signatures): 1176 seen_signatures.append(signature) 1177 1178 # Re-create concrete functions for these signatures. Re-creating ensures 1179 # that if the cache key has changed, the function will be traced again. 1180 concrete_functions = [] 1181 for args, kwargs in seen_signatures: 1182 concrete_functions.append(self.get_concrete_function(*args, **kwargs)) 1183 return concrete_functions 1184 1185 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 1186 """Returns a `ConcreteFunction` specialized to inputs and execution context. 1187 1188 Unlike `get_concrete_function(...)`, the graph will be deleted when the 1189 returned function is deleted. It's useful to avoid creating a reference 1190 cycle when you know for sure that the graph will be no longer used without 1191 the returned function. 1192 1193 Args: 1194 *args: inputs to specialize on. 1195 **kwargs: inputs to specialize on. 1196 1197 Returns: 1198 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 1199 1200 Raises: 1201 ValueError: if this object has not yet been called on concrete values. 1202 """ 1203 with self._lock: 1204 if self._stateful_fn is None: 1205 initializers = [] 1206 self._initialize(args, kwargs, add_initializers_to=initializers) 1207 self._initialize_uninitialized_variables(initializers) 1208 1209 if self._created_variables: 1210 # In this case we have created variables on the first call, so we run the 1211 # defunned version which is guaranteed to never create variables. 1212 return self._stateless_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1213 *args, **kwargs) 1214 elif self._stateful_fn is not None: 1215 # In this case we have not created variables on the first call. So we can 1216 # run the first trace but we should fail if variables are created. 1217 concrete = self._stateful_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 1218 *args, **kwargs) 1219 if self._created_variables: 1220 raise ValueError("Creating variables on a non-first call to a function" 1221 " decorated with tf.function.") 1222 return concrete 1223 1224 def get_concrete_function(self, *args, **kwargs): 1225 """Returns a `ConcreteFunction` specialized to inputs and execution context. 1226 1227 If this `Function` was created with an `input_signature`, `args` and 1228 `kwargs` may be omitted. With an input signature there is only one 1229 concrete function associated with this `Function`. 1230 1231 If there is no fixed `input_signature` associated with this 1232 `Function`, positional and keyword arguments to `get_concrete_function` 1233 follow the same rules as input signature specification, with `tf.TensorSpec` 1234 objects describing `tf.Tensor`s which will be passed to the concrete 1235 function. 1236 1237 Each `tf.Tensor` argument to the concrete function must have a unique name, 1238 either because it is the only one associated with a named argument of the 1239 Python function or because an explicit `name=` was passed to its 1240 `tf.TensorSpec` object. These names become the argument names for the 1241 concrete function. 1242 1243 Arguments to the concrete function may always be specified as keyword 1244 arguments, naming the Tensor input. Positional arguments may be used instead 1245 when each preceding argument to the Python function is a Tensor. 1246 1247 ```python 1248 @tf.function 1249 def f(x): 1250 return x 1251 1252 f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64)) 1253 f_concrete(tf.constant(1.)) 1254 f_concrete(x=tf.constant(1.)) 1255 ``` 1256 1257 Nested structures containing Tensors may be specified when retrieving 1258 concrete functions. Structures with multiple Tensors are expanded into 1259 multiple arguments of the concrete function. Since multiple concrete 1260 function arguments are associated with one argument to the original 1261 function, these Tensors must be named explicitly. Tensors in nested 1262 structures may not be passed using positional arguments when calling the 1263 concrete function. 1264 1265 ```python 1266 f_concrete2 = f.get_concrete_function( 1267 (tf.TensorSpec(None, tf.float64, name="first"), 1268 tf.TensorSpec([], tf.float32, name="second"))) 1269 # Keyword arguments are required when identifying Tensors in nested 1270 # structures. 1271 f_concrete2(first=tf.constant([1.]), second=tf.constant(0.)) 1272 ``` 1273 1274 Functions with fixed input signatures have only one concrete function 1275 associated with them, which can be retrieved without specifying any 1276 arguments. As before Tensors must have unique names, either inferred from 1277 the argument names in the original Python function or specified 1278 explicitly. 1279 1280 ```python 1281 @tf.function(input_signature=(tf.TensorSpec(None, tf.float32))) 1282 def f_sig(y): 1283 return y 1284 1285 f_sig_concrete = f.get_concrete_function() 1286 f_sig_concrete(tf.constant(1.)) 1287 f_sig_concrete(y=tf.constant(1.)) 1288 ``` 1289 1290 Args: 1291 *args: inputs to specialize on. 1292 **kwargs: inputs to specialize on. 1293 1294 Returns: 1295 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 1296 1297 Raises: 1298 ValueError: if this object has not yet been called on concrete values. 1299 """ 1300 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1301 concrete._garbage_collector.release() # pylint: disable=protected-access 1302 return concrete 1303 1304 def __get__(self, instance, owner): 1305 """Makes it possible to defun instance methods.""" 1306 del owner 1307 # `instance` here is the instance that this `Function` was accessed through 1308 # e.g., for 1309 # 1310 # class Foo(object): 1311 # 1312 # @function.defun 1313 # def bar(self): 1314 # ... 1315 # 1316 # foo = Foo() 1317 # foo.bar() # `foo.bar` is a `Function` instance 1318 # 1319 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 1320 # new instance of `Function` here to allow different instances each 1321 # to create variables once, thereby allowing methods to be decorated with 1322 # tf.function. Keeps a cache to avoid retracing the function every time the 1323 # descriptor is accessed. 1324 if instance not in self._descriptor_cache: 1325 if instance is None: 1326 return self 1327 self._descriptor_cache[instance] = ( 1328 function_lib.class_method_to_instance_method(self, instance)) 1329 return self._descriptor_cache[instance] 1330 1331 1332@tf_export("function") 1333@deprecation.deprecated_args(None, 1334 "experimental_compile is deprecated, use " 1335 "jit_compile instead", "experimental_compile") 1336def function(func=None, 1337 input_signature=None, 1338 autograph=True, 1339 jit_compile=None, 1340 experimental_implements=None, 1341 experimental_autograph_options=None, 1342 experimental_relax_shapes=False, 1343 experimental_compile=None, 1344 experimental_follow_type_hints=None): 1345 """Compiles a function into a callable TensorFlow graph. 1346 1347 `tf.function` constructs a callable that executes a TensorFlow graph 1348 (`tf.Graph`) created by trace-compiling the TensorFlow operations in `func`, 1349 effectively executing `func` as a TensorFlow graph. 1350 1351 Example usage: 1352 1353 >>> @tf.function 1354 ... def f(x, y): 1355 ... return x ** 2 + y 1356 >>> x = tf.constant([2, 3]) 1357 >>> y = tf.constant([3, -2]) 1358 >>> f(x, y) 1359 <tf.Tensor: ... numpy=array([7, 7], ...)> 1360 1361 _Features_ 1362 1363 `func` may use data-dependent control flow, including `if`, `for`, `while` 1364 `break`, `continue` and `return` statements: 1365 1366 >>> @tf.function 1367 ... def f(x): 1368 ... if tf.reduce_sum(x) > 0: 1369 ... return x * x 1370 ... else: 1371 ... return -x // 2 1372 >>> f(tf.constant(-2)) 1373 <tf.Tensor: ... numpy=1> 1374 1375 `func`'s closure may include `tf.Tensor` and `tf.Variable` objects: 1376 1377 >>> @tf.function 1378 ... def f(): 1379 ... return x ** 2 + y 1380 >>> x = tf.constant([-2, -3]) 1381 >>> y = tf.Variable([3, -2]) 1382 >>> f() 1383 <tf.Tensor: ... numpy=array([7, 7], ...)> 1384 1385 `func` may also use ops with side effects, such as `tf.print`, `tf.Variable` 1386 and others: 1387 1388 >>> v = tf.Variable(1) 1389 >>> @tf.function 1390 ... def f(x): 1391 ... for i in tf.range(x): 1392 ... v.assign_add(i) 1393 >>> f(3) 1394 >>> v 1395 <tf.Variable ... numpy=4> 1396 1397 Important: Any Python side-effects (appending to a list, printing with 1398 `print`, etc) will only happen once, when `func` is traced. To have 1399 side-effects executed into your `tf.function` they need to be written 1400 as TF ops: 1401 1402 >>> l = [] 1403 >>> @tf.function 1404 ... def f(x): 1405 ... for i in x: 1406 ... l.append(i + 1) # Caution! Will only happen once when tracing 1407 >>> f(tf.constant([1, 2, 3])) 1408 >>> l 1409 [<tf.Tensor ...>] 1410 1411 Instead, use TensorFlow collections like `tf.TensorArray`: 1412 1413 >>> @tf.function 1414 ... def f(x): 1415 ... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) 1416 ... for i in range(len(x)): 1417 ... ta = ta.write(i, x[i] + 1) 1418 ... return ta.stack() 1419 >>> f(tf.constant([1, 2, 3])) 1420 <tf.Tensor: ..., numpy=array([2, 3, 4], ...)> 1421 1422 _`tf.function` is polymorphic_ 1423 1424 Internally, `tf.function` can build more than one graph, to support arguments 1425 with different data types or shapes, since TensorFlow can build more 1426 efficient graphs that are specialized on shapes and dtypes. `tf.function` 1427 also treats any pure Python value as opaque objects, and builds a separate 1428 graph for each set of Python arguments that it encounters. 1429 1430 To obtain an individual graph, use the `get_concrete_function` method of 1431 the callable created by `tf.function`. It can be called with the same 1432 arguments as `func` and returns a special `tf.Graph` object: 1433 1434 >>> @tf.function 1435 ... def f(x): 1436 ... return x + 1 1437 >>> isinstance(f.get_concrete_function(1).graph, tf.Graph) 1438 True 1439 1440 Caution: Passing python scalars or lists as arguments to `tf.function` will 1441 always build a new graph. To avoid this, pass numeric arguments as Tensors 1442 whenever possible: 1443 1444 >>> @tf.function 1445 ... def f(x): 1446 ... return tf.abs(x) 1447 >>> f1 = f.get_concrete_function(1) 1448 >>> f2 = f.get_concrete_function(2) # Slow - builds new graph 1449 >>> f1 is f2 1450 False 1451 >>> f1 = f.get_concrete_function(tf.constant(1)) 1452 >>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1 1453 >>> f1 is f2 1454 True 1455 1456 Python numerical arguments should only be used when they take few distinct 1457 values, such as hyperparameters like the number of layers in a neural network. 1458 1459 _Input signatures_ 1460 1461 For Tensor arguments, `tf.function` instantiates a separate graph for every 1462 unique set of input shapes and datatypes. The example below creates two 1463 separate graphs, each specialized to a different shape: 1464 1465 >>> @tf.function 1466 ... def f(x): 1467 ... return x + 1 1468 >>> vector = tf.constant([1.0, 1.0]) 1469 >>> matrix = tf.constant([[3.0]]) 1470 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1471 False 1472 1473 An "input signature" can be optionally provided to `tf.function` to control 1474 the graphs traced. The input signature specifies the shape and type of each 1475 Tensor argument to the function using a `tf.TensorSpec` object. More general 1476 shapes can be used. This is useful to avoid creating multiple graphs when 1477 Tensors have dynamic shapes. It also restricts the shape and datatype of 1478 Tensors that can be used: 1479 1480 >>> @tf.function( 1481 ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 1482 ... def f(x): 1483 ... return x + 1 1484 >>> vector = tf.constant([1.0, 1.0]) 1485 >>> matrix = tf.constant([[3.0]]) 1486 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 1487 True 1488 1489 _Variables may only be created once_ 1490 1491 `tf.function` only allows creating new `tf.Variable` objects when it is called 1492 for the first time: 1493 1494 >>> class MyModule(tf.Module): 1495 ... def __init__(self): 1496 ... self.v = None 1497 ... 1498 ... @tf.function 1499 ... def __call__(self, x): 1500 ... if self.v is None: 1501 ... self.v = tf.Variable(tf.ones_like(x)) 1502 ... return self.v * x 1503 1504 In general, it is recommended to create stateful objects like `tf.Variable` 1505 outside of `tf.function` and passing them as arguments. 1506 1507 _Using type annotations to improve performance_ 1508 1509 'experimental_follow_type_hints` can be used along with type annotations to 1510 improve performance by reducing the number of expensive graph retracings. 1511 For example, an argument annotated with `tf.Tensor` is converted to Tensor 1512 even when the input is a non-Tensor value. 1513 1514 >>> @tf.function(experimental_follow_type_hints=True) 1515 ... def f_with_hints(x: tf.Tensor): 1516 ... print('Tracing') 1517 ... return x 1518 >>> @tf.function(experimental_follow_type_hints=False) 1519 ... def f_no_hints(x: tf.Tensor): 1520 ... print('Tracing') 1521 ... return x 1522 >>> f_no_hints(1) 1523 Tracing 1524 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1525 >>> f_no_hints(2) 1526 Tracing 1527 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1528 >>> f_with_hints(1) 1529 Tracing 1530 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1531 >>> f_with_hints(2) 1532 <tf.Tensor: shape=(), dtype=int32, numpy=2> 1533 1534 Args: 1535 func: the function to be compiled. If `func` is None, `tf.function` returns 1536 a decorator that can be invoked with a single argument - `func`. In other 1537 words, `tf.function(input_signature=...)(func)` is equivalent to 1538 `tf.function(func, input_signature=...)`. The former can be used as 1539 decorator. 1540 input_signature: A possibly nested sequence of `tf.TensorSpec` objects 1541 specifying the shapes and dtypes of the Tensors that will be supplied to 1542 this function. If `None`, a separate function is instantiated for each 1543 inferred input signature. If input_signature is specified, every input to 1544 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. 1545 autograph: Whether autograph should be applied on `func` before tracing a 1546 graph. Data-dependent control flow requires `autograph=True`. For more 1547 information, see the [tf.function and AutoGraph guide]( 1548 https://www.tensorflow.org/guide/function). 1549 jit_compile: If `True`, compiles the function using 1550 [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations, 1551 such as fusion, and attempts to emit more efficient code. This may 1552 drastically improve the performance. If set to `True`, 1553 the whole function needs to be compilable by XLA, or an 1554 `errors.InvalidArgumentError` is thrown. 1555 If `None` (default), compiles the function with XLA when running on TPU 1556 and goes through the regular function execution path when running on 1557 other devices. 1558 If `False`, executes the function without XLA compilation. Set this value 1559 to `False` when directly running a multi-device function on TPUs (e.g. two 1560 TPU cores, one TPU core and its host CPU). 1561 Not all functions are compilable, see a list of 1562 [sharp corners](https://tensorflow.org/xla/known_issues). 1563 experimental_implements: If provided, contains a name of a "known" function 1564 this implements. For example "mycompany.my_recurrent_cell". 1565 This is stored as an attribute in inference function, 1566 which can then be detected when processing serialized function. 1567 See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long 1568 for details. For an example of utilizing this attribute see this 1569 [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc) 1570 The code above automatically detects and substitutes function that 1571 implements "embedded_matmul" and allows TFLite to substitute its own 1572 implementations. For instance, a tensorflow user can use this 1573 attribute to mark that their function also implements 1574 `embedded_matmul` (perhaps more efficiently!) 1575 by specifying it using this parameter: 1576 `@tf.function(experimental_implements="embedded_matmul")` 1577 This can either be specified as just the string name of the function or 1578 a NameAttrList corresponding to a list of key-value attributes associated 1579 with the function name. The name of the function will be in the 'name' 1580 field of the NameAttrList. To define a formal TF op for this function 1581 implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr) 1582 project. 1583 experimental_autograph_options: Optional tuple of 1584 `tf.autograph.experimental.Feature` values. 1585 experimental_relax_shapes: When True, `tf.function` may generate fewer, 1586 graphs that are less specialized on input shapes. 1587 experimental_compile: Deprecated alias to 'jit_compile'. 1588 experimental_follow_type_hints: When True, the function may use type 1589 annotations from `func` to optimize the tracing performance. For example, 1590 arguments annotated with `tf.Tensor` will automatically be converted 1591 to a Tensor. 1592 1593 Returns: 1594 If `func` is not None, returns a callable that will execute the compiled 1595 function (and return zero or more `tf.Tensor` objects). 1596 If `func` is None, returns a decorator that, when invoked with a single 1597 `func` argument, returns a callable equivalent to the case above. 1598 1599 Raises: 1600 ValueError when attempting to use jit_compile=True, but XLA support is not 1601 linked. 1602 """ 1603 # TODO(mdan): Link to `tf.types` section once published. 1604 if input_signature is not None: 1605 function_lib.validate_signature(input_signature) 1606 if experimental_follow_type_hints is None: 1607 experimental_follow_type_hints = False 1608 1609 def decorated(inner_function): 1610 try: 1611 name = inner_function.__name__ 1612 except AttributeError: 1613 name = "function" 1614 return tf_decorator.make_decorator( 1615 inner_function, 1616 decorator_name="tf.function", 1617 decorator_func=Function( 1618 inner_function, 1619 name, 1620 input_signature=input_signature, 1621 autograph=autograph, 1622 experimental_autograph_options=experimental_autograph_options, 1623 experimental_relax_shapes=experimental_relax_shapes, 1624 1625 # TODO(b/171825496): Update once `experimental_compile` is removed 1626 # entirely in favor of 'jit_compile'. 1627 jit_compile=deprecation.deprecated_argument_lookup( 1628 "jit_compile", 1629 jit_compile, 1630 "experimental_compile", 1631 experimental_compile), 1632 experimental_implements=experimental_implements, 1633 experimental_follow_type_hints=experimental_follow_type_hints)) 1634 1635 # This code path is for the `foo = tf.function(foo, ...)` use case 1636 if func is not None: 1637 return decorated(func) 1638 1639 # This code path is for the 1640 # 1641 # @tf.function(...) 1642 # def foo(...): 1643 # ... 1644 # 1645 # use case, which is equivalent to `foo = tf.function(...)(foo)` 1646 return decorated 1647