1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Defun decorator for defining graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import functools 24import itertools 25import pprint 26import threading 27import types as types_lib 28import weakref 29 30import numpy as np 31import six 32from six.moves import map 33 34from tensorflow.core.framework import attr_value_pb2 35from tensorflow.core.framework import function_pb2 36from tensorflow.python import pywrap_tfe 37from tensorflow.python.client import pywrap_tf_session 38from tensorflow.python.eager import backprop 39from tensorflow.python.eager import backprop_util 40from tensorflow.python.eager import context 41from tensorflow.python.eager import execute 42from tensorflow.python.eager import forwardprop_util 43from tensorflow.python.eager import monitoring 44from tensorflow.python.eager import tape 45from tensorflow.python.eager.graph_only_ops import graph_placeholder 46from tensorflow.python.framework import c_api_util 47from tensorflow.python.framework import composite_tensor 48from tensorflow.python.framework import constant_op 49from tensorflow.python.framework import device as pydev 50from tensorflow.python.framework import dtypes 51from tensorflow.python.framework import error_interpolation 52from tensorflow.python.framework import errors 53from tensorflow.python.framework import func_graph as func_graph_module 54from tensorflow.python.framework import ops 55from tensorflow.python.framework import tensor_shape 56from tensorflow.python.framework import tensor_spec 57from tensorflow.python.framework import type_spec 58from tensorflow.python.ops import array_ops 59from tensorflow.python.ops import control_flow_ops 60from tensorflow.python.ops import default_gradient 61from tensorflow.python.ops import functional_ops 62from tensorflow.python.ops import gradients_util 63from tensorflow.python.ops import handle_data_util 64from tensorflow.python.ops import resource_variable_ops 65from tensorflow.python.platform import tf_logging as logging 66from tensorflow.python.profiler import trace 67from tensorflow.python.saved_model import save_context 68from tensorflow.python.types import core 69from tensorflow.python.util import _pywrap_utils 70from tensorflow.python.util import compat 71from tensorflow.python.util import function_utils 72from tensorflow.python.util import lazy_loader 73from tensorflow.python.util import memory 74from tensorflow.python.util import nest 75from tensorflow.python.util import object_identity 76from tensorflow.python.util import tf_decorator 77from tensorflow.python.util import tf_inspect 78from tensorflow.python.util.tf_export import tf_export 79 80# Loaded lazily due to a circular dependency (roughly 81# tf.function->autograph->->dataset->tf.function). 82# TODO(b/133251390): Use a regular import. 83ag_ctx = lazy_loader.LazyLoader( 84 "ag_ctx", globals(), 85 "tensorflow.python.autograph.core.ag_ctx") 86np_arrays = lazy_loader.LazyLoader( 87 "np_arrays", globals(), 88 "tensorflow.python.ops.numpy_ops.np_arrays") 89 90 91FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" 92BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" 93IMPLEMENTS_ATTRIBUTE_NAME = "_implements" 94SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous" 95 96_graph_building_time_counter = monitoring.Counter( 97 "/tensorflow/core/tf_function/graph_building_time_usecs", 98 "Time for tf.function to build a graph (us).") 99 100 101# TODO(b/195985838): cleanup this function. 102def _make_input_signature_hashable(elem): 103 """Rewrite input signature to be hashable. 104 105 We replace nested variables in the input signature with TensorSpec in order to 106 be hashable. 107 108 Args: 109 elem: Input signature element 110 111 Returns: 112 A hashable object for the requested input signature 113 """ 114 try: 115 hash(elem) 116 except TypeError: 117 # TODO(slebedev): consider using nest. 118 if isinstance(elem, tuple): 119 return tuple(map(_make_input_signature_hashable, elem)) 120 121 # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect 122 # all recognized types to be hashable. 123 assert isinstance(elem, weakref.ReferenceType) 124 v = elem() 125 126 if resource_variable_ops.is_resource_variable(v): 127 # We special case variables here to use unique_id as the cache key. This 128 # ensures we have to retrace whenever a different variable is passed in. 129 # This is needed to support cases where the user may use the id of a 130 # variable in the function perhaps as a lookup in a dictionary. 131 # 132 # This choice leads to more retracing when we could have possibly used the 133 # shape and dtype instead. However, we expect the number of variables in a 134 # program to be bounded, and correspondingly the number of retraces. 135 # 136 # Note we also include the class name to avoid collisions with strings. 137 return v.__class__, v._unique_id # pylint: disable=protected-access 138 139 if _is_ndarray(v): 140 # Numpy arrays are not hashable, but when calling functions we treat them 141 # in the same way as tf.Tensors. 142 if not hasattr(v, "shape") or not hasattr(v, "dtype"): 143 # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs. 144 v = _as_ndarray(v) 145 return tensor_spec.TensorSpec(v.shape, v.dtype) 146 147 raise ValueError("Arguments to a tf.function must be a nested structure of " 148 "Tensors, Variables, NumPy arrays, or hashable Python " 149 f"objects, got {type(v)}.") 150 151 return elem 152 153 154CacheKey = collections.namedtuple("CacheKey", [ 155 "input_signature", 156 "parent_graph", 157 "device_functions", 158 "colocation_stack", 159 "in_cross_replica_context", 160 "variable_policy", 161 "xla_context_id", 162]) 163 164 165def _type_spec_for(x): 166 """Returns a TypeSpec for `x`, or `None` if `x` doesn't have a TensorSpec.""" 167 if isinstance(x, ops.Tensor): 168 return tensor_spec.TensorSpec.from_tensor(x) 169 elif isinstance(x, type_spec.TypeSpec): 170 return x 171 elif isinstance(x, composite_tensor.CompositeTensor): 172 return x._type_spec # pylint: disable=protected-access 173 else: 174 return None 175 176 177def _is_type_subset(a, b): 178 """Returns true if TypeSpec `b` is a subset of type `a` (or if a is None.)""" 179 if a is None: 180 return True 181 else: 182 return a.most_specific_compatible_type(b) == a 183 184 185def _shape_relaxed_type_for_composite_tensor(x): 186 """Returns a shape-relaxed TypeSpec for x (if composite) or x (if not).""" 187 if isinstance(x, composite_tensor.CompositeTensor): 188 # pylint: disable=protected-access 189 return x._type_spec._with_tensor_ranks_only() 190 else: 191 return x 192 193 194def common_shape(x, y): 195 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 196 if x is None != y is None: 197 raise RuntimeError( 198 "Cannot find a common shape when LHS shape is None but RHS shape " 199 f"is not (or vice versa): {x} vs. {y}.") 200 if x is None: 201 return None # The associated input was not a Tensor, no shape generated. 202 if not isinstance(x, tensor_shape.TensorShape): 203 raise TypeError(f"`x` must be a TensorShape, got type {type(x)}.") 204 if not isinstance(y, tensor_shape.TensorShape): 205 raise TypeError(f"`y` must be a TensorShape, got type {type(y)}.") 206 if x.rank != y.rank or x.rank is None: 207 return tensor_shape.TensorShape(None) 208 dims = [] 209 for dim_x, dim_y in zip(x.dims, y.dims): 210 if (dim_x != dim_y 211 or tensor_shape.dimension_value(dim_x) is None 212 or tensor_shape.dimension_value(dim_y) is None): 213 dims.append(None) 214 else: 215 dims.append(tensor_shape.dimension_value(dim_x)) 216 return tensor_shape.TensorShape(dims) 217 218 219def is_same_structure(structure1, 220 structure2, 221 check_values=False): 222 """Check two structures for equality, optionally of types and of values.""" 223 try: 224 nest.assert_same_structure(structure1, structure2, expand_composites=True) 225 except (ValueError, TypeError): 226 return False 227 if check_values: 228 flattened1 = nest.flatten(structure1, expand_composites=True) 229 flattened2 = nest.flatten(structure2, expand_composites=True) 230 # First check the types to avoid AttributeErrors. 231 if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)): 232 return False 233 return flattened1 == flattened2 234 return True 235 236 237def _parse_func_attrs(attributes): 238 """Convert the keyword arguments into function_def attributes. 239 240 Currently only support primitive types: bool, int, float and string. 241 242 Args: 243 attributes: the dictionary of attributes. 244 Returns: 245 A dict of attributes where the key is the name of attribute and the value 246 is the AttrValue proto. 247 Raises: 248 ValueError: If the kwargs contains unallowlisted name or unsupported value 249 types. 250 """ 251 attrs = {} 252 for key, value in attributes.items(): 253 if isinstance(value, attr_value_pb2.AttrValue): 254 attrs[key] = value 255 # bool type check has to happen before int since bool is a subclass of int. 256 elif isinstance(value, bool): 257 attrs[key] = attr_value_pb2.AttrValue(b=value) 258 elif isinstance(value, int): 259 attrs[key] = attr_value_pb2.AttrValue(i=value) 260 elif isinstance(value, float): 261 attrs[key] = attr_value_pb2.AttrValue(f=value) 262 elif isinstance(value, (str, bytes, six.text_type)): 263 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 264 else: 265 raise ValueError(f"Attribute {key} must be bool, int, float, string, or " 266 f"AttrValue. Got {type(value)}.") 267 return attrs 268 269 270class _InterpolateFunctionError(object): 271 """Context Manager that interpolates the exception from 'top_level_func'.""" 272 273 __slots__ = ["_func"] 274 275 def __init__(self, top_level_func): 276 self._func = top_level_func 277 278 def __enter__(self): 279 pass 280 281 def __exit__(self, typ, exc, tb): 282 if not exc or not isinstance(exc, errors.OpError): 283 return False 284 message = compat.as_text(exc.message) 285 _, tags = error_interpolation.parse_message(message) 286 g = None 287 func_stack = [] 288 for t in tags: 289 if t.type == "function_node": 290 # TODO(mdan): Tests should cover this. 291 if t.name == compat.as_str(self._func.name): 292 g = self._func.graph 293 elif g: 294 next_func = g._get_function(t.name) # pylint: disable=protected-access 295 if next_func is not None and isinstance(next_func, 296 _EagerDefinedFunction): 297 g = next_func.graph 298 if g: 299 func_stack.append(g.name) 300 else: 301 func_stack.append("<unknown>") 302 if g: 303 message = error_interpolation.interpolate(message, g) 304 if len(func_stack) >= 2: 305 message += "\n\nFunction call stack:\n" 306 message += " -> ".join(func_stack) 307 message += "\n" 308 exc._message = message # pylint: disable=protected-access 309 return False 310 311 312_function_callbacks = set() 313 314 315def add_function_callback(function_callback): 316 """Add a callback function for Function creation. 317 318 The callback function has the signature: 319 320 `def function_callback(function, name, graph, inputs, outputs):` 321 322 where: 323 - `function`: _EagerDefinedFunction being created before finalizing the graph. 324 Do not modify the function directly but instead modify the graph. 325 - `name`: name of the function. 326 - `graph`: Graph of the function. 327 - `inputs`: `tuple` of tensors used as inputs to the function. 328 - `outputs`: `tuple` of tensors used as outputs from the function. 329 330 The callback is at the top of the `_EagerDefinedFunction` construction, giving 331 callback an opportunity to make the last edits to the graph. Do not make 332 changes to `graph, inputs`, and `outputs` manually, but, instead, set the 333 `graph` as the default then define ops. 334 335 Repeated registration of the same callback function is idempotent. 336 After a callback is added, it can be removed with the 337 `remove_function_callback()` method. 338 339 Args: 340 function_callback: The callback to add. 341 """ 342 _function_callbacks.add(function_callback) 343 344 345def remove_function_callback(function_callback): 346 """Remove an already-added function callback. 347 348 See the doc string of `add_function_callback()` for more information. 349 350 Args: 351 function_callback: The callback to remove. 352 """ 353 _function_callbacks.remove(function_callback) 354 355 356def clear_function_callbacks(): 357 """Clear all function callbacks, if any have been regisered.""" 358 _function_callbacks.clear() 359 360 361_FORWARD_PREFIX = "__forward_" 362_BACKWARD_PREFIX = "__backward_" 363_INFERENCE_PREFIX = "__inference_" 364 365 366def _forward_name(n): 367 """The name of a generated forward defun named n.""" 368 return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid()) 369 370 371def _backward_name(n): 372 """The name of a generated backward defun named n.""" 373 return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid()) 374 375 376def _inference_name(n): 377 """The name of a forward-but-no-gradient defun named n.""" 378 return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid()) 379 380 381def _enclosing_xla_context(): 382 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" 383 graph = ops.get_default_graph() 384 while graph is not None: 385 # pylint: disable=protected-access 386 context_ = graph._get_control_flow_context() 387 # pylint: enable=protected-access 388 while context_ is not None: 389 if isinstance(context_, control_flow_ops.XLAControlFlowContext): 390 return context_ 391 context_ = context_.outer_context 392 # This may be a FuncGraph due to defuns or v2 control flow. We need to 393 # find the original graph with the XLAControlFlowContext. 394 graph = getattr(graph, "outer_graph", None) 395 return None 396 397 398class _EagerDefinedFunctionDeleter(object): 399 """Unregister function from eager context.""" 400 401 __slots__ = ["name"] 402 403 def __init__(self, name): 404 self.name = name 405 406 def __del__(self): 407 try: 408 context.remove_function(self.name) 409 except TypeError: 410 # Suppress some exceptions, mainly for the case when we're running on 411 # module deletion. Things that can go wrong include the context module 412 # already being unloaded, self._handle._handle_data no longer being 413 # valid, and so on. Printing warnings in these cases is silly 414 # (exceptions raised from __del__ are printed as warnings to stderr). 415 pass # 'NoneType' object is not callable when the handle has been 416 # partially unloaded. 417 except AttributeError: 418 pass # 'NoneType' object has no attribute 'eager_mode' when context has 419 # been unloaded. Will catch other module unloads as well. 420 421 422class FunctionAlreadyGarbageCollectedError(Exception): 423 424 def __init__(self, function_name): 425 super(FunctionAlreadyGarbageCollectedError, self).__init__( 426 "{} has already been garbage collected and cannot be called.".format( 427 function_name)) 428 429 430# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction 431# so it doesn't have the definition-generating logic and is just a container for 432# an already-defined function. 433class _EagerDefinedFunction(object): 434 """Callable with the interface of `framework.function._DefinedFunction`. 435 436 `_EagerDefinedFunction` encapsulates a function definition and its properties, 437 and it provides a method for calling the encapsulated function. Some Ops 438 take functions as attributes, which have type `func`; an instance of this 439 class may be provided as the value of these `func` attributes. 440 """ 441 442 def __init__(self, name, graph, inputs, outputs, attrs): 443 """Initializes an eager defined function. 444 445 Args: 446 name: str, the name for the created function. 447 graph: Graph, the graph containing the operations in the function 448 inputs: the tensors in the graph to be used as inputs to the function 449 outputs: the tensors in the graph which will be outputs from the function 450 attrs: dict mapping names of attributes to their AttrValue values 451 """ 452 for function_callback in _function_callbacks: 453 function_callback(self, name, graph, tuple(inputs), tuple(outputs)) 454 455 input_ops = set(arg.op for arg in inputs) 456 operations = [op for op in graph.get_operations() if op not in input_ops] 457 458 graph_output_names = graph._output_names # pylint: disable=protected-access 459 if (graph_output_names is not None and 460 all(ops.tensor_id(t) in graph_output_names for t in outputs)): 461 output_names = [ 462 compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs 463 ] 464 if len(set(output_names)) != len(output_names): 465 # There are duplicate names for some reason, probably an invalid 466 # signature. Revert to auto-naming. 467 output_names = [] 468 else: 469 output_names = [] 470 fn = pywrap_tf_session.TF_GraphToFunction_wrapper( 471 graph._c_graph, # pylint: disable=protected-access 472 compat.as_str(name), 473 False, 474 [o._c_op for o in operations], # pylint: disable=protected-access 475 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 476 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 477 output_names, 478 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access 479 [], # control_output_names 480 None, 481 compat.as_str("")) 482 483 for name, attr_value in attrs.items(): 484 serialized = attr_value.SerializeToString() 485 # TODO(iga): this creates and deletes a new TF_Status for every attr. 486 # It might be worth creating a convenient way to re-use status. 487 pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name), 488 serialized) 489 490 # TODO(apassos) avoid creating a FunctionDef (specially to grab the 491 # signature, but also in general it's nice not to depend on it. 492 with c_api_util.tf_buffer() as buffer_: 493 pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_) 494 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 495 function_def = function_pb2.FunctionDef() 496 function_def.ParseFromString(compat.as_bytes(proto_data)) 497 self._name = compat.as_bytes(function_def.signature.name) 498 with ops.init_scope(): 499 if context.executing_eagerly(): 500 context.ensure_initialized() 501 context.add_function(fn) 502 self._function_deleter = _EagerDefinedFunctionDeleter(self.name) 503 self._registered_on_context = True 504 self.definition = function_def 505 self.signature = function_def.signature 506 self._num_outputs = len(self.signature.output_arg) 507 self._output_types = [o.type for o in self.signature.output_arg] 508 self._output_shapes = [o.shape for o in outputs] 509 self._control_captures = graph.control_captures 510 # Shallow copy outputs since ConcreteFunction may mutate it. 511 self._func_graph_outputs = list(outputs) 512 self.grad_func_name = None 513 self.python_grad_func = None 514 self._c_func = c_api_util.ScopedTFFunction(fn) 515 self._grad_func = None 516 self.graph = graph 517 self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access 518 519 def add_to_graph(self, g=None): 520 """Add the function to the current context or a graph, if supplied. 521 522 Args: 523 g: the graph to add the function to. If not supplied, the function will 524 be added to the current context. 525 """ 526 # pylint: disable=protected-access 527 if not g and context.executing_eagerly(): 528 ctx = context.context() 529 if not ctx.has_function(self.name): 530 ctx.add_function_def(self.definition) 531 else: 532 if not g._is_function(self.name): 533 g._add_function(self) 534 for f in self.graph._functions.values(): 535 if not g._is_function(f.name): 536 g._add_function(f) 537 # pylint: enable=protected-access 538 539 @property 540 def name(self): 541 return self._name 542 543 @property 544 def stateful_ops(self): 545 return self._stateful_ops 546 547 def call(self, ctx, args, cancellation_manager=None): 548 """Calls this function with `args` as inputs. 549 550 `ConcreteFunction` execution respects device annotations only if the 551 function won't be compiled with xla. 552 553 Args: 554 ctx: a Context object 555 args: a list of arguments to supply this function with. 556 cancellation_manager: a `CancellationManager` object that can be used to 557 cancel function execution. 558 559 Returns: 560 The outputs of the function call. 561 562 Raises: 563 ValueError: if the number of arguments is incorrect. 564 FunctionAlreadyGarbageCollectedError: if the function is no longer 565 available to be called because it has been garbage collected. 566 """ 567 if len(args) != len(self.signature.input_arg): 568 raise ValueError( 569 f"Signature specifies {len(list(self.signature.input_arg))} " 570 f"arguments, got: {len(args)}.") 571 572 # If the `ScopedTFFunction` (accessed via `_c_func`) has already been 573 # cleaned up as a part of garbage collection, this `_EagerDefinedFunction` 574 # should also be garbage and is likely being called as part of a `__del__` 575 # elsewhere. In that case, there's nothing we can do, so we raise an 576 # exception for the caller to handle. 577 if self._c_func.has_been_garbage_collected: 578 raise FunctionAlreadyGarbageCollectedError(self.name) 579 580 function_call_options = ctx.function_call_options 581 if function_call_options.config_proto_serialized is None: 582 config = function_utils.get_disabled_rewriter_config() 583 else: 584 config = function_call_options.config_proto_serialized 585 executor_type = function_call_options.executor_type or "" 586 587 executing_eagerly = ctx.executing_eagerly() 588 attrs = ("executor_type", executor_type, "config_proto", config) 589 if executing_eagerly: 590 with _InterpolateFunctionError(self): 591 if cancellation_manager is None: 592 outputs = execute.execute( 593 str(self.signature.name), 594 num_outputs=self._num_outputs, 595 inputs=args, 596 attrs=attrs, 597 ctx=ctx) 598 else: 599 outputs = execute.execute_with_cancellation( 600 str(self.signature.name), 601 num_outputs=self._num_outputs, 602 inputs=args, 603 attrs=attrs, 604 ctx=ctx, 605 cancellation_manager=cancellation_manager) 606 # Replace empty list with None 607 outputs = outputs or None 608 else: 609 # TODO(akshayka): Either remove this if the FunctionLibraryRuntime 610 # creates `PartitionedCallOp` kernels by default, or remove the previous 611 # branch if a TPU kernel is registered for `PartitionedCall`. 612 with _InterpolateFunctionError(self): 613 with ops.control_dependencies(self._control_captures): 614 # The caller must use record_operation to record this operation in the 615 # eager case, so we enforce the same requirement for the non-eager 616 # case by explicitly pausing recording. We don't have a gradient 617 # registered for PartitionedCall, so recording this operation confuses 618 # forwardprop code (GradientTape manages to ignore it). 619 with tape.stop_recording(): 620 outputs = functional_ops.partitioned_call( 621 args=args, 622 f=self, 623 tout=self._output_types, 624 executing_eagerly=executing_eagerly, 625 config=config, 626 executor_type=executor_type) 627 628 for i, func_graph_output in enumerate(self._func_graph_outputs): 629 handle_data_util.copy_handle_data(func_graph_output, outputs[i]) 630 if executing_eagerly: 631 return outputs 632 else: 633 # TODO(b/128924522): This additional set_shape should not be 634 # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this 635 # once that's done. 636 for i, shape in enumerate(self._output_shapes): 637 outputs[i].set_shape(shape) 638 return outputs 639 640 641def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph): 642 """Creates forward and backward functions from the function graphs.""" 643 forward_function_name = _forward_name(forward_graph.name) 644 common_attributes = dict(attrs) 645 # NB: forward and backward function need to drop "_implements". 646 # attribute, because their signature contains all the intermediate tensors 647 # that they compute. Thus they don't have a stable signature which can 648 # be directly optimized downstream. 649 # See for more details: 650 # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions 651 common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None) 652 backward_function_attr = _parse_func_attrs( 653 {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) 654 backward_function_attr.update(common_attributes) 655 backward_function = ConcreteFunction( 656 backwards_graph, attrs=backward_function_attr) 657 forward_function_attr = _parse_func_attrs({ 658 BACKWARD_FUNCTION_ATTRIBUTE_NAME: 659 backward_function.name}) 660 forward_function_attr.update(common_attributes) 661 forward_function = _EagerDefinedFunction( 662 forward_function_name, forward_graph, forward_graph.inputs, 663 forward_graph.outputs, forward_function_attr) 664 return forward_function, backward_function 665 666 667class _DelayedRewriteGradientFunctions(object): 668 """Caches forward/backward functions with a delayed forward rewrite.""" 669 670 def __init__(self, func_graph, attrs, func_graph_deleter): 671 """Construct an inference function and initialize caches.""" 672 # A map from the number of forward function outputs with accepted gradients 673 # to forward and backward functions, used to cache non-tape backward 674 # function generation. 675 self._cached_function_pairs = {} 676 self._func_graph = func_graph 677 self._inference_function = _EagerDefinedFunction( 678 _inference_name(self._func_graph.name), self._func_graph, 679 self._func_graph.inputs, self._func_graph.outputs, attrs) 680 self._attrs = attrs 681 self._gradient_name = None 682 # Note that the FuncGraph is mutated later, so we need to inspect it now to 683 # figure out the user-specified outputs of the inference function. 684 self._num_inference_outputs = len(self._func_graph.outputs) 685 self._func_graph_deleter = func_graph_deleter 686 687 def forward_backward(self, num_doutputs=None): 688 """A possibly-cached pair of forward and backward functions.""" 689 if num_doutputs is None: 690 num_doutputs = self._num_inference_outputs 691 forward_backward = self._cached_function_pairs.get(num_doutputs) 692 if forward_backward is not None: 693 return forward_backward 694 forward, backward = self._construct_forward_backward(num_doutputs) 695 self._cached_function_pairs[num_doutputs] = (forward, backward) 696 return forward, backward 697 698 def _construct_forward_backward(self, num_doutputs): 699 """Constructs a pair of forward and backward functions. 700 701 Args: 702 num_doutputs: The constructed backprop function will take output gradients 703 for the first `num_doutputs` outputs of the forward function. Defaults 704 to the number of outputs for the inference function, but when 705 higher-order gradients are computed this will increase to include side 706 outputs. 707 708 Returns: 709 A pair of (forward_function, backward_function): 710 forward_function: A re-generated inference function (an 711 _EagerDefinedFunction) to account for new side outputs, if any extra 712 were required when building the backward pass. 713 backward_function: A ConcreteFunction that Takes `num_doutputs` 714 arguments and returns gradients with respect to inputs of the forward 715 function. 716 """ 717 trainable_outputs = [ 718 output for output in self._func_graph.outputs[:num_doutputs] 719 if backprop_util.IsTrainable(output)] 720 721 signature = [] 722 for t in trainable_outputs: 723 signature.append( 724 tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) 725 726 def _backprop_function(*grad_ys): 727 with ops.device(None): 728 return gradients_util._GradientsHelper( # pylint: disable=protected-access 729 trainable_outputs, 730 self._func_graph.inputs, 731 grad_ys=grad_ys, 732 src_graph=self._func_graph) 733 734 with self._func_graph.as_default(): 735 backwards_graph = func_graph_module.FuncGraph( 736 _backward_name(self._func_graph.name)) 737 func_graph_module.func_graph_from_py_func( 738 name=backwards_graph.name, 739 python_func=_backprop_function, 740 args=[], kwargs={}, 741 signature=signature, 742 func_graph=backwards_graph) 743 backwards_graph_captures = backwards_graph.external_captures 744 captures_from_forward = [ 745 c for c in backwards_graph_captures if 746 not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] 747 748 existing_outputs = object_identity.ObjectIdentitySet( 749 self._func_graph.outputs) 750 for capture in captures_from_forward: 751 if capture not in existing_outputs: 752 existing_outputs.add(capture) 753 self._func_graph.outputs.append(capture) 754 755 forward_function, backward_function = _create_forward_backward_with_graph( 756 self._attrs, self._func_graph, backwards_graph) 757 return forward_function, backward_function 758 759 def _rewrite_forward_and_call_backward(self, op, *doutputs): 760 """Add outputs to the forward call and feed them to the grad function.""" 761 forward_function, backwards_function = self.forward_backward(len(doutputs)) 762 if not backwards_function.outputs: 763 return backwards_function.structured_outputs 764 forward_function.add_to_graph(op.graph) 765 766 # pylint: disable=protected-access 767 # Rewrite an inference call op to be a forward call op 768 op._set_func_attr("f", forward_function.name) 769 op._set_type_list_attr("Tout", forward_function._output_types) 770 op._add_outputs( 771 forward_function._output_types[len(op.outputs):], 772 forward_function._output_shapes[len(op.outputs):]) 773 for i in range(len(op.outputs)): 774 func_graph_output = forward_function._func_graph_outputs[i] 775 handle_data_util.copy_handle_data(func_graph_output, op.outputs[i]) 776 # pylint: enable=protected-access 777 778 capture_mapping = dict( 779 zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs)) 780 remapped_captures = [ 781 capture_mapping.get(ops.tensor_id(capture), capture) 782 for capture in backwards_function.captured_inputs 783 ] 784 785 # Replace Nones with zeros since we're calling a graph function which 786 # expects numeric inputs. 787 cleaned_doutputs = [] 788 for doutput, placeholder in zip(doutputs, self._func_graph.outputs): 789 if backprop_util.IsTrainable(placeholder): 790 if isinstance(doutput, ops.IndexedSlices): 791 # Gradient passed to a backward ConcreteFunction must be tf.Tensor, 792 # so we convert tf.IndexedSlices to tf.Tensor. 793 cleaned_doutputs.append(ops.convert_to_tensor(doutput)) 794 elif doutput is not None: 795 cleaned_doutputs.append(doutput) 796 else: 797 cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) 798 799 # Compute the gradients using the side outputs 800 return backwards_function._call_flat( # pylint: disable=protected-access 801 cleaned_doutputs, remapped_captures) 802 803 def get_gradient_function(self): 804 """Returns gradient function. 805 806 The gradient rewrites an inference call op to a forward call op, but does 807 not modify a pre-existing forward call op. It then computes the gradient 808 from the output's gradients and the side outputs of the forward op. 809 """ 810 return self._rewrite_forward_and_call_backward 811 812 def forward(self, inference_args=None, input_tangents=None): 813 """A forward function with only user-specified outputs. 814 815 The call operation for the returned inference function can be rewritten into 816 a forward function. This only happens if the backward function (from the 817 `backward` method) ends up being used to compute gradients. 818 819 This approach avoids constructing unnecessary graphs, but it only works if 820 we are calling this function when not executing eagerly. 821 822 Args: 823 inference_args: A flat list of Tensors, arguments to the inference 824 function. Unused, but taken for compatibility with 825 _TapeGradientFunctions. 826 input_tangents: A flat list of Tensors, jvps associated with 827 `inference_args`. Unused; if required, tape functions must be used 828 instead. 829 830 Returns: 831 An _EagerDefinedFunction. 832 """ 833 del inference_args # unused 834 if input_tangents: 835 # This class does not support special-cased forwardprop. The arguments are 836 # here for compatibility with _TapeGradientFunctions. 837 raise errors.InternalError("unexpectedly got forwardprop information in " 838 "a class that does not support forwardprop.") 839 return self._inference_function 840 841 def _backward(self, outputs): 842 """Fetch a backward function for `outputs` from the forward function.""" 843 def _backward_function(*args): 844 call_op = outputs[0].op 845 return self._rewrite_forward_and_call_backward(call_op, *args) 846 return _backward_function, outputs 847 848 def record(self, flat_outputs, inference_args, input_tangents): 849 """Record the function call operation. 850 851 _DelayedRewriteGradientFunctions supports only first-order backprop tape 852 gradients (and then only when graph building). It does not work with 853 higher-order tape gradients or forward autodiff, but does work with 854 higher-order symbolic gradients (tf.gradients). 855 856 Args: 857 flat_outputs: The result of running `forward`. 858 inference_args: A flat list of Tensors with inference inputs to the 859 operation. 860 input_tangents: A flat list of Tensors with input tangents consumed by the 861 operation. 862 """ 863 backward_function, to_record = self._backward(flat_outputs) 864 tape.record_operation(self._inference_function.signature.name, 865 to_record, inference_args + input_tangents, 866 backward_function) 867 868 869# Contains information about a forward function wrapped to compute jvps. 870_ForwardWrapper = collections.namedtuple( 871 "_ForwardWrapper", ( 872 # The wrapper Graph. 873 "graph", 874 # A flat list of non-tangent Tensor outputs from the wrapped forward 875 # function. 876 "outputs", 877 # Indices for output tangents, same format as 878 # forwardprop_util.pack_tangents. 879 "output_indices", 880 # A flat list of tangents for `outputs`. 881 "output_tangents")) 882 883 884class _TapeGradientFunctions(object): 885 """Caches forward and backward functions compatible with eager gradients. 886 887 In contrast to the delayed-rewrite approach in 888 `_DelayedRewriteGradientFunctions` which only works with delayed execution, 889 the forward function generated by this class has a fixed set of outputs which 890 may be preserved by a tape in order to compute gradients later. 891 892 This class is abstract; its child classes differ in how many side outputs of 893 the forward function their backward function accepts gradients for, which 894 determines whether higher-order tape gradients are possible. 895 """ 896 897 def __init__(self, func_graph, attrs, func_graph_deleter, 898 forwardprop_input_indices, delayed_rewrite_functions, 899 need_gradients_for_jvps): 900 self._func_graph = func_graph 901 self._forward_graph = None 902 self._attrs = attrs 903 self._forward = None 904 self._backward = None 905 self._num_outputs = len(func_graph.outputs) 906 self._func_graph_deleter = func_graph_deleter 907 self._forwardprop_input_indices = forwardprop_input_indices 908 self._forwardprop_output_indices = None 909 self._num_forwardprop_outputs = 0 910 self._num_inference_outputs = len(func_graph.outputs) 911 self._num_trainable_inference_outputs = len( 912 [t for t in func_graph.outputs if backprop_util.IsTrainable(t)]) 913 self._delayed_rewrite_functions = delayed_rewrite_functions 914 self._need_gradients_for_jvps = need_gradients_for_jvps 915 916 def _build_functions_for_outputs( 917 self, outputs, inference_args, input_tangents): 918 """Forward+backward functions where the backward function sees `outputs`.""" 919 # First figure out which of `outputs` are trainable. We'll accept gradients 920 # for each of these in the backward function. 921 handles_to_variables = self._func_graph.variable_captures 922 trainable_outputs = [] 923 trainable_indices = [] 924 for index, output in enumerate(outputs): 925 926 if backprop_util.IsTrainable(output): 927 # Swap in the Variable object for resource handles if we can so 928 # sparse gradients work. 929 output = handles_to_variables.get(id(output), output) 930 trainable_outputs.append(output) 931 trainable_indices.append(index) 932 933 backwards_graph = func_graph_module.FuncGraph( 934 _backward_name(self._func_graph.name)) 935 with backwards_graph.as_default(): 936 gradients_wrt_outputs = [] 937 for output in trainable_outputs: 938 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 939 output) 940 gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 941 handle_data_util.copy_handle_data(output, gradient_placeholder) 942 gradients_wrt_outputs.append(gradient_placeholder) 943 with ops.device(None): 944 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 945 trainable_outputs, 946 self._func_graph.inputs, 947 grad_ys=gradients_wrt_outputs, 948 src_graph=self._func_graph) 949 950 if input_tangents: 951 # Convert IndexedSlices to dense tensors (as we do elsewhere for 952 # function gradients). Our C++ bindings don't know how to handle them 953 # currently. 954 gradients_wrt_inputs = nest.map_structure( 955 lambda x: ops.convert_to_tensor(x) if x is not None else None, 956 gradients_wrt_inputs) 957 captures_from_forward = [ 958 c for c in backwards_graph.external_captures 959 if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph 960 ] 961 existing_outputs = object_identity.ObjectIdentitySet( 962 self._func_graph.outputs) 963 for capture in captures_from_forward: 964 if capture not in existing_outputs: 965 existing_outputs.add(capture) 966 self._func_graph.outputs.append(capture) 967 968 # The ordering of `backwards_graph.inputs` is important: inputs of 969 # `backward_function` correspond to outputs (including 970 # side outputs) of `self._tape_forward_function`. 971 backwards_graph.inputs = ( 972 gradients_wrt_outputs + backwards_graph.internal_captures) 973 backwards_graph.outputs.extend( 974 grad 975 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) 976 if grad is not None) 977 backwards_graph.structured_outputs = gradients_wrt_inputs 978 979 forward_function, backward_function = _create_forward_backward_with_graph( 980 self._attrs, self._func_graph, backwards_graph) 981 982 if not input_tangents: 983 # There is no need to special-case forwardprop, so we can return the 984 # forward+backward pair we've created without further wrapping. 985 return (forward_function, self._func_graph, backward_function, 986 # No forwardprop outputs. 987 None, 0) 988 forward_wrapper = self._wrap_forward_function_with_jvps( 989 forward_function, backward_function, inference_args, input_tangents) 990 (wrapped_backwards_graph, 991 forward_wrapper) = self._wrap_backward_function_with_jvp_backprop( 992 backward_function, gradients_wrt_outputs, forward_wrapper) 993 # Now that we've added new captures, we need to make sure forward outputs 994 # are in the same order the backward function expects them to be in: 995 # [inference outputs] + [jvps] + [side outputs] + [captures]. 996 forward_wrapper = self._shuffle_forward_outputs(forward_wrapper) 997 (wrapped_forward_function, 998 wrapped_backward_function) = _create_forward_backward_with_graph( 999 self._attrs, forward_wrapper.graph, wrapped_backwards_graph) 1000 if (len(inference_args) + len(input_tangents) 1001 != len(forward_wrapper.graph.inputs)): 1002 raise errors.InternalError( 1003 f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, " 1004 f"but we expected {len(inference_args) + len(input_tangents)} " 1005 f"({len(inference_args)} inference inputs and " 1006 f"{len(input_tangents)} input tangents).") 1007 return (wrapped_forward_function, forward_wrapper.graph, 1008 wrapped_backward_function, forward_wrapper.output_indices, 1009 len(forward_wrapper.output_tangents)) 1010 1011 def _wrap_forward_function_with_jvps( 1012 self, forward_function, backward_function, 1013 inference_args, input_tangents): 1014 """Adds inline JVP computation to a forward function.""" 1015 forward_wrapper_graph = func_graph_module.FuncGraph( 1016 _forward_name(self._func_graph.name)) 1017 with forward_wrapper_graph.as_default(): 1018 # Tell forward accumulators to free up space for new JVP computations, 1019 # since one may be in the process of computing a JVP (if that computation 1020 # triggered this function building). 1021 # 1022 # We'll make symbolic versions of input JVPs, run the forward function 1023 # under forward accumulators to get symbolic output JVPs, then set those 1024 # as outputs of the new wrapped forward function. 1025 with forwardprop_util.push_forwardprop_state(): 1026 forward_captures = { 1027 ops.tensor_id(internal): external 1028 for external, internal in self._func_graph.captures} 1029 for input_index, real_input in enumerate(self._func_graph.inputs): 1030 # This loop is more or less equivalent to running tf.identity on each 1031 # of self._func_graph.inputs. However, doing that also captures jvps 1032 # for resource handles, which confuses the jvp capturing code below 1033 # (since primal inputs are interwoven with jvp inputs). 1034 input_placeholder = array_ops.placeholder( 1035 dtype=real_input.dtype, 1036 shape=real_input.shape) 1037 capture = forward_captures.get(ops.tensor_id(real_input)) 1038 if capture is not None: 1039 forward_wrapper_graph.add_capture(capture, input_placeholder) 1040 if capture.dtype == dtypes.resource: 1041 handle_data_util.copy_handle_data(capture, input_placeholder) 1042 else: 1043 forward_wrapper_graph.inputs.append(input_placeholder) 1044 for inp, arg in zip(forward_wrapper_graph.inputs, inference_args): 1045 tape.record_operation( 1046 "captured_value", [inp], [arg], 1047 backward_function=lambda x: [x], 1048 forward_function=lambda x: [x]) 1049 num_inference_inputs = len(inference_args) 1050 for tape_indices in self._forwardprop_input_indices: 1051 for input_index, jvp_index in tape_indices: 1052 input_placeholder = forward_wrapper_graph.inputs[input_index] 1053 if len(forward_wrapper_graph.inputs) != jvp_index: 1054 raise errors.InternalError( 1055 f"Expected {jvp_index} forward graph inputs, " 1056 f"got {len(forward_wrapper_graph.inputs)}.") 1057 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 1058 input_placeholder) 1059 jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 1060 external_jvp = input_tangents[jvp_index - num_inference_inputs] 1061 forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder) 1062 tensor_shape.TensorShape( 1063 external_jvp.shape).assert_is_compatible_with( 1064 jvp_placeholder.shape) 1065 tape.record_operation( 1066 "captured_value", 1067 [jvp_placeholder], 1068 [external_jvp], 1069 backward_function=lambda x: [x], 1070 forward_function=lambda x: [x]) 1071 forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs] 1072 gradient_function = ( 1073 self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access 1074 with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access 1075 {"PartitionedCall": gradient_function, 1076 "StatefulPartitionedCall": gradient_function}): 1077 forward_outputs = forward_function.call(context.context(), 1078 forward_inputs) 1079 if isinstance(forward_outputs, ops.Operation): 1080 # _wrapped_backward_function expects a list, but if the function has 1081 # no outputs its call() returns an Operation. We need to undo that 1082 # so we don't cause problems later. 1083 forward_outputs = [] 1084 py_backward, _ = self._wrap_backward_function( 1085 self._func_graph, backward_function, forward_outputs) 1086 # We will never request backward tape gradients for this operation 1087 # directly since we're wrapping the call; forwardprop will call the 1088 # backward function (and nested forward accumulators may build 1089 # higher-order gradients), but any watching GradientTapes should ignore 1090 # it. 1091 # 1092 # TODO(allenl): It might be better to explicitly stop backward recording 1093 # so we don't use the second-order tape cases unnecessarily. 1094 tape.record_operation_forwardprop_only( 1095 forward_function.signature.name, 1096 forward_outputs, forward_inputs, py_backward, None) 1097 output_indices, output_tangents = ( 1098 pywrap_tfe.TFE_Py_PackJVPs(forward_outputs)) 1099 output_tangents = [forward_wrapper_graph.capture(t) 1100 for t in output_tangents] 1101 return _ForwardWrapper( 1102 graph=forward_wrapper_graph, outputs=forward_outputs, 1103 output_indices=output_indices, output_tangents=output_tangents) 1104 1105 def _wrap_backward_function_with_jvp_backprop( 1106 self, backward_function, gradients_wrt_outputs, forward_wrapper): 1107 """Wraps `backward_function` to include gradients for JVPs.""" 1108 wrapped_backwards_graph = func_graph_module.FuncGraph( 1109 _backward_name(self._func_graph.name)) 1110 with wrapped_backwards_graph.as_default(): 1111 py_backward, recorded_outputs = self._wrap_backward_function( 1112 self._func_graph, backward_function, forward_wrapper.outputs) 1113 trainable_index = 0 1114 forward_doutputs = [] 1115 doutput_args = [] 1116 for output in recorded_outputs: 1117 if backprop_util.IsTrainable(output): 1118 doutput = gradients_wrt_outputs[trainable_index] 1119 doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape) 1120 doutput_args.append(doutput_placeholder) 1121 forward_doutputs.append(doutput_placeholder) 1122 trainable_index += 1 1123 else: 1124 doutput_args.append(None) 1125 1126 dinputs = py_backward(*doutput_args) 1127 existing_outputs = object_identity.ObjectIdentitySet( 1128 forward_wrapper.outputs + forward_wrapper.output_tangents) 1129 num_processed_output_tangents = 0 1130 gradients_wrt_output_tangents = [] 1131 tangent_doutputs = [] 1132 output_tangents = forward_wrapper.output_tangents 1133 output_indices = forward_wrapper.output_indices 1134 if self._need_gradients_for_jvps: 1135 # TODO(allenl): Consider using a throwaway graph to avoid extra gradient 1136 # evaluations; gradients for jvps may have common subgraphs. 1137 while num_processed_output_tangents != len(output_tangents): 1138 for output in output_tangents[num_processed_output_tangents:]: 1139 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 1140 output) 1141 placeholder = graph_placeholder(gradient_dtype, gradient_shape) 1142 gradients_wrt_output_tangents.append(placeholder) 1143 tangent_doutputs.append(placeholder) 1144 num_processed_output_tangents = len(output_tangents) 1145 with ops.device(None): 1146 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 1147 output_tangents, 1148 forward_wrapper.graph.inputs, 1149 grad_ys=gradients_wrt_output_tangents, 1150 src_graph=forward_wrapper.graph) 1151 dinputs = [ 1152 backprop.aggregate_indexed_slices_gradients((existing, new)) 1153 for existing, new in zip(dinputs, gradients_wrt_inputs) 1154 if existing is not None or new is not None] 1155 dinputs.extend(gradients_wrt_inputs[len(dinputs):]) 1156 captures_from_forward = [ 1157 c for c in wrapped_backwards_graph.external_captures 1158 if (not isinstance(c, ops.EagerTensor) 1159 and c.graph is forward_wrapper.graph)] 1160 for capture in captures_from_forward: 1161 if capture not in existing_outputs: 1162 existing_outputs.add(capture) 1163 forward_wrapper.outputs.append(capture) 1164 output_indices, output_tangents = ( 1165 forwardprop_util.pack_tangents(forward_wrapper.outputs)) 1166 output_tangents = [forward_wrapper.graph.capture(t) 1167 for t in output_tangents] 1168 for t in output_tangents: 1169 existing_outputs.add(t) 1170 wrapped_backwards_graph.inputs = ( 1171 forward_doutputs[:self._num_trainable_inference_outputs] 1172 + tangent_doutputs 1173 + forward_doutputs[self._num_trainable_inference_outputs:] 1174 + wrapped_backwards_graph.internal_captures) 1175 wrapped_backwards_graph.structured_outputs = dinputs 1176 wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None] 1177 return (wrapped_backwards_graph, 1178 forward_wrapper._replace(output_indices=output_indices, 1179 output_tangents=output_tangents)) 1180 1181 def _shuffle_forward_outputs(self, forward_wrapper): 1182 """Reorders function outputs so captures are last.""" 1183 def _index_map(original): 1184 if original < self._num_inference_outputs: 1185 return original 1186 if original >= len(forward_wrapper.outputs): 1187 return (original - len(forward_wrapper.outputs) 1188 + self._num_inference_outputs) 1189 return original + len(forward_wrapper.output_tangents) 1190 output_indices = nest.map_structure( 1191 _index_map, forward_wrapper.output_indices) 1192 forward_wrapper.graph.outputs = ( 1193 forward_wrapper.outputs[:self._num_inference_outputs] 1194 + forward_wrapper.output_tangents 1195 + forward_wrapper.outputs[self._num_inference_outputs:]) 1196 return forward_wrapper._replace(output_indices=output_indices) 1197 1198 def forward(self, inference_args, input_tangents): 1199 """Construct or fetch a forward function with side-outputs. 1200 1201 When graph building without a tape active, symbolic gradients rely on 1202 regenerating the backward function for higher-order gradients (to account 1203 for new side outputs of the rewritten forward function call). Thus there is 1204 no fixed backward function for this case. However, when a tape is active 1205 (eager or graph building), we generate fixed backward and forward functions 1206 at forward function call time. 1207 1208 This difference between the tape and non-tape cases is to avoid building 1209 unneeded backward functions while graph building (where we may or may not 1210 eventually need gradients). 1211 1212 Args: 1213 inference_args: A flat list of Tensors, arguments to the inference 1214 function. 1215 input_tangents: A flat list of Tensors, jvps associated with 1216 `inference_args`. 1217 1218 Returns: 1219 A forward _EagerDefinedFunction. 1220 """ 1221 if self._forward is None: 1222 (self._forward, self._forward_graph, self._backward, 1223 self._forwardprop_output_indices, self._num_forwardprop_outputs) = ( 1224 self._forward_and_backward_functions(inference_args, input_tangents)) 1225 return self._forward 1226 1227 def _wrap_backward_function(self, forward_graph, backward, outputs): 1228 """Create a backward function given `outputs` from the forward function.""" 1229 capture_mapping = dict( 1230 zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs)) 1231 captured_inputs = backward.captured_inputs 1232 remapped_captures = [ 1233 capture_mapping.get(ops.tensor_id(capture), capture) 1234 for capture in captured_inputs 1235 ] 1236 if any(t.graph is forward_graph for t in remapped_captures 1237 if not isinstance(t, ops.EagerTensor)): 1238 incorrect_mapping = [t for t in remapped_captures 1239 if (not isinstance(t, ops.EagerTensor) and 1240 t.graph is not forward_graph)] 1241 raise errors.InternalError("Failed to map all backward graph captures to " 1242 "the forward graph. Incorrectly mapped: " 1243 f"{incorrect_mapping}.") 1244 # We may need to use zeros_like to get a zero for variant Tensors with 1245 # unconnected gradients. We do that in advance so we don't have to hold on 1246 # to the outputs themselves, which may not be needed otherwise. 1247 variant_zeros_like = {} 1248 backward_function_inputs = (len(backward.inputs) - len(captured_inputs)) 1249 recorded_outputs = [] 1250 trainable_recorded_outputs = 0 1251 skip_positions = [] 1252 if self._num_forwardprop_outputs and not self._need_gradients_for_jvps: 1253 relevant_outputs = ( 1254 outputs[:self._num_inference_outputs] 1255 + outputs[self._num_inference_outputs 1256 + self._num_forwardprop_outputs:]) 1257 else: 1258 relevant_outputs = outputs 1259 for output_index, output in enumerate(relevant_outputs): 1260 if trainable_recorded_outputs < backward_function_inputs: 1261 recorded_outputs.append(output) 1262 if backprop_util.IsTrainable(output): 1263 trainable_recorded_outputs += 1 1264 else: 1265 skip_positions.append(output_index) 1266 if output.dtype == dtypes.variant: 1267 variant_zeros_like[output_index] = default_gradient.zeros_like(output) 1268 1269 def _backward_function_wrapper(*args): 1270 """Process output gradients and call the backward function.""" 1271 if not backward.outputs: 1272 return backward.structured_outputs 1273 1274 processed_args = [] 1275 input_index = 0 1276 for output_index, arg in enumerate(args): 1277 # Convert IndexedSlices to dense tensors. The IndexedSlices optimization 1278 # is only really effective when doing tf.gather(variable) as the 1279 # adjoint functions for most operations are unlikely to preserve the 1280 # sparsity in IndexedSlices. 1281 if isinstance(arg, ops.IndexedSlices): 1282 arg = ops.convert_to_tensor(arg) 1283 if output_index in skip_positions: 1284 continue 1285 if arg is None: 1286 # We're calling a (non-polymorphic) ConcreteFunction, so we need to 1287 # have a Tensor value for each Tensor we thought would be trainable 1288 # based on its dtype, even if it ended up being unconnected. 1289 input_placeholder = backward.inputs[ 1290 input_index] 1291 if input_placeholder.dtype == dtypes.variant: 1292 arg = variant_zeros_like[output_index] 1293 else: 1294 arg = array_ops.zeros( 1295 *default_gradient.shape_and_dtype(input_placeholder)) 1296 processed_args.append(arg) 1297 input_index += 1 1298 if input_index >= backward_function_inputs: 1299 break 1300 return backward._call_flat( # pylint: disable=protected-access 1301 processed_args, remapped_captures) 1302 1303 return _backward_function_wrapper, recorded_outputs 1304 1305 def record(self, flat_outputs, inference_args, input_tangents): 1306 """Record the function call operation. 1307 1308 For backprop, indicates the backward function to use and which new Tensors 1309 must be watched. For forwardprop from eager, the function call itself will 1310 have produced tangents which need to be recorded. 1311 1312 Args: 1313 flat_outputs: The result of running `forward`. 1314 inference_args: A flat list of Tensors with inference inputs to the 1315 operation. 1316 input_tangents: A flat list of Tensors with input tangents consumed by the 1317 operation. 1318 """ 1319 backward_function, to_record = self._wrap_backward_function( 1320 self._forward_graph, self._backward, flat_outputs) 1321 if self._forwardprop_output_indices: 1322 tape.record_operation_backprop_only( 1323 self._forward.signature.name, 1324 to_record, inference_args, 1325 backward_function) 1326 tape.record_operation_forwardprop_only( 1327 self._forward.signature.name, 1328 flat_outputs, inference_args + input_tangents, 1329 backward_function, 1330 self._forwardprop_output_indices) 1331 else: 1332 tape.record_operation(self._forward.signature.name, 1333 to_record, inference_args + input_tangents, 1334 backward_function) 1335 1336 1337class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions): 1338 """Caches tape-friendly functions for first-order gradients.""" 1339 1340 def __init__(self, func_graph, attrs, func_graph_deleter, 1341 forwardprop_input_indices, delayed_rewrite_functions, 1342 need_gradients_for_jvps): 1343 super(_FirstOrderTapeGradientFunctions, self).__init__( 1344 func_graph, attrs, func_graph_deleter, forwardprop_input_indices, 1345 delayed_rewrite_functions, need_gradients_for_jvps) 1346 self._func_graph_deleter = func_graph_deleter 1347 self._forwardprop_input_indices = forwardprop_input_indices 1348 1349 def _forward_and_backward_functions(self, inference_args, input_tangents): 1350 """Shortcut for when only first-order gradients are required. 1351 1352 The returned backward function does not accept gradients with respect to 1353 side output of forward_function. This is fine as long as the user can't 1354 possibly request second order tape gradients, as when they've used a single 1355 non-persistent GradientTape. Since we don't need the backward function to 1356 take gradients with respect to side outputs, we can skip some potentially 1357 slow graph building. 1358 1359 Args: 1360 inference_args: A flat list of Tensors, arguments to the inference 1361 function. 1362 input_tangents: A flat list of Tensors, jvps associated with 1363 `inference_args`. 1364 1365 Returns: 1366 A tuple of (forward_function, backward_function): 1367 forward_function: Takes the same inputs as the inference function, but 1368 returns side outputs used by backward_function in addition to the 1369 inference function's outputs. 1370 backward_function: Takes side outputs from forward_function and 1371 gradients with respect to the "real" outputs of forward_function and 1372 returns gradients with respect to the inputs. 1373 """ 1374 outputs = self._func_graph.outputs[:self._num_inference_outputs] 1375 return self._build_functions_for_outputs( 1376 outputs, inference_args, input_tangents) 1377 1378 1379class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): 1380 """Caches tape-friendly functions for higher-order gradients.""" 1381 1382 # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider 1383 # generalizing if so. 1384 def _forward_and_backward_functions(self, inference_args, input_tangents): 1385 """Forward and backward functions suitable for higher-order gradients. 1386 1387 Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by 1388 this method accepts gradients for all of the outputs of the returned forward 1389 function, including side outputs. 1390 1391 Args: 1392 inference_args: A flat list of Tensors, arguments to the inference 1393 function. 1394 input_tangents: A flat list of Tensors, jvps associated with 1395 `inference_args`. 1396 1397 Returns: 1398 A tuple of (forward_function, backward_function): 1399 forward_function: Takes the same inputs as the inference function, but 1400 returns side outputs used by backward_function in addition to the 1401 inference function's outputs. 1402 backward_function: Takes side outputs from forward_function and 1403 gradients with respect to all of its outputs, real and side. Returns 1404 gradients with respect to the inputs. 1405 """ 1406 outputs = [] 1407 iteration_count = 0 1408 # First we need to figure out how many side outputs from the forward pass 1409 # will be required. We do this in a temporary graph to avoid actually 1410 # running multiple copies of the backward pass (one per _GradientsHelper 1411 # call). 1412 # 1413 # While computing gradients, the backward function captures Tensors from 1414 # the forward function. We add these as side outputs of the original 1415 # function. However, we then need to accept output gradients with respect 1416 # to these side outputs for higher order gradients to work. Thus we loop 1417 # until the number of outputs of the function stabilizes. Note that this 1418 # is only required for tape gradients, where we need to declare in advance 1419 # all of the forward op's outputs: symbolic gradients with tf.gradients 1420 # instead rely on regenerating backward functions when higher-order 1421 # gradients are requested. 1422 while (len(outputs) < len(self._func_graph.outputs) 1423 # It's possible for gradient generation to add new ops to the forward 1424 # pass. If all of the new outputs are non-trainable, there's no 1425 # reason to continue. 1426 and any(backprop_util.IsTrainable(output) 1427 for output in self._func_graph.outputs[len(outputs):])): 1428 iteration_count += 1 1429 if iteration_count >= 20 and iteration_count % 5 == 0: 1430 new_op_with_trainable_output = None 1431 num_new_trainable_outputs = 0 1432 for output in self._func_graph.outputs[len(outputs):]: 1433 if backprop_util.IsTrainable(output): 1434 num_new_trainable_outputs += 1 1435 new_op_with_trainable_output = output.op 1436 logging.warning( 1437 ("Determining side outputs for the function '{}' is taking longer " 1438 "than expected ({} iterations, typically this converges in 5 or " 1439 "so). This could indicate that a gradient registration is adding " 1440 "new ops to the forward pass every time gradients are generated. " 1441 "{} new trainable output(s) were added this iteration, one from " 1442 "the following op:\n {}\nThis may indicate a TensorFlow bug, or " 1443 "an issue in a tf.custom_gradient.") 1444 .format( 1445 self._func_graph.name, iteration_count, 1446 num_new_trainable_outputs, new_op_with_trainable_output)) 1447 outputs = list(self._func_graph.outputs) 1448 self._build_functions_for_outputs( 1449 outputs, inference_args, input_tangents) 1450 1451 (forward_function, forward_graph, 1452 backward_function, output_indices, num_output_tangents) = ( 1453 self._build_functions_for_outputs( 1454 outputs, inference_args, input_tangents)) 1455 if (len(self._func_graph.outputs) > len(outputs) 1456 and any(backprop_util.IsTrainable(output) 1457 for output in self._func_graph.outputs[len(outputs):])): 1458 raise errors.InternalError( 1459 "Unexpectedly added new outputs to the forward function when " 1460 "building the backward function: " 1461 f"{self._func_graph.outputs[len(outputs):]}.") 1462 return (forward_function, forward_graph, backward_function, output_indices, 1463 num_output_tangents) 1464 1465 1466class _ForwardBackwardCall(object): 1467 """Holds the state of a function call between execution and recording.""" 1468 1469 __slots__ = [ 1470 "_functions", "_inference_args", "_input_tangents", "_tape_watching" 1471 ] 1472 1473 def __init__(self, functions, inference_args, input_tangents, tape_watching): 1474 """Collects information about the function call. 1475 1476 Args: 1477 functions: An object which produces forward and backward functions, either 1478 a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object. 1479 inference_args: A flat list of Tensors, arguments to the inference 1480 function. 1481 input_tangents: A flat list of Tensors, jvps associated with 1482 `inference_args`. 1483 tape_watching: Boolean, with True indicating that recording is necessary. 1484 """ 1485 self._functions = functions 1486 self._inference_args = inference_args 1487 self._input_tangents = input_tangents 1488 self._tape_watching = tape_watching 1489 1490 def forward(self): 1491 """Builds or retrieves a forward function for this call.""" 1492 forward_function = self._functions.forward( 1493 self._inference_args, self._input_tangents) 1494 return forward_function, self._inference_args + self._input_tangents 1495 1496 def record(self, flat_outputs): 1497 """Given outputs from the execution of `forward`, records the operation.""" 1498 if (self._tape_watching 1499 and not isinstance(flat_outputs, ops.Operation) 1500 and flat_outputs is not None): 1501 # We only record function calls which have outputs, and then only when a 1502 # tape is watching. 1503 self._functions.record( 1504 flat_outputs, self._inference_args, self._input_tangents) 1505 1506 1507# Sentinel value used by with ConcreteFunction's structured signature to 1508# indicate that a non-tensor parameter should use the value that was 1509# specified when the concrete function was created. 1510_BOUND_VALUE = object() 1511 1512 1513class ConcreteFunction(core.ConcreteFunction): 1514 """A `tf.types.experimental.ConcreteFunction` created from `tf.function`.""" 1515 1516 def __init__(self, 1517 func_graph, 1518 attrs=None, 1519 shared_func_graph=True, 1520 function_spec=None): 1521 """Initialize a `ConcreteFunction`. 1522 1523 Args: 1524 func_graph: An instance of FuncGraph: the function body to wrap. 1525 attrs: (optional) dict mapping names of attributes to their AttrValue 1526 values. Attributes in `attrs` will be included in this function's 1527 definition. 1528 shared_func_graph: If False, the ConcreteFunction takes ownership of 1529 `func_graph` and will break reference cycles when it is deleted. This 1530 makes the FuncGraph inoperable. 1531 function_spec: FunctionSpec for the original function. If not specified, 1532 then this ConcreteFunction may only be called using the flat signature. 1533 1534 Raises: 1535 ValueError: If number of input_placeholders is not equal to the number 1536 of function inputs. 1537 """ 1538 # _arg_keywords and _num_positional_args define the flat signature. They 1539 # are assigned after construction. 1540 self._arg_keywords = None 1541 self._num_positional_args = None 1542 1543 self._func_graph = func_graph 1544 self._captured_inputs = self._func_graph.external_captures 1545 self._captured_closures = self._func_graph.deferred_external_captures 1546 1547 # function_spec defines the structured signature. 1548 self._set_function_spec(function_spec) 1549 1550 if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs: 1551 # The alternative is to silently drop "implements" tag 1552 # but it seems likely it would lead to hard to catch bugs. 1553 # Another alternative is to make func_body to preserve the order 1554 # of arguments if variables are present. Yet another option 1555 # is to automatically replace variables as arguments to functions 1556 # to v.read_value() whenever "implements" tag is present 1557 # Anytime we annotate existing function we probably want to wrap 1558 # it with safe read_value for backward compatibility. 1559 has_resource_vars = any(inp.dtype == dtypes.resource 1560 for inp in self.inputs) 1561 1562 assert not any( 1563 (has_resource_vars, self._captured_inputs, self._captured_closures) 1564 ), ('Function {name} has "{attr}={value}" attribute and thus can not ' 1565 "depend on any tensors outside of its signature or modify variables. " 1566 "\n\nNote: variables are always captured and cause function " 1567 "re-tracing for every variable called.\n" 1568 " inputs: {inputs}\n captures: {captured}\n" 1569 " closures: {closures}.\n\n" 1570 "To pass a variable to such function use " 1571 "use variable.read_value().".format( 1572 name=func_graph.name, 1573 attr=IMPLEMENTS_ATTRIBUTE_NAME, 1574 value=attrs[IMPLEMENTS_ATTRIBUTE_NAME], 1575 inputs=self.inputs, 1576 captured=self._captured_inputs, 1577 closures=self._captured_closures)) 1578 self._output_shapes = tuple( 1579 output.shape for output in self._func_graph.outputs) 1580 self._attrs = _parse_func_attrs(attrs or {}) 1581 1582 if shared_func_graph: 1583 self._garbage_collector = None 1584 else: 1585 self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph) 1586 1587 # Pairs of forward and backward functions used for computing gradients. 1588 # 1589 # These each get a reference to the FuncGraph deleter since they use the 1590 # FuncGraph directly. 1591 self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions( 1592 func_graph, self._attrs, self._garbage_collector) 1593 self._first_order_tape_functions = {} 1594 self._higher_order_tape_functions = {} 1595 # Cache the inference function to avoid a (Python) function call when not 1596 # building gradients. 1597 self._inference_function = self._delayed_rewrite_functions.forward() 1598 1599 def _set_function_spec(self, function_spec): 1600 """Enables the structured signature by supplying a function_spec.""" 1601 self._function_spec = None 1602 self._pre_initialized_function_spec = function_spec 1603 1604 # Note: when ConcreteFunctions are built by recreate_function() in 1605 # function_deserialization.py, they don't have a structured_input_signature 1606 # yet. In that case, _initialize_function_spec() gets called by 1607 # _setup_functions_structures() in load.py. 1608 if (function_spec is not None and 1609 self.structured_input_signature is not None): 1610 self._initialize_function_spec() 1611 1612 def _initialize_function_spec(self): 1613 """Updates `self._function_spec` to include varargs and bound variables. 1614 1615 Adds new positional arguments for any varargs (i.e., for args that are 1616 in `structured_input_signature`, but not in the original fullargspec.args). 1617 1618 Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for 1619 all args and kwargs in `structured_input_signature`. 1620 1621 Sets `varkw` and `varargs` to None. 1622 """ 1623 if self._pre_initialized_function_spec is None: 1624 return # e.g., SavedBareConcreteFunction doesn't have function_spec yet. 1625 assert not self._function_spec, "already initialized" 1626 function_spec = self._pre_initialized_function_spec 1627 args = function_spec.fullargspec.args 1628 arg_specs, kwarg_specs = self.structured_input_signature 1629 vararg_indices = range(len(function_spec.arg_names), len(arg_specs)) 1630 fullargspec = tf_inspect.FullArgSpec( 1631 args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices], 1632 varargs=None, 1633 varkw=None, 1634 defaults=[_BOUND_VALUE] * len(arg_specs), 1635 kwonlyargs=list(sorted(kwarg_specs)), 1636 kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs), 1637 annotations=function_spec.fullargspec.annotations) 1638 self._function_spec = FunctionSpec( 1639 fullargspec, 1640 function_spec.is_method, 1641 function_spec.input_signature, 1642 function_spec.is_pure, 1643 name=self._func_graph.name) 1644 1645 @property 1646 def variables(self): 1647 """Sequence of variables for this function.""" 1648 return tuple(self._func_graph.variables) 1649 1650 @property 1651 def trainable_variables(self): 1652 """Sequence of trainable variables for this function.""" 1653 return tuple(self._func_graph.trainable_variables) 1654 1655 def __call__(self, *args, **kwargs): 1656 """Executes the wrapped function. 1657 1658 ConcreteFunctions have two signatures: 1659 1660 * The signature of the original function wrapped by this ConcreteFunction. 1661 * A flat signature, where each argument accepts a single Tensor. 1662 1663 The original function signature is generally preferred, but the flat input 1664 signature is supported for backward compatibility. 1665 1666 ### Original Function Signature 1667 1668 When calling a ConcreteFunction with the signature of the original function, 1669 each argument must match the type or value that was used when the 1670 ConcreteFunction's graph was traced. In particular: 1671 1672 * Tensor arguments (including CompositeTensors, such as RaggedTensor) must 1673 have matching `TypeSpec`s. 1674 * Non-Tensor arguments (such as booleans or ints) must have equal values. 1675 * Nested arguments (such as lists, tuples, or dictionaries) must have the 1676 same nesting structure; and each nested value must have a matching type 1677 or value. 1678 1679 The default value for any arguments that were traced with non-Tensor values 1680 is the value that was used in the trace. Arguments that were traced with 1681 tensor arguments do not have a default value (even if the original function 1682 had a default value for that argument). 1683 1684 ### Flat Signature 1685 1686 When calling a ConcreteFunction with the flat signature, the arguments 1687 correspond to the flattened component tensors of the arguments that were 1688 used to construct the ConcreteFunction. Parameter names are assigned based 1689 on `TensorSpec.name` (when specified) or the original argument names (with 1690 suffixes automatically added for nested arguments or composite tensors with 1691 multiple components). 1692 1693 Args: 1694 *args: Positional arguments to the concrete function. 1695 **kwargs: Keyword arguments to the concrete function. 1696 1697 Returns: 1698 The result of applying the TF function on the given Tensors. 1699 1700 Raises: 1701 AssertionError: If this `ConcreteFunction` was not created through 1702 `get_concrete_function`. 1703 TypeError: If the arguments do not match the function's signature. 1704 """ 1705 return self._call_impl(args, kwargs) 1706 1707 def _call_impl(self, args, kwargs, cancellation_manager=None): 1708 """See `__call__` for details.""" 1709 with trace.Trace(self._func_graph.name, tf_function_call="concrete"): 1710 # Construct the list of input tensors: check if the structured signature 1711 # applies first; and if not, then use the flat signature. 1712 if self._function_spec is not None: 1713 try: 1714 return self._call_with_structured_signature(args, kwargs, 1715 cancellation_manager) 1716 except TypeError as structured_err: 1717 try: 1718 return self._call_with_flat_signature(args, kwargs, 1719 cancellation_manager) 1720 except TypeError: 1721 raise structured_err 1722 1723 return self._call_with_flat_signature(args, kwargs, cancellation_manager) 1724 1725 def _call_with_flat_signature(self, args, kwargs, cancellation_manager): 1726 """Executes the wrapped function with the flat signature. 1727 1728 Args: 1729 args: Positional arguments to the concrete function. 1730 kwargs: Keyword arguments to the concrete function. 1731 cancellation_manager: A `CancellationManager` that can be used to cancel 1732 function invocation. 1733 1734 Returns: 1735 The result of applying the function on the Tensors/Variables contained in 1736 `args` and `kwargs`. 1737 Raises: 1738 TypeError: if `args` and `kwargs` do not match the flat signature of this 1739 `ConcreteFunction`. 1740 """ 1741 if len(args) > self._num_positional_args: 1742 raise TypeError( 1743 f"{self._flat_signature_summary()} takes {self._num_positional_args} " 1744 f"positional arguments, got {len(args)}.") 1745 args = list(args) 1746 kwargs = dict(kwargs) 1747 for keyword in self._arg_keywords[len(args):]: 1748 try: 1749 args.append(kwargs.pop(compat.as_str(keyword))) 1750 except KeyError: 1751 specified_keywords = ( 1752 list(self._arg_keywords[:len(args)]) + list(kwargs.keys())) 1753 missing_required_args = sorted( 1754 set(self._arg_keywords) - set(specified_keywords)) 1755 raise TypeError(f"{self._flat_signature_summary()} missing required " 1756 f"arguments: {', '.join(missing_required_args)}.") 1757 if kwargs: 1758 positional_arg_keywords = set(self._arg_keywords[:len(args)]) 1759 for unused_key in kwargs: 1760 if unused_key in positional_arg_keywords: 1761 raise TypeError(f"{self._flat_signature_summary()} got two values " 1762 f"for '{unused_key}'.") 1763 raise TypeError(f"{self._flat_signature_summary()} got unexpected " 1764 f"keyword arguments: {', '.join(sorted(kwargs))}.") 1765 1766 for i, arg in enumerate(args): 1767 if not isinstance( 1768 arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 1769 raise TypeError(f"{self._flat_signature_summary()}: expected argument " 1770 f"#{i}(zero-based) to be a Tensor; " 1771 f"got {type(arg).__name__} ({arg}).") 1772 return self._call_flat(args, self.captured_inputs, cancellation_manager) 1773 1774 def _call_with_structured_signature(self, args, kwargs, cancellation_manager): 1775 """Executes the wrapped function with the structured signature. 1776 1777 Args: 1778 args: Positional arguments to the concrete function. 1779 kwargs: Keyword arguments to the concrete function. 1780 cancellation_manager: A `CancellationManager` that can be used to cancel 1781 function invocation. 1782 1783 Returns: 1784 The result of applying the function on the Tensors/Variables contained in 1785 `args` and `kwargs`. 1786 Raises: 1787 TypeError: if `args` and `kwargs` do not match the structured signature 1788 of this `ConcreteFunction`. 1789 """ 1790 args, kwargs, _, filtered_flat_args = \ 1791 self._function_spec.canonicalize_function_inputs(*args, **kwargs) 1792 self._structured_signature_check_missing_args(args, kwargs) 1793 self._structured_signature_check_unexpected_args(args, kwargs) 1794 self._structured_signature_check_arg_types(args, kwargs) 1795 return self._call_flat( 1796 filtered_flat_args, 1797 captured_inputs=self.captured_inputs, 1798 cancellation_manager=cancellation_manager) 1799 1800 def _structured_signature_check_missing_args(self, args, kwargs): 1801 """Raises a TypeError if any args are missing.""" 1802 arg_specs, kwarg_specs = self.structured_input_signature 1803 missing_arguments = [] 1804 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1805 if arg is _BOUND_VALUE and _contains_type_spec(spec): 1806 missing_arguments.append(self._function_spec.arg_names[i]) 1807 for (name, arg) in kwargs.items(): 1808 if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]): 1809 missing_arguments.append(name) 1810 if missing_arguments: 1811 raise TypeError(f"{self._structured_signature_summary()} missing " 1812 "required arguments: " 1813 f"{', '.join(sorted(missing_arguments))}.") 1814 1815 def _structured_signature_check_unexpected_args(self, args, kwargs): 1816 """Raises a TypeError if there are any extra args.""" 1817 arg_specs, kwarg_specs = self.structured_input_signature 1818 if len(args) > len(arg_specs): 1819 raise TypeError( 1820 f"{self._structured_signature_summary()} takes " 1821 f"{len(self._function_spec.arg_names)} positional arguments but got " 1822 f"{len(args)}.") 1823 if len(kwargs) > len(kwarg_specs): 1824 extra_args = set(kwargs) - set(kwarg_specs) 1825 raise TypeError(f"{self._structured_signature_summary()} got unexpected " 1826 f"keyword arguments: {', '.join(extra_args)}.") 1827 1828 def _structured_signature_check_arg_types(self, args, kwargs): 1829 """Raises a TypeError if any args have the wrong type.""" 1830 # Check argument types 1831 arg_specs, kwarg_specs = self.structured_input_signature 1832 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1833 name = self._function_spec.arg_names[i] 1834 self._structured_signature_check_arg_type(arg, spec, name) 1835 for (name, arg) in kwargs.items(): 1836 self._structured_signature_check_arg_type(arg, kwarg_specs[name], name) 1837 1838 def _structured_signature_check_arg_type(self, arg, spec, name): 1839 """Raise TypeError if `arg`'s type doesn't match `spec`.""" 1840 if arg is _BOUND_VALUE: 1841 return 1842 1843 # Check the overall nested structure of the argument. 1844 try: 1845 nest.assert_same_structure(arg, spec, expand_composites=True) 1846 except (ValueError, TypeError): 1847 try: 1848 nest.assert_same_structure(arg, spec, expand_composites=False) 1849 expected, got = spec, arg 1850 except (ValueError, TypeError): 1851 expected, got = _structure_summary(spec), _structure_summary(arg) 1852 raise TypeError(f"{self._structured_signature_summary()}: argument " 1853 f"{name} had incorrect type\n" 1854 f" expected: {expected}\n" 1855 f" got: {got}") 1856 1857 # Check the type for each leaf in the nested structure. 1858 arg_pieces = nest.flatten(arg, expand_composites=True) 1859 spec_pieces = nest.flatten(spec, expand_composites=True) 1860 for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces): 1861 # TODO(mdan): Use consistent error messages. 1862 if isinstance(spec_piece, tensor_spec.DenseSpec): 1863 # TODO(edloper): Consider calling convert_to_tensor on non-tensor 1864 # values here. That would match the behavior of 1865 # _call_concrete_function() in function_deserialization.py. If 1866 # we do, then we need to change the nest assert_same_structure and 1867 # flatten calls above to use shallow variants. 1868 tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable) 1869 if not isinstance(arg_piece, tensor_types): 1870 raise TypeError(f"{self._structured_signature_summary()} expected a " 1871 f"Tensor in {name}, but got " 1872 f"{type(arg_piece).__name__} value {arg_piece}.") 1873 elif arg_piece is not _BOUND_VALUE: 1874 try: 1875 arg_matches_spec = bool(arg_piece == spec_piece) 1876 except (ValueError, TypeError): 1877 logging.vlog(1, "Error matching value with spec", exc_info=True) 1878 arg_matches_spec = False 1879 if not arg_matches_spec: 1880 raise TypeError( 1881 f"ConcreteFunction {self._structured_signature_summary()} was " 1882 f"constructed with {type(spec_piece).__name__} value " 1883 f"{spec_piece} in {name}, but was called with " 1884 f"{type(arg_piece).__name__} value {arg_piece}.") 1885 1886 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 1887 """Executes the wrapped function. 1888 1889 Args: 1890 args: a list of Tensors or Variables. Arguments from the Python function 1891 should be filtered before calling this method: objects aside from 1892 Tensors, CompositeTensors, and Variables are ignored. Any 1893 CompositeTensors should be expanded before calling this method. 1894 captured_inputs: the captured inputs that are also part of the input args 1895 to the actual execution. By default, it should be self._captured_inputs. 1896 cancellation_manager: (Optional.) A `CancellationManager` that can be 1897 used to cancel function invocation. 1898 1899 Returns: 1900 The result of applying the TF function to `args`. 1901 1902 Raises: 1903 ValueError: If `args` contains anything other than Tensors or Variables. 1904 """ 1905 ctx = context.context() 1906 executing_eagerly = ctx.executing_eagerly() 1907 1908 # Copy saveable status of function's graph to current FuncGraph. 1909 default_graph = ops.get_default_graph() 1910 if default_graph.building_function and not self._func_graph.saveable: 1911 default_graph.mark_as_unsaveable(self._func_graph.saving_errors) 1912 1913 if (tape.could_possibly_record() or 1914 hasattr(default_graph, "watch_variable")): 1915 for v in self._func_graph.variables: 1916 resource_variable_ops.variable_accessed(v) 1917 1918 tensor_inputs = [] 1919 variables_used = set([]) 1920 for i, arg in enumerate(args): 1921 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1922 # We can pass a variable more than once, and in this case we need to 1923 # pass its handle only once. 1924 if id(arg.handle) in variables_used: 1925 continue 1926 resource_variable_ops.variable_accessed(arg) 1927 tensor_inputs.append(arg.handle) 1928 variables_used.add(id(arg.handle)) 1929 elif isinstance(arg, ops.Tensor): 1930 tensor_inputs.append(arg) 1931 if not executing_eagerly: 1932 # If we're graph building, shape inference is on. We check for input 1933 # compatibility up front to avoid hard to debug incompatibilities 1934 # later. 1935 graph_input_shape = tensor_shape.TensorShape( 1936 self._func_graph.inputs[i].shape) 1937 if not graph_input_shape.is_compatible_with(arg.shape): 1938 if self._arg_keywords: 1939 arg_name = "'{}'".format(self._arg_keywords[i]) 1940 else: 1941 arg_name = "with index {}".format(i) 1942 raise ValueError( 1943 f"The argument {arg_name} (value {arg}) is not compatible with " 1944 "the shape this function was traced with. Expected shape " 1945 f"{self._func_graph.inputs[i].shape}, but got shape " 1946 f"{arg.shape}.\n\nIf you called get_concrete_function, you may " 1947 "need to pass a tf.TensorSpec(..., shape=...) with a less " 1948 "specific shape, having None on axes which can vary.") 1949 else: 1950 raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got " 1951 f"{type(arg)} when calling {self._func_graph.name}.") 1952 args = tensor_inputs + captured_inputs 1953 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) 1954 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE 1955 and executing_eagerly): 1956 # No tape is watching; skip to running the function. 1957 return self._build_call_outputs(self._inference_function.call( 1958 ctx, args, cancellation_manager=cancellation_manager)) 1959 forward_backward = self._select_forward_and_backward_functions( 1960 args, 1961 possible_gradient_type, 1962 executing_eagerly) 1963 forward_function, args_with_tangents = forward_backward.forward() 1964 if executing_eagerly: 1965 flat_outputs = forward_function.call( 1966 ctx, args_with_tangents, cancellation_manager=cancellation_manager) 1967 else: 1968 with default_graph._override_gradient_function( # pylint: disable=protected-access 1969 {"PartitionedCall": self._get_gradient_function(), 1970 "StatefulPartitionedCall": self._get_gradient_function()}): 1971 flat_outputs = forward_function.call(ctx, args_with_tangents) 1972 forward_backward.record(flat_outputs) 1973 return self._build_call_outputs(flat_outputs) 1974 1975 def _experimental_with_cancellation_manager(self, cancellation_manager): 1976 """Returns a callable that invokes a cancellable version of this function. 1977 1978 Args: 1979 cancellation_manager: A `CancellationManager` object that can be used to 1980 cancel function invocation. 1981 1982 Returns: 1983 A callable with the same signature as this concrete function. 1984 """ 1985 1986 def cancellable_call(*args, **kwargs): 1987 return self._call_impl( 1988 args, kwargs, cancellation_manager=cancellation_manager) 1989 1990 return cancellable_call 1991 1992 @property 1993 def name(self): 1994 """`ConcreteFunction` name.""" 1995 return self._delayed_rewrite_functions.forward().name 1996 1997 @property 1998 def graph(self): 1999 """Returns the graph from which this function was constructed.""" 2000 return self._func_graph 2001 2002 @property 2003 def inputs(self): 2004 """Returns tensors in `self.graph` corresponding to arguments.""" 2005 return self._func_graph.inputs 2006 2007 @property 2008 def structured_input_signature(self): 2009 """Returns structured signature for this concrete function. 2010 2011 Returns: 2012 A tuple `(args, kwargs)`, where: 2013 2014 * `args` is a tuple that specifies the expected type or value each for 2015 positional argument. 2016 * `kwargs` is a dictionary that specifies the expected type or value 2017 for each keyword-only argument. 2018 2019 The type or value for each argument is specified using one of the 2020 following: 2021 2022 * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native 2023 value is expected. 2024 * A Python value, such as an integer, indicating that an equal value 2025 is expected. 2026 * A nested structure of `tf.TypeSpec`s and Python values, indicating 2027 that a corresponding nested structure is expected. 2028 """ 2029 return self._func_graph.structured_input_signature 2030 2031 @property 2032 def outputs(self): 2033 """Returns tensors in `self.graph` corresponding to returned tensors.""" 2034 return self._func_graph.outputs 2035 2036 @property 2037 def structured_outputs(self): 2038 """Returns outputs in `self.graph` as returned by the original function.""" 2039 return self._func_graph.structured_outputs 2040 2041 @property 2042 def captured_inputs(self): 2043 """Returns external Tensors captured by this function. 2044 2045 self.__call__(*args) passes `args + self.captured_inputs` to the function. 2046 """ 2047 from_closures = nest.flatten([x() for x in self._captured_closures], 2048 expand_composites=True) 2049 return self._captured_inputs + from_closures 2050 2051 @property 2052 def function_def(self): 2053 """Returns a `FunctionDef` object representing this function.""" 2054 return self._delayed_rewrite_functions.forward().definition 2055 2056 @property 2057 def output_shapes(self): 2058 """The function's output shapes.""" 2059 return nest.map_structure( 2060 lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)), 2061 composite_tensor.replace_composites_with_components( 2062 self._func_graph.structured_outputs), 2063 expand_composites=False) 2064 2065 @property 2066 def output_dtypes(self): 2067 # TODO(akshayka): Consider removing this. 2068 return nest.map_structure( 2069 lambda x: x.dtype if x is not None else None, 2070 composite_tensor.replace_composites_with_components( 2071 self._func_graph.structured_outputs), 2072 expand_composites=False) 2073 2074 def add_to_graph(self, g=None): 2075 """Registers the function, adds it to the graph g or default graph. 2076 2077 Args: 2078 g: If specified, registers the function with this graph. Defaults to the 2079 current context (either the default graph or the eager context). 2080 """ 2081 # If we are not executing eagerly, adds the function to default graph if no 2082 # graph is specified. 2083 # In case of eager execution, function definition gets added to context 2084 # during construction itself. 2085 2086 if not context.executing_eagerly() and not g: 2087 g = ops.get_default_graph() 2088 self._delayed_rewrite_functions.forward().add_to_graph(g) 2089 2090 def add_gradient_functions_to_graph(self, g=None): 2091 """Add forward/backward functions to graph `g` or the current context.""" 2092 if not context.executing_eagerly() and not g: 2093 g = ops.get_default_graph() 2094 self._delayed_rewrite_functions.forward().add_to_graph(g) 2095 forward_function, backward_function = ( 2096 self._delayed_rewrite_functions.forward_backward()) 2097 forward_function.add_to_graph(g) 2098 backward_function.add_to_graph(g) 2099 2100 def _get_gradient_function(self): 2101 """Returns gradient function. It will be lazily created at first call.""" 2102 return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access 2103 2104 def _select_forward_and_backward_functions( 2105 self, args, possible_gradient_type, executing_eagerly): 2106 """Selects forward and backward functions based on the calling context. 2107 2108 The forward function computes the "real" function outputs, `self._outputs`, 2109 and any extra values needed by the corresponding backward function. 2110 2111 Args: 2112 args: A flat list of Tensors with all of the inputs to the forward 2113 function (including user-specified and captured inputs). 2114 possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. 2115 executing_eagerly: Boolean, the value of context.executing_eagerly(). 2116 2117 Returns: 2118 An object with a `forward` method returning a tuple of (forward_function : 2119 _EagerDefinedFunction, augmented_arguments : List), and a corresponding 2120 `record` method which takes outputs from the forward function and records 2121 the operation. forward_function should be called with augmented_arguments. 2122 """ 2123 if executing_eagerly: 2124 input_tangents = forwardprop_util.pack_tangents(args) 2125 else: 2126 input_tangents = forwardprop_util.TangentInfo() 2127 need_gradients_for_jvps = tape.should_record_backprop( 2128 input_tangents.tangents) 2129 # Allows re-use of forward and backward function pairs depending on the 2130 # tapes and forward accumulators watching its inputs. 2131 cache_key = (need_gradients_for_jvps, input_tangents.indices) 2132 if (possible_gradient_type 2133 == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): 2134 if input_tangents.indices or executing_eagerly: 2135 # There is a single non-persistent tape active, so the user can only 2136 # request first-order gradients from a tape. We can spend less time 2137 # graph building since we know this. 2138 # 2139 # We may still end up computing higher-order gradients, but that'd be 2140 # through `tf.gradients`, which can re-write the forward pass and so 2141 # needs no preparation here. 2142 functions = self._first_order_tape_functions.get(cache_key, None) 2143 if functions is None: 2144 functions = _FirstOrderTapeGradientFunctions( 2145 self._func_graph, self._attrs, self._garbage_collector, 2146 forwardprop_input_indices=input_tangents.indices, 2147 delayed_rewrite_functions=self._delayed_rewrite_functions, 2148 need_gradients_for_jvps=need_gradients_for_jvps) 2149 self._first_order_tape_functions[cache_key] = functions 2150 return _ForwardBackwardCall( 2151 functions, args, input_tangents.tangents, tape_watching=True) 2152 else: 2153 # We can avoid computing second-order gradients in some cases by doing a 2154 # delayed rewrite when graph building. Since we know we'll only compute 2155 # first-order tape gradients, the delayed rewrite is safe: we won't need 2156 # to tell the tape about side outputs. 2157 # 2158 # TODO(allenl): This case is really dirty. It would be better if we 2159 # could temporarily pop all of the current tapes to avoid 2160 # accidentally taking second-order gradients. 2161 return _ForwardBackwardCall( 2162 self._delayed_rewrite_functions, args, input_tangents.tangents, 2163 tape_watching=True) 2164 elif (possible_gradient_type 2165 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): 2166 # Either there's a persistent tape watching, or there are multiple nested 2167 # tapes. Either way, the user may request higher-order gradients. We'll 2168 # spend a bit more time and make sure higher-order gradients are correct. 2169 functions = self._higher_order_tape_functions.get( 2170 cache_key, None) 2171 if functions is None: 2172 functions = _HigherOrderTapeGradientFunctions( 2173 self._func_graph, self._attrs, self._garbage_collector, 2174 forwardprop_input_indices=input_tangents.indices, 2175 delayed_rewrite_functions=self._delayed_rewrite_functions, 2176 need_gradients_for_jvps=need_gradients_for_jvps) 2177 self._higher_order_tape_functions[cache_key] = functions 2178 return _ForwardBackwardCall(functions, args, input_tangents.tangents, 2179 tape_watching=True) 2180 # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no 2181 # tape is recording. 2182 return _ForwardBackwardCall( 2183 self._delayed_rewrite_functions, args, input_tangents.tangents, 2184 tape_watching=False) 2185 2186 def _build_call_outputs(self, result): 2187 """Maps the fdef output list to actual output structure. 2188 2189 Args: 2190 result: Output lists defined by FunctionDef. 2191 Returns: 2192 The actual call output. 2193 """ 2194 # TODO(jlchu): call C++ version in function.cc when speed is improved 2195 if self._func_graph.structured_outputs is None: 2196 return result 2197 2198 # Replace outputs with results, skipping over any 'None' values. 2199 outputs_list = nest.flatten( 2200 self._func_graph.structured_outputs, expand_composites=True) 2201 j = 0 2202 for i, o in enumerate(outputs_list): 2203 if o is not None: 2204 handle_data_util.copy_handle_data(self.outputs[j], result[j]) 2205 outputs_list[i] = result[j] 2206 j += 1 2207 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, 2208 outputs_list, expand_composites=True) 2209 return ret 2210 2211 @property 2212 def _as_name_attr_list(self): 2213 """Returns a `NameAttrList` representing this function.""" 2214 ret = attr_value_pb2.NameAttrList(name=self.name) 2215 for name, value in self._attrs.items(): 2216 ret.attr[name].CopyFrom(value) 2217 return ret 2218 2219 def _structured_signature_summary(self, default_values=False): 2220 """Returns a string summarizing this function's structured signature. 2221 2222 Args: 2223 default_values: If true, then include default values in the signature. 2224 2225 Returns: 2226 A `string`. 2227 """ 2228 # Note: we can't just use self._funcion_spec.signature_summary(), because 2229 # that would show "_BOUND_VALUE" as the default value for all arguments. 2230 assert self._function_spec is not None 2231 arg_specs, kwarg_specs = self.structured_input_signature 2232 arg_names = list(self._function_spec.arg_names) 2233 2234 # If an explicit input_signature is provided to @tf.function, then any 2235 # arguments with defaults that are not covered by that explicit signature 2236 # are simply dropped from the signature. 2237 # TODO(b/159639913) Look into whether dropping arguments with default values 2238 # from the signature is the right thing to do. 2239 arg_names = arg_names[:len(arg_specs)] 2240 2241 if default_values: 2242 for i in range(len(arg_names)): 2243 if not _contains_type_spec(arg_specs[i]): 2244 arg_names[i] += "={}".format(arg_specs[i]) 2245 if kwarg_specs: 2246 arg_names.append("*") 2247 for name, spec in kwarg_specs.items(): 2248 arg_names.append(name) 2249 if default_values and not _contains_type_spec(spec): 2250 arg_names[-1] += "={}".format(spec) 2251 signature = f"{self._func_graph.name}({', '.join(arg_names)})" 2252 2253 return signature 2254 2255 def _flat_signature_summary(self): 2256 """Returns a string summarizing this function's flat signature.""" 2257 assert self._arg_keywords is not None 2258 assert self._num_positional_args is not None 2259 arg_names = self._arg_keywords 2260 if self._num_positional_args > len(arg_names): 2261 arg_names.extend( 2262 "<arg{}>".format(i + 1) 2263 for i in range(len(arg_names), self._num_positional_args)) 2264 return f"{self._func_graph.name}({', '.join(arg_names)})" 2265 2266 def pretty_printed_signature(self, verbose=True): 2267 """Returns a string summarizing the signature of this concrete function.""" 2268 if not verbose: 2269 return self._structured_signature_summary(default_values=True) 2270 2271 def pretty_print_spec(spec): 2272 """Returns a string describing the spec for a single argument.""" 2273 if isinstance(spec, tensor_spec.TensorSpec): 2274 return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) 2275 elif nest.is_sequence(spec): 2276 pieces = nest.flatten(spec, expand_composites=False) 2277 markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))] 2278 structure = nest.pack_sequence_as(spec, markers) 2279 # Ensure dictionaries are sorted by key (for determinism) 2280 result = pprint.pformat(structure, width=10000) 2281 for (marker, piece) in zip(markers, pieces): 2282 result += "\n {}: {}".format(marker, pretty_print_spec(piece)) 2283 return result 2284 else: 2285 return repr(spec) 2286 2287 lines = [self._structured_signature_summary(default_values=True)] 2288 arg_specs, kwarg_specs = self.structured_input_signature 2289 names = list(self._function_spec.arg_names) 2290 2291 # If an explicit input_signature is provided to @tf.function, then any 2292 # arguments with defaults that are not covered by that explicit signature 2293 # are simply dropped from the signature. 2294 # TODO(b/159639913) Look into whether dropping arguments with default values 2295 # from the signature is the right thing to do. 2296 2297 # Note: we can skip bound args, since we already displayed their bound 2298 # value in the signature summary. 2299 arg_details = [] 2300 for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)): 2301 if _contains_type_spec(spec): 2302 arg_details.append(" {}: {}".format(name, pretty_print_spec(spec))) 2303 2304 if kwarg_specs: 2305 for kwarg in sorted(kwarg_specs): 2306 spec = kwarg_specs[kwarg] 2307 if _contains_type_spec(spec): 2308 arg_details.append(" {}: {}".format( 2309 kwarg, pretty_print_spec(spec))) 2310 2311 if arg_details: 2312 lines.append(" Args:") 2313 lines.extend(arg_details) 2314 lines.append(" Returns:") 2315 2316 def spec_from_value(value): 2317 # For loaded function, structured_outputs are already specs. 2318 if isinstance(value, type_spec.TypeSpec): 2319 return value 2320 return type_spec.type_spec_from_value(value) 2321 2322 lines.append(" {}".format( 2323 pretty_print_spec( 2324 nest.map_structure(spec_from_value, self.structured_outputs)))) 2325 2326 return "\n".join(lines) 2327 2328 def __repr__(self): 2329 if self._function_spec is not None: 2330 return "<ConcreteFunction {} at 0x{:X}>".format( 2331 self.pretty_printed_signature(verbose=False), id(self)) 2332 elif not (self._num_positional_args is None or self._arg_keywords is None): 2333 return "<ConcreteFunction {} at 0x{:X}>".format( 2334 self._flat_signature_summary(), id(self)) 2335 else: 2336 return object.__repr__(self) 2337 2338 def __str__(self): 2339 if self._function_spec is not None: 2340 return "ConcreteFunction {}".format(self.pretty_printed_signature()) 2341 else: 2342 return self.__repr__() 2343 2344 2345_pywrap_utils.RegisterType("Tensor", ops.Tensor) 2346_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) 2347_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices) 2348 2349 2350def _deterministic_dict_values(dictionary): 2351 return tuple(dictionary[key] for key in sorted(dictionary)) 2352 2353 2354class FunctionSpec(object): 2355 """Specification of how to bind arguments to a function.""" 2356 2357 @staticmethod 2358 def from_function_and_signature(python_function, 2359 input_signature, 2360 is_pure=False, 2361 experimental_follow_type_hints=False, 2362 jit_compile=None): 2363 """Create a FunctionSpec instance given a python function and signature. 2364 2365 Args: 2366 python_function: a function to inspect 2367 input_signature: a signature of the function (None, if variable) 2368 is_pure: if True all input arguments (including variables and constants) 2369 will be converted to tensors and no variable changes allowed. 2370 experimental_follow_type_hints: see `tf.function` 2371 jit_compile: see `tf.function` 2372 2373 Returns: 2374 instance of FunctionSpec 2375 """ 2376 fullargspec = tf_inspect.getfullargspec(python_function) 2377 # Checks if the `fullargspec` contains self or cls as its first argument. 2378 is_method = tf_inspect.isanytargetmethod(python_function) 2379 2380 # Treat a wrapped partial function as a special case. For all arguments that 2381 # were overridden with keywords in the partial: 2382 # - remove the corresponding arguments, 2383 # - remove the corresponding keywords. 2384 _, unwrapped = tf_decorator.unwrap(python_function) 2385 # TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and 2386 # fullargspec.kwonlydefaults. 2387 if isinstance(unwrapped, functools.partial): 2388 # Also consider the Python3 case with kwonlydefaults. 2389 if fullargspec.defaults or fullargspec.kwonlydefaults: 2390 new_defaults = fullargspec.defaults 2391 new_args = fullargspec.args 2392 if fullargspec.defaults: 2393 # To be able to canonicalize the function properly, we want to ignore 2394 # default values that are overridden via a partial kwarg. For example: 2395 # 2396 # def func(a, b, c, d=5, e=7): 2397 # return a, b, c, d, e 2398 # p_func = tf.function(functools.partial(func, 10, e=9)) 2399 # 2400 # Here we want to drop from the defaults the parameter `e`. If we 2401 # forwarded the call to the partial function with a default for `e` 2402 # we would get an error for passing two values for one parameter. 2403 # 2404 # Note that this has a limitation: we can only override parameters at 2405 # the end of the parameter list. 2406 # 2407 # In this case we want to end up with 3 arguments (b, c, d) and 1 2408 # default value (5). We do this by constructing a mask where 0 stands 2409 # for a value that was overridden by a partial kwarg. The seemingly 2410 # complicated logic below does just that - for arguments (b, c, d, e) 2411 # we would get a mask (1, 1, 1, 0). 2412 old_args = fullargspec.args 2413 old_defaults = fullargspec.defaults 2414 2415 no_default = object() 2416 num_args_without_defaults = len(old_args) - len(old_defaults) 2417 left_padding = tuple([no_default] * num_args_without_defaults) 2418 2419 args_with_defaults = zip(old_args, left_padding + old_defaults) 2420 2421 # Create a mask where 0 stands for args that had a partial kwarg 2422 # defined. 2423 non_keyword_defaults_mask = [ 2424 0 if key in unwrapped.keywords else 1 for key in old_args 2425 ] 2426 # Keep only arguments and defaults that were not kwargs of partial. 2427 new_args_with_defaults = list( 2428 itertools.compress(args_with_defaults, non_keyword_defaults_mask)) 2429 # Keep all args. 2430 new_args = [arg for arg, _ in new_args_with_defaults] 2431 # Keep only real default values. 2432 new_defaults = [ 2433 default for _, default in new_args_with_defaults 2434 if default is not no_default 2435 ] 2436 fullargspec = tf_inspect.FullArgSpec( 2437 args=new_args, 2438 varargs=fullargspec.varargs, 2439 varkw=fullargspec.varkw, 2440 defaults=new_defaults, 2441 kwonlyargs=[], 2442 kwonlydefaults={}, 2443 annotations=fullargspec.annotations) 2444 2445 # Get the function's name. Remove functools.partial wrappers if necessary. 2446 while isinstance(python_function, functools.partial): 2447 python_function = python_function.func 2448 name = getattr(python_function, "__name__", "f") 2449 2450 return FunctionSpec( 2451 fullargspec, 2452 is_method, 2453 input_signature, 2454 is_pure=is_pure, 2455 jit_compile=jit_compile, 2456 experimental_follow_type_hints=experimental_follow_type_hints, 2457 name=name) 2458 2459 def __init__(self, 2460 fullargspec, 2461 is_method, 2462 input_signature, 2463 is_pure=False, 2464 experimental_follow_type_hints=False, 2465 name=None, 2466 jit_compile=None): 2467 """Constructs a FunctionSpec describing a python function. 2468 2469 Args: 2470 fullargspec: `tf_inspect.FullArgSpec` object describing the function. 2471 is_method: True if the function is a method. 2472 input_signature: a signature of the function (None, if variable) 2473 is_pure: if True all input arguments (including variables and constants) 2474 will be converted to tensors and no variable changes allowed. 2475 experimental_follow_type_hints: see `tf.function`. 2476 name: Name of the function 2477 jit_compile: see `tf.function`. 2478 """ 2479 self._fullargspec = fullargspec 2480 self._is_method = is_method 2481 self._is_pure = is_pure 2482 self._jit_compile = jit_compile 2483 self._experimental_follow_type_hints = experimental_follow_type_hints 2484 2485 # TODO(edloper): Include name when serializing for SavedModel? 2486 self._name = name or "f" 2487 2488 if self._is_method: 2489 # Remove `self`: default arguments shouldn't be matched to it. 2490 # TODO(b/127938157): Should this error out if there is no arg to 2491 # be removed? 2492 args = fullargspec.args[1:] 2493 else: 2494 args = fullargspec.args 2495 2496 # A cache mapping from argument name to index, for canonicalizing 2497 # arguments that are called in a keyword-like fashion. 2498 self._args_to_indices = {arg: i for i, arg in enumerate(args)} 2499 self._arg_names = args 2500 2501 # A cache mapping from arg index to default value, for canonicalization. 2502 default_values = fullargspec.defaults 2503 offset = len(args) - len(default_values or []) 2504 self._arg_indices_to_default_values = { 2505 offset + index: default 2506 for index, default in enumerate(default_values or []) 2507 } 2508 self._arg_indices_no_default_values = set(range(len(args))) - set( 2509 self._arg_indices_to_default_values) 2510 if input_signature is None: 2511 self._input_signature = None 2512 else: 2513 if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()): 2514 raise ValueError("Cannot define a TensorFlow function from a Python " 2515 "function with keyword-only arguments when " 2516 "input_signature is provided.") 2517 2518 if not isinstance(input_signature, (tuple, list)): 2519 raise TypeError(f"input_signature must be either a tuple or a " 2520 f"list, got {type(input_signature)}.") 2521 2522 self._input_signature = tuple(input_signature) 2523 self._flat_input_signature = tuple(nest.flatten(input_signature, 2524 expand_composites=True)) 2525 2526 @property 2527 def fullargspec(self): 2528 return self._fullargspec 2529 2530 @property 2531 def is_method(self): 2532 return self._is_method 2533 2534 @property 2535 def args_to_indices(self): 2536 return self._args_to_indices 2537 2538 @property 2539 def kwargs_to_include(self): 2540 return self._kwargs_to_include 2541 2542 @property 2543 def input_signature(self): 2544 return self._input_signature 2545 2546 @property 2547 def flat_input_signature(self): 2548 return self._flat_input_signature 2549 2550 @property 2551 def is_pure(self): 2552 return self._is_pure 2553 2554 @property 2555 def jit_compile(self): 2556 return self._jit_compile 2557 2558 @property 2559 def arg_names(self): 2560 return self._arg_names 2561 2562 @property 2563 def vararg_name(self): 2564 return self._fullargspec.varargs 2565 2566 @property 2567 def varkw_name(self): 2568 return self._fullargspec.varkw 2569 2570 def signature_summary(self, default_values=False): 2571 """Returns a string summarizing this function's signature. 2572 2573 Args: 2574 default_values: If true, then include default values in the signature. 2575 2576 Returns: 2577 A `string`. 2578 """ 2579 args = list(self._arg_names) 2580 if default_values: 2581 for (i, default) in self._arg_indices_to_default_values.items(): 2582 args[i] += "={}".format(default) 2583 if self._fullargspec.kwonlyargs: 2584 args.append("*") 2585 for arg_name in self._fullargspec.kwonlyargs: 2586 args.append(arg_name) 2587 if default_values and arg_name in self._fullargspec.kwonlydefaults: 2588 args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name]) 2589 return f"{self._name}({', '.join(args)})" 2590 2591 def _to_tensor_or_tensor_spec(self, x): 2592 return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) 2593 else ops.convert_to_tensor(x)) 2594 2595 def _convert_variables_to_tensors(self, args, kwargs): 2596 args = [self._to_tensor_or_tensor_spec(x) for x in args] 2597 kwargs = {kw: self._to_tensor_or_tensor_spec(x) 2598 for kw, x in kwargs.items()} 2599 return tuple(args), kwargs 2600 2601 def _convert_annotated_args_to_tensors(self, args, kwargs): 2602 """Attempts to autobox arguments annotated as tf.Tensor.""" 2603 if self.input_signature is not None: 2604 return 2605 2606 args = list(args) 2607 for i, arg in enumerate(args): 2608 # See 2609 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 2610 if i < len(self._fullargspec.args): 2611 annotation_key = self._fullargspec.args[i] 2612 else: 2613 annotation_key = self._fullargspec.varargs 2614 arg_annotation = self._fullargspec.annotations.get(annotation_key, None) 2615 2616 # TODO(rahulkamat): Change to TensorLike (here ans below) 2617 if arg_annotation == ops.Tensor: 2618 args[i] = self._to_tensor_or_tensor_spec(arg) 2619 2620 for kw, v in kwargs.items(): 2621 if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args: 2622 annotation_key = kw 2623 else: 2624 annotation_key = self._fullargspec.varkw 2625 kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None) 2626 if kwarg_annotation == ops.Tensor: 2627 kwargs[kw] = self._to_tensor_or_tensor_spec(v) 2628 return tuple(args), kwargs 2629 2630 def _validate_inputs(self, flat_inputs): 2631 """Raises an error if inputs contain illegal values.""" 2632 for inp in flat_inputs: 2633 # TODO(b/183107079): Allow these once they're handled properly. 2634 if isinstance(inp, weakref.ref): 2635 raise ValueError( 2636 f"weakref input {inp} not supported for function {self._name}") 2637 2638 def canonicalize_function_inputs(self, *args, **kwargs): 2639 """Canonicalizes `args` and `kwargs`. 2640 2641 Canonicalize the inputs to the Python function using a `FunctionSpec` 2642 instance. In particular, we parse the varargs and kwargs that the 2643 original function was called with into a tuple corresponding to the 2644 Python function's positional (named) arguments and a dictionary 2645 corresponding to its kwargs. Missing default arguments are added. 2646 2647 If this `FunctionSpec` has an input signature, then it is used to convert 2648 arguments to tensors; otherwise, any inputs containing numpy arrays are 2649 converted to tensors. 2650 2651 Additionally, any inputs containing numpy arrays are converted to Tensors. 2652 2653 Args: 2654 *args: The varargs this object was called with. 2655 **kwargs: The keyword args this function was called with. 2656 2657 Returns: 2658 A canonicalized ordering of the inputs, as well as full and filtered 2659 (Tensors and Variables only) versions of their concatenated flattened 2660 representations, represented by a tuple in the form (args, kwargs, 2661 flat_args, filtered_flat_args). Here: `args` is a full list of bound 2662 arguments, and `kwargs` contains only true keyword arguments, as opposed 2663 to named arguments called in a keyword-like fashion. 2664 2665 Raises: 2666 ValueError: If a keyword in `kwargs` cannot be matched with a positional 2667 argument when an input signature is specified, or when the inputs 2668 do not conform to the input signature. 2669 """ 2670 if self._is_pure: 2671 args, kwargs = self._convert_variables_to_tensors(args, kwargs) 2672 if self._experimental_follow_type_hints: 2673 args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs) 2674 # Pre-calculate to reduce overhead 2675 arglen = len(args) 2676 if self._input_signature is not None: 2677 if arglen > len(self._input_signature): 2678 raise TypeError(f"{self.signature_summary()} specifies " 2679 f"{len(self._input_signature)} positional arguments, " 2680 f"but got {arglen}.") 2681 for arg in six.iterkeys(kwargs): 2682 index = self._args_to_indices.get(arg, None) 2683 if index is None: 2684 raise TypeError(f"{self.signature_summary()} got unexpected keyword " 2685 f"argument `{arg}`.") 2686 if index >= len(self._input_signature): 2687 raise TypeError( 2688 f"{self.signature_summary()} got keyword argument `{arg}` that " 2689 "was not included in input_signature.") 2690 2691 if not kwargs: 2692 inputs = args 2693 if self._arg_indices_to_default_values: 2694 try: 2695 inputs += tuple(self._arg_indices_to_default_values[i] 2696 for i in range(arglen, len(self._arg_names))) 2697 except KeyError: 2698 missing_args = [ 2699 self._arg_names[i] 2700 for i in range(arglen, len(self._arg_names)) 2701 if i not in self._arg_indices_to_default_values 2702 ] 2703 raise TypeError(f"{self.signature_summary()} missing required " 2704 f"arguments: {', '.join(missing_args)}.") 2705 2706 if self._fullargspec.kwonlydefaults: 2707 kwargs.update(self._fullargspec.kwonlydefaults) 2708 else: 2709 # Maps from index of arg to its corresponding value, according to `args` 2710 # and `kwargs`; seeded with the default values for the named args that 2711 # aren't in `args`. 2712 arg_indices_to_values = { 2713 index: default for index, default in six.iteritems( 2714 self._arg_indices_to_default_values) if index >= arglen 2715 } 2716 consumed_args = [] 2717 missing_arg_indices = self._arg_indices_no_default_values - set( 2718 range(arglen)) 2719 for arg, value in six.iteritems(kwargs): 2720 index = self._args_to_indices.get(arg, None) 2721 if index is not None: 2722 if index < arglen: 2723 raise TypeError(f"{self.signature_summary()} got two values for " 2724 f"{arg!r}.") 2725 arg_indices_to_values[index] = value 2726 # These arguments in 'kwargs' might also belong to 2727 # positional arguments 2728 missing_arg_indices.discard(index) 2729 consumed_args.append(arg) 2730 for arg in consumed_args: 2731 # After this loop, `kwargs` will only contain keyword_only arguments, 2732 # and all positional_or_keyword arguments have been moved to `inputs`. 2733 kwargs.pop(arg) 2734 inputs = args + _deterministic_dict_values(arg_indices_to_values) 2735 # Exclude positional args with values 2736 if missing_arg_indices: 2737 missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)] 2738 if len(missing_args) == 1: 2739 raise TypeError(f"{self.signature_summary()} missing 1 required " 2740 f"argument: {missing_args[0]}.") 2741 else: 2742 raise TypeError(f"{self.signature_summary()} missing required " 2743 f"arguments: {', '.join(missing_args)}.") 2744 2745 if kwargs and self._input_signature is not None: 2746 raise TypeError("Keyword arguments are not supported when " 2747 "input_signature is provided. Signature: " 2748 f"{self.signature_summary()}.") 2749 2750 if self._fullargspec.kwonlydefaults: 2751 for (kwarg, default) in self._fullargspec.kwonlydefaults.items(): 2752 kwargs.setdefault(kwarg, default) 2753 2754 if self._input_signature is None: 2755 inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs) 2756 kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs) 2757 flat_inputs += flat_kwargs 2758 filtered_flat_inputs += filtered_flat_kwargs 2759 else: 2760 assert not kwargs 2761 inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature( 2762 inputs, self._input_signature, self._flat_input_signature) 2763 2764 self._validate_inputs(flat_inputs) 2765 2766 return inputs, kwargs, flat_inputs, filtered_flat_inputs 2767 2768 2769def _as_ndarray(value): 2770 """Converts value to an ndarray, assumes _is_ndarray(value).""" 2771 # TODO(tomhennigan) Support __array_interface__ too. 2772 return value.__array__() 2773 2774 2775def _is_ndarray(value): 2776 """Tests whether the given value is an ndarray (and not a TF tensor/var).""" 2777 # TODO(tomhennigan) Support __array_interface__ too. 2778 return hasattr(value, "__array__") and not ( 2779 isinstance(value, ops.Tensor) 2780 or isinstance(value, resource_variable_ops.BaseResourceVariable) 2781 or hasattr(value, "_should_act_as_resource_variable") 2782 2783 # For legacy reasons we do not automatically promote Numpy strings. 2784 or isinstance(value, np.str_) 2785 # NumPy dtypes have __array__ as unbound methods. 2786 or isinstance(value, type) 2787 # CompositeTensors should be flattened instead. 2788 or isinstance(value, composite_tensor.CompositeTensor)) 2789 2790 2791def _convert_numpy_inputs(inputs): 2792 """Convert numpy array inputs to tensors.""" 2793 # We assume that any CompositeTensors have already converted their components 2794 # from numpy arrays to Tensors, so we don't need to expand composites here for 2795 # the numpy array conversion. Instead, we do so because the flattened inputs 2796 # are eventually passed to ConcreteFunction()._call_flat, which requires 2797 # expanded composites. 2798 flat_inputs = nest.flatten(inputs, expand_composites=True) 2799 2800 # Check for NumPy arrays in arguments and convert them to Tensors. 2801 # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps 2802 # finding a way to store them directly in the cache key (currently not 2803 # possible since ndarrays are not hashable). 2804 need_packing = False 2805 filtered_flat_inputs = [] 2806 for index, value in enumerate(flat_inputs): 2807 if isinstance(value, 2808 (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 2809 filtered_flat_inputs.append(value) 2810 elif hasattr(value, "__array__") and not ( 2811 hasattr(value, "_should_act_as_resource_variable") or 2812 isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))): 2813 # This case is equivalent to _is_ndarray(value) == True 2814 a = _as_ndarray(value) 2815 if not isinstance(a, np.ndarray): 2816 raise TypeError(f"The output of __array__ must be an np.ndarray, " 2817 f"got {type(a)} from {value}.") 2818 flat_inputs[index] = constant_op.constant(a) 2819 filtered_flat_inputs.append(flat_inputs[index]) 2820 need_packing = True 2821 if need_packing: 2822 return (nest.pack_sequence_as( 2823 structure=inputs, flat_sequence=flat_inputs, 2824 expand_composites=True), flat_inputs, filtered_flat_inputs) 2825 else: 2826 return inputs, flat_inputs, filtered_flat_inputs 2827 2828 2829def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): 2830 """Convert inputs to pass into a function with an explicit signature.""" 2831 2832 def format_error_message(inputs, input_signature): 2833 return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) + 2834 ")\n" + " input_signature: (\n" + " " + 2835 ",\n ".join(str(i) for i in input_signature) + ")") 2836 2837 try: 2838 flatten_inputs = nest.flatten_up_to( 2839 input_signature, 2840 inputs[:len(input_signature)], 2841 expand_composites=True, 2842 check_types=False) # lists are convert to tuples for `tf.data`. 2843 except ValueError: 2844 raise ValueError("Structure of Python function inputs does not match " 2845 "input_signature:\n" 2846 f"{format_error_message(inputs, input_signature)}.") 2847 2848 need_packing = False 2849 for index, (value, spec) in enumerate(zip(flatten_inputs, 2850 flat_input_signature)): 2851 if (isinstance(spec, tensor_spec.TensorSpec) and 2852 not _pywrap_utils.IsTensor(value)): 2853 try: 2854 flatten_inputs[index] = ops.convert_to_tensor( 2855 value, dtype_hint=spec.dtype) 2856 need_packing = True 2857 except ValueError: 2858 raise ValueError("When input_signature is provided, all inputs to " 2859 "the Python function must be convertible to " 2860 "tensors:\n" 2861 f"{format_error_message(inputs, input_signature)}.") 2862 2863 if any(not spec.is_compatible_with(other) for spec, other in zip( 2864 flat_input_signature, 2865 flatten_inputs)): 2866 raise ValueError("Python inputs incompatible with input_signature:\n" 2867 f"{format_error_message(inputs, input_signature)}.") 2868 2869 if need_packing: 2870 inputs = nest.pack_sequence_as( 2871 structure=input_signature, 2872 flat_sequence=flatten_inputs, 2873 expand_composites=True) 2874 2875 flat_inputs = nest.flatten(inputs, expand_composites=True) 2876 2877 return (inputs, flat_inputs, [ 2878 t for t in flat_inputs 2879 if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 2880 ]) 2881 2882 2883class FunctionCache(object): 2884 """A lightweight container for cached functions. 2885 """ 2886 2887 __slots__ = [ 2888 "missed", "primary", "arg_relaxed_specs", "arg_relaxed", 2889 "_garbage_collectors" 2890 ] 2891 2892 def __init__(self): 2893 # The set of functions that have been missed; entries are CacheKey with 2894 # input_signature `None` (e.g. a "call context key") 2895 self.missed = set() 2896 # The primary cache, mapping a fully shaped CacheKey to a function. 2897 self.primary = collections.OrderedDict() 2898 # A cache key lookup, mapping a CacheKey generated without shape info to a 2899 # flat list of `TypeSpec`s with relaxed shapes (one for each flattened 2900 # argument). Arguments that are not Tensors or `CompositeTensor`s contain a 2901 # `None` for the corresponding relaxed spec. 2902 self.arg_relaxed_specs = collections.OrderedDict() 2903 # The secondary cache, mapping a CacheKey generated without shape info to a 2904 # function. 2905 self.arg_relaxed = collections.OrderedDict() 2906 # All OrderedDicts require manual garbage collection. 2907 self._garbage_collectors = [ 2908 _FunctionGarbageCollector(self.primary), 2909 _FunctionGarbageCollector(self.arg_relaxed), 2910 _FunctionGarbageCollector(self.arg_relaxed_specs)] 2911 2912 def all_values(self): 2913 """A list of all `ConcreteFunction` instances held by this cache.""" 2914 # We need to simultaneously make sure our returned concrete functions are 2915 # unique *and* make sure they are returned in a deterministic order for 2916 # serialization. 2917 # 2918 # TODO(b/174215821): It's likely that we ultimately would just prefer to 2919 # choose the most specific concrete function shape given a set of 2920 # arguments. If and when that is implemented, this logic can be revisited. 2921 primary_functions = set(self.primary.values()) 2922 return list(self.primary.values()) + [ 2923 v for v in self.arg_relaxed.values() if v not in primary_functions] 2924 2925 2926# TODO(mdan): Refactor this and clarify relationship with def_function.Function. 2927# Right now, def_function.Function is the higher level implementation. 2928class Function(object): 2929 """Wrapper class for the graph functions defined for a Python function. 2930 2931 See the documentation for `defun` for more information on the semantics of 2932 defined functions. 2933 2934 `Function` class is thread-compatible meaning that minimal usage of defuns 2935 (defining and calling) is thread-safe, but if users call other methods or 2936 invoke the base `python_function` themselves, external synchronization is 2937 necessary. 2938 In addition, Function is not reentrant, so recursive functions need to call 2939 the wrapped function, not the wrapper. 2940 """ 2941 2942 def __init__(self, 2943 python_function, 2944 name, 2945 input_signature=None, 2946 attributes=None, 2947 autograph=True, 2948 autograph_options=None, 2949 experimental_relax_shapes=False, 2950 capture_by_value=None, 2951 jit_compile=None, 2952 experimental_follow_type_hints=False): 2953 """Initializes a `Function`. 2954 2955 Args: 2956 python_function: the function to be wrapped. 2957 name: the name given to it. 2958 input_signature: a possibly nested sequence of `TensorSpec` objects 2959 specifying the input signature of this function. If `None`, a separate 2960 function is instantiated for each inferred input signature. 2961 attributes: dict, extra keyword arguments that will be added as attribute 2962 of the function. 2963 autograph: whether to use autograph to compile 2964 `python_function`. See https://www.tensorflow.org/guide/autograph for 2965 more information. 2966 autograph_options: Experimental knobs to control behavior 2967 `when autograph=True`. See https://www.tensorflow.org/guide/autograph 2968 for more information. 2969 experimental_relax_shapes: When true, argument shapes may be relaxed to 2970 avoid unnecessary retracing. 2971 capture_by_value: Experimental. Whether to capture resource variables by 2972 value or reference. If None, will inherit from a parent context or 2973 default to False. 2974 jit_compile: Force-compile the function with XLA, cf. 2975 def_function.Function doc on jit_compile. 2976 experimental_follow_type_hints: See the documentation for `tf.function`. 2977 2978 Raises: 2979 ValueError: if `input_signature` is not None and the `python_function`'s 2980 argspec has keyword arguments. 2981 """ 2982 self._python_function = python_function 2983 pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes 2984 self._function_spec = FunctionSpec.from_function_and_signature( 2985 python_function, 2986 input_signature, 2987 is_pure=pure_function, 2988 experimental_follow_type_hints=experimental_follow_type_hints) 2989 self._name = name 2990 self._autograph = autograph 2991 self._autograph_options = autograph_options 2992 self._experimental_relax_shapes = experimental_relax_shapes 2993 self._function_cache = FunctionCache() 2994 self._function_attributes = attributes or {} 2995 self._capture_by_value = capture_by_value 2996 self.tracing_count = 0 2997 if self.input_signature is not None: 2998 self._hashable_input_signature = _make_input_signature_hashable( 2999 self.flat_input_signature) 3000 3001 self._lock = threading.Lock() 3002 # _descriptor_cache is a of instance of a class to an instance-specific 3003 # `Function`, used to make sure defun-decorated methods create different 3004 # functions for each instance. 3005 self._descriptor_cache = weakref.WeakKeyDictionary() 3006 self._jit_compile = jit_compile 3007 self._experimental_follow_type_hints = experimental_follow_type_hints 3008 3009 def __call__(self, *args, **kwargs): 3010 """Calls a graph function specialized to the inputs.""" 3011 with self._lock: 3012 (graph_function, 3013 filtered_flat_args) = self._maybe_define_function(args, kwargs) 3014 return graph_function._call_flat( 3015 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access 3016 3017 @property 3018 def python_function(self): 3019 """Returns the wrapped Python function.""" 3020 return self._python_function # pylint: disable=protected-access 3021 3022 @property 3023 def function_spec(self): 3024 return self._function_spec 3025 3026 @property 3027 def input_signature(self): 3028 """Returns the input signature.""" 3029 return self._function_spec.input_signature 3030 3031 @property 3032 def flat_input_signature(self): 3033 """Returns the flattened input signature.""" 3034 return self._function_spec.flat_input_signature 3035 3036 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): 3037 """Returns a concrete function which cleans up its graph function.""" 3038 if self.input_signature: 3039 args, kwargs = None, None 3040 with self._lock: 3041 graph_function, _ = self._maybe_define_function(args, kwargs) 3042 return graph_function 3043 3044 def _get_concrete_function_internal(self, *args, **kwargs): 3045 """Bypasses error checking when getting a graph function.""" 3046 graph_function = self._get_concrete_function_internal_garbage_collected( 3047 *args, **kwargs) 3048 # We're returning this concrete function to someone, and they may keep a 3049 # reference to the FuncGraph without keeping a reference to the 3050 # ConcreteFunction object. So we won't clean up the reference cycles 3051 # manually and instead will leave them to Python's garbage collector. 3052 graph_function._garbage_collector.release() # pylint: disable=protected-access 3053 return graph_function 3054 3055 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 3056 """Returns a `ConcreteFunction` specialized to inputs and execution context. 3057 3058 Unlike `get_concrete_function(...)`, the graph will be deleted when the 3059 returned function is deleted. It's useful to avoid creating a reference 3060 cycle when you know for sure that the graph will be no longer used without 3061 the returned function. 3062 3063 Args: 3064 *args: inputs to specialize on. 3065 **kwargs: inputs to specialize on. 3066 """ 3067 if self.input_signature: 3068 if kwargs: 3069 raise ValueError("Cannot define a TensorFlow function from a Python " 3070 "function with keyword arguments when " 3071 "input_signature is provided, got keyword arguments " 3072 f"({kwargs}) with input_signature " 3073 f"({self.input_signature}).") 3074 if args: 3075 # If args are provided, they must match the input signature. 3076 if not is_same_structure(self.input_signature, args): 3077 raise ValueError("Structure of Python function inputs does not match " 3078 f"input_signature: inputs ({args}), " 3079 f"input_signature ({self.input_signature}).") 3080 flat_inputs = nest.flatten(args, expand_composites=True) 3081 if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec, 3082 resource_variable_ops.BaseResourceVariable)) 3083 for arg in flat_inputs): 3084 raise ValueError("When input_signature is provided, all inputs to " 3085 "the Python function must be Tensors, Variables, " 3086 "tf.TensorSpec or tf.VariableSpec objects.") 3087 if any(not spec.is_compatible_with(other) 3088 for spec, other in zip(self.flat_input_signature, flat_inputs)): 3089 raise ValueError("Python inputs incompatible with input_signature: " 3090 f"inputs ({args}), input_signature " 3091 f"({self.input_signature}).") 3092 args, kwargs = None, None 3093 with self._lock: 3094 graph_function, _ = self._maybe_define_function(args, kwargs) 3095 seen_names = set() 3096 captured = object_identity.ObjectIdentitySet( 3097 graph_function.graph.internal_captures) 3098 # pylint: disable=protected-access 3099 graph_function._arg_keywords = [] 3100 prefix_counts = {} 3101 # pylint: enable=protected-access 3102 num_positional = 0 3103 for arg in graph_function.graph.inputs: 3104 if arg in captured: 3105 break 3106 num_positional += 1 3107 user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) 3108 proposal = user_arg_name 3109 while proposal in seen_names: 3110 index = prefix_counts.get(user_arg_name, 1) 3111 proposal = "{}_{}".format(user_arg_name, index) 3112 prefix_counts[user_arg_name] = index + 1 3113 seen_names.add(proposal) 3114 graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access 3115 # Anything can be a positional argument, in the same order as .inputs 3116 graph_function._num_positional_args = num_positional # pylint: disable=protected-access 3117 return graph_function 3118 3119 def get_concrete_function(self, *args, **kwargs): 3120 """Returns a `ConcreteFunction` specialized to inputs and execution context. 3121 3122 Args: 3123 *args: inputs to specialize on. Can be concrete values (e.g. 1) 3124 or `tf.Tensor` or `tf.TensorSpec`. 3125 **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) 3126 or `tf.Tensor` or `tf.TensorSpec`. 3127 """ 3128 graph_function = self._get_concrete_function_garbage_collected( 3129 *args, **kwargs) 3130 graph_function._garbage_collector.release() # pylint: disable=protected-access 3131 return graph_function 3132 3133 def __get__(self, instance, owner): 3134 """Makes it possible to defun instance methods.""" 3135 del owner 3136 # `instance` here is the instance that this `Function` was accessed through 3137 # e.g., for 3138 # 3139 # class Foo(object): 3140 # 3141 # @function.defun 3142 # def bar(self): 3143 # ... 3144 # 3145 # foo = Foo() 3146 # foo.bar() # `foo.bar` is a `Function` instance 3147 # 3148 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 3149 # new instance of `Function` here to allow different instances each 3150 # to create variables once, thereby allowing methods to be decorated with 3151 # defun. Keeps a cache to avoid retracing the function every time the 3152 # descriptor is accessed. 3153 if instance not in self._descriptor_cache: 3154 if instance is None: 3155 return self 3156 # If there is no instance-specific `Function` in the cache, we construct 3157 # an instance-specific `Function` that uses a weak reference to the 3158 # instance (so that the instance will be correctly gc'd). 3159 3160 # And finally add the wrapped function to the description cache 3161 self._descriptor_cache[instance] = class_method_to_instance_method( 3162 self, instance) 3163 3164 # Return the cached `Function` for the instance 3165 return self._descriptor_cache[instance] 3166 3167 def _cache_key(self, 3168 args, 3169 kwargs, 3170 cache_key_context, 3171 include_tensor_ranks_only=False): 3172 """Computes the cache key given inputs and execution context.""" 3173 if self.input_signature is None: 3174 # We always use both args and kwargs to form input even if one is empty. 3175 # This reduces ambiguity, for example, when args contains a dict and 3176 # kwargs is empty. 3177 inputs = (args, kwargs) 3178 input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs, 3179 include_tensor_ranks_only) 3180 hashable_input_signature = _make_input_signature_hashable(input_signature) 3181 else: 3182 del args, kwargs 3183 assert not include_tensor_ranks_only 3184 hashable_input_signature = self._hashable_input_signature 3185 3186 (parent_graph, device_functions, colocation_stack, in_cross_replica_context, 3187 variable_policy, xla_context_id) = cache_key_context 3188 3189 return CacheKey(hashable_input_signature, parent_graph, device_functions, 3190 colocation_stack, in_cross_replica_context, variable_policy, 3191 xla_context_id) 3192 3193 def _cache_key_context(self): 3194 """Returns execution context.""" 3195 ctx = context.context() 3196 3197 # Don't need to open an init_scope if the _cache_key call is in eager mode 3198 # already. 3199 executing_eagerly = ctx.executing_eagerly() 3200 parent_graph = None 3201 xla_context_id = 0 3202 if not executing_eagerly: 3203 # We want to force function retracing for each different 3204 # XLAControlFlowContext, so add `xla_context_id` to the cache key. 3205 xla_context = _enclosing_xla_context() 3206 if xla_context is not None and \ 3207 xla_context.RequiresUniqueFunctionRetracing(): 3208 xla_context_id = id(xla_context) 3209 3210 with ops.init_scope(): 3211 # The graph, or whether we're executing eagerly, should be a part of the 3212 # cache key so we don't improperly capture tensors such as variables. 3213 executing_eagerly = ctx.executing_eagerly() 3214 parent_graph = None if executing_eagerly else ops.get_default_graph() 3215 3216 # pylint: disable=protected-access 3217 default_graph = ops.get_default_graph() 3218 # TODO(b/117617952): The current distribution strategy will affect graph 3219 # building (e.g. accessing different variables from different devices) and 3220 # so requires retracing for each device. 3221 strategy_stack = default_graph._distribution_strategy_stack 3222 uses_distribution_strategy = ( 3223 strategy_stack and 3224 strategy_stack[-1].strategy.extended._retrace_functions_for_each_device 3225 ) 3226 if executing_eagerly: 3227 colocation_stack = () 3228 if uses_distribution_strategy: 3229 device_functions = (pydev.merge_device(ctx.device_name),) 3230 else: 3231 device_functions = () 3232 else: 3233 colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) 3234 if (uses_distribution_strategy 3235 or func_graph_module.device_stack_has_callable( 3236 default_graph._device_function_stack)): 3237 # Putting the device in the cache key ensures that call-site device 3238 # annotations are respected. 3239 device_functions = tuple(default_graph._device_functions_outer_to_inner) 3240 else: 3241 device_functions = () 3242 3243 in_cross_replica_context = False 3244 try: 3245 in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access 3246 except (AttributeError, IndexError): 3247 pass 3248 3249 if save_context.in_save_context(): 3250 variable_policy = ( 3251 save_context.get_save_options().experimental_variable_policy) 3252 else: 3253 variable_policy = None 3254 3255 return (parent_graph, device_functions, colocation_stack, 3256 in_cross_replica_context, variable_policy, xla_context_id) 3257 3258 def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): 3259 """Create a `ConcreteFunction` from `args` and `kwargs`.""" 3260 self.tracing_count += 1 3261 3262 if self.input_signature is None: 3263 arglen = len(args) 3264 else: 3265 arglen = len(self.input_signature) 3266 base_arg_names = self._function_spec.arg_names[:arglen] 3267 num_missing_args = arglen - len(self._function_spec.arg_names) 3268 missing_arg_names = [self._function_spec.vararg_name] * num_missing_args 3269 # Produce a list of missing args of the form ["arg_0", "arg_1", ...], 3270 # where arg is based on the self._function_spec.vararg_name. 3271 missing_arg_names = [ 3272 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) 3273 ] 3274 arg_names = base_arg_names + missing_arg_names 3275 graph_function = ConcreteFunction( 3276 func_graph_module.func_graph_from_py_func( 3277 self._name, 3278 self._python_function, 3279 args, 3280 kwargs, 3281 self.input_signature, 3282 autograph=self._autograph, 3283 autograph_options=self._autograph_options, 3284 arg_names=arg_names, 3285 override_flat_arg_shapes=override_flat_arg_shapes, 3286 capture_by_value=self._capture_by_value), 3287 self._function_attributes, 3288 function_spec=self.function_spec, 3289 # Tell the ConcreteFunction to clean up its graph once it goes out of 3290 # scope. This is not the default behavior since it gets used in some 3291 # places (like Keras) where the FuncGraph lives longer than the 3292 # ConcreteFunction. 3293 shared_func_graph=False) 3294 return graph_function 3295 3296 def _define_function_with_shape_relaxation(self, args, kwargs, flat_args, 3297 filtered_flat_args, 3298 cache_key_context): 3299 """Define a function, relaxing arg shapes to avoid unnecessary retracing.""" 3300 flat_no_comp = nest.flatten((args, kwargs), expand_composites=False) 3301 3302 any_composite_args = any( 3303 isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp) 3304 3305 # Build a cache key where TensorShapes include only rank information (and 3306 # not information about the size of each dimension). 3307 if not any_composite_args: 3308 rank_only_cache_key = self._cache_key( 3309 args, kwargs, cache_key_context, include_tensor_ranks_only=True) 3310 else: 3311 # For the rank-only cache key, replace any composite tensors with 3312 # shape-relaxed TypeSpecs. 3313 (cache_key_args, cache_key_kwargs) = nest.map_structure( 3314 _shape_relaxed_type_for_composite_tensor, (args, kwargs)) 3315 rank_only_cache_key = self._cache_key( 3316 cache_key_args, 3317 cache_key_kwargs, 3318 cache_key_context, 3319 include_tensor_ranks_only=True) 3320 3321 arg_specs = [_type_spec_for(x) for x in flat_no_comp] 3322 relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get( 3323 rank_only_cache_key, None) 3324 relaxed_arg_function = self._function_cache.arg_relaxed.get( 3325 rank_only_cache_key, None) 3326 3327 if (relaxed_arg_function is not None 3328 and all(_is_type_subset(x, y) for (x, y) in 3329 zip(relaxed_arg_specs, arg_specs))): 3330 return relaxed_arg_function, filtered_flat_args 3331 3332 if relaxed_arg_specs is None: 3333 relaxed_arg_specs = arg_specs 3334 else: 3335 if len(arg_specs) != len(relaxed_arg_specs): 3336 raise RuntimeError("Expected arg_specs len to match relaxed_arg_specs " 3337 f"len: {len(arg_specs):d} vs. " 3338 f"{len(relaxed_arg_specs):d}.") 3339 relaxed_arg_specs = [ 3340 x if x is None else x.most_specific_compatible_type(y) 3341 for (x, y) in zip(arg_specs, relaxed_arg_specs)] 3342 self._function_cache.arg_relaxed_specs[rank_only_cache_key] = ( 3343 relaxed_arg_specs) 3344 relaxed_arg_shapes = [ 3345 x if x is None else x.shape 3346 for x in nest.flatten(relaxed_arg_specs, expand_composites=True)] 3347 3348 if any_composite_args: 3349 # Rebuild composite tensors with the relaxed TypeSpecs. For example, 3350 # if a tf.data iterator is passed as an argument, then we need to relax 3351 # the TensorShapes in its element_spec. 3352 (relaxed_arg_specs, relaxed_kwarg_specs) = nest.pack_sequence_as( 3353 (args, kwargs), relaxed_arg_specs, expand_composites=False) 3354 (args, kwargs) = nest.pack_sequence_as( 3355 (relaxed_arg_specs, relaxed_kwarg_specs), 3356 flat_args, 3357 expand_composites=True) 3358 3359 graph_function = self._create_graph_function( 3360 args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) 3361 self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function 3362 3363 return (graph_function, [ 3364 t for t in nest.flatten((args, kwargs), expand_composites=True) 3365 if isinstance(t, (ops.Tensor, 3366 resource_variable_ops.BaseResourceVariable)) 3367 ]) 3368 3369 def _maybe_define_function(self, args, kwargs): 3370 """Gets a function for these inputs, defining it if necessary. 3371 3372 `args` and `kwargs` can be None if this `Function` was created with an 3373 `input_signature`. 3374 3375 Caller must hold self._lock. 3376 3377 Args: 3378 args: The varargs for the Python function. 3379 kwargs: The keyword args for the Python function. 3380 3381 Returns: 3382 A graph function corresponding to the input signature implied by args and 3383 kwargs, as well as filtered flattened inputs (only Tensors and Variables) 3384 that the object should be called with. 3385 3386 Raises: 3387 ValueError: If inputs are incompatible with the input signature. 3388 TypeError: If the function inputs include non-hashable objects 3389 RuntimeError: If there's an internal bug (inconsistency) in handling 3390 shape relaxation retracing. 3391 """ 3392 if self.input_signature is None or args is not None or kwargs is not None: 3393 args, kwargs, flat_args, filtered_flat_args = \ 3394 self._function_spec.canonicalize_function_inputs(*args, **kwargs) 3395 else: 3396 flat_args, filtered_flat_args = [None], [] 3397 3398 cache_key_context = self._cache_key_context() 3399 cache_key = self._cache_key(args, kwargs, cache_key_context) 3400 3401 try: 3402 hash(cache_key) 3403 except TypeError as e: 3404 raise TypeError( 3405 "Arguments supplied to `defun`-generated functions must be " 3406 f"hashable. Original error: {e}.") 3407 3408 graph_function = self._function_cache.primary.get(cache_key, None) 3409 if graph_function is not None: 3410 return graph_function, filtered_flat_args 3411 3412 with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()): 3413 with trace.Trace("tf.function-graph_building"): 3414 logging.vlog(1, 3415 "Creating new FuncGraph for Python function %r (key: %r)", 3416 self._python_function, cache_key) 3417 logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", 3418 args, kwargs) 3419 3420 # pylint: disable=protected-access 3421 call_context_key = cache_key._replace(input_signature=None) 3422 # pylint: disable=protected-access 3423 3424 ag_status = ( 3425 ag_ctx.Status.ENABLED 3426 if self._autograph else ag_ctx.Status.DISABLED) 3427 with ag_ctx.ControlStatusCtx( 3428 status=ag_status, options=self._autograph_options): 3429 3430 # Build a function with shape relaxation retracing if: 3431 # 1. shape relaxation is explicitly enabled 3432 # and 2. there's no provided input signature 3433 # and 3. there's been a cache miss for this calling context 3434 if (self._experimental_relax_shapes and 3435 self.input_signature is None and 3436 call_context_key in self._function_cache.missed): 3437 return self._define_function_with_shape_relaxation( 3438 args, kwargs, flat_args, filtered_flat_args, cache_key_context) 3439 3440 self._function_cache.missed.add(call_context_key) 3441 graph_function = self._create_graph_function(args, kwargs) 3442 self._function_cache.primary[cache_key] = graph_function 3443 3444 return graph_function, filtered_flat_args 3445 3446 3447def register(func, *args, **kwargs): 3448 """Register a specialization of a `Function` into the graph. 3449 3450 This won't actually call the function with the inputs, and only put the 3451 function definition into graph. Register function with different input param 3452 will result into multiple version of functions registered in graph. 3453 3454 Args: 3455 func: the `Function` instance that generated by a @defun 3456 *args: input arguments for the Python function. 3457 **kwargs: input keyword arguments for the Python function. 3458 3459 Returns: 3460 a `ConcreteFunction` object specialized to inputs and execution context. 3461 3462 Raises: 3463 ValueError: When the input function is not a defun wrapped python function. 3464 """ 3465 if not isinstance(func, Function): 3466 raise ValueError("Only defun function is allowed to be registered. " 3467 f"Got {func} with type {type(func)}.") 3468 concrete_func = func.get_concrete_function(*args, **kwargs) 3469 concrete_func.add_to_graph() 3470 concrete_func.add_gradient_functions_to_graph() 3471 return concrete_func 3472 3473 3474def validate_signature(signature): 3475 if any(not isinstance(arg, tensor_spec.DenseSpec) 3476 for arg in nest.flatten(signature, expand_composites=True)): 3477 bad_args = [arg for arg in nest.flatten(signature, expand_composites=True) 3478 if not isinstance(arg, tensor_spec.DenseSpec)] 3479 raise TypeError("input_signature must be a possibly nested sequence of " 3480 f"TensorSpec objects, got invalid args {bad_args} with " 3481 f"types {list(map(type, bad_args))}.") 3482 3483 3484def validate_python_function(python_function): 3485 if not callable(python_function): 3486 raise TypeError(f"{python_function} is not a callable object.") 3487 3488 3489def defun(func=None, 3490 input_signature=None, 3491 autograph=True, 3492 experimental_autograph_options=None, 3493 experimental_relax_shapes=False): 3494 """Compiles a Python function into a callable TensorFlow graph. 3495 3496 `defun` (short for "define function") compiles a Python function 3497 composed of TensorFlow operations into a callable that executes a `tf.Graph` 3498 containing those operations. The callable produced by `defun` contains only 3499 the subgraph of TensorFlow operations that were executed when the Python 3500 function was called with a particular input signature, defined as a list 3501 of the shapes and dtypes of the Python function's Tensor-valued arguments and 3502 the values of its non-Tensor Python objects. 3503 3504 When eager execution is enabled, the ability to create graphs from Python 3505 functions makes it possible to incrementally trade off debuggability and 3506 interactivity for performance. Functions compiled with `defun` cannot be 3507 inspected with `pdb`; however, executing a graph 3508 generated by `defun` sometimes takes less time and memory than eagerly 3509 executing the corresponding Python function, since specifying computations as 3510 graphs allows for optimizations like automatic buffer reuse and 3511 parallelization among ops. Note that executing a `defun`-compiled function 3512 incurs a small constant overhead, so eagerly executing sufficiently small 3513 Python functions might take less time than executing their corresponding 3514 `defun`-generated graphs. 3515 3516 For a Python function to be compatible with `defun`, all of its arguments must 3517 be hashable Python objects or lists thereof. The function itself may not 3518 modify the list/map structure of its arguments. Additionally, it must return 3519 zero or more `tf.Tensor` objects. If the Python function returns 3520 a `tf.Variable`, its compiled version will return the value of that variable 3521 as a `tf.Tensor`. 3522 3523 Executing a graph generated by `defun` respects device annotations (i.e., 3524 all `with tf.device` directives present in a Python function will also be 3525 present in its corresponding graph), but it is not yet possible to execute the 3526 generated graphs across multiple machines. 3527 3528 _Example Usage_ 3529 3530 ```python 3531 import tensorflow as tf 3532 3533 tf.compat.v1.enable_eager_execution() 3534 3535 # A simple example. 3536 def f(x, y): 3537 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 3538 3539 g = tf.contrib.eager.defun(f) 3540 3541 x = tf.constant([[2.0, 3.0]]) 3542 y = tf.constant([[3.0, -2.0]]) 3543 3544 # `f` and `g` will return the same value, but `g` will be executed as a 3545 # TensorFlow graph. 3546 assert f(x, y).numpy() == g(x, y).numpy() 3547 3548 # `defun` is capable of compiling Python functions that close over Python 3549 # objects, including Tensors and Variables. 3550 @tf.contrib.eager.defun 3551 def h(): 3552 return f(x, y) 3553 3554 assert (h().numpy() == f(x, y).numpy()).all() 3555 3556 # `defun` automatically lifts variables out of the graphs it creates, 3557 # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and 3558 # `tf.keras.Model` objects. 3559 class MyModel(tf.keras.Model): 3560 3561 def __init__(self, keep_probability=0.2): 3562 super(MyModel, self).__init__() 3563 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 3564 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 3565 self.keep_probability = keep_probability 3566 3567 @tf.contrib.eager.defun 3568 def call(self, inputs, training=True): 3569 x = self.dense2(self.dense1(inputs)) 3570 if training: 3571 return tf.nn.dropout(x, self.keep_probability) 3572 else: 3573 return x 3574 3575 model = MyModel() 3576 model(x, training=True) # executes a graph, with dropout 3577 model(x, training=False) # executes a graph, without dropout 3578 3579 # `defun`-compiled functions are differentiable. 3580 optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01) 3581 with tf.GradientTape() as tape: 3582 outputs = model(x) 3583 gradient = tape.gradient(outputs, model.trainable_variables) 3584 optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, 3585 model.trainable_variables)) 3586 ``` 3587 3588 When using `defun`, there are subtleties regarding inputs, Python control 3589 flow, and variable creation that one should be aware of. For concreteness, let 3590 `f` be a Python function that returns zero or more `tf.Tensor` objects and 3591 let `F = defun(f)`. `F` builds a graph for each unique input signature it 3592 sees, Python control flow is baked into graphs, and operations related to 3593 variable initialization are automatically lifted out of the graphs that `F` 3594 generates and placed in the eager context if executing eagerly or into an 3595 outer graph otherwise. 3596 3597 _Input Signatures_ 3598 3599 By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph 3600 for every unique sequence of the shapes and dtypes of Tensor arguments and 3601 the values of Python objects it is invoked with. For example, calling 3602 `F(tf.random.uniform([2])` will execute a different graph than 3603 `F(tf.random.uniform([3])` because the two inputs have different shapes. 3604 The first time that `F(*args, **kwargs)` is called with a particular sequence 3605 of Tensor shapes and dtypes and Python values, it constructs a graph by 3606 tracing the execution of `f(*args, **kwargs)`; this graph is bound to an 3607 input signature inferred from `(*args, **kwargs)` and cached for future reuse. 3608 3609 NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects 3610 before being passed to `f`, and are treated as Tensors for caching. This 3611 allows a function to be called multiple times with NumPy arrays having 3612 different values but the same shape and dtype without re-tracing each time. 3613 3614 `tf.contrib.eager.defun` caches graphs for your convenience, letting you 3615 define TensorFlow functions without explicitly specifying their signatures. 3616 However, this policy is conservative and potentially expensive; for example, 3617 when different invocations of your function have differently-shaped Tensor 3618 inputs, this policy might generate more graph functions than necessary. To 3619 eliminate such costs, `tf.contrib.eager.defun` allows you to supply an 3620 optional `input_signature` argument specifying the shapes and dtypes of the 3621 inputs. In particular, the shapes may be partially unspecified, with `None`s 3622 in the unknown dimensions. When an input signature is provided, 3623 `tf.contrib.eager.defun` will only instantiate a single graph for the 3624 decorated Python function. The following is an example: 3625 3626 ```python 3627 import tensorflow as tf 3628 3629 # The first `TensorSpec` below describes the shape and dtype of `words`, 3630 # and the second describes the shape and dtype of `another_tensor`. Note that 3631 # the last dimension of the `words` `TensorSpec` is left unspecified. 3632 @tf.contrib.eager.defun(input_signature=[ 3633 tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), 3634 tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) 3635 ]) 3636 def my_sequence_model(words, another_tensor): 3637 ... 3638 3639 # Note how the third dimension of the first input can vary freely. 3640 words = tf.random.uniform(([50, 300, 10]) 3641 second_input = tf.random.uniform([300, 100]) 3642 my_sequence_model(words, second_input) 3643 3644 words = tf.random.uniform(([50, 300, 20]) 3645 my_sequence_model(words, second_input) 3646 3647 # Passing an input with an incompatible shape will raise an error. 3648 words = tf.random.uniform(([50, 100, 20]) 3649 my_sequence_model(words, second_input) # <---- This will raise an error. 3650 3651 ``` 3652 3653 Python functions that are compiled with an `input_signature` must only accept 3654 Tensors as arguments and must not take unnamed keyword arguments (**kwargs). 3655 3656 _Tracing_ 3657 3658 Be aware that because `F` only logs TensorFlow operations, all the other 3659 Python code that `f` executes will only shape the _construction_ of the graphs 3660 that `F` executes: the Python code won't be executed when the graphs 3661 themselves are executed, though it will be executed every time the Python 3662 function is traced (and a given Python function might be traced multiple 3663 times, once for each input signature it is invoked with). For example, whereas 3664 the Python function 3665 3666 ```python 3667 import tensorflow as tf 3668 import numpy as np 3669 3670 tf.compat.v1.enable_eager_execution() 3671 3672 def add_noise(): 3673 return tf.eye(5) + np.random.randn(5, 5) 3674 ``` 3675 3676 will return a different output everytime it is invoked, the compiled function 3677 `compiled = tf.contrib.eager.defun(add_noise)` will return the same value 3678 every time it is called, since a particular random offset generated by NumPy 3679 will be inserted into the graph as a TensorFlow constant. The solution is to 3680 replace the call to `np.random.randn` with `tf.random.normal((5, 5))`. 3681 3682 _Python Side-Effects_ 3683 3684 A corollary of the previous discussion on tracing is the following: If a 3685 Python function `f` has Python side-effects, then executing `f` multiple times 3686 will not necessarily be semantically equivalent to executing `F = 3687 tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact 3688 that `defun` only captures the subgraph of TensorFlow operations that is 3689 constructed when `f` is called in a graph-building context. 3690 3691 _Python Control Flow_ 3692 3693 The structure of many machine learning computations depend upon whether one is 3694 training or validating, and it is common to nest specialized logic under `if 3695 training:` blocks. By mapping each input signature to a unique graph, `defun` 3696 lets users transparently compile such code, as the following code snippet 3697 demonstrates: 3698 3699 ```python 3700 import tensorflow as tf 3701 3702 tf.compat.v1.enable_eager_execution() 3703 3704 @tf.contrib.eager.defun 3705 def lossy_matmul(W, x, training=True): 3706 outputs = tf.matmul(W, x) 3707 if training: 3708 outputs = tf.nn.dropout(outputs, keep_probability=0.2) 3709 return outputs 3710 3711 W = tf.random.normal((3, 5)) 3712 x = tf.random.normal((5, 1)) 3713 3714 # Executes a graph that applies dropout. 3715 lossy_outputs = lossy_matmul(W, x, training=True) 3716 3717 # Executes a graph that does not apply dropout. 3718 exact_outputs = lossy_matmul(W, x, training=False) 3719 ``` 3720 3721 _TensorFlow Control Flow_ 3722 3723 When `autograph` is `True`, data-dependent control flow is allowed as well. 3724 Control flow statements that depend on `Tensor` values are staged into 3725 corresponding TensorFlow ops. For example, the following code will work as 3726 expected: 3727 3728 ```python 3729 @tf.contrib.eager.defun 3730 def dynamic_rnn_loop(cell, seq): 3731 state, output = cell.zero_state() 3732 for input in seq: 3733 state, output = cell(input, state) 3734 return output 3735 ``` 3736 3737 For more information see `tf.autograph`. 3738 3739 _Variables_ 3740 3741 TensorFlow operations related to variable creation and initialization are 3742 automatically lifted out of the graphs generated by `defun`. In practice, this 3743 implies that variable creation and initialization only happen the first time 3744 `F` is called, and that variables are reused every time thereafter. Many 3745 TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the 3746 first time they are called and reuse them thereafter. Automatic variable 3747 lifting makes it possible to compile these APIs without extra effort, at the 3748 cost of introducing a discrepancy between the semantics of executing Python 3749 functions and their corresponding compiled functions. For example: 3750 3751 ```python 3752 import tensorflow as tf 3753 3754 tf.compat.v1.enable_eager_execution() 3755 3756 def fn(): 3757 x = tf.Variable(0.0) 3758 x.assign_add(1.0) 3759 return x.read_value() 3760 3761 # `fn` is a Python function, so x is created, initialized, and destroyed upon 3762 # every invocation 3763 assert fn().numpy() == fn().numpy() == 1.0 3764 3765 compiled = tf.contrib.eager.defun(fn) 3766 3767 # Compiling `fn` with `defun` hoists all variables outside of the generated 3768 # graph, so initialization happens exactly once. 3769 assert compiled().numpy() == 1.0 3770 assert compiled().numpy() == 2.0 3771 ``` 3772 3773 Finally, because each input signature is bound to a unique graph, if your 3774 Python function constructs `tf.Variable` objects, then each graph constructed 3775 for that Python function will reference a unique set of variables. To 3776 circumvent this problem, we recommend against compiling Python functions that 3777 create `tf.Variable` objects. Instead, Python functions should either 3778 lexically close over `tf.Variable` objects or accept them as arguments, 3779 preferably encapsulated in an object-oriented container. If you must create 3780 variables inside your Python function and you want each graph generated for it 3781 to reference the same set of variables, add logic to your Python function that 3782 ensures that variables are only created the first time it is called and are 3783 reused for every subsequent invocation; note that this is precisely what 3784 `tf.keras.layers.Layer` objects do, so we recommend using them to represent 3785 variable-bearing computations whenever possible. 3786 3787 Args: 3788 func: function to be compiled. If `func` is None, returns a 3789 decorator that can be invoked with a single argument - `func`. The 3790 end result is equivalent to providing all the arguments up front. 3791 In other words, defun(input_signature=...)(func) is equivalent to 3792 defun(func, input_signature=...). The former allows 3793 the following use case: 3794 @tf.contrib.eager.defun(input_signature=...) 3795 def foo(...): 3796 ... 3797 3798 input_signature: A possibly nested sequence of 3799 `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of 3800 the Tensors that will be supplied to this function. If `None`, a separate 3801 function is instantiated for each inferred input signature. If a 3802 signature is specified, every input to `func` must be a `Tensor`, and 3803 `func` cannot accept `**kwargs`. 3804 autograph: Whether `func` should be compiled before 3805 constructing the graph. See https://www.tensorflow.org/guide/autograph 3806 for more information. 3807 experimental_autograph_options: Experimental knobs (in the form of a tuple 3808 of tensorflow.autograph.Feature values) to control behavior when 3809 autograph=True. 3810 experimental_relax_shapes: When true, argument shapes may be relaxed to 3811 avoid unnecessary retracing. 3812 3813 Returns: 3814 If `func` is not None, returns a callable that will execute the compiled 3815 function (and return zero or more `tf.Tensor` objects). 3816 If `func` is None, returns a decorator that, when invoked with a single 3817 `func` argument, returns a callable equivalent to the case above. 3818 3819 Raises: 3820 TypeError: If `input_signature` is neither `None` nor a sequence of 3821 `tf.contrib.eager.TensorSpec` objects. 3822 """ 3823 return defun_with_attributes( 3824 func=func, 3825 input_signature=input_signature, 3826 autograph=autograph, 3827 experimental_autograph_options=experimental_autograph_options, 3828 experimental_relax_shapes=experimental_relax_shapes) 3829 3830 3831@tf_export("__internal__.function.defun_with_attributes", v1=[]) 3832def defun_with_attributes(func=None, 3833 input_signature=None, 3834 attributes=None, 3835 autograph=True, 3836 experimental_autograph_options=None, 3837 jit_compile=None, 3838 experimental_relax_shapes=False, 3839 experimental_follow_type_hints=False): 3840 """Compiles a Python function into a callable TensorFlow graph. 3841 3842 This function supports adding extra function attributes. See detailed 3843 documentation in defun(). Currently this is not exposed in public API since we 3844 don't expect user to directly use attributes, and attribute won't work by 3845 itself. This assumption might change in future. 3846 3847 Args: 3848 func: function to be compiled. 3849 input_signature: same as defun()'s input_signature. 3850 attributes: A dictionary of arguments which will be added to function def as 3851 attributes. Currently only support primitive types as value, and only 3852 allowlisted attribute name is allowed. Unallowlisted attribute name or 3853 unsupported value will result into ValueError. `func_name` is also one of 3854 the allowlisted argument which is a python string, and sets the name for 3855 this `ConcreteFunction` in the graph. 3856 autograph: same as defun()'s autograph. 3857 experimental_autograph_options: same as defun()'s 3858 experimental_autograph_options. 3859 jit_compile: same as defun()'s jit_compile. 3860 experimental_relax_shapes: same as defun()'s experimental_relax_shapes 3861 experimental_follow_type_hints: see `tf.function`. 3862 3863 Returns: 3864 Same as the return value of defun, with attributes added to the function in 3865 graph. 3866 """ 3867 if input_signature is not None: 3868 validate_signature(input_signature) 3869 3870 # TODO(apassos): deal with captured global state. Deal with control flow. 3871 def decorated(function): 3872 try: 3873 if attributes: 3874 name = attributes.pop("func_name", function.__name__) 3875 else: 3876 name = function.__name__ 3877 except AttributeError: 3878 name = "function" 3879 return tf_decorator.make_decorator( 3880 function, 3881 Function( 3882 function, 3883 name, 3884 input_signature=input_signature, 3885 attributes=attributes, 3886 autograph=autograph, 3887 autograph_options=experimental_autograph_options, 3888 jit_compile=jit_compile, 3889 experimental_relax_shapes=experimental_relax_shapes, 3890 experimental_follow_type_hints=experimental_follow_type_hints)) 3891 3892 # This code path is for the `foo = tfe.defun(foo, ...)` use case 3893 if func is not None: 3894 return decorated(func) 3895 3896 # This code path is for the 3897 # 3898 # @tfe.defun(...) 3899 # def foo(...): 3900 # ... 3901 # 3902 # use case, which is equivalent to `foo = tfe.defun(...)(foo)` 3903 return decorated 3904 3905 3906# When a method is bound to objects of this type, it allows AutoGraph to 3907# recover a weak reference the original method's self pointer, so that it can 3908# execute it consistent with class_method_to_instance_method's 3909# bound_method_wrapper. 3910# TODO(b/119246461): This is not pretty. Use a descriptor instead? 3911class TfMethodTarget(object): 3912 """Binding target for methods replaced by function and defun.""" 3913 3914 __slots__ = ("weakrefself_target__", "weakrefself_func__") 3915 3916 def __init__(self, target, original_python_function): 3917 self.weakrefself_target__ = target 3918 self.weakrefself_func__ = weakref.ref(original_python_function) 3919 3920 @property 3921 def target(self): 3922 return self.weakrefself_target__() 3923 3924 @property 3925 def target_class(self): 3926 true_self = self.weakrefself_target__() 3927 if tf_inspect.isclass(true_self): 3928 # Class method 3929 return true_self 3930 else: 3931 return true_self.__class__ 3932 3933 def call(self, args, kwargs): 3934 wrapped_fn = self.weakrefself_func__() 3935 if tf_inspect.ismethod(wrapped_fn): 3936 wrapped_fn = six.get_unbound_function(wrapped_fn) 3937 return wrapped_fn(self.weakrefself_target__(), *args, **kwargs) 3938 3939 3940def class_method_to_instance_method(original_function, instance): 3941 """Constructs a new `Function` with `self` bound.""" 3942 weak_instance = weakref.ref(instance) 3943 3944 # Note: while we could bind to a weakref proxy instead, that causes the 3945 # bound method to be unhashable. 3946 bound_method = types_lib.MethodType( 3947 original_function.python_function, 3948 TfMethodTarget(weak_instance, original_function.python_function)) 3949 3950 # original_function is expected to be of one of the two `Function` types 3951 # (defined either in function.py or def_function.py). 3952 assert hasattr(original_function, "_name") 3953 assert hasattr(original_function, "_autograph") 3954 assert hasattr(original_function, "_function_spec") 3955 assert hasattr(original_function, "python_function") 3956 3957 weak_bound_method_wrapper = None 3958 def bound_method_wrapper(*args, **kwargs): 3959 """Wraps either a dummy MethodType or a converted AutoGraph function.""" 3960 # __wrapped__ allows AutoGraph to swap in a converted function. 3961 strong_bound_method_wrapper = weak_bound_method_wrapper() 3962 wrapped_fn = strong_bound_method_wrapper.__wrapped__ 3963 3964 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: 3965 # If __wrapped__ was not replaced, then call original_function. 3966 # TODO(mdan): For better consistency, use the wrapper's call(). 3967 wrapped_fn = original_function.python_function 3968 if tf_inspect.ismethod(wrapped_fn): 3969 wrapped_fn = six.get_unbound_function(wrapped_fn) 3970 return wrapped_fn(weak_instance(), *args, **kwargs) 3971 3972 # If __wrapped__ was replaced, then it is always an unbound function. 3973 # However, the replacer is still responsible for attaching self properly. 3974 # TODO(mdan): Is it possible to do it here instead? 3975 return wrapped_fn(*args, **kwargs) 3976 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) 3977 3978 # pylint: disable=protected-access 3979 # We make a dummy MethodType object to generate the correct bound method 3980 # signature. The actual call is to a function with a weak reference to 3981 # `instance`. 3982 instance_func = type(original_function)( 3983 tf_decorator.make_decorator(bound_method, bound_method_wrapper), 3984 name=original_function._name, 3985 autograph=original_function._autograph, 3986 input_signature=original_function.input_signature, 3987 experimental_relax_shapes=original_function._experimental_relax_shapes, 3988 jit_compile=original_function._jit_compile) 3989 # pylint: enable=protected-access 3990 3991 # We wrap the the bound method with tf_decorator so inspection works correctly 3992 wrapped_instance_func = tf_decorator.make_decorator(bound_method, 3993 instance_func) 3994 return wrapped_instance_func 3995 3996 3997class _FunctionGarbageCollector(object): 3998 """Cleans up cycles when a defun goes out of scope.""" 3999 4000 __slots__ = ["_cache"] 4001 4002 def __init__(self, cache): 4003 self._cache = cache 4004 4005 def __del__(self): 4006 if func_graph_module is None or memory is None: 4007 return 4008 try: 4009 while self._cache: 4010 self._cache.popitem() 4011 memory.dismantle_ordered_dict(self._cache) 4012 except: # pylint: disable=bare-except 4013 pass 4014 4015 4016class ConcreteFunctionGarbageCollector(object): 4017 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" 4018 4019 __slots__ = ["_func_graph"] 4020 4021 def __init__(self, func_graph): 4022 self._func_graph = func_graph 4023 4024 def release(self): 4025 """Call off the FuncGraph deletion.""" 4026 self._func_graph = None 4027 4028 def __del__(self): 4029 if func_graph_module is None or memory is None or self._func_graph is None: 4030 return 4031 try: 4032 func_graph_module.dismantle_func_graph(self._func_graph) 4033 except: # pylint: disable=bare-except 4034 pass 4035 4036 4037class _Marker(object): 4038 """Markers used to pretty-print nested args in function signatures.""" 4039 4040 __slots__ = ["_s"] 4041 4042 def __init__(self, s): 4043 self._s = s 4044 4045 def __repr__(self): 4046 return str(self._s) 4047 4048 4049def _structure_summary(structure): 4050 """Displays a summary of the nesting structure of the given value.""" 4051 4052 def type_name(x): 4053 if isinstance(x, type_spec.TypeSpec): 4054 return x.value_type.__name__ 4055 else: 4056 return type(x).__name__ 4057 4058 markers = [_Marker(type_name(v)) for v in nest.flatten(structure)] 4059 return str(nest.pack_sequence_as(structure, markers)) 4060 4061 4062def _contains_type_spec(value): 4063 return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value)) 4064